#!/usr/bin/python3
#
# Copyright (c) 2020 AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#
# This script verifies firewall ruleset attached to interface firewall
# originate point and if protocol-group has protocol equal to
# ipv6-frag referenced by at least one firewall rule it generates an error.
# The script validates all interfaces and the referenced
# firewalls at once and generates error for each firewall attached
# to interface originate.

import sys

from vyatta.npf.npf_debug import NpfDebug
from vyatta import configd

FORCE = False

CONFIG_CANDIDATE = configd.Client.CANDIDATE
CONFIG_RUNNING = configd.Client.RUNNING

RESOURCES_GROUP_PATH = "resources group protocol-group"
SECURITY_FIREWALL_PATH = "security firewall name"
INTERFACES_PATH = "interfaces"

# class used for printing debugs
dbg = NpfDebug()


def err(msg):
    print(msg, file=sys.stderr)


def validate_res_group(protocol_group_name, intf_err_msg, fw_name):
    dbg.pprint("validate_res_group: protocol-group = {}".
               format(protocol_group_name))

    rg_proto_path = (RESOURCES_GROUP_PATH
                     + " " + protocol_group_name + " protocol")
    try:

        rg_proto_cfg = (client.tree_get_dict(rg_proto_path,
                        CONFIG_CANDIDATE, 'internal')
                        ["protocol"])
        dbg.pprint("protocol-group {}".format(rg_proto_cfg))
        if "ipv6-frag" in rg_proto_cfg:
            err(intf_err_msg)
            err((fw_name + " ruleset has rules with protocol-group "
                 + protocol_group_name + " that has " +
                 "protocol equal to ipv6-frag which can't be " +
                 "configured for originate firewall"))
            return 1
    except configd.Exception:
        err("failed getting candidate tree for '{}'".format(rg_proto_path))
        return 1
    return 0


def validate_fw(fw_name, intf_err_msg):
    dbg.pprint("validate_fw: fw_name {}".format(fw_name))
    ret = 0

    fw_path = SECURITY_FIREWALL_PATH + " " + fw_name
    try:

        fw_cfg = (client.tree_get_dict(fw_path,
                  CONFIG_CANDIDATE, 'internal')
                  [fw_name])
        if "rule" in fw_cfg:
            for rule_id in fw_cfg["rule"]:
                rule = fw_cfg["rule"][rule_id]
                dbg.pprint("validate_fw: rule {}".format(rule))
                if "protocol-group" in rule:
                    if validate_res_group(rule["protocol-group"],
                                          intf_err_msg, fw_name) != ret:
                        ret = 1
    except configd.Exception:
        err("failed getting candidate tree for '{}'".format(fw_path))
        ret = 1
    return ret


def validate_fw_interface(intf, intf_err_msg):
    ret = 0
    if 'firewall' in intf:
        dbg.pprint("validate_fw_interfaces: {}".
                   format(intf))
        if 'originate' in intf['firewall']:
            for fw_name in intf['firewall']['originate']:
                if validate_fw(fw_name, intf_err_msg) != ret:
                    ret = 1
    return ret


def validate_interfaces():
    dbg.pprint("validate_interfaces:")
    ret = 0
    try:
        intf_cfg = (client.tree_get_dict(INTERFACES_PATH,
                    CONFIG_CANDIDATE, 'internal')
                    ['interfaces'])

        for iftype in intf_cfg:
            dbg.pprint("validate_interfaces: iftype {}".format(iftype))
            if iftype in ['loopback', 'dataplane']:
                for ifname in intf_cfg[iftype]:
                    intf_err_msg = "[interface " + iftype + " " + ifname + "]"
                    if validate_fw_interface(intf_cfg[iftype][ifname],
                                             intf_err_msg) != ret:
                        ret = 1

    except configd.Exception:
        err("failed getting candidate tree for '{}'".
            format(INTERFACES_PATH))
        ret = 1

    return ret


def check_if_changed(path):
    try:
        status = client.node_get_status(CONFIG_CANDIDATE, path)

        if status == client.UNCHANGED:
            dbg.pprint("unchanged: {} so no work to do".
                       format(path))
            return 0
        else:
            return 1

    except configd.Exception:
        dbg.pprint("there is no configuration under '{}'".format(path))
        return 0


def is_firewall_changed():
    return check_if_changed(SECURITY_FIREWALL_PATH)


def is_resource_group_changed():
    return check_if_changed(RESOURCES_GROUP_PATH)


def is_interface_changed():
    return check_if_changed(INTERFACES_PATH)


def validate_orig_fw():
    global client
    try:
        client = configd.Client()
    except Exception as exc:
        err("Cannot establish client session: '{}'".
            format(str(exc).strip()))
        return 1

    if (is_firewall_changed() or
            is_resource_group_changed() or
            is_interface_changed()):
        return validate_interfaces()

    return 0


if __name__ == "__main__":
    dbg.pprint("validate_fw_protocol_groups")
    ret = validate_orig_fw()
    exit(ret)
