#!/usr/bin/python3
#
# Copyright (c) 2019-2020 AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#
# This is run after there has been a change under "interfaces", and
# it looks for npf assignments to interfaces and programs them in
# the dataplane. It also sends out traps if they are enabled.
#
# It can be called with the '--ruleset-warnings' option, which will cause
# it to just run checks for giving warnings for unchanged configuration
# and not send commands to the dataplane . This is intended to be called
# when firewall rulesets are changed which could need a warning sent
# (currently if rules change "state" or "session").
#

import sys
import getopt

from collections import defaultdict
from vyatta import configd
from vyatta.npf.npf_debug import NpfDebug
from vyatta.npf.npf_store import store_cfg, dataplane_commit
from vyatta.npf.npf_traps import send_npf_snmp_traps
from vyatta.npf.npf_warning import npf_config_warning

FORCE = False
COMMIT = True
WARNINGS = True
RULESET_WARNINGS = False

CONFIG_CANDIDATE = configd.Client.CANDIDATE
CONFIG_RUNNING = configd.Client.RUNNING

BASE_ADDRESS_PATH = "interfaces"

FW_RULESETS_PATH = "security firewall name"

# class used for printing debugs
dbg = NpfDebug()

rulesets_cfg = None
client = None


# Used for easy assignment of nested dictionaries
def nested_dict():
    return defaultdict(nested_dict)


fw_ruleset_dict = {
    "in": "fw-in",
    "out": "fw-out",
    "l2": "bridge",
    "local": "local",
    "originate": "originate",
}


def err(msg):
    print(msg, file=sys.stderr)


def npf_config_error(msg, path):
    print("[{}]\n\n{}".format(path, msg))


def add_rules_to_attach_point(key, cmd):
    fcmd = "npf-cfg attach " + cmd
    dbg.pprint("SET: " + fcmd)
    dbg.pprint("   " + key)
    store_cfg(key, fcmd, "SET", dbg)


def remove_rules_from_attach_point(key, cmd):
    fcmd = "npf-cfg detach " + cmd
    dbg.pprint("DELETE: " + fcmd)
    dbg.pprint("   " + key)
    store_cfg(key, fcmd, "DELETE", dbg)


def process_options():
    global FORCE, WARNINGS, RULESET_WARNINGS, COMMIT
    try:
        opts, rest = getopt.getopt(sys.argv[1:], "fd",
                                   ['force', 'debug', 'ruleset-warnings'])
    except getopt.GetoptError as r:
        err(r)
        err("usage: {} [-f|--force] [-d|--debug] [--ruleset-warnings] "
            .format(sys.argv[0]))
        sys.exit(2)

    for opt, arg in opts:
        if opt in ('-f', '--force'):
            FORCE = True
        elif opt in ('-d', '--debug'):
            dbg.enable()
        elif opt in ('--ruleset-warnings'):
            COMMIT = False
            RULESET_WARNINGS = True
            FORCE = True


def needs_validation(ruleset_type, ruleset):
    if (ruleset_type == "local"):
        return (ruleset, "locally")
    if (ruleset_type == "bridge"):
        return (ruleset, "for layer 2")
    return None


def run_validations(path, ruleset, msg_suffix):
    global rulesets_cfg, client, WARNINGS

    dbg.pprint("run_validations: path {}; ruleset {}; suffix {}"
               .format(path, ruleset, msg_suffix))

    if rulesets_cfg is None:
        # load in if not done previously
        try:
            rulesets_cfg = (client.tree_get_dict(FW_RULESETS_PATH,
                                                 CONFIG_CANDIDATE, 'internal')
                            ['name'])
        except configd.Exception:
            return

    if ruleset in rulesets_cfg and 'rule' in rulesets_cfg[ruleset]:
        for rulenum in rulesets_cfg[ruleset]['rule']:
            dbg.pprint("checking rule {}".format(rulenum))
            rcfg = rulesets_cfg[ruleset]['rule'][rulenum]
            if WARNINGS and 'state' in rcfg and rcfg['state'] == 'enable':
                npf_config_warning("firewall '{}' is not stateful when "
                                   "configured {}".format(ruleset, msg_suffix),
                                   path + ' rule ' + rulenum)
            if 'session' in rcfg:
                if WARNINGS:
                    npf_config_warning("firewall '{}' is sessionless when "
                                       "configured {}".format(ruleset,
                                                              msg_suffix),
                                       path + ' rule ' + rulenum)


def build_npf_interface_attach_point(commands, cfg, key, ifname,
                                     rg_class, ruleset_type, tree):
    dbg.pprint("build_npf_interface_attach_point: key {}; ifname {};"
               " ruleset_type {}".format(key, ifname, ruleset_type))

    if ifname == 'lo':
        ap = "global:"
    else:
        ap = "interface:" + ifname

    commands[ifname][ruleset_type][tree] = []
    for ruleset in cfg:
        command = ap + ' ' + ruleset_type + ' ' + rg_class + ':' + ruleset
        vals = needs_validation(ruleset_type, ruleset)
        if vals is not None:
            dbg.pprint("marking for validation: type {}; ruleset {}".format(
                       ruleset_type, ruleset))
        commands[ifname][ruleset_type][tree].append(tuple([key + ' ' + ruleset,
                                                    command, vals]))
        dbg.pprint("ADDED COMMAND: " + command)


def build_npf_interface_fw(commands, cfg, key, ifname, tree):

    dbg.pprint("build_npf_interface_fw: key {}; ifname {}".format(key, ifname))

    for dir in cfg:
        dbg.pprint("handling direction {}".format(dir))

        if dir in fw_ruleset_dict:
            ruleset_type = fw_ruleset_dict[dir]
        else:
            dbg.pprint("ERROR: unexpected direction {}".format(dir))
            continue

        build_npf_interface_attach_point(
          commands, cfg[dir], key + ' ' + dir, ifname, 'fw',
          ruleset_type, tree)


def build_npf_interface(commands, cfg, key, ifname, tree):

    dbg.pprint("build_npf_interface: key {}; ifname {}".format(key, ifname))

    if 'firewall' in cfg:
        build_npf_interface_fw(commands, cfg['firewall'],
                               key + ' firewall', ifname, tree)

    if 'policy' in cfg and 'route' in cfg['policy'] and \
       'pbr' in cfg['policy']['route']:
        build_npf_interface_attach_point(
            commands, cfg['policy']['route']['pbr'],
            key + ' policy route pbr', ifname, 'pbr', 'pbr', tree)


def build_npf_config(commands, cfg, tree):
    for iftype in cfg:
        for ifname in cfg[iftype]:
            key = 'interfaces ' + iftype + ' ' + ifname
            dbg.pprint("Processing interface {}".format(ifname))
            build_npf_interface(commands, cfg[iftype][ifname], key,
                                ifname, tree)

            # Handle vif interfaces
            if 'vif' in cfg[iftype][ifname]:
                for vifno in cfg[iftype][ifname]['vif']:
                    vkey = key + ' vif ' + vifno
                    vifname = ifname + '.' + vifno
                    dbg.pprint("Processing interface {}".format(vifname))
                    build_npf_interface(commands,
                                        cfg[iftype][ifname]['vif'][vifno],
                                        vkey, vifname, tree)


def program_npf_config(commands):
    global COMMIT, RULESET_WARNINGS

    dbg.pprint("program_npf_config()")
    snmp_traps = set()

    for ifname in commands:
        for ruleset_type in commands[ifname]:
            dbg.pprint("Interface: {}, ruleset_type: {}".format(ifname,
                       ruleset_type))

            try:
                rc = commands[ifname][ruleset_type]["running"]
            except KeyError:
                rc = []

            try:
                cc = commands[ifname][ruleset_type]["cand"]
            except KeyError:
                cc = []

            # The entries in running and candidate that start the same
            # to not need programmed, as there is no change to them,
            # however we want to check for warnings on these unchanged rules
            while len(rc) > 0:
                if len(cc) == 0 or rc[0][1] != cc[0][1]:
                    break
                if RULESET_WARNINGS:
                    c = cc[0]
                    if c[2] is not None:
                        dbg.pprint("validations for {}".format(c[2]))
                        # marked for validations, so check now
                        run_validations(c[0], c[2][0], c[2][1])
                dbg.pprint("Identical at start, so not sending to dp: {}".
                           format(cc[0][1]))
                rc.pop(0)
                cc.pop(0)

            # Remove entries which are left in the running config
            if COMMIT:
                for r in rc:
                    remove_rules_from_attach_point(r[0], r[1])
                    snmp_traps.add(r[0])

            # Add entries which are left in the candidate config
            for c in cc:
                if not RULESET_WARNINGS and c[2] is not None:
                    dbg.pprint("validations for {}".format(c[2]))
                    # marked for validations on adding, so check now
                    run_validations(c[0], c[2][0], c[2][1])

                if COMMIT:
                    add_rules_to_attach_point(c[0], c[1])
                    snmp_traps.add(c[0])

    if COMMIT:
        # request the dataplane rebuilds the rulesets
        dataplane_commit(dbg)

        dbg.pprint("Traps to send: {}".format(snmp_traps))
        send_npf_snmp_traps(list(snmp_traps), dbg)


def program_npf_interfaces():
    global client, FORCE
    dbg.pprint("program_npf_interfaces()")

    commands = nested_dict()

    try:
        client = configd.Client()
    except Exception as exc:
        err("Cannot establish client session: '{}'".format(str(exc).strip()))
        return 1

    try:
        status = client.node_get_status(CONFIG_CANDIDATE, BASE_ADDRESS_PATH)

        if status == client.UNCHANGED and not FORCE:
            dbg.pprint("unchanged: {} so no work to do".
                       format(BASE_ADDRESS_PATH))
            return 0

        try:
            cand_cfg = (client.tree_get_dict(BASE_ADDRESS_PATH,
                                             CONFIG_CANDIDATE, 'internal')
                        ['interfaces'])
            dbg.pprint("BUILD CANDIDATE")
            build_npf_config(commands, cand_cfg, "cand")
        except configd.Exception:
            dbg.pprint("failed getting candidtate tree for {}".
                       format(BASE_ADDRESS_PATH))

    except configd.Exception:
        dbg.pprint("there is no configuration under {}".format(
                   BASE_ADDRESS_PATH))

    try:
        running_cfg = client.tree_get_dict(BASE_ADDRESS_PATH, CONFIG_RUNNING,
                                           'internal')['interfaces']
        dbg.pprint("BUILD RUNNING")
        build_npf_config(commands, running_cfg, "running")
    except configd.Exception:
        dbg.pprint("failed getting running tree for {}".format(
                   BASE_ADDRESS_PATH))

    # send commands to the dataplane using cstore which will change
    # the running configuration into the candidate configuration
    program_npf_config(commands)


def program_npf_interfaces_main():
    program_npf_interfaces()
    return 0


if __name__ == "__main__":
    process_options()
    ret = program_npf_interfaces_main()
    exit(ret)
