#!/usr/bin/python3
#
# Copyright (c) 2019-2020 AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#
# This should perform any validation required for configuration under
# 'resource groups'. Note that currently validation is only required for
# address groups.

import sys
import getopt

from collections import defaultdict
from vyatta.npf.npf_debug import NpfDebug
from vyatta import configd
from netaddr import IPRange, IPNetwork, IPAddress, IPSet

FORCE = False

LAST_ERROR_HEADER = None

CONFIG_CANDIDATE = configd.Client.CANDIDATE
CONFIG_RUNNING = configd.Client.RUNNING

BASE_ADDRESS_PATH = "resources group address-group"

# class used for printing debugs
dbg = NpfDebug()


# Dictionaries keeping track of compared entries, so do not compare A and B
# and also B and A, and get two similar errors.
def nested_dict():
    return defaultdict(nested_dict)


ADDR_COMPARED = nested_dict()
RANGE_COMPARED = nested_dict()


def err(msg):
    print(msg, file=sys.stderr)


def process_options():
    global FORCE
    try:
        opts, rest = getopt.getopt(sys.argv[1:], "fd", ['force', 'debug'])
    except getopt.GetoptError as r:
        err(r)
        err("usage: {} [-f|--force] [-d|--debug]".format(sys.argv[0]))
        sys.exit(2)

    for opt, arg in opts:
        if opt in ('-f', '--force'):
            FORCE = True
        elif opt in ('-d', '--debug'):
            dbg.enable()


def address_group_err(group, err_msg):

    global LAST_ERROR_HEADER

    error_header = "[{} {}]\n".format(BASE_ADDRESS_PATH, group)
    if error_header != LAST_ERROR_HEADER:
        err(error_header)
        LAST_ERROR_HEADER = error_header

    err(err_msg + "\n")


def address_group_overlap_error(group, new_start_addr, new_end_addr,
                                start_addr, end_addr):

    if new_end_addr:
        err_msg = "range {}-{} ".format(new_start_addr, new_end_addr)
    else:
        err_msg = "{} ".format(new_start_addr)

    err_msg += "overlaps with "

    if end_addr:
        err_msg += "range {}-{}".format(start_addr, end_addr)
    else:
        err_msg += "{}".format(start_addr)

    address_group_err(group, err_msg)


def valid_address_start_and_end(start_addr, end_addr):
    start_ip = IPAddress(start_addr)
    end_ip = IPAddress(end_addr)

    if start_ip is None or end_ip is None or \
       start_ip.version != end_ip.version or \
       start_ip > end_ip:
        return False

    return True


def overlaps(new_start_addr, new_end_addr, start_addr, end_addr, group):
    dbg.pprint("checking new addr: {} to {} against {} to {}".format(
               new_start_addr, new_end_addr, start_addr, end_addr))

    if new_end_addr is None:
        new_range = IPNetwork(new_start_addr)
    else:
        new_range = IPRange(new_start_addr, new_end_addr)
    new_range_set = IPSet(new_range)
    dbg.pprint("Set of new addresses:")
    dbg.pprint(new_range_set)

    if end_addr is None:
        range = IPNetwork(start_addr)
    else:
        range = IPRange(start_addr, end_addr)
    range_set = IPSet(range)
    dbg.pprint("Set of existing addresses:")
    dbg.pprint(range_set)

    common_set = new_range_set & range_set
    dbg.pprint("Intersection:")
    dbg.pprint(common_set)
    if len(common_set) != 0:
        address_group_overlap_error(group, new_start_addr, new_end_addr,
                                    start_addr, end_addr)
        return 1

    return 0


def validate_against_ranges(new_start_addr, new_end_addr, addr_ranges, group):
    error_cnt = 0
    dbg.pprint("ranges: {}".format(addr_ranges))
    for start_addr in addr_ranges:
        try:
            end_addr = addr_ranges[start_addr]['to']
        except KeyError:
            continue

        # ignore own address and ones previously compared
        try:
            if RANGE_COMPARED[group][start_addr][end_addr]:
                dbg.pprint("Ignoring already compared: {}-{}".format(
                           start_addr, end_addr))
                continue
        except KeyError:
            pass

        # cannot compare against invalid addresses
        if not valid_address_start_and_end(start_addr, end_addr):
            continue

        error_cnt += overlaps(new_start_addr, new_end_addr, start_addr,
                              end_addr, group)
    return error_cnt


def validate_against_addresses(new_start_addr, new_end_addr, addresses, group):
    error_cnt = 0
    dbg.pprint("addresses: {}".format(addresses))
    for addr in addresses:

        # ignore own address and ones previously compared
        try:
            if ADDR_COMPARED[group][addr]:
                dbg.pprint("Ignoring already compared: {}".format(addr))
                continue
        except KeyError:
            pass

        error_cnt += overlaps(new_start_addr, new_end_addr, addr,
                              None, group)
    return error_cnt


def validate_address_groups():
    process_options()
    dbg.pprint("validate_address_groups()")
    error_cnt = 0

    try:
        client = configd.Client()
    except Exception as exc:
        err("Cannot establish client session: '{}'".format(str(exc).strip()))
        return 1

    try:
        status = client.node_get_status(CONFIG_CANDIDATE, BASE_ADDRESS_PATH)
    except configd.Exception:
        dbg.pprint("there is no configuration under {}".format(
                   BASE_ADDRESS_PATH))
        # address groups do not exist, so configuration is valid
        return 0

    if status == client.UNCHANGED and not FORCE:
        dbg.pprint("unchanged: {}".format(BASE_ADDRESS_PATH))
        return 0

    try:
        cand_cfg = client.tree_get_dict(BASE_ADDRESS_PATH, CONFIG_CANDIDATE,
                                        'internal')['address-group']
    except configd.Exception:
        # no configuration, so passes validation
        dbg.pprint("failed getting candidtate tree for {}".format(
                   BASE_ADDRESS_PATH))
        return 0

    try:
        running_cfg = client.tree_get_dict(BASE_ADDRESS_PATH, CONFIG_RUNNING,
                                           'internal')['address-group']
    except configd.Exception:
        # this is okay, as may be getting configured for the first time
        dbg.pprint("failed getting running tree for {}".format(
                   BASE_ADDRESS_PATH))

    for group in cand_cfg:
        dbg.pprint("Processing group {}". format(group))
        if 'address' in cand_cfg[group]:
            for addr in cand_cfg[group]['address']:
                dbg.pprint("  Processing address {}".format(addr))

                # Ignore if already in running config
                try:
                    if addr in running_cfg[group]['address']:
                        dbg.pprint("    Ignoring address {} as not new".format(
                                   addr))
                        continue
                except KeyError:
                    pass
                except UnboundLocalError:
                    pass

                ADDR_COMPARED[group][addr] = True

                try:
                    error_cnt += validate_against_ranges(
                        addr, None, cand_cfg[group]['address-range'], group)
                except KeyError:
                    pass

        if 'address-range' in cand_cfg[group]:
            for start_addr in cand_cfg[group]['address-range']:
                try:
                    end_addr = \
                      cand_cfg[group]['address-range'][start_addr]['to']
                except KeyError:
                    continue

                dbg.pprint("  Processing address range {} to {}".format(
                           start_addr, end_addr))

                # Ignore if already in running config
                try:
                    if end_addr in \
                      running_cfg[group]['address-range'][start_addr]['to']:
                        dbg.pprint("    Ignoring address range {} to {} as "
                                   "not new".format(start_addr, end_addr))
                        continue
                except KeyError:
                    pass
                except UnboundLocalError:
                    pass

                RANGE_COMPARED[group][start_addr][end_addr] = True

                if not valid_address_start_and_end(start_addr, end_addr):
                    address_group_err(group, "Invalid range {} to {}".
                                      format(start_addr, end_addr))
                    continue

                try:
                    error_cnt += validate_against_ranges(
                      start_addr, end_addr, cand_cfg[group]['address-range'],
                      group)
                except KeyError:
                    pass

                try:
                    error_cnt += validate_against_addresses(
                      start_addr, end_addr, cand_cfg[group]['address'], group)
                except KeyError:
                    pass

    dbg.pprint("all checks performed - error count {}".format(error_cnt))
    return error_cnt


def validate_resource_groups_main():
    ret = validate_address_groups()
    return ret


if __name__ == "__main__":
    ret = validate_resource_groups_main()
    exit(ret)
