#!/usr/bin/env python3
#
# Copyright (c) 2019-2020 AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#

"""Build and send dataplane configuration messages from the IP packet filter YANG"""

# Grammar:
# =======
#
# CLI command:                      Dataplane config:
# -----------                       ----------------
#
# counters                          counter config sent in rule 0: handle=ctr_def(...)
#     count bytes                     bytes=Y
#     count packets                   packets=Y
#     named NNNN                      named=name1,named=name2,...
#     sharing per-group               sharing=per-interface
#     sharing per-interface           sharing=per-group
#     type auto-per-action            type=auto-per-action
#     type auto-per-rule              type=auto-per-rule
#     type named                      type=named
#
# ip-version ...                    sent in rule 0: family=inet | inet6
#
# description                       (not sent to dataplane)
# platform ...                      (not sent to dataplane)
#
# rule N ...
#     description                   (not sent to dataplane)
#     disable                       (rule is removed from dataplane)
#
#     match ...
#         source ipv4 host          src-addr=1.1.1.1
#         source ipv4 prefix        src-addr=1.1.1.1/nn
#         source ipv6 host          src-addr=1::1
#         source ipv6 prefix        src-addr=1::1/nn
#         source port equals        src-port=NN
#         destination ipv4 host     dst-addr=1.1.1.1
#         destination ipv4 prefix   dst-addr=1.1.1.1/nn
#         destination ipv6 host     dst-addr=1::1
#         destination ipv6 prefix   dst-addr=1::1/nn
#         destination port equals   dst-port=NN
#         fragment any              fragment=y           # NB "y" rather than "any"
#         fragment initial-only     fragment=initial
#         fragment subsequent-only  fragment=subsequent
#         dscp name                 dscp=N               # name converted to value
#         dscp value                dscp=N
#         ttl equals                ttl=NN
#         icmp name                 icmpv4=AAAA
#         icmp type N               icmpv4=TT            # match any VV
#         icmp type N code N        icmpv4=TT:VV
#         icmpv6 name               icmpv6=AAAA
#         icmpv6 type N             icmpv6=TT
#         icmpv6 type N code N      icmpv6=TT:VV         # match any code
#         icmpv6 class              icmpv6-class:AAAA
#         protocol base name        proto-base=NN        # name converted to value
#         protocol base number      proto-base=NN
#         protocol final name       proto-final=NN       # name converted to value
#         protocol final number     proto-final=NN
#         protocol final unknown    proto-final=256      # NB special value
#
#     then ...
#         action accept             action=accept
#         action drop               action=drop
#         action punt               action=accept, rproc=punt
#         action reject             action=drop, rproc=reject
#         counter local NNNN        rproc=ctr_ref(local=NNNN)
#         counter global NNNN       rproc=ctr_ref(global=NNNN)
#         log                       rproc=log
#
# Example output:
#   npf-cfg add acl:PF1 10 action=drop proto-base=6 src-addr=1.1.1.1 src-port=11 dst-addr=2.2.0.0/22 dst-port=22 dscp=8 fragment=any
#   npf-cfg attach interface:dp0p1s1 acl-in acl:PF1


import os
import sys
import argparse
import vplaned
from vyatta import configd

ROOTPATH = "security ip-packet-filter"


def store(key, config, action):
    """
    Send the given config to the dataplane

    @input:  key:    the configuration key

    @input:  config: the configuration to be sent

    @input:  action: the configuration action

    @output: none
    """

    with vplaned.Controller() as controller:
        controller.store(key, config, action=action)


def get_addr(kind, address):
    """
    Build the config string for a single address

    @input:  kind: the prefix "src" or "dst"

    @input:  address: the address dictionary,
                      eg {'prefix': 'n.n.n.n/NN'}

    @output: the address config string, eg "src-addr=n.n.n.n"
    """

    addr_str = ""
    for addr in address:
        addr_str += "{}-addr={} ".format(kind, address.get(addr))

    return addr_str


def get_port_number(kind, ports):
    """
    Build the config string for a single port

    @input:  kind: the prefix "src" or "dst"

    @input:  ports: a list of port numbers

    @output: port config string, eg "src-port=80"
    """

    return "{}-port={} ".format(kind, ",".join(str(p) for p in ports))


def get_port(kind, ports):
    """
    Build the port config from the ports dictionary

    @input:  kind: the prefix "src" or "dst"

    @input:  ports: the ports dictionary
                    eg {'number': 22}

    @output: port config string eg "dst-port=22"
    """

    fn_table = {
        'number': get_port_number,
    }

    port_str = ""
    for port in ports:
        port_str += fn_table[port](kind, ports.get(port))

    return port_str


def get_address(kind, addresses):
    """
    Build the address config from the addresses dict

    @input:  addresses: the addresses dictionary
                        eg {'ipv4': {'host': 'n.n.n.n'}, 'port': {'number': NN}}

    @output: address config string, eg "src-addr=n.n.n.n src-port=NN"
    """

    fn_table = {
        'ipv4': get_addr,
        'ipv6': get_addr,
        'port': get_port,
    }

    addr_str = ""
    for addr in addresses:
        addr_str += fn_table[addr](kind, addresses.get(addr))

    return addr_str


def destination_fn(addr):
    """
    Return the destination config string

    @input:  addr: the addresses dictionary

    @output: address config string, eg "dst-addr=n.n.n.n dst-port=NN"
    """

    return get_address("dst", addr)


def source_fn(addr):
    """
    Return source config string

    @input:  addr: the addresses dictionary

    @output: address config string, eg "src-addr=n.n.n.n src-port=NN"
    """

    return get_address("src", addr)


def dscp_fn(dscp):
    """
    Build the DSCP config from the DSCP dict

    @input:  dscp: the dscp dictionary
                   eg {'name': 'AAA'} or {'value': NN}

    @output: DSCP config string, "dscp=NN"
    """

    # DSCP list from
    #   https://www.iana.org/assignments/dscp-registry/dscp-registry.xhtml#dscp-registry-1
    #
    # plus 'default' of zero

    dscp_list = {
        'cs0':      0,
        'cs1':      8,
        'cs2':      16,
        'cs3':      24,
        'cs4':      32,
        'cs5':      40,
        'cs6':      48,
        'cs7':      56,
        'af11':     10,
        'af12':     12,
        'af13':     14,
        'af21':     18,
        'af22':     20,
        'af23':     22,
        'af31':     26,
        'af32':     28,
        'af33':     30,
        'af41':     34,
        'af42':     36,
        'af43':     38,
        'ef':       46,
        'va':       44,
        'default':  0,
    }

    # NB there's only one element in the dictionary
    for d in dscp:
        val = dscp.get(d)

        d_fn = {
            'name':  dscp_list.get(val),
            'value': val,
        }

        return "dscp={} ".format(d_fn[d])


def fragment_fn(frag_str):
    """
    Build the fragment config from the fragment str

    @input:  frag_str: the fragment string
                       eg 'initial-only', 'subsequent-only', 'any'

    @output: fragment config string, eg "fragment=initial-only"
    """

    f = {
        'any':             "y",
        'initial-only':    "initial",
        'subsequent-only': "subsequent"
    }

    return "fragment={} ".format(f.get(frag_str))


def get_icmp_name(kind, icmp_name):
    """
    Build the ICMP name config from the ICMP name

    @input:  kind: either 'icmpv4' or 'icmpv6'

    @input:  icmp_name: the ICMP name

    @output: ICMP config string, "icmpv4=AAAA"
    """

    return "{}={} ".format(kind, icmp_name)


def get_icmp_type(kind, icmp):
    """
    Build the ICMP type and code config

    @input:  kind: either 'icmpv4' or 'icmpv6'

    @input:  icmp: a list containing the ICMP dictionary
                   eg [{'type-number': TT, 'code': iCC}]

    @output: ICMP config string, "icmpv4=TT:CC"
    """

    icmp_type = icmp[0].get("type-number")
    icmp_code = icmp[0].get("code")

    if icmp_code:
        icmp_msg = "{}={}:{} ".format(kind, icmp_type, icmp_code)
    else:
        icmp_msg = "{}={} ".format(kind, icmp_type)

    return icmp_msg


def get_icmp_class(kind, icmp):
    """
    Build the ICMP class config

    @input:  kind: either 'icmpv4' or 'icmpv6'

    @input:  icmp: the ICMP class
                   eg "info" or "error"

    @output: ICMP config string, "icmpv4=TT:CC"
    """

    return "{}-class={} ".format(kind, icmp)


def icmp_fn(kind, icmp):
    """
    Build the ICMP config from the ICMP dict

    @input:  kind: either 'icmpv4' or 'icmpv6'

    @input:  icmp: the ICMP dictionary
                   eg {'type': [{'code': CC, 'type-number': TT}]}
                   or {'name': 'AAAA'}
                   or {'class': 'AAAA'}

    @output: ICMP config string, "icmpv4=TT:CC" or "icmpv4=AAAA"
    """

    fn_table = {
        'name':  get_icmp_name,
        'type':  get_icmp_type,
        'class': get_icmp_class,
    }

    icmp_str = ""
    for i in icmp:
        icmp_str += fn_table[i](kind, icmp.get(i))

    return icmp_str


def icmpv4_fn(icmp):
    """
    Build the ICMPv4 config from the ICMP dict

    @input:  icmp: the ICMP dictionary
                   eg {'type': [{'code': CC, 'type-number': TT}]}
                   or {'name': 'AAAA'}

    @output: ICMPv4 config string, "icmpv4=TT:CC" or "icmpv4=AAAA"
    """

    return icmp_fn("icmpv4", icmp)


def icmpv6_fn(icmp):
    """
    Build the ICMPv6 config from the ICMP dict

    @input:  icmp: the ICMP dictionary
                   eg {'type': [{'code': CC, 'type-number': TT}]}
                   or {'name': 'AAAA'}
                   or {'class': 'AAAA'}

    @output: ICMPv6 config string, "icmpv6=TT:CC", "icmpv6=AAAA", or "icmpv6-class:AAAA"
    """

    return icmp_fn("icmpv6", icmp)


def get_proto(protocol):
    """
    Return the protocol ID for a single protocol

    @input:  protocol - dictionary for a single protocol
                        eg {'name': 'ip'}, {'number': N}, or {'unknown': None}

    @output: the protocol ID, or the special value 256 for 'unknown'
    """

    # Protocol list from `getent protocols`, with "ip 0" removed.
    protocol_list = {
        'hopopt':           0,
        'icmp':             1,
        'igmp':             2,
        'ggp':              3,
        'ipencap':          4,
        'st':               5,
        'tcp':              6,
        'egp':              8,
        'igp':              9,
        'pup':              12,
        'udp':              17,
        'hmp':              20,
        'xns-idp':          22,
        'rdp':              27,
        'iso-tp4':          29,
        'dccp':             33,
        'xtp':              36,
        'ddp':              37,
        'idpr-cmtp':        38,
        'ipv6':             41,
        'ipv6-route':       43,
        'ipv6-frag':        44,
        'idrp':             45,
        'rsvp':             46,
        'gre':              47,
        'esp':              50,
        'ah':               51,
        'skip':             57,
        'ipv6-icmp':        58,
        'ipv6-nonxt':       59,
        'ipv6-opts':        60,
        'rspf':             73,
        'vmtp':             81,
        'eigrp':            88,
        'ospf':             89,
        'ax.25':            93,
        'ipip':             94,
        'etherip':          97,
        'encap':            98,
        'pim':              103,
        'ipcomp':           108,
        'vrrp':             112,
        'l2tp':             115,
        'isis':             124,
        'sctp':             132,
        'fc':               133,
        'mobility-header':  135,
        'udplite':          136,
        'mpls-in-ip':       137,
        'manet':            138,
        'hip':              139,
        'shim6':            140,
        'wesp':             141,
        'rohc':             142,
    }

    # NB there's only one element in the dictionary
    for kind in protocol:
        val = protocol.get(kind)

        p = {
            'name':    protocol_list.get(val),
            'number':  val,
            'unknown': 256,     # special value
        }

        return p[kind]


def protocol_fn(protocols):
    """
    Build the protocol config from the protocols dict

    @input:  protocols: the protocols dictionary
                        eg {'base': {'number': N}, 'final': {'number': N}}

    @output: protocol config string eg "proto-base=N proto-final=N"
    """

    proto_str = ""
    for proto in protocols:
        proto_str += "proto-{}={} ".format(proto, get_proto(protocols.get(proto)))

    return proto_str


def ttl_fn(ttl):
    """
    Build the TTL config from the TTL dict

    @input:  ttl: the ttl dictionary
                  eg {'equals': 255}}

    @output: TTL config string, "ttl=NN"
    """

    # NB there's only one element in the dictionary
    for t in ttl:
        val = ttl.get(t)

        fn = {
            'equals': val,
        }

        return "ttl={} ".format(fn[t])


def action_fn(action, rproc_list):
    """
    Build the action config from the action string

    @input:  action: the action dictionary
                     eg "action=accept"

    @output: action config string, eg "action=accept"
             rprocs may be added to rproc_list
    """

    if action == "punt":
        # Change "punt" to "accept" action + "punt" rproc
        action = "accept"
        rproc_list += ["punt"]
    elif action == "reject":
        # Change "reject" to "drop" action + "reject" rproc
        action = "drop"
        rproc_list += ["reject"]

    return "action={} ".format(action)


def counter_fn(counter, rproc_list):
    """
    Build the counter config from the counter dict

    @input:  counter: the counter dictionary
                      eg {'local': 'NNNN'}

    @output: counter config string, "ctr_ref(local=NNNN)" or "ctr_ref(global=NNNN)"
    """

    # NB there's only one element in the dictionary
    for count in counter:
        rproc_list += ["ctr_ref({}={});ctr({})".format(
            count, counter.get(count), counter.get(count))]
        return ""


def log_fn(_log, rproc_list):
    """
    Build the log config

    @input:  _log: unused. It's a presence param without any value
                   eg log=None

    @output: action config string, "ttl=NN"
    """

    rproc_list += ["log"]
    return ""


def get_rprocs(rproc_list):
    """
    Build the rproc string from the rproc list

    @input:  rproc_list: the list of rprocs
                         eg ['item1', 'item2', 'item3']

    @output: rproc config string, "rprocs=(item1, item2, item3)"
             or an empty string "" if there are no rprocs
    """

    if rproc_list:
        # Join the rproc_list elements with a semicolon.
        rprocs = "rproc=" + ";".join(str(r) for r in rproc_list) + " "
    else:
        rprocs = ""

    return rprocs


def get_counter_name(action):
    """
    Get the auto-per-action counter name from the action.

    @input:  action: action dict, eg "{'accept': None}"

    @output: The required counter name, eg "accept" or "drop".
    """
    actions = {"accept", "drop", "punt", "reject"}
    for act in actions:
        if act in action.keys():
            return act

    return None


def process_rule(group_name, rule, ctr_type):
    """
    Build and send the dataplane config for a single rule

    @input:  group_name: the name of the group containing the rule

    @input:  rule: a rule containing number, match, and then clauses

    @output: send the rule configuration to the dataplane
    """

    if "disable" in rule.keys():
        # Don't send disabled rules to the dataplane
        # There's no need to delete them,
        # because the whole group is deleted and rules resent on a change.
        return

    number = rule.get("number")
    match = rule.get("match")
    action = rule.get("action")

    config = "npf-cfg add acl:{} {} ".format(group_name, number)

    match_table = {
        'destination': destination_fn,
        'dscp':        dscp_fn,
        'fragment':    fragment_fn,
        'icmp':        icmpv4_fn,
        'icmpv6':      icmpv6_fn,
        'protocol':    protocol_fn,
        'source':      source_fn,
        'ttl':         ttl_fn,
    }

    action_table = {
        'accept':      action_fn,
        'drop':        action_fn,
        'counter':     counter_fn,
        'log':         log_fn,
    }

    if match:
        for key in match:
            config += match_table[key](match.get(key))

    rproc_list = []

    if action:
        for key in action:
            config += action_table[key](key, rproc_list)

    config += get_rprocs(rproc_list)

    # Add ctr_ref for auto-per-rule and auto-per-action
    if ctr_type == "auto-per-rule":
        config += "rproc=ctr_ref(numbered);ctr(numbered)"
    if ctr_type == "auto-per-action":
        ctr_name = get_counter_name(action)
        if ctr_name:
            config += "rproc=ctr_ref({});ctr({})".format(ctr_name, ctr_name)

    key = "security acl group {} rule {}".format(group_name, number)

    store(key, config, "SET")


#
# Group
#

def process_group_counters_count(count):
    """
    Whether to count packets or bytes?

    @input:  count: list of the items to count
                    eg count = ['bytes', 'packets']

    @output: counter config string
             eg bytes=Y,packets=Y
    """

    if count:
        # Comma-separated list of "element=Y"
        return ",".join("{}=Y".format(c) for c in count)

    return None


def process_group_counters(counters):
    """
    Process the group counters configuration

    @input:  counters: dictionary containing the group counter configuration

    @output: if there are any counters, return the rproc ctr_def config
             eg: ctr_def(count=packets,named=[name1,name2],sharing=per-interface,type=named)
             Also return the counter type: auto-per-action, auto-per-rule, named.

             if there are no counters, return None
    """

    if 'count' not in counters:
        return None, None

    if 'sharing' not in counters:
        return None, None

    if 'type' not in counters:
        return None, None

    # Count packets or bytes?
    count = process_group_counters_count(counters["count"])
    if not count:
        # There's nothing to do here if there's nothing to count
        return None, None

    # ctr_def is a handle rproc because it has no packet processing logic
    config = " handle=ctr_def({},sharing={}".format(count, counters["sharing"])

    ctr_type = counters["type"]

    if "named" in ctr_type:
        # Add all the named counters, if any.
        # Comma-separated list of "name={}"
        config += ",type=named," + ",".join("named={}".format(n) for n in ctr_type["named"])

    elif "auto-per-action" in ctr_type:
        auto_per_action = ctr_type["auto-per-action"]
        if "action" in auto_per_action:
            action = auto_per_action["action"]
            if "accept" in action:
                config += ",type=named,named=accept"
            elif "drop" in action:
                config += ",type=named,named=drop"
        else:
            # Add all the actions
            config += ",type=named,named=accept,named=drop"

    elif "auto-per-rule" in ctr_type:
        # The dataplane must name the counters after the rule numbers.
        config += ",type=numbered"

    else:
        assert 0, "Invalid counter type {}".format(ctr_type)

    config += ")"   # end of ctr_def

    # Return the config and ctr_type. Avoid returning a "dict_keys" object.
    return config, list(ctr_type)[0]


def convert_af(af):
    """
    Convert the given address family name into the corresponding protocol family name

    @input:  af: the address family name (ipv4, ipv6)

    @output: the corresponding protocol family name (inet, inet6), or None
    """

    if af == "ipv4":
        return "inet"

    if af == "ipv6":
        return "inet6"

    return None


def process_group_config(group):
    """
    Process the group configuration (ie, the config outwith the rules)

    @input:  group: the group to be processed

    @output: send the group configuration to the dataplane (in rule 0)
             eg "npf-cfg add acl:GroupName 0 family=inet rproc=ctr_def(params)"

             params:
                count = bytes | packets | both
                named = [name1,name2,...,nameN] (only if type == named)
                sharing = per-interface | per-group
                type = auto-per-rule | auto-per-action | named

             eg: rproc=ctr_def(count=packets,named=[name1,name2],sharing=per-interface,type=named)
    """

    group_name = group["group-name"]
    address_family = convert_af(group["ip-version"])

    config = "npf-cfg add acl:{} 0 family={}".format(group_name, address_family)

    # Add counter configuration, if any
    if "counters" in group:
        ctr_cfg, ctr_type = process_group_counters(group["counters"])
        config += ctr_cfg
    else:
        ctr_type = None

    # Finally, send config to the dataplane in rule 0
    key = "security acl group {} rule 0".format(group_name)
    store(key, config, "SET")

    return ctr_type


def delete_group(group_name):
    """
    Delete the given group

    @input:  group_name: name of the group to be deleted

    @output: send the group delete to the dataplane
    """

    key = "security acl group {}".format(group_name)
    cmd = "npf-cfg delete acl:{}".format(group_name)
    store(key, cmd, "DELETE")


def process_group(client, group_name):
    """
    Process all the rules in a group

    @input:  client: configd handle

    @input:  group_name: name of the IPPF group to be processed

    @output: none
    """

    try:
        group = client.tree_get_full_dict(ROOTPATH + " group {}"
                                          .format(group_name), client.CANDIDATE)
    except Exception:
        print("Group {} doesn't exist\n".format(group_name))
        return

    # Ensure we got the right group
    assert group.get("group-name") == group_name, "Group name doesn't match"

    # Delete the existing group before sending the new rules
    try:
        delete_group(group_name)
    except Exception:
        pass

    ctr_type = process_group_config(group)

    rules = group.get("rule")
    for rule in rules:
        process_rule(group_name, rule, ctr_type)


#
# Interface
#

def delete_interface(client, interface_name, err_msg):
    """
    Delete IPPF from the given interface

    @input:  client: configd handle

    @input:  interface_name: interface from which IPPF is to be deleted

    @input:  err_msg: whether to print error messages.

    @output: send the interface config delete to the dataplane
    """

    try:
        interface = client.tree_get_full_dict(ROOTPATH + " interface {}"
                                              .format(interface_name), client.RUNNING)
    except Exception:
        # Interface doesn't exist. Only emit a message if required.
        if err_msg:
            print("Interface {} doesn't exist\n".format(interface_name))
        return

    rules = interface.get("in")
    if rules:
        for rule in rules:
            key = "security acl interface {} acl-in {}".format(interface_name, rule)
            cmd = "npf-cfg detach interface:{} acl-in acl:{}".format(interface_name, rule)
            store(key, cmd, "DELETE")

    rules = interface.get("out")
    if rules:
        for rule in rules:
            key = "security acl interface {} acl-out {}".format(interface_name, rule)
            cmd = "npf-cfg detach interface:{} acl-out acl:{}".format(interface_name, rule)
            store(key, cmd, "DELETE")


def process_interface(client, interface_name):
    """
    Process IPPF for the given interface

    @input:  client: configd handle

    @input:  interface_name: interface on which IPPF is to be configured

    @output: send the interface config to the dataplane
    """

    try:
        interface = client.tree_get_full_dict(ROOTPATH + " interface {}"
                                              .format(interface_name), client.CANDIDATE)
    except Exception:
        print("Interface {} doesn't exist\n".format(interface_name))
        return

    # Delete the existing interface before sending the new rules.
    # No error message required.
    try:
        delete_interface(client, interface_name, False)
    except Exception:
        pass

    rules = interface.get("in")
    if rules:
        for rule in rules:
            key = "security acl interface {} acl-in {}".format(interface_name, rule)
            cmd = "npf-cfg attach interface:{} acl-in acl:{}".format(interface_name, rule)
            store(key, cmd, "SET")

    rules = interface.get("out")
    if rules:
        for rule in rules:
            key = "security acl interface {} acl-out {}".format(interface_name, rule)
            cmd = "npf-cfg attach interface:{} acl-out acl:{}".format(interface_name, rule)
            store(key, cmd, "SET")


#
# Commit
#

def commit():
    """
    Commit the config

    @input:  none

    @output: send the commit message to the dataplane
    """

    path = "npf-cfg commit"
    store(path, path, "SET")


#
# Main
#

def main():
    """
    Parse the arguments and establish a configd session
    before handing-off to the necessary helpers
    """

    parser = argparse.ArgumentParser(
        prog='end-ippf-ruleset',
        description="Process IP Packet Filter rules")
    parser.add_argument("--group", dest="group",
                        help="group",
                        nargs='?', default="")
    parser.add_argument("--interface", dest="interface",
                        help="interface",
                        nargs='?', default="")
    parser.add_argument("--commit", dest="commit",
                        help="group",
                        action='store_true')

    args = parser.parse_args()

    try:
        client = configd.Client()
    except Exception as exc:
        print("Cannot establish client session: '{}'".format(str(exc).strip()))
        return 1

    try:
        commit_action = os.environ["COMMIT_ACTION"]
    except:
        print("Unspecified commit action\n")
        return 1

    if args.group:
        if commit_action == "DELETE":
            delete_group(args.group)
        else:
            process_group(client, args.group)

    if args.interface:
        if commit_action == "DELETE":
            delete_interface(client, args.interface, True)
        else:
            process_interface(client, args.interface)

    if args.commit:
        commit()

    return 0


if __name__ == "__main__":
    sys.exit(main())
