#!/usr/bin/env python3

# Copyright (c) 2021, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: LGPL-2.1-only

#
# Yang get-state backend for ICMP rate limiting.
#

from argparse import ArgumentParser
from vplaned import Controller
from vyatta import configd
import sys
import json
import logging
import logging.handlers
from systemd.journal import JournalHandler

MAX_LIMIT = 1000
expected_types = {}
expected_types['v4'] = ['destination-unreachable',
                       'redirect',
                       'time-exceeded']
expected_types['v6'] = ['destination-unreachable',
                       'too-big',
                        'parameter-problem',
                        'time-exceeded']

def valid(af, json_string):
    str_data = json.loads(json_string)
    data = str_data.get('icmp-types')
    if data is None:
        logger.error("Missing icmp-types in {}, af={}".format(str_data, af))
        return False

    for entry in data:
        icmptype = entry.get('icmp-type')
        if icmptype is None:
            logger.error("Missing icmp-type in {}, af={}".format(entry, af))
            return False

        if not icmptype in expected_types[af]:
            logger.error("Unexpcted type {} in {}. af={}"
                         .format(icmptype, entry, af))
            return False

        if entry['limit'] > MAX_LIMIT:
            logger.error("Limit {} out of range in {}, af={}"
                         .format(entry['limit'], entry, af));
            return False

    return True

if __name__ == "__main__":
    logger = logging.getLogger()
    logging.root.addHandler(
        JournalHandler(SYSLOG_IDENTIFIER='vyatta-icmp-rate-limit-state'))

    arg_parser = ArgumentParser()
    arg_parser.add_argument('--af', required=True)
    args = arg_parser.parse_args()
    if args.af != 'v4' and args.af != 'v6':
        logger.error("Invalid address family {} specified".format(args.af))
        sys.exit(1)

    with Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                data = dp.string_command("icmprl show {}".format(args.af))
                if (valid(args.af, data)):
                    print(data)
