#!/usr/bin/python3
#
# Copyright (c) 2020-2021, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#

""" IP Packet Filter RPC """

import json
import sys
import re
import vplaned
from vyatta import configd

ROOTPATH = "security ip-packet-filter"
RULENUM_ALL = 0


def send(cmd):
    """
    Send the given command to the dataplane

    @input cmd: The command to send to the dataplane

    @output: dictionary of the dataplane responses
    """

    d = {}

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                try:
                    r = dp.json_command(cmd)
                    if r:
                        d[dp.id] = r
                except:
                    pass

    return d


def get_action_counters(client):
    """
    Gets the list of ACL groups from the system configuration that contains
    action counters (auto-per-action)

    @input: client: configd client

    @output: Dictionary of groups that contain auto-per-action counters
    """
    auto_per_action_groups = {}
    config = client.tree_get_full_dict(ROOTPATH + " group", client.RUNNING)
    groups = config.get("group", {})
    for group in groups:
        group_name = group.get("group-name")
        counters = group.get("counters", {})
        counter_type = counters.get("type", {})

        if "auto-per-rule" in counter_type:
            continue
        elif "auto-per-action" in counter_type:
            auto_per_action_groups[group_name] = {}
            grp_action_counters = auto_per_action_groups[group_name]

            auto_per_action = counter_type.get("auto-per-action")
            action = auto_per_action.get("action", {})
            if "accept" in action:
                grp_action_counters["accept"] = None
            elif "drop" in action:
                grp_action_counters["drop"] = None
            else:
                grp_action_counters["accept"] = None
                grp_action_counters["drop"] = None

    return auto_per_action_groups


def ippf_show_rpc(filters, client):
    """
    Request IP Packet Filter statistics from dataplane

    @input:  filters: Filter string

    @output: Prints the matching statistics in JSON.
    """

    match_interfaces = filters.get("interfaces", "")
    match_directions = filters.get("directions", "")
    match_groupnames = filters.get("groups", "")
    match_rulenums = list(map(str, filters.get("rules", "")))
    match_actions = filters.get("actions", "")

    rownum = 0
    stats = {}

    # Process SW statistics

    sw_reply = send("npf-op show all: acl-in acl-out")

    auto_per_action_groups = get_action_counters(client)

    # iterate dataplane replies
    for dataplane in sw_reply:
        config = sw_reply[dataplane]["config"]

        # iterate attach points
        for attach_point in config:

            # match interface
            if (attach_point["attach_type"] == "interface") and \
                    ((not match_interfaces) or (attach_point["attach_point"] in match_interfaces)):

                # iterate groups
                for groups in attach_point["rulesets"]:

                    # match direction. Ignore the initial "acl-".
                    if (not match_directions) or (groups["ruleset_type"][4:] in match_directions):

                        # iterate group
                        for group in groups["groups"]:

                            # check the group class
                            assert group["class"] == "acl", \
                                "wrong group class: {}".format(group["class"])

                            # double-check the group direction
                            if match_directions:
                                assert group["direction"] in match_directions, \
                                    "wrong group direction: {}".format(group["direction"])

                            # match group name
                            if (not match_groupnames) or (group["name"] in match_groupnames):

                                grp_action_counters = auto_per_action_groups.get(group["name"])

                                # iterate rules
                                for rulenum in group["rules"]:

                                    # match rule number
                                    if (not match_rulenums) or (rulenum in match_rulenums):

                                        rule = group["rules"][rulenum]

                                        # match action
                                        # Get the raw action from "config" rather than from "action"
                                        action = re.match(r"(.*)action=(.*?)\s|$", rule["config"])
                                        if action:
                                            action = action.group(2)
                                        if (not match_actions) or (action in match_actions):

                                            rp = rule.get("rprocs")
                                            if rp:
                                                if grp_action_counters is not None:
                                                    # auto-per-action counter case.
                                                    # The counter name is an action name.
                                                    if action in grp_action_counters:
                                                        ctr = rp.get("ctr")
                                                        if ctr:
                                                            if grp_action_counters[action] is None:
                                                                grp_action_counters[action] = 0
                                                            grp_action_counters[action] = grp_action_counters[action] + ctr["hits"]
                                                else:
                                                    # auto-per-rule counter case.
                                                    # The counter name is the rule number in string format.
                                                    row = {
                                                        'row': rownum,
                                                        'interface': attach_point["attach_point"],
                                                        'direction': group["direction"],
                                                        'group': group["name"],
                                                        'rule': int(rulenum),
                                                        'name': rulenum,
                                                        'action': action,
                                                        'hardware': {}
                                                    }

                                                    # Append SW packet counter, if any
                                                    ctr = rp.get("ctr")
                                                    if ctr:
                                                        row['software'] = {
                                                            'packets': ctr["hits"]
                                                        }

                                                    # key = interface, direction, group, rule, counter name
                                                    # The counter name is the rule number in string format.
                                                    key = (
                                                        attach_point["attach_point"],
                                                        group["direction"],
                                                        group["name"],
                                                        int(rulenum),
                                                        rulenum)
                                                    stats[key] = row
                                                    rownum += 1
                                if grp_action_counters is not None:
                                    for counter_name in grp_action_counters:
                                        if grp_action_counters[counter_name] is None:
                                            continue
                                        row = {
                                            'row': rownum,
                                            'interface': attach_point["attach_point"],
                                            'direction': group["direction"],
                                            'group': group["name"],
                                            'rule': RULENUM_ALL,
                                            'name': counter_name,
                                            'action': action,
                                            'hardware': {}
                                        }

                                        # Append SW packet action counter
                                        row['software'] = {
                                            'packets': grp_action_counters[counter_name]
                                        }

                                        # key = interface, direction, group, rule, counter name
                                        key = (
                                            attach_point["attach_point"],
                                            group["direction"],
                                            group["name"],
                                            RULENUM_ALL,
                                            counter_name)
                                        stats[key] = row
                                        rownum += 1
    # Process HW statistics

    hw_reply = send("npf-op acl show counters")

    # iterate dataplane replies
    for dataplane in hw_reply:
        rulesets = hw_reply[dataplane]["rulesets"]
        for ruleset in rulesets:
            interface_name = ruleset["interface"]
            direction = ruleset["direction"]

            # match interface name and direction
            if ((not match_interfaces) or (interface_name in match_interfaces)) and \
               ((not match_directions) or (direction in match_directions)):

                groups = ruleset["groups"]
                for group in groups:
                    group_name = group["name"]

                    # match group name
                    if (not match_groupnames) or (group_name in match_groupnames):
                        counters = group["counters"]
                        for counter in counters:
                            counter_name = counter["name"]

                            # match rule number
                            if (not match_rulenums) or (counter_name in match_rulenums):

                                hw_counter = counter.get("hw")
                                if hw_counter:
                                    if counter_name in ["accept", "drop"]:
                                        # auto-per-action:
                                        rule = RULENUM_ALL
                                    else:
                                        # auto-per-rule:
                                        rule = int(counter_name)

                                    # key = interface, direction, group, rule, counter name
                                    key = (
                                        interface_name,
                                        direction,
                                        group["name"],
                                        rule,
                                        counter_name)

                                    if key in stats:
                                        # There are already SW stats for this key.
                                        # Add the HW stats.
                                        stats[key]['hardware'] = {
                                            'packets': hw_counter.get("pkts", "-")
                                        }
                                    else:
                                        # There are no SW stats for this key.
                                        # They could have been filtered out by a "match action",
                                        # so only proceed if no "match actions" were specified.
                                        if not match_actions:
                                            # Create a new row from scratch.
                                            # NB we don't know the action.
                                            row = {
                                                'row': rownum,
                                                'interface': interface_name,
                                                'direction': direction,
                                                'group': group_name,
                                                'rule': rule,
                                                'name': counter_name,
                                                'action': "-",
                                                'software': {},
                                                'hardware': {'packets': hw_counter.get("pkts", "-")}
                                            }

                                            stats[key] = row
                                            rownum += 1

    # Strip the keys from the dictionary; only the values are needed. Make a list.
    # Then wrap the value list in a dictionary because that's what the RPC expects.
    output_dict = {
        "statistics": [v for v in stats.values()]
    }

    # The RPC expects JSON
    print(json.dumps(output_dict))

    return 0


def ippf_clear_rpc(filters, client):
    """
    Send IP Packet Filter 'clear statistics' command to dataplane

    @input filt: Filter string to send to the dataplane
                 eg {'interface': 'dp0p1s1', 'direction': 'in', 'group': 'G1', 'rule': 10}

    @output: None - the clear RPC has no output
    """

    intf = filters.get("interface", "")
    dirn = filters.get("direction", "")
    group = filters.get("group", "")
    rule = filters.get("rule", "")

    # Clear HW ACL counters
    send("npf-op acl clear counters {} {} {} {}".format(intf, dirn, group, rule))

    if intf:
        intf = " interface:" + intf

        if dirn:
            dirn = " acl-" + dirn

        if group:
            group = " -n acl:" + group

        if rule:
            rule = " -r " + str(rule)

        # Clear specified SW ACL counters
        send("npf-op clear" + group + rule + intf + dirn)

    else:
        # Clear all SW ACL counters
        send("npf-op clear all: acl-in acl-out")

    return 0


def main():
    """
    Parse IP packet filter RPC commands

    @input: argv[1]: Required action: either '--show' or '--clear'.
    """

    commands = {
        "--show": ippf_show_rpc,
        "--clear": ippf_clear_rpc
    }

    # First argument specifies the action
    if len(sys.argv) < 2:
        return 1

    try:
        cmd = commands[sys.argv[1]]
    except KeyError:
        print("usage: {} [--show | --clear]\n".format(sys.argv[0]))
        return 1

    try:
        client = configd.Client()
    except Exception as exc:
        print("Cannot establish client session: '{}'".format(str(exc).strip()))
        return 1

    # Get the RPC input
    try:
        rpc_input = json.load(sys.stdin)
    except ValueError as exc:
        print("Failed to parse input JSON: {}".format(exc), sys.stderr)
        return 1

    return cmd(rpc_input, client)


if __name__ == "__main__":
    sys.exit(main())
