#!/usr/bin/python3
#
# Copyright (c) 2019-2021 AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#
# Validate IP Packet Filter configuration.
#

""" Perform validation of the IP packet filter (ippf) configuration. Note that
    most of the validation is performed using YANG 'must' statements, but
    some cannot be done using them, so instead it is done by this script. """

import sys
import getopt
from vyatta import configd


BASE_ADDRESS_PATH = "security ip-packet-filter"

STATUS_SUCCESS = 0
STATUS_FAILED = 1

STAT_MAX = {
    "in": 0,
    "out": 0,
}


def err(msg):
    """ prints an error to stderr """
    print(msg, file=sys.stderr)


def process_options():
    """ process command-line options """
    try:
        opts, _ = getopt.getopt(sys.argv[1:], "", ['max-in-stats=',
                                                   'max-out-stats='])
    except getopt.GetoptError as opt_error:
        err(opt_error)
        err(f"usage: {sys.argv[0]} [--max-in-stats <num>] "
            f"[--max-out-stats <num>]")
        sys.exit(2)

    try:
        for opt, arg in opts:
            if opt in '--max-in-stats':
                STAT_MAX['in'] = int(arg)
            elif opt in '--max-out-stats':
                STAT_MAX['out'] = int(arg)
    except ValueError:
        err("Ensure values of parameters are numbers")
        sys.exit(2)


def group_stats_count(grp):
    """ Count how many statistics entries are needed for the rules. """

    ctrs = grp.get('counters')
    if not ctrs:
        return 0

    ctrtype = ctrs.get('type')
    if 'auto-per-action' in ctrtype:
        if ctrtype['auto-per-action']:
            return 1
        return 2           # all action types

    if 'auto-per-rule' not in ctrtype:
        return 0

    rules = grp.get('rule')
    if not rules:
        return 0

    count = 0

    for rule in rules:
        if 'disable' not in rule:
            count += 1

    return count


def validate_maximum_stats_rules(ifcfg, grpcfg):
    """ Check that the number of statistics that need allocated are not
    greater than the total number that can be configured, for both 'in'
    and 'out' directions. """

    status = STATUS_SUCCESS

    stat_count = {
        "in": 0,
        "out": 0
    }

    for direction in ["in", "out"]:

        if STAT_MAX[direction] == 0:    # no limit
            continue

        for interface in ifcfg:

            if not interface.get(direction):
                continue

            for grpname in interface[direction]:
                for grp in grpcfg:
                    if grp['group-name'] == grpname:
                        stat_count[direction] += group_stats_count(grp)

        if stat_count[direction] > STAT_MAX[direction]:
            status = STATUS_FAILED
            print(f"For direction '{direction}' the {stat_count[direction]} "
                  f"statistics required is more than maximum supported "
                  f"({STAT_MAX[direction]})\n")
    return status


def validate_ippf_cfg(client):
    """ Validate the IP packet filter configuration. """

    try:
        cfg = client.tree_get_dict(BASE_ADDRESS_PATH)
    except configd.Exception:
        # no ip-packet-filter configuration
        return STATUS_SUCCESS

    ifcfg = cfg["ip-packet-filter"].get("interface")
    if not ifcfg:
        # no ip-packet-filter interface configuration
        return STATUS_SUCCESS

    status = STATUS_SUCCESS

    grpcfg = cfg["ip-packet-filter"].get("group")
    if grpcfg:
        status = validate_maximum_stats_rules(ifcfg, grpcfg)

    for interface in ifcfg:
        ifname = interface["interface-name"]
        for direction in ["in", "out"]:
            if interface.get(direction):
                # Less than two groups is fine.
                # Yang prevents more than two groups.
                # So we only need to compare the address families
                # if there are exactly two groups.

                if len(interface[direction]) != 2:
                    continue

                # Get group names
                group_name_1 = interface[direction][0]
                group_name_2 = interface[direction][1]

                af1 = af2 = None

                # Walk group list to extract AFs
                for group in cfg["ip-packet-filter"]["group"]:
                    if group["group-name"] == group_name_1:
                        af1 = group["ip-version"]

                    elif group["group-name"] == group_name_2:
                        af2 = group["ip-version"]

                # Finally we can compare the AFs
                if af1 == af2:
                    status = STATUS_FAILED
                    print("\n[{} interface {} {}]\n\n"
                          "Configure only one group per address-family"
                          .format(BASE_ADDRESS_PATH, ifname, direction))

    return status


def validate_ippf_main():
    """ Connect to client then validate the IP packet filter configuration. """

    try:
        client = configd.Client()
    except configd.FatalException as exc:
        print("Cannot establish client session: '{}'".format(str(exc).strip()))
        return 1

    return validate_ippf_cfg(client)


if __name__ == "__main__":
    process_options()
    RET = validate_ippf_main()
    exit(RET)
