#!/usr/bin/python3
#
# Copyright (c) 2019-2020 AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#
# This is run after there has been a change in resource groups.
# it looks for the changes and programs them in the dataplane.

import sys
import getopt
import socket
import re

from vyatta import configd
from vyatta.npf.npf_debug import NpfDebug
from vyatta.npf.npf_store import store_cfg, dataplane_commit
from vyatta.npf.npf_traps import send_npf_snmp_traps
from vyatta.npf.npf_warning import npf_config_warning

FORCE = False

CONFIG_CANDIDATE = configd.Client.CANDIDATE
CONFIG_RUNNING = configd.Client.RUNNING

ADDR_CMD = "address-group"
ICMP_CMD = "icmp-group"
ICMP6_CMD = "icmpv6-group"
PORT_CMD = "port-group"
PROTO_CMD = "protocol-group"

RG_BASE = "resources group"
BASE_ADDR_PATH = RG_BASE + " " + ADDR_CMD
BASE_ICMP_PATH = RG_BASE + " " + ICMP_CMD
BASE_ICMP6_PATH = RG_BASE + " " + ICMP6_CMD
BASE_PORT_PATH = RG_BASE + " " + PORT_CMD
BASE_PROTO_PATH = RG_BASE + " " + PROTO_CMD

COMMAND_ADDRESS_PREFIX = "npf-cfg fw table"

ICMPv4 = 4
ICMPv6 = 6

# class used for printing debugs
dbg = NpfDebug()


def get_configs(ext):

    global client, cand_cfg, running_cfg

    path = RG_BASE + " " + ext

    try:
        status = client.node_get_status(CONFIG_CANDIDATE, path)

        if status == client.UNCHANGED and not FORCE:
            dbg.pprint("unchanged: {} so no work to do".
                       format(path))
            return ({}, {})

        try:
            c_cfg = cand_cfg[ext]
        except KeyError:
            dbg.pprint("failed getting candidate tree for {}".
                       format(path))
            c_cfg = {}

    except configd.Exception:
        dbg.pprint("there is no configuration under {}".
                   format(path))
        c_cfg = {}

    try:
        r_cfg = running_cfg[ext]
    except KeyError:
        dbg.pprint("failed getting running tree for {}".
                   format(path))
        r_cfg = {}

    return (c_cfg, r_cfg)


# ===== Address =====

def create_address_group(group):
    dbg.pprint("Creating address group {}".format(group))
    store_cfg("{} {}".format(BASE_ADDR_PATH, group),
              "{} create {}".format(COMMAND_ADDRESS_PREFIX, group),
              "SET", dbg)


def delete_address_group(group):
    dbg.pprint("Deleting address group {}".format(group))
    store_cfg("{} {}".format(BASE_ADDR_PATH, group),
              "{} delete {}".format(COMMAND_ADDRESS_PREFIX, group),
              "DELETE", dbg)


def add_address(group, address):
    dbg.pprint("Adding address {} to group {}".format(address, group))
    store_cfg("{} {} address {}".format(BASE_ADDR_PATH, group, address),
              "{} add {} {}".format(COMMAND_ADDRESS_PREFIX, group, address),
              "SET", dbg)


def delete_address(group, address):
    dbg.pprint("Deleting address {} from group {}".format(address, group))
    store_cfg("{} {} address {}".format(BASE_ADDR_PATH, group, address),
              "{} remove {} {}".format(COMMAND_ADDRESS_PREFIX, group, address),
              "DELETE", dbg)


def add_address_range(group, start_addr, end_addr):
    dbg.pprint("Adding address-range {}-{} to group {}".format(start_addr,
               end_addr, group))
    store_cfg("{} {} address-range {}".format(BASE_ADDR_PATH, group,
                                              start_addr),
              "{} add {} {} {}".format(COMMAND_ADDRESS_PREFIX, group,
                                       start_addr, end_addr),
              "SET", dbg)


def delete_address_range(group, start_addr, end_addr):
    dbg.pprint("Deleting address-range {}-{} from group {}".format(start_addr,
               end_addr, group))
    store_cfg("{} {} address-range {}".format(BASE_ADDR_PATH, group,
                                              start_addr),
              "{} remove {} {} {}".format(COMMAND_ADDRESS_PREFIX, group,
                                          start_addr, end_addr),
              "DELETE", dbg)


def program_address_groups():
    """
    Create and delete address groups.
    """

    dbg.pprint("program_address_groups()")

    (cand_cfg, running_cfg) = get_configs(ADDR_CMD)

    # Look for deleted configuration
    for group in sorted(running_cfg):
        dbg.pprint("Processing old group {}". format(group))

        if group not in cand_cfg:
            dbg.pprint("No new group {}". format(group))
            # Whole group removed
            delete_address_group(group)
            continue

        try:
            for addr in running_cfg[group]['address']:
                dbg.pprint("  Processing old address {}". format(addr))

                try:
                    if addr in cand_cfg[group]['address']:
                        dbg.pprint("  also exists as new address - no change")
                    else:
                        delete_address(group, addr)
                except KeyError:
                    delete_address(group, addr)
        except KeyError:
            # no 'address' running config
            pass

        try:
            for start_addr in running_cfg[group]['address-range']:
                try:
                    end_addr = (running_cfg[group]['address-range']
                                [start_addr]['to'])
                except KeyError:
                    dbg.pprint("  Invalid old addr config for group {} "
                               "start addr {}".format(group, start_addr))
                    continue

                dbg.pprint("  Processing old address-range {}-{}".format(
                           start_addr, end_addr))

                try:
                    if start_addr in cand_cfg[group]['address-range']:
                        dbg.pprint("  Start of range in old and new")
                    else:
                        delete_address_range(group, start_addr, end_addr)
                        continue
                except KeyError:
                    delete_address_range(group, start_addr, end_addr)
                    continue
        except KeyError:
            # no 'address-range' in running config
            pass

    # Look for added configuration
    for group in sorted(cand_cfg):
        dbg.pprint("Processing new group {}". format(group))

        if group not in running_cfg:
            dbg.pprint("New group {}". format(group))
            if group == "masquerade":
                npf_config_warning("address group 'masquerade' will be "
                                   "ignored in NAT rules")
            # New group to create
            create_address_group(group)

        try:
            # look for added addresses
            for addr in cand_cfg[group]['address']:
                dbg.pprint("  Processing new address {}". format(addr))

                try:
                    if addr in running_cfg[group]['address']:
                        dbg.pprint("  also exists as old address - no change")
                    else:
                        add_address(group, addr)
                except KeyError:
                    add_address(group, addr)
        except KeyError:
            # no 'address' in candidate config
            pass

        try:
            # look for added address ranges
            for start_addr in cand_cfg[group]['address-range']:
                try:
                    end_addr = (cand_cfg[group]['address-range']
                                [start_addr]['to'])
                except KeyError:
                    dbg.pprint("  Invalid new addr config for group {} "
                               "start addr {}".format(group, start_addr))
                    continue

                dbg.pprint("  Processing new address-range {}-{}".format(
                           start_addr, end_addr))

                try:
                    if end_addr in (running_cfg[group]['address-range']
                                    [start_addr]['to']):
                        dbg.pprint("  also exists in old address-range"
                                   " - no change")
                        continue
                except KeyError:
                    pass
                add_address_range(group, start_addr, end_addr)
        except KeyError:
            # no 'address-range' in candidate config
            pass


# ===== ICMP and ICMPv6 =====

def create_icmp_group(group, proto, path, group_cfg):
    """
    Create the named ICMP or ICMPv6 group

    @group:     the name of the group to be created
    @proto:     icmp or icmpv6
    @path:      BASE_ICMP_PATH or BASE_ICMP6_PATH
    @group_cfg: the items to be added to the group
    """

    dbg.pprint("Creating {} group {}".format(proto, group))
    store_cfg("{} {} {}".format(path, group, proto),
              "npf-cfg add {}-group:{} 0 {}".
              format(proto, group, ";".join(group_cfg)),
              "SET", dbg)


def delete_icmp_group(group, proto, path):
    """
    Delete the named ICMP or ICMPv6 group

    @group: the icmp group name to be created
    @proto: icmp or icmpv6
    @path:  BASE_ICMP_PATH or BASE_ICMP6_PATH
    """

    dbg.pprint("Deleting {} group {}".format(proto, group))
    store_cfg("{} {}".format(path, group),
              "npf-cfg delete {}-group:{}".format(proto, group),
              "DELETE", dbg)


def program_icmp_groups(proto):
    """
    Create and delete ICMP and ICMPv6 groups.

    We process three kinds of configuration:
      * name <name>
      * type <type>
      * type <type> code <code>

    Names are passed down unconverted.

    @proto:  Either ICMPv4 or ICMPv6
    """

    if proto == ICMPv4:
        cmd = ICMP_CMD
        path = BASE_ICMP_PATH
        name = "icmp"
    elif proto == ICMPv6:
        cmd = ICMP6_CMD
        path = BASE_ICMP6_PATH
        name = "icmpv6"
    else:
        return

    dbg.pprint("program_{}_groups()".format(name))

    (cand_cfg, running_cfg) = get_configs(cmd)

    if cand_cfg == {} and running_cfg == {}:
        dbg.pprint("unchanged so no work to do")
        return

    # Look for deleted groups
    for group in sorted(running_cfg.keys()):
        dbg.pprint("Processing old group {}". format(group))

        if group not in cand_cfg:
            # Whole group removed
            dbg.pprint("No new group {}". format(group))
            delete_icmp_group(group, name, path)
            continue

        # Compare running and candidate configs
        dbg.pprint("Compare old and new group {}". format(group))
        if running_cfg[group] != cand_cfg[group]:
            dbg.pprint("changed group {}". format(group))

            group_cfg = []

            # First add the names
            names = cand_cfg[group].get('name')
            if names:
                group_cfg.extend(names)

            # Then add the types
            types = cand_cfg[group].get('type')
            if types:
                for t in sorted(types.keys()):
                    codes = cand_cfg[group]['type'][t].get("code")
                    if codes:
                        for c in sorted(codes):
                            group_cfg.append("{}:{}".format(t, c))
                    else:
                        group_cfg.append("{}".format(t))

            delete_icmp_group(group, name, path)
            create_icmp_group(group, name, path, group_cfg)

    # Look for added configuration
    for group in sorted(cand_cfg.keys()):
        dbg.pprint("Processing new group {}". format(group))

        if group not in running_cfg:
            dbg.pprint("New group {}". format(group))

            group_cfg = []

            # First add the names
            names = cand_cfg[group].get('name')
            if names:
                group_cfg.extend(names)

            # Then add the types
            types = cand_cfg[group].get('type')
            if types:
                for t in sorted(types.keys()):
                    codes = cand_cfg[group]['type'][t].get("code")
                    if codes:
                        for c in sorted(codes):
                            group_cfg.append("{}:{}".format(t, c))
                    else:
                        group_cfg.append("{}".format(t))

            create_icmp_group(group, name, path, group_cfg)


# ===== Port =====

def create_port_group(group, group_cfg):
    """
    Create the named port group

    @group:     the name of the group to be created
    @group_cfg: the items to be added to the group
    """

    dbg.pprint("Creating port group {}".format(group))
    store_cfg("{} {} port".format(BASE_PORT_PATH, group),
              "npf-cfg add port-group:{} 0 {}".format(group, group_cfg),
              "SET", dbg)


def delete_port_group(group):
    """
    Delete the named port group

    @group: the name of the port group to be deleted
    """

    dbg.pprint("Deleting port group {}".format(group))
    store_cfg("{} {}".format(BASE_PORT_PATH, group),
              "npf-cfg delete port-group:{}".format(group),
              "DELETE", dbg)


def getPort(p):
    """
    Try to convert the given port.
    It could be a range, or a number, or a name.

    @p: the port to be converted
    """

    try:
        return int(p)
    except ValueError:
        # Is it a port range, ie "<digits>-<digits>" ?
        if re.search(r"\d+-\d+", p):
            return p
        try:
            return socket.getservbyname(p)
        except OSError:
            raise NameError("{} is not a recognised port".format(p))


def getPortKey(p):
    """
    Get the key to use for ordering, after converting any port names
    into a number. Note that ranges are sorted using the first number in
    the range.

    @p: the port to get the key for
    """

    try:
        return int(p)
    except ValueError:
        # Is it a port range, ie "<digits>-<digits>" ?
        match = re.search(r"(\d+)-\d+", p)
        if match:
            return int(match.group(1))
        try:
            return socket.getservbyname(p)
        except OSError:
            raise NameError("{} is not a recognised port".format(p))


def program_port_groups():
    """
    Create and delete port groups.

    We process four kinds of configuration:
      * port <name>
      * port <value>
      * port <value>-<value>
      * <empty>

    Names are converted to values prior to being passed down.
    An empty group defaults to port 1-65535.
    """

    dbg.pprint("program_port_groups()")

    (cand_cfg, running_cfg) = get_configs(PORT_CMD)

    if cand_cfg == {} and running_cfg == {}:
        dbg.pprint("unchanged so no work to do")
        return

    # Look for deleted groups
    for group in sorted(running_cfg):
        dbg.pprint("Processing old group {}". format(group))

        if group not in cand_cfg:
            dbg.pprint("No new group {}". format(group))
            # Whole group removed
            delete_port_group(group)
            continue

        # Compare running and candidate configs
        dbg.pprint("Compare old and new group {}". format(group))
        if running_cfg[group].get('port') != cand_cfg[group].get('port'):
            dbg.pprint("changed group {}". format(group))
            cfg = cand_cfg[group].get('port')
            if cfg:
                group_cfg = ";".join("{}".format(getPort(c))
                                     for c in sorted(cfg, key=getPortKey))
            else:
                group_cfg = "1-65535"        # default
            delete_port_group(group)
            create_port_group(group, group_cfg)

    # Look for added configuration
    for group in sorted(cand_cfg):
        dbg.pprint("Processing new group {}". format(group))

        if group not in running_cfg:
            dbg.pprint("New group {}". format(group))
            cfg = cand_cfg[group].get('port')
            if cfg:
                group_cfg = ";".join("{}".format(getPort(c))
                                     for c in sorted(cfg, key=getPortKey))
            else:
                group_cfg = "1-65535"        # default
            create_port_group(group, group_cfg)


# ===== Protocol =====

def create_protocol_group(group, group_cfg):
    """
    Create the named protocol group

    @group:     the name of the group to be created
    @group_cfg: the items to be added to the group
    """

    dbg.pprint("Creating protocol group {}".format(group))
    store_cfg("{} {} protocol".format(BASE_PROTO_PATH, group),
              "npf-cfg add protocol-group:{} 0 {}".format(group, group_cfg),
              "SET", dbg)


def delete_protocol_group(group):
    """
    Delete the named protocol group

    @group: the name of the group to be deleted
    """

    dbg.pprint("Deleting protocol group {}".format(group))
    store_cfg("{} {}".format(BASE_PROTO_PATH, group),
              "npf-cfg delete protocol-group:{}".format(group),
              "DELETE", dbg)


def getProtocol(p):
    """
    Try to convert the given protocol.
    It could be a name or a number.

    @p: the protocol to be converted
    """

    try:
        return int(p)
    except ValueError:
        try:
            return socket.getprotobyname(p)
        except OSError:
            raise NameError("{} is not a recognised protocol".format(p))


def program_protocol_groups():
    """
    Create and delete protocol groups.

    We process two kinds of configuration:
      * protocol <name>
      * protocol <value>

    Names are converted to values prior to being passed down.
    """

    dbg.pprint("program_protocol_groups()")

    (cand_cfg, running_cfg) = get_configs(PROTO_CMD)

    if cand_cfg == {} and running_cfg == {}:
        dbg.pprint("unchanged so no work to do")
        return

    # Look for deleted groups
    for group in sorted(running_cfg):
        dbg.pprint("Processing old group {}". format(group))

        if group not in cand_cfg:
            dbg.pprint("No new group {}". format(group))
            # Whole group removed
            delete_protocol_group(group)
            continue

        # Compare running and candidate configs
        dbg.pprint("Compare old and new group {}". format(group))

        if running_cfg[group]['protocol'] != cand_cfg[group]['protocol']:
            dbg.pprint("changed group {}". format(group))
            group_cfg = ";".join("{}".format(getProtocol(c))
                                 for c in sorted(cand_cfg[group]['protocol']))
            delete_protocol_group(group)
            create_protocol_group(group, group_cfg)

    # Look for added configuration
    for group in sorted(cand_cfg):
        dbg.pprint("Processing new group {}". format(group))

        if group not in running_cfg:
            dbg.pprint("New group {}". format(group))
            group_cfg = ";".join("{}".format(getProtocol(c))
                                 for c in sorted(cand_cfg[group]['protocol']))
            create_protocol_group(group, group_cfg)


# ===== Main =====

def err(msg):
    print(msg, file=sys.stderr)


def process_options():
    global FORCE
    try:
        opts, rest = getopt.getopt(sys.argv[1:], "fd", ['force', 'debug'])
    except getopt.GetoptError as r:
        err(r)
        err("usage: {} [-f|--force] [-d|--debug]".format(sys.argv[0]))
        sys.exit(2)

    for opt, arg in opts:
        if opt in ('-f', '--force'):
            FORCE = True
        elif opt in ('-d', '--debug'):
            dbg.enable()


def program_resource_groups_main():
    global client, cand_cfg, running_cfg

    try:
        client = configd.Client()
    except Exception as exc:
        err("Cannot establish client session: '{}'".format(str(exc).strip()))
        return 1

    # Get the candidate and running config trees.
    try:
        cand_cfg = client.tree_get_dict(RG_BASE, CONFIG_CANDIDATE, 'internal')
        cand_cfg = cand_cfg["group"]
    except configd.Exception:
        dbg.pprint("failed getting candidate tree for {}".format(RG_BASE))
        cand_cfg = {}

    try:
        running_cfg = client.tree_get_dict(RG_BASE, CONFIG_RUNNING, 'internal')
        running_cfg = running_cfg["group"]
    except configd.Exception:
        dbg.pprint("failed getting running tree for {}".format(RG_BASE))
        running_cfg = {}

    program_address_groups()
    program_icmp_groups(ICMPv4)
    program_icmp_groups(ICMPv6)
    program_port_groups()
    program_protocol_groups()

    dataplane_commit(dbg)

    send_npf_snmp_traps([RG_BASE], dbg)
    return 0


if __name__ == "__main__":
    process_options()
    ret = program_resource_groups_main()
    exit(ret)
