#!/usr/bin/env python3
#
# Copyright (c) 2019-2020, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#
# Scripts for NAT64 configuration that begins either "service nat nat64", or
# the NAT64 configuration that begins "service nat nat46".
#

import sys
import vplaned
import socket
import getopt
from vyatta import configd


#
# nat64_store
#
def nat64_store(path, cmd, set_or_delete):
    debug = False

    if set_or_delete != "SET" and set_or_delete != "DELETE":
        return

    if debug:
        print("%6s: %s" % (set_or_delete, path))
        print("%6s  %s" % (" ", cmd))

    with vplaned.Controller() as ctrl:
        ctrl.store(path, cmd, action=set_or_delete)


#
# nat64_protocol_number
#
# Protocol name or number to number.  (This is similar to the C function
# getprotoent used by perl in FWHelper.pm)
#
def nat64_protocol_number(name):
    number = None

    # Ignore 'ip' and 'ipv6' (yang will disallow these anyway)
    if name == "ip" or name == "ipv6":
        return None

    # Some likely protocol values
    proto2num = {
        'icmp': '1',
        'tcp': '6',
        'udp': '17',
        'ipv6-icmp': '58'
    }

    if name.isdigit():
        # Already a number?
        return int(name)

    if name in proto2num:
        # Common protocol?
        return int(proto2num[name])
    else:
        # else uncommon protocol
        try:
            number = socket.getprotobyname(name)
        except:
            return None

    return int(number)


#
# nat64_group_status
#
# Determine which nat64 or nat46 groups have been changed, added or deleted
# Returns three lists, where each list element is a sub-dictionary
#
# 'node' is one of the two nodes under "service nat", i.e. "nat64" or "nat46"
#
def nat64_group_status(client, node):
    path = "service nat %s group" % (node)

    cfg_run = None
    cfg_cand = None
    changed = []
    added = []
    deleted = []

    if client.node_exists(client.RUNNING, path):
        cfg_run = client.tree_get_full_dict(path, client.RUNNING, "json")

    if client.node_exists(client.CANDIDATE, path):
        cfg_cand = client.tree_get_full_dict(path, client.CANDIDATE, "json")

    # Check running config for deleted groups
    if cfg_run:
        for group in cfg_run.get('group'):
            name = group.get('ruleset-name')
            grp_path = "%s %s" % (path, name)
            status = client.node_get_status(client.CANDIDATE, grp_path)
            if status == client.DELETED:
                deleted.append(group)

    # Check candidate config for added and changed groups
    if cfg_cand:
        for group in cfg_cand.get('group'):
            name = group.get('ruleset-name')
            grp_path = "%s %s" % (path, name)
            status = client.node_get_status(client.CANDIDATE, grp_path)
            if status == client.ADDED:
                added.append(group)
            elif status == client.CHANGED:
                changed.append(group)

    return (changed, added, deleted)


#
# nat64_group_rule_status
#
# Determine which rules in a group have been changed, added or deleted
# Returns three lists, where each list element is a sub-dictionary
#
def nat64_group_rule_status(client, node, group):
    cfg_run = None
    cfg_cand = None
    changed = []
    added = []
    deleted = []

    if 'ruleset-name' not in group:
        return(changed, added, deleted)

    path = "service nat %s group %s rule" % (node, group.get('ruleset-name'))

    if client.node_exists(client.RUNNING, path):
        cfg_run = client.tree_get_full_dict(path, client.RUNNING, "json")

    if client.node_exists(client.CANDIDATE, path):
        cfg_cand = client.tree_get_full_dict(path, client.CANDIDATE, "json")

    if cfg_run:
        for rule in cfg_run.get('rule'):
            rlnum = rule.get('rule-number')
            rl_path = "%s %s" % (path, rlnum)
            status = client.node_get_status(client.CANDIDATE, rl_path)
            if status == client.DELETED:
                deleted.append(rule)

    if cfg_cand:
        for rule in cfg_cand.get('rule'):
            rlnum = rule.get('rule-number')
            rl_path = "%s %s" % (path, rlnum)
            status = client.node_get_status(client.CANDIDATE, rl_path)
            if status == client.ADDED:
                added.append(rule)
            elif status == client.CHANGED:
                changed.append(rule)

    return (changed, added, deleted)


#
# Port number or service name to number string
#
# The 'port' config may be a port number (e.g. 22) or a named service
# (e.g. ssh, http etc).  We only send down the number to the dataplane.
#
def nat64_port_str(number_or_string):
    # Convert to string so that we can use isdigit method
    port = str(number_or_string)

    # Is port already a number?
    if port.isdigit():
        return port

    # Is it a named service?
    try:
        nstr = socket.getservbyname(port)
    except:
        # Silently ignore if named service not found.  This should
        # never happen as the yang and socket.getservbyname should
        # both use /etc/services
        return None

    return nstr


#
# nat64_group_rule_match_sd_str
#
# source or dest match string
#
# sd is "dst" or "src".  'sd_dict' is the src or dest dictionary
#
def nat64_group_rule_match_sd_str(sd_dict, sd):
    #
    # Address options are the same for IPv4 and IPv6
    #
    key = None
    if 'ip6-address' in sd_dict:
        key = 'ip6-address'
    elif 'ip-address' in sd_dict:
        key = 'ip-address'

    cmd = ""

    if key:
        if 'address-group' in sd_dict[key]:
            cmd = "%s-addr-group=%s " % (sd, sd_dict[key]['address-group'])
        elif 'host' in sd_dict[key]:
            cmd = "%s-addr=%s " % (sd, sd_dict[key]['host'])
        elif 'prefix' in sd_dict[key]:
            cmd = "%s-addr=%s " % (sd, sd_dict[key]['prefix'])
        else:
            # Should never happen
            return None

    # Port number or service name
    if 'port' in sd_dict:
        port = nat64_port_str(sd_dict.get('port'))
        if port:
            cmd += "%s-port=%s " % (sd, port)

    # Port range
    if 'port-range' in sd_dict:
        range = sd_dict.get('port-range')
        if '-' in range:
            rng = range.split('-')
            if len(rng) == 2 and rng[0].isdigit() and rng[1].isdigit():
                cmd += "%s-port=%s-%s " % (sd, rng[0], rng[1])

    # Port group
    if 'port-group' in sd_dict:
        cmd += "%s-port-group=%s " % (sd, sd_dict.get('port-group'))

    # Remove trailing space, if any
    cmd.rstrip()

    return cmd


#
# nat64_group_rule_match_str
#
# Examine the nat64 match config and put together a corresponding nat64 rule
# match string
#
# e.g. "dst-addr=2001:101:2::/96"
#
def nat64_group_rule_match_str(client, rule):
    if 'match' not in rule:
        return None

    # Get the match dictionary
    match = rule.get('match')

    cmd = ""

    # protocol
    if 'protocol' in match:
        # The cfgd value may be a number or a name.  Convert to number.
        number = nat64_protocol_number(match['protocol'])
        if number:
            cmd += "proto-final=%d " % (number)

    # protocol-group
    if 'protocol-group' in match:
        cmd += "protocol-group=%s " % (match['protocol-group'])

    # source
    if 'source' in match:
        src_str = nat64_group_rule_match_sd_str(match['source'], "src")
        if src_str and len(src_str) > 0:
            cmd += "%s " % (src_str)

    # destination
    if 'destination' in match:
        dst_str = nat64_group_rule_match_sd_str(match['destination'], "dst")
        if dst_str and len(dst_str) > 0:
            cmd += "%s " % (dst_str)

    # Remove trailing space, if any
    cmd.rstrip()

    return cmd


#
# nat64_trans_map_str
#
# This handles all options for both source and dest of nat64 and nat46.  We
# trust the yang to have ensured the config is valid.
#
# 'sd' is either "s" or "d" to denote source or dest
#
def nat64_trans_map_str(map, sd):
    cmd = None

    #
    # .. mapping rfc6052 prefix-length <pl>
    # .. mapping rfc6052 ip6-address prefix <pfx>
    #
    if 'rfc6052' in map:
        #
        # When rfc6052 is specified, the map options differ between IPv4 and
        # IPv6.  The former requires an IPv6 prefix to be entered whereas the
        # latter requires just a prefix length.
        #
        if 'prefix-length' in map['rfc6052']:
            pl = map['rfc6052']['prefix-length']
            cmd = "%stype=rfc6052,%spl=%s" % (sd, sd, pl)
        elif 'ip6-address' in map['rfc6052']:
            addr = map['rfc6052']['ip6-address']
            prefix = addr.get('prefix')
            if prefix:
                cmd = "%stype=rfc6052,%saddr=%s" % (sd, sd, prefix)
    #
    # .. mapping overload ip-address-pool address-group <name>
    # .. mapping overload ip-address-pool address-range <x.x.x.x> to <x.x.x.x>
    # .. mapping overload ip-address-pool prefix <x.x.x.x/x>
    #
    elif 'overload' in map:
        if 'ip-address-pool' in map['overload']:
            pool = map['overload']['ip-address-pool']
            if 'prefix' in pool:
                prefix = pool['prefix']
                cmd = "%stype=overload,%saddr=%s" % (sd, sd, prefix)
            elif 'address-group' in pool:
                name = pool['address-group']
                cmd = "%stype=overload,%sgroup=%s" % (sd, sd, name)
            elif 'address-range' in pool:
                # address-range is a list with one member
                for range in pool['address-range']:
                    start = range.get('start')
                    end = range.get('to')
                    cmd = "%stype=overload,%srange=%s-%s" % \
                        (sd, sd, start, end)
    #
    # .. mapping host-to-host ip6-address host <h:h:h:h:h:h:h:h>
    # .. mapping host-to-host ip-address host <x.x.x.x>
    # .. mapping host-to-host port (<1..65535>|<service name>)
    #
    elif 'host-to-host' in map:
        if 'ip-address' in map['host-to-host']:
            addr = map['host-to-host']['ip-address']
            if 'host' in addr:
                host = addr.get('host')
                cmd = "%stype=one2one,%saddr=%s/32" % (sd, sd, host)
        elif 'ip6-address' in map['host-to-host']:
            addr = map['host-to-host']['ip6-address']
            if 'host' in addr:
                host = addr.get('host')
                cmd = "%stype=one2one,%saddr=%s/128" % (sd, sd, host)

        # Convert port number or service name to number string
        if cmd and 'port' in map['host-to-host']:
            port = nat64_port_str(map['host-to-host']['port'])
            if port:
                cmd += ",%sport=%s" % (sd, port)

    return cmd


#
# nat64_group_rule_trans_str
#
# Examine the nat64 translation config and put together a corresponding nat64
# rproc rule
#
#
# 'translation': {'source': {'mapping': {'rfc6052': {'prefix-length': 96}}},
#             'destination': {'mapping': {'rfc6052': {'prefix-length': 96}}}}}
#
# e.g.  handle=nat64(stype=rfc6052,saddr=2001:101:1::/96,
#                   dtype=rfc6052,daddr=2001:101:2::/96)
#
# A nat46 rule will be "handle=nat46(...)"
#
def nat64_group_rule_trans_str(node, rule):
    if 'translation' not in rule:
        return None

    src = None
    dst = None
    trans = rule['translation']

    if 'destination' in trans:
        if 'mapping' in trans['destination']:
            dst = nat64_trans_map_str(trans['destination']['mapping'], "d")

    if 'source' in trans:
        if 'mapping' in trans['source']:
            src = nat64_trans_map_str(trans['source']['mapping'], "s")

    # Should never happen
    if not src or not dst:
        return None

    ret_str = "handle=%s(%s,%s" % (node, src, dst)

    # Optional log parameter
    if 'log' in rule:
        if 'sessions' in rule['log']:
            ret_str += ",log=1"

    ret_str += ")"

    return ret_str


#
# nat64_group_add
#
# Add a nat64 or nat46 ruleset group to the dataplane
#
def nat64_group_add(client, node, group):
    if 'rule' not in group or 'ruleset-name' not in group:
        return

    name = group.get('ruleset-name')
    path_pfx = "service nat %s group %s" % (node, name)
    cmd_pfx = "npf-cfg add %s:%s" % (node, name)

    # For each rule in the group ...
    for rule in group.get('rule'):
        rlnum = rule.get('rule-number')
        path = "%s rule %s" % (path_pfx, rlnum)

        match_str = nat64_group_rule_match_str(client, rule)
        trans_str = nat64_group_rule_trans_str(node, rule)

        if match_str and trans_str:
            cmd = "%s %s %s %s" % (cmd_pfx, rlnum, match_str, trans_str)
            nat64_store(path, cmd, "SET")


#
# nat64_group_change
#
# A nat64 or nat46 ruleset group has changed.  Examine the rules in the group
# and removed deleted rules from the dataplane, or add new or changed rules.
#
def nat64_group_change(client, node, group):
    if 'ruleset-name' not in group:
        return

    name = group.get('ruleset-name')
    path_pfx = "service nat %s group %s" % (node, name)

    changed, added, deleted = nat64_group_rule_status(client, node, group)

    #
    # For each deleted rule in the group ...
    #
    cmd_pfx = "npf-cfg delete %s:%s" % (node, name)
    for rule in deleted:
        rlnum = rule.get('rule-number')
        path = "%s rule %s" % (path_pfx, rlnum)
        cmd = "%s %s" % (cmd_pfx, rlnum)
        nat64_store(path, cmd, "DELETE")

    #
    # For each added and changed rule in the group ...
    #
    cmd_pfx = "npf-cfg add %s:%s" % (node, name)
    for rule in added + changed:
        rlnum = rule.get('rule-number')
        path = "%s rule %s" % (path_pfx, rlnum)

        match_str = nat64_group_rule_match_str(client, rule)
        trans_str = nat64_group_rule_trans_str(node, rule)

        if match_str and trans_str:
            cmd = "%s %s %s %s" % (cmd_pfx, rlnum, match_str, trans_str)
            nat64_store(path, cmd, "SET")


#
# nat64_group_delete
#
# Remove nat64 or nat46 ruleset group from the dataplane
#
def nat64_group_delete(client, node, group):
    if 'ruleset-name' not in group:
        return

    name = group.get('ruleset-name')
    path_pfx = "service nat %s group %s" % (node, name)
    cmd_pfx = "npf-cfg delete %s:%s" % (node, name)

    nat64_store(path_pfx, cmd_pfx, "DELETE")


#
# nat64_intf_status
#
# Determine which nat64 or nat46 interface configs have been changed, added
# or deleted
#
# node is one of the two nodes under "service nat", i.e. "nat64" or "nat46"
#
# Returns three lists - changed, added, deleted
#
# Each list entry is of the form:
#   {'in': [{'name': 'NAT64_GRP1'}, {'name': 'NAT64_GRP2'}], 'name': 'dp0p1s1'}
#
# An interface may have multiple groups:
#
# service nat nat64 interface dp0p1s1 in NAT64_GRP1
# service nat nat64 interface dp0p1s1 in NAT64_GRP2
#
def nat64_intf_status(client, node):
    cfg_run = None
    cfg_cand = None
    changed = []
    added = []
    deleted = []

    path = "service nat %s interface" % (node)

    if client.node_exists(client.RUNNING, path):
        cfg_run = client.tree_get_full_dict(path, client.RUNNING, "json")

    if client.node_exists(client.CANDIDATE, path):
        cfg_cand = client.tree_get_full_dict(path, client.CANDIDATE, "json")

    if cfg_run:
        for intf in cfg_run.get('interface'):
            name = intf.get('name')
            if name is None:
                continue
            intf_path = "%s %s" % (path, name)
            status = client.node_get_status(client.CANDIDATE, intf_path)
            if status == client.DELETED:
                deleted.append(intf)

    if cfg_cand:
        for intf in cfg_cand.get('interface'):
            name = intf.get('name')
            if name is None:
                continue
            intf_path = "%s %s" % (path, name)
            status = client.node_get_status(client.CANDIDATE, intf_path)
            if status == client.ADDED:
                added.append(intf)
            elif status == client.CHANGED:
                changed.append(intf)

    return (changed, added, deleted)


#
# nat64_intf_delete
#
def nat64_intf_delete(client, node, intf):
    if 'in' not in intf:
        return

    iname = intf.get('name')
    path_pfx = "service nat %s interface %s in" % (node, iname)
    cmd_pfx = "npf-cfg detach interface:%s %s" % (iname, node)

    for group in intf.get('in'):
        name = group.get('name')
        if name:
            path = "%s %s" % (path_pfx, name)
            cmd = "%s %s:%s" % (cmd_pfx, node, name)
            nat64_store(path, cmd, "DELETE")


#
# nat64_intf_add
# npf-ut attach interface:dpT21 fw-out fw:FW2_OUT
#
def nat64_intf_add(client, node, intf):
    if 'in' not in intf:
        return

    iname = intf.get('name')
    if iname is None:
        return

    path_pfx = "service nat %s interface %s in" % (node, iname)
    cmd_pfx = "npf-cfg attach interface:%s %s" % (iname, node)

    for group in intf.get('in'):
        name = group.get('name')
        if name:
            path = "%s %s" % (path_pfx, name)
            cmd = "%s %s:%s" % (cmd_pfx, node, name)
            nat64_store(path, cmd, "SET")


#
# nat64_intf_group_status
#
# Determine which groups assigned to an interface been changed, added or
# deleted. Returns three lists, where each list element is a sub-dictionary.
#
def nat64_intf_group_status(client, node, intf):
    cfg_run = None
    cfg_cand = None
    changed = []
    added = []
    deleted = []

    path = "service nat %s interface %s" % (node, intf.get('name'))

    if client.node_exists(client.RUNNING, path):
        cfg_run = client.tree_get_full_dict(path, client.RUNNING, "json")

    if client.node_exists(client.CANDIDATE, path):
        cfg_cand = client.tree_get_full_dict(path, client.CANDIDATE, "json")

    if cfg_run:
        for group in cfg_run.get('in'):
            name = group.get('name')
            grp_path = "%s in %s" % (path, name)
            status = client.node_get_status(client.CANDIDATE, grp_path)
            if status == client.DELETED:
                deleted.append(name)

    if cfg_cand:
        for group in cfg_cand.get('in'):
            name = group.get('name')
            grp_path = "%s in %s" % (path, name)
            status = client.node_get_status(client.CANDIDATE, grp_path)
            if status == client.ADDED:
                added.append(name)
            elif status == client.CHANGED:
                changed.append(name)

    return (changed, added, deleted)


#
# nat64_intf_change
#
def nat64_intf_change(client, node, intf):
    changed, added, deleted = nat64_intf_group_status(client, node, intf)

    iname = intf.get('name')
    path_pfx = "service nat %s interface %s in" % (node, iname)
    attach_pfx = "npf-cfg attach interface:%s %s" % (iname, node)
    detach_pfx = "npf-cfg detach interface:%s %s" % (iname, node)

    for name in deleted:
        path = "%s %s" % (path_pfx, name)
        cmd = "%s %s:%s" % (detach_pfx, node, name)
        nat64_store(path, cmd, "DELETE")

    for name in added:
        path = "%s %s" % (path_pfx, name)
        cmd = "%s %s:%s" % (attach_pfx, node, name)
        nat64_store(path, cmd, "SET")


#
# Main script for nat64 *and* nat46.  The type of nat is determined by the
# 'node' parameter is either "nat64" or "nat46".
#
def nat646_common(client, node):
    if node != "nat64" and node != "nat46":
        return

    #
    # Which nat64 groups have been changed, added or deleted?
    #
    grp_ch, grp_add, grp_del = nat64_group_status(client, node)

    #
    # Which nat64 interfaces have been changed, added or deleted?
    #
    intf_ch, intf_add, intf_del = nat64_intf_status(client, node)

    # Deleted interfaces
    for intf in intf_del:
        nat64_intf_delete(client, node, intf)

    # Deleted groups
    for group in grp_del:
        nat64_group_delete(client, node, group)

    # Changed and added groups
    for group in grp_ch:
        nat64_group_change(client, node, group)

    for group in grp_add:
        nat64_group_add(client, node, group)

    # Changed and added interfaces
    for intf in intf_ch:
        nat64_intf_change(client, node, intf)

    for intf in intf_add:
        nat64_intf_add(client, node, intf)


#
# nat64_cfg_main
#
def nat64_cfg_main():
    """
    NAT64 and NAT46
    """

    # "nat64" or "nat46"
    t_opt = None

    #
    # Parse options
    #
    try:
        opts, args = getopt.getopt(sys.argv[1:], "t:", ['type='])
    except getopt.GetoptError as r:
        print(r, file=sys.stderr)
        sys.exit(2)

    for opt, arg in opts:
        if opt in ('-t', '--type'):
            t_opt = arg

    # t_opt must be specified
    if t_opt is None or (t_opt != "nat64" and t_opt != "nat46"):
        sys.exit(2)

    try:
        client = configd.Client()
    except configd.FatalException as fatal_exec:
        print("can't connect to configd: {}".format(fatal_exec),
              file=sys.stderr)
        sys.exit(1)

    nat646_common(client, t_opt)


if __name__ == '__main__':
    nat64_cfg_main()
