#!/usr/bin/env python3

# Copyright (c) 2019 AT&T Intellectual Property. All Rights Reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#
# This is run if there has been a change under "interfaces" vlan-modify policy
# or if vlan-modify policy is changed where it looks for vlan-modify policies
# attached  to interfaces
#

import argparse
import logging
import subprocess
import sys

from collections import defaultdict
from vyatta import configd
from vyatta.interfaces.interfaces import getInterfaceConfig as gid

CONFIG_CANDIDATE = configd.Client.CANDIDATE
CONFIG_RUNNING = configd.Client.RUNNING
RULE_ANY_INGRESS = '10000'
RULE_ANY_EGRESS = '10001'
# 32 used for EGRESS handle chosen randomly after leaving some buffer
# 1 is used elsewhere by the system
EGRESS_HANDLE = '32'
INGRESS_HANDLE = '0xffff'
VLAN_ALL = '0x0fff'
client = None
dbg = logging.debug
err = logging.error

# Used for easy assignment of nested dictionaries
def nested_dict():
    return defaultdict(nested_dict)

def vlan_modify_update_rule(rule, rulenum, ifname, policy, operation):
    global client

    direction = rule['direction']
    vlan_id = rule['selector']['vlan-id']
    if direction == 'egress':
        handle = EGRESS_HANDLE
    if direction == 'ingress':
        handle = INGRESS_HANDLE
    if vlan_id == 'any':
        if direction == 'egress':
            rule_no = RULE_ANY_EGRESS;
        else:
            rule_no = RULE_ANY_INGRESS;
        dbg("Using rule {} for selector vlan-id any".format(rule_no))
        vlan_id = VLAN_ALL
    else:
        rule_no = rulenum
    action = rule['action']
    dbg("{}:Rule {}, direction {}, selector vlan_id {}, action {}".format(operation, rule_no, direction, vlan_id, action))
    action_param = []
    for act_key in action:
        if act_key == 'swap' or act_key == 'push':
            if act_key == 'swap':
                action_str = 'modify'
            if act_key == 'push':
                action_str = 'push'
            if 'tag-protocol-id' in action[act_key]:
                action_param += ['protocol',  action[act_key]['tag-protocol-id']]
            if 'pcp' in action[act_key]:
                action_param += ['priority', str(action[act_key]['pcp'])]
            action_param += ['id', str(action[act_key]['vlan-id'])]
        elif act_key == 'pop':
            action_str = 'pop'
        else:
            err("Invalid action {}",format(action))
            return 1
    tc_filter = ['tc','filter',operation, 'dev', ifname, 'parent', handle+':', 'pref', str(rule_no), 'protocol', '802.1ad', 'chain', handle, 'u32', 'match', 'u16', str(vlan_id), '0x0FFF', 'at', '-4', 'action', 'vlan', action_str] + action_param
    rc = subprocess.call(tc_filter)
    if rc != 0:
        print("Error:{} {} ".format(rc, tc_filter))
        return 1
    dbg("Success: {}".format(tc_filter))
    return 0

def tc_qdisc_cmd(ifname, direction, operation):
    if direction == "ingress":
        direct_list = ['ingress']
    elif direction == "egress":
        # 32 used for EGRESS handle chosen randomly after leaving some buffer
        # 1 is used elsewhere by the system
        direct_list = ['handle', '32:', 'root', 'prio']
    else:
        err("dev {} tc operation {} invalid direction".format(ifname, operation, direction))
        return 1
    cmd = ['tc', 'qdisc', operation,'dev', ifname] + direct_list
    rc = subprocess.call(cmd)
    if rc != 0:
        err("dev {} tc qdisc {} direction  {} failed {}".format(ifname, operation, direction, rc))
        return 1
    dbg("dev {} tc qdisc {} direction {} success".format(ifname, operation, direction))
    return 0


def vlan_modify_update_interface_policy(ifname, policy, operation):
    global client

    egress_deleted = 0
    ingress_deleted = 0
    egress_added = 0
    ingress_added = 0
    egress_count = 0
    ingress_count = 0
    rules_cfg = nested_dict()
    dbg("Policy {} interface {} operation {}".format(policy, ifname, operation))
    if operation == "add":
        rules_cfg = client.tree_get_dict("policy vlan-modify {} rule".format(policy), CONFIG_CANDIDATE, 'internal')['rule']
        for rulenum, rulec in rules_cfg.items():
            direction = rulec['direction']
            if direction == "egress":
                # Create qdisc for egress when 1st rule is created
                if egress_added==0:
                    if tc_qdisc_cmd(ifname, direction, operation) != 0:
                        return 1
                egress_added+=1
            if direction == "ingress":
                # Create qdisc for ingress when 1st ingress rule is created
                if ingress_added == 0:
                    if tc_qdisc_cmd(ifname, direction, operation) != 0:
                        return 1
                ingress_added+=1
            if vlan_modify_update_rule(rulec, rulenum, ifname, policy, 'add'):
                return 1

    elif operation == "delete":
        rules_cfg = client.tree_get_dict("policy vlan-modify {} rule".format(policy), CONFIG_RUNNING, 'internal')['rule']
        for rulenum, rulec in rules_cfg.items():
            direction = rulec['direction']
            if direction == "egress":
                egress_deleted+=1
            if direction == "ingress":
                ingress_deleted+=1
            if vlan_modify_update_rule(rulec, rulenum, ifname, policy, 'delete'):
                return 1
        # Delete qdisc for ingress if ingress rules deleted
        if ingress_deleted != 0:
            if tc_qdisc_cmd(ifname, "ingress", operation) != 0:
                return 1
        # Delete qdisc for egress if egress rules deleted
        if egress_deleted != 0:
            if tc_qdisc_cmd(ifname, "egress", operation) != 0:
                return 1

    elif operation == "update":
        pol_cfg = client.node_get(CONFIG_RUNNING, "policy vlan-modify {}".format(policy))
        if pol_cfg:
            rules_rcfg = client.tree_get_dict("policy vlan-modify {} rule".format(policy), CONFIG_RUNNING, 'internal')['rule']
            for rulenum, rulerc in rules_rcfg.items():
                direction = rulerc['direction']
                if direction == "egress":
                    egress_count+=1
                if direction == 'ingress':
                    ingress_count+=1
                status = client.node_get_status(CONFIG_CANDIDATE, "policy vlan-modify {} rule {}".format(policy, rulenum))
                if status == client.DELETED:
                    dbg("intf {} policy {} rule {} deleted".format(ifname, policy, rulenum))
                    if direction == 'egress':
                        egress_deleted+=1
                    if direction == 'ingress':
                        ingress_deleted+=1
                    if vlan_modify_update_rule(rulerc, rulenum, ifname, policy, 'delete'):
                        return 1

        pol_cfg = client.node_get(CONFIG_CANDIDATE, "policy vlan-modify {}".format(policy))
        if pol_cfg:
            rules_cfg = client.tree_get_dict("policy vlan-modify {} rule".format(policy), CONFIG_CANDIDATE, 'internal')['rule']
            for rulenum, rulec in rules_cfg.items():
                status = client.node_get_status(CONFIG_CANDIDATE, "policy vlan-modify {} rule {}".format(policy, rulenum))
                if status == client.ADDED:
                    dbg("intf {} policy {} rule {} added".format(ifname, policy, rulenum))
                    direction = rulec['direction']
                    if direction == "egress":
                        egress_added+=1;
                    if direction == "ingress":
                        ingress_added+=1;
                    if egress_count == 0 and egress_added == 1:
                        if (tc_qdisc_cmd(ifname,"egress", "add") != 0):
                            return 1
                    if ingress_count == 0 and ingress_added == 1:
                        if tc_qdisc_cmd(ifname,"ingress", "add") != 0:
                            return 1
                    if vlan_modify_update_rule(rulec, rulenum, ifname, policy, 'add'):
                        return 1
                elif status == client.CHANGED:
                    dbg("intf {} policy {} rule {} changed".format(ifname, policy, rulenum))
                    dir_status = client.node_get_status(CONFIG_CANDIDATE, "policy vlan-modify {} rule {} direction".format(policy, rulenum))
                    direction = rulec['direction']
                    # Direction change requires add and a delete
                    if dir_status == client.CHANGED:
                        if direction == "egress":
                            egress_added+=1
                            ingress_deleted+=1
                        if direction == "ingress":
                            ingress_added+=1
                            egress_deleted+=1
                        rules_rcfg = client.tree_get_dict("policy vlan-modify {} rule".format(policy), CONFIG_RUNNING, 'internal')['rule']
                        if vlan_modify_update_rule(rules_rcfg[rulenum], rulenum, ifname, policy, 'delete'):
                            return 1
                        if egress_count == 0 and egress_added == 1:
                            if (tc_qdisc_cmd(ifname,"egress", "add") != 0):
                                return 1
                        if ingress_count == 0 and ingress_added == 1:
                            if tc_qdisc_cmd(ifname,"ingress", "add") != 0:
                                return 1
                        if vlan_modify_update_rule(rulec, rulenum, ifname, policy, 'add'):
                            return 1
                    else:
                        if vlan_modify_update_rule(rulec, rulenum, ifname, policy, 'replace'):
                            return 1

                else:
                    dbg("Rule not changed {}".format(rulenum))

        dbg("Intf {}, policy {} egress: count = {}, added = {}, deleted = {}".format(ifname, policy, egress_count, egress_added, egress_deleted))
        dbg("Intf {}, policy {} ingress: count = {}, added = {}, deleted = {}".format(ifname, policy, ingress_count, ingress_added, ingress_deleted))
        if egress_count != 0 and egress_count == egress_deleted and egress_added == 0:
            if tc_qdisc_cmd(ifname, "egress", "delete") != 0:
                return 1
        if ingress_count != 0 and ingress_count == ingress_deleted and ingress_added == 0:
            if tc_qdisc_cmd(ifname, "ingress", "delete") != 0:
                return 1
    else:
        err("Unknown operation {}".format(operation))
        return 1

    return 0;

def policy_vlan_modify(policy):
    global client

    try:
        client = configd.Client()
    except Exception as exc:
        err("Cannot establish client session: '{}'".format(str(exc).strip()))
        return 1

    status = client.node_get_status(CONFIG_CANDIDATE, "policy vlan-modify "+policy)
    if status == client.ADDED:
        dbg("New policy added, Ignore -- executed as part of interface policy vlan-modify add")
        return 0
    if status == client.DELETED:
        dbg("Policy Deleted, Ignore -- executed as part of interface policy vlan-modify delete")
        return 0
    if status == client.CHANGED:
        dbg("Policy {} changed ".format(policy))
        intf_run_cfg = nested_dict()
        intf_run_cfg = client.tree_get_dict("interfaces ", CONFIG_RUNNING, 'internal')['interfaces']
        for iftype in intf_run_cfg:
            if iftype == "dataplane" or iftype == "vhost":
                for ifname in intf_run_cfg[iftype]:
                    dbg(" Check interface {} for Policy {} change".format(ifname, policy))
                    intf_policy = client.node_get(CONFIG_RUNNING, "interfaces {} {} switch-group port-parameters policy vlan-modify".format(iftype, ifname))
                    if intf_policy and policy == intf_policy[0]:
                        dbg("Changed policy for "+ifname)
                        ifpolicy = "interfaces {} {} switch-group port-parameters policy vlan-modify".format(iftype, ifname)
                        status = client.node_get_status(CONFIG_CANDIDATE, ifpolicy)
                        if status == client.UNCHANGED:
                            dbg("Policy {}  changed but not interface config ".format(policy))
                            vlan_modify_update_interface_policy(ifname, policy, "update")
                        else:
                            dbg("Policy {} changed but interface {} changed too ".format(policy, ifname))

        return 0
    dbg("Policy {} unchanged".format(policy))
    return 0

def interface_vlan_modify(intf, intfpolicy):
    global client

    try:
        client = configd.Client()
    except Exception as exc:
        err("Cannot establish client session: '{}'".format(str(exc).strip()))
        return 1

    type_found = False
    status = client.UNCHANGED # anything other than client.DELETED

    ifc = gid()
    if intf in ifc:
        type_found = True
        iftype = ifc[intf].type
    else:
        #
        # Interface may have been deleted, so try the
        # RUNNING config to find iftype.
        #
        intf_run_cfg = client.tree_get_dict("interfaces ",
                CONFIG_RUNNING, 'internal')['interfaces']
        for iftype in intf_run_cfg:
            if intf in intf_run_cfg[iftype]:
                type_found = True
                status = client.DELETED
                break

    if not type_found:
        err("Could not determine type for intf {}".format(intf))
        return 1

    ifpolicy = "interfaces {} {} switch-group port-parameters "\
               "policy vlan-modify".format(iftype, intf)

    try:
        if status != client.DELETED:
            status = client.node_get_status(CONFIG_CANDIDATE, ifpolicy)

        if status == client.UNCHANGED:
           err("No change - incorrect script invocation")
           return 1
        if status == client.ADDED:
           name = client.node_get(client.CANDIDATE, ifpolicy)
           dbg("Adding policy {} to intf {}".format(name[0], intf))
           vlan_modify_update_interface_policy(intf, name[0], "add")
           return 0
        if status == client.DELETED:
           name = client.node_get(client.RUNNING, ifpolicy)
           dbg("Deleting vlan-modify policy {} from interface {}".format(name[0], intf))
           vlan_modify_update_interface_policy(intf, name[0], "delete")
           return 0
        if status == client.CHANGED:
           name = client.node_get(client.CANDIDATE, ifpolicy)
           if (intfpolicy == name[0]):
               old_name = client.node_get(client.RUNNING, ifpolicy)
               dbg("Replacing vlan-modify policy {} with {} on interface {}".format(old_name[0], name[0], intf))
               vlan_modify_update_interface_policy(intf, old_name[0], "delete")
               vlan_modify_update_interface_policy(intf, name[0], "add")
           else:
               dbg("Ignoring vlan-modify policy {} change on interface {} when called for delete".format(intfpolicy, intf))
           return 0
        else:
           name = client.node_get(client.CANDIDATE, ifpolicy)
           err("Interface {} Policy {} change - incorrect script invocation".format(intf, name[0]))
           return 1

    except configd.Exception:
        dbg("There is no configuration under {}".format(ifpolicy))

    return 0

if __name__ == "__main__":

    logging.basicConfig(level=logging.INFO, format='vlan-modify: %(message)s')
    parser = argparse.ArgumentParser(description='Vyatta vlan modify command')
    parser.add_argument('-i','--interface', help='interface')
    parser.add_argument('-p','--policy', help='vlan-modify policy', required=True)
    parser.add_argument('-d','--debug', action='store_true', help='Enable debug output')
    args = parser.parse_args()

    if args.debug:
        log = logging.getLogger()
        log.setLevel(logging.DEBUG)
    if args.policy:
        if args.interface:
            ret = interface_vlan_modify(args.interface, args.policy)
            sys.exit(ret)
        ret = policy_vlan_modify(args.policy)
    else:
        parser.print_help()
        ret = 1

    sys.exit(ret)
