#!/usr/bin/env python3
#
# Copyright (c) 2019-2020, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#
# Scripts for NAT64 configuration that begins either "service nat nat64", or
# the NAT64 configuratuin that begins "service nat nat46".
#

import json
import sys
import getopt
import vplaned


#
# npf_attach_point_list
#
# Query the dataplane and compile a list of attach-points containing the given
# ruleset-class.
#
# Each list element is a dictionary with keys: "attach_point", "attach_type",
# and "rulesets".
#
def npf_attach_point_list(rsclass):
    # Interface (attach-point) list
    intf_list = []

    # For each controller
    #
    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                status = dp.json_command("npf-op show all: %s" % (rsclass))
                if 'config' in status:
                    intf_list = intf_list + status['config']

    return intf_list


#
# npf_ruleset_class_dict
#
# Each element in intf_list is a dictionary with keys: "attach_point",
# "attach_type", and "rulesets".
#
# "rulesets" is itself a dictionary with keys "groups" and "ruleset_type".
#
# A given attach-point may appear more than once in the list.
#
# This function converts the interface list into a dictionary keyed by
# attach-point name, where each element of the dictionary is a list of
# rulesets of the given ruleset class.
#
# Thus we end up with a more flattened data structure than that returned from
# the dataplane.
#
def npf_ruleset_class_dict(intf_list, rsclass):
    intf_dict = {}

    for intf in intf_list:
        # print(json.dumps(intf, sort_keys=True, indent=3))
        intf_name = intf.get('attach_point')

        # Initialize dictionary element with an empty list
        if intf_name not in intf_dict:
            intf_dict[intf_name] = []

        if 'rulesets' not in intf:
            continue

        # intf['rulesets'] is a list of dictionaries where each dictionary
        # item contains two keys 'groups' and 'ruleset_type'.  We only expect
        # one element in this list, namely ruleset_type 'nat64' or 'nat46'
        #
        for rstype in intf.get('rulesets'):
            if rstype.get('ruleset_type') == rsclass:
                group_list = rstype.get('groups')
                intf_dict[intf_name] = intf_dict[intf_name] + group_list

    return intf_dict


#
# nat64_group_type_filter
#
# Filter a list of rulesets based on whether they are nat46 or nat64 rulesets
#
# groups - list of rulesets
# type   - "nat46" or "nat64".  Relates to the rproc type.
#
def nat64_group_filter(groups, type):
    new = []
    for group in groups:
        # Use an exception to continue outer loop from inside inner loop
        try:
            rules = group.get('rules')
            for rule in rules:
                config = rules[rule].get('config')
                loc = config.find("handle=%s" % (type))
                if loc != -1:
                    new.append(group)
                    raise Exception()
        except Exception:
            continue

    return new


#
# If this is a host address then drop the /32 or /128
#
def nat64_addr_str(addr):
    sep = addr.find("/")
    if sep == -1:
        return addr

    pl = addr[(sep+1):]
    host = addr[:sep]

    sep = addr.find(":")
    if (sep > 0 and pl == "128") or pl == "32":
        return host

    return addr


#
# Extract proto or protocol-group from "proto=6 dst-addr=2002::1 dst-port=80"
#
def nat64_parse_match_proto(match):
    list = match.split(" ")
    for m in list:
        n = m.split("=")
        if len(n) != 2:
            continue
        if (n[0] == "proto-final") or (n[0] == "proto"):
            return "Protocol", n[1]
        if n[0] == "protocol-group":
            return "Protocol-group", n[1]
    return "Protocol", "any"


#
# nat64_parse_match_str
#
# dst-addr, dst-addr-group, dst-port, dst-port-group
#   or
# src-addr, src-addr-group, src-port, src-port-group
#
def nat64_parse_match_str(match, sd):
    if not sd or (sd != "src" and sd != "dst"):
        return None

    list = match.split(" ")

    mstr = ""
    for m in list:
        n = m.split("=")
        if len(n) != 2:
            continue

        if n[0] == "%s-addr" % (sd):
            mstr = "%s" % (n[1])
        elif n[0] == "%s-addr-group" % (sd):
            mstr = "address-group: %s" % (n[1])

        if n[0] == "%s-port" % (sd):
            mstr += ", port %s" % (n[1])
        elif n[0] == "%s-port-group" % (sd):
            mstr += ", port-group: %s" % (n[1])

    if len(mstr) == 0:
        mstr = "any"

    return mstr


#
# nat64_parse_map_str
#
# Parse the rproc string, "handle=(..)", returned from the dataplane in a
# rule. This function looks for either source or dest strings which are of the
# format "item=value".  The first letter of the item string is either 's' or
# 'd', which indicates its a source or destination item.
#
# sd is either "s" or "d".
#
# Output is a string that is then used in the translation part of "show nat
# nat64 rules" output. e.g. the part just after '->'
#
# show nat nat64 rules
# Interface: dp0p1s1, Ruleset: NAT64_GRP2
#   Rule: 10, Protocol: any, Log: sessions
#     Source: any            -> overload, 10.10.1.10-10.10.1.12, Used 1/196605
#     Dest:   2001:2::2      -> one2one, 10.10.2.2/32
#
def nat64_parse_map_str(map, sd):
    if not sd or (sd != "s" and sd != "d"):
        return None

    # Remove up to, and including, "("
    loc = map.find("(")
    if loc != -1:
        map = map[loc+1:]

    # Remove ")" onwards
    loc = map.find(")")
    if loc != -1:
        map = map[:loc]

    mstr = ""
    log = None

    list = map.split(",")
    for m in list:
        n = m.split("=")
        if len(n) != 2:
            continue

        if n[0] == "%stype" % (sd):
            mstr += "%s " % (n[1])

        elif n[0] == "%saddr" % (sd):
            mstr += "%s " % (nat64_addr_str(n[1]))

        elif n[0] == "%sgroup" % (sd):
            mstr += "address-group: %s " % (n[1])

        elif n[0] == "%srange" % (sd):
            mstr += "%s " % (n[1])

        elif n[0] == "%sport" % (sd):
            mstr += "port: %s " % (n[1])

        elif n[0] == "%spl" % (sd):
            if len(mstr) > 0:
                mstr.rstrip()
                mstr += ", "
            mstr += "prefix-length: %s " % (n[1])

        elif n[0] == "log":
            log = n[1]

    mstr.rstrip()

    if len(mstr) == 0:
        mstr = "unknown"

    return mstr, log


#
# Show a rule
#
def nat64_show_rule(rlnum, rule, col1):

    # match and map info is in the 'config' string
    config = rule.get('config')

    # We want to separate the match part of the config from the rproc.
    # The rproc part is of the form "handle=(...)"
    loc = config.find("handle")
    if loc == -1:
        return col1

    # match string is everything before "handle="
    match = config[:loc]

    # Get everything after "handle="
    map = config[(loc+7):]

    ptype, proto = nat64_parse_match_proto(match)

    match_src = nat64_parse_match_str(match, "src")
    match_dst = nat64_parse_match_str(match, "dst")
    map_src, log = nat64_parse_map_str(map, "s")
    map_dst, log = nat64_parse_map_str(map, "d")

    #
    # If source mapping is type 'overload', then a nat policy will exists
    # and 'used_ts' and 'total_ts' should be present
    #
    if 'total_ts' in rule:
        map_src = map_src.rstrip()
        map_src += ", Total %u" % (rule.get('total_ts'))

    if 'protocols' in rule:
        used = {}
        for prot in rule['protocols']:
            used[prot.get('protocol')] = prot.get('used_ts')
        map_src += ", Used TCP %u, Used UDP %u, Used other %u" % \
                   (used['tcp'], used['udp'], used['other'])

    rl_head = "  Rule: %s, %s: %s" % (rlnum, ptype, proto)
    if log and log == "1":
        rl_head += ", Log: sessions"

    print("%s" % (rl_head))

    col1 = max(col1, len(match_dst))
    if match_src and map_src:
        print("    Source: %-*s -> %s" % (col1, match_src, map_src))
    if match_dst and map_dst:
        print("    Dest:   %-*s -> %s" % (col1, match_dst, map_dst))

    return col1


#
# Show a ruleset
#
def nat64_show_ruleset(intf_name, group, col1):

    print("Interface: %s, Ruleset: %s" % (intf_name, group.get('name')))

    # For each rule in ruleset ...
    #
    rules = group.get('rules')
    for rlnum in rules:
        col1 = nat64_show_rule(rlnum, rules[rlnum], col1)

    return col1


#
# nat64_show_intf
#
# name   - interface name
# groups - list of rulesets
#
def nat64_show_intf(name, groups, col1):
    debug = False

    # For each ruleset attached to interface ...
    #
    for group in groups:
        if debug:
            print(json.dumps(group, sort_keys=True, indent=4))
            continue
        col1 = nat64_show_ruleset(name, group, col1)
        print("")

    return col1


#
# usage
#
def usage():
    print("usage: {} -s|--show=<translations|rules> "
          "-t|--type=<nat64|nat46>".format(sys.argv[0]),
          file=sys.stderr)


#
# nat64_show_rules
#
# The 'match' and 'map' parts of the rules are split into a column each. col1
# starts out as 14, and expands as required.
#
# Interface: dp0p1s1, Ruleset: NAT64_GRP2
#   Rule: 10, Protocol: any, Log: sessions
#     Source: any            -> overload, 10.10.1.10-10.10.1.12, Used 1/196605
#     Dest:   2001:2::2      -> one2one, 10.10.2.2/32
#             \____________/
#               col1 = 14
#
def nat64_show_rules(t_opt, detail):
    # Ruleset-class is same as t_opt, .i.e. either "nat64" or "nat46"
    rsclass = t_opt

    intf_list = npf_attach_point_list(rsclass)
    intf_dict = npf_ruleset_class_dict(intf_list, rsclass)

    col1 = 14
    for intf_name in sorted(intf_dict):
        col1 = nat64_show_intf(intf_name, intf_dict[intf_name], col1)


#
# Create a table from /etc/protocols that is indexed with a protocol number.
#
class ProtocolTable:
    def __init__(self):
        self._protocol_names = {}
        with open('/etc/protocols', 'r') as f:
            for line in f:
                proto_data, _, _ = line.partition('#')
                if not proto_data:
                    continue
                fields = proto_data.split()
                if len(fields) > 2:
                    self._protocol_names[int(fields[1])] = fields[0]

    def __getitem__(self, protocol_number):
        name = self._protocol_names.get(protocol_number)
        if name is None:
            name = "Unassigned"
        return name


# Global proto_table
proto_table = []


#
# Protocol number to name
#
def npf_num2proto(pnum):
    # Look for the common ones first
    if pnum == 6:
        return "tcp"
    elif pnum == 17:
        return "udp"
    elif pnum == 1:
        return "icmp"
    elif pnum == 58:
        # Use the short form of icmp-ipv6 when appropriate
        return "icmpv6"

    # Get cached proto table, else create new one
    global proto_table
    if not bool(proto_table):
        proto_table = ProtocolTable()

    pname = proto_table[pnum]

    # If not found, return the number as a string
    if not pname:
        return str(pnum)

    return pname


#
# npf_sess_state_str
#
# Get generic state from session.  Returns one of "New", "Active", "Closing"
# or "Closed"
#
def npf_sess_state_str(se, detail):
    if not se:
        return "Closed"

    proto = se.get('proto')
    state = se.get('state')

    # state is an index from one of the following arrays:
    #
    # tcp_states   = ["NO", "SS", "SS", "SR", "ES", "FW", "FW", "CW",
    #                 "FW", "CG", "LA", "TW", "RR", "CL"]
    # other_states = ["NO", "new", "ES", "TM", "CL"]
    #
    tcp_state = ["None", "syn-sent", "simsyn-sent", "syn-received",
                 "established", "fin-sent", "fin-received", "close-wait",
                 "fin-wait", "closing", "last-ack", "time-wait",
                 "rst-received", "closed"]

    if proto == 6:
        # TCP
        if detail and state < len(tcp_state):
            return tcp_state[state]
        estd_state = 4
    elif proto is not None:
        # Other
        estd_state = 2
    else:
        return "Unkn"

    if state < estd_state:
        return "New"
    elif state == estd_state:
        return "Active"

    return "Closing"


#
# Get a given feature dict from a session dict
#
# The npf feature is "type" 3
#
def npf_sess_feat_dict(sess, feat_num):
    feat_list = sess.get('features')
    for feat in feat_list:
        if feat.get('type') == feat_num:
            return feat
    return None


#
# nat64_show_trans_detail_one
#
def nat64_show_trans_detail_one(is_ingress, se, feat):
    if is_ingress:
        dir_str = "In: "
    else:
        dir_str = "Out:"

    if not se:
        print("  %s" % (dir_str))
        print("    Session State: Closed")
        return

    src_str = "%s/%s" % (se.get('src_addr'), se.get('src_port'))
    dst_str = "%s/%s" % (se.get('dst_addr'), se.get('dst_port'))
    proto_str = npf_num2proto(se.get('proto'))
    state_str = npf_sess_state_str(se, True)

    # time_to_expire will go negative for a few secs before garbage
    # collector next runs
    time_to_expire = se.get('time_to_expire')
    if time_to_expire < 0:
        time_to_expire = 0

    print("  %-4s %s --> %s, %s" % (dir_str, src_str, dst_str, proto_str))
    print("    Interface: %s" % (se.get('interface')))
    print("    VRF ID: %s" % (se.get('vrf_id')))
    print("    Session State: %s" % (state_str))
    print("    Timeout: %u of %u" % (time_to_expire,
                                     se.get('state_expire_window')))

    if feat:
        flags = feat.get('flags')
        if flags:
            # Mask out direction flags
            flags = flags & ~0x3
            print("    Flags: 0x%x" % (flags))
        if is_ingress:
            fwd_stats = feat['nat64']['stats_in']
            rev_stats = feat['nat64']['stats_out']
        else:
            fwd_stats = feat['nat64']['stats_out']
            rev_stats = feat['nat64']['stats_in']
        print("    Forward Pkts: %d, Bytes: %d" % (fwd_stats['packets'],
                                                   fwd_stats['bytes']))
        print("    Reverse Pkts: %d, Bytes: %d" % (rev_stats['packets'],
                                                   rev_stats['bytes']))


#
# nat64_show_trans_detail
#
def nat64_show_trans_detail(in_key, in_se, out_key, out_se):
    debug = False
    in_feat = None
    out_feat = None
    rlset_str = None

    if in_se:
        in_feat = npf_sess_feat_dict(in_se, 3)
        if 'ruleset' in in_feat.get('nat64'):
            rlset_str = "%s %u" % (in_feat['nat64']['ruleset'],
                                   in_feat['nat64']['rule'])
    if out_se:
        out_feat = npf_sess_feat_dict(out_se, 3)
        if not rlset_str and 'ruleset' in out_feat.get('nat64'):
            rlset_str = "%s %u" % (out_feat['nat64']['ruleset'],
                                   out_feat['nat64']['rule'])
    if not rlset_str:
        rlset_str = "N/A"

    if debug:
        if in_se:
            print("In:")
            print(json.dumps(in_se, indent=4, sort_keys=True))
        if out_se:
            print("Out:")
            print(json.dumps(out_se, sort_keys=False, indent=4))

    # Headline block
    print("Session ID: In %s Out %s" % (in_key, out_key))
    print("  Rule: %s" % (rlset_str))

    nat64_show_trans_detail_one(True, in_se, in_feat)
    nat64_show_trans_detail_one(False, out_se, out_feat)

    return


#
# nat64_show_trans_brief_headline
#
def nat64_show_trans_brief_headline(in_key, in_se, out_key, out_se):
    if not in_key:
        in_key = "-"
    if not out_key:
        out_key = "-"

    ruleset_str = "-"
    in_state = npf_sess_state_str(in_se, False)
    out_state = npf_sess_state_str(out_se, False)
    in_to = 0
    out_to = 0
    npf_feat = None

    if in_se:
        in_to = in_se.get('time_to_expire')
        npf_feat = npf_sess_feat_dict(in_se, 3)

    if out_se:
        out_to = out_se.get('time_to_expire')
        if not npf_feat:
            npf_feat = npf_sess_feat_dict(out_se, 3)

    # Get ruleset info from either session
    if npf_feat and 'ruleset' in npf_feat.get('nat64'):
        ruleset_str = "%s %u" % (npf_feat['nat64']['ruleset'],
                                 npf_feat['nat64']['rule'])

    print("Session IDs: %s/%s, Rule: %s, State: %s/%s, Timeout: %u/%u" %
          (in_key, out_key, ruleset_str, in_state, out_state, in_to, out_to))


#
# nat64_show_trans_brief
#
# Whilst this is not a tabular output, we do make some attempt to line up the
# dest addresses using a 'src_col' width parameter.
#
def nat64_show_trans_brief(is_ingress, id, se, src_col):
    if not se:
        return src_col

    # print(json.dumps(se, sort_keys=False, indent=4))

    src_str = "%s/%s" % (se.get('src_addr'), se.get('src_port'))
    dst_str = "%s/%s" % (se.get('dst_addr'), se.get('dst_port'))
    proto_str = npf_num2proto(se.get('proto'))
    stats_str = "-"
    intf_str = se.get('interface')
    src_col = max(src_col, len(src_str))

    if is_ingress:
        dir_str = "In: "
    else:
        dir_str = "Out:"

    # Get session npf feature nat64 rproc dict
    #
    npf_feat = npf_sess_feat_dict(se, 3)
    if npf_feat and 'nat64' in npf_feat:
        if is_ingress:
            stats = npf_feat['nat64']['stats_in']
        else:
            stats = npf_feat['nat64']['stats_out']
        stats_str = "Pkts: %d, Bytes: %d" % (stats['packets'],
                                             stats['bytes'])

    # Allow at least 10 spaces for dest string
    dst_width = max(len(dst_str), 10)

    print("  %s  %-*s --> %-*s; %s, If: %s, %s" %
          (dir_str, src_col, src_str, dst_width, dst_str,
           proto_str, intf_str, stats_str))

    return src_col


#
# nat64_show_trans
#
def nat64_show_trans(in_key, in_se, out_key, out_se, src_col, detail):

    if detail:
        nat64_show_trans_detail(in_key, in_se, out_key, out_se)
        return src_col

    nat64_show_trans_brief_headline(in_key, in_se, out_key, out_se)
    src_col = nat64_show_trans_brief(True, in_key, in_se, src_col)
    src_col = nat64_show_trans_brief(False, out_key, out_se, src_col)

    return src_col


#
# nat64_show_trans
#
# keys      - List of sessions IDs in string format sorted numerically.
#             Each member item may be used as a key to sess_dict.
# sess_dict - Dictionary of all sessions. Keyed by session ID string,
#             e.g. "12"
#
# Whilst this is not a tabular output, we do make some attempt to line up the
# dest addresses using a 'src_col' width parameter.  This will expand, as
# required.
#
def nat64_show_trans_all(keys, sess_dict, detail):
    src_col = 14

    #
    # For each session ID / key in the keys list ...
    #
    for id in keys:
        se = sess_dict.get(id)

        # npf feature is 3.  A "nat64" object exists inside the npf feature
        # for either nat64 or nat46.
        #
        feat = npf_sess_feat_dict(se, 3)
        if not feat or 'nat64' not in feat:
            continue

        nat64 = feat.get('nat64')
        peer_id = str(nat64.get('peer_id'))
        se_is_ingress = nat64.get('in')
        peer = None

        if peer_id != "0":
            peer = sess_dict.get(peer_id)

        # Remove peer_id from keys list so we do not show either of these
        # sessions a second time
        if peer_id in keys:
            keys.remove(peer_id)

        # Determine which is ingress session and which is egress session
        if se_is_ingress:
            in_key = id
            in_se = se
            out_key = peer_id
            out_se = peer
        else:
            in_key = peer_id
            in_se = peer
            out_key = id
            out_se = se

        # show nat64 or nat46 pair
        src_col = nat64_show_trans(in_key, in_se, out_key, out_se,
                                   src_col, detail)


#
# nat64_show_translations
#
def nat64_show_translations(t_opt, detail):
    debug = False

    # For each controller
    #
    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                start = 0
                req_count = 500
                ret_count = 0

                while True:
                    # Fetch either the nat46 or nat64 sessions only
                    #
                    cmd = "session-op show sessions %s %u %u" % (t_opt, start,
                                                                 req_count)
                    npf_dict = dp.json_command(cmd)

                    if 'config' not in npf_dict:
                        return

                    sess_dict = npf_dict.get('config').get('sessions')
                    if len(sess_dict.keys()) == 0:
                        break

                    if debug:
                        print(json.dumps(sess_dict, indent=4, sort_keys=True))

                    # Get a list of session numbers (dictionary keys), and
                    # sort numerically
                    #
                    keys = list(sess_dict.keys())
                    keys.sort(key=int)

                    nat64_show_trans_all(keys, sess_dict, detail)

                    if ret_count != req_count:
                        break
                    start += req_count


#
# nat64_clear_translations
#
def nat64_clear_translations(t_opt):
    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                dp.string_command("session-op clear session %s" % (t_opt))


#
# nat64_op_main
#
def nat64_op_main():
    """
    NAT64 and NAT46
    """
    # "nat64" or "nat46"
    t_opt = None

    # "translations"
    c_opt = None

    # "translations" or "rules"
    s_opt = None

    d_opt = False

    #
    # Parse options
    #
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "c:ds:t:", ['clear=', 'detail',
                                               'show=', 'type='])
    except getopt.GetoptError as r:
        print(r, file=sys.stderr)
        usage()
        sys.exit(2)

    for opt, arg in opts:
        if opt in ('-d', '--detail'):
            d_opt = True

        if opt in ('-c', '--clear'):
            c_opt = arg

        if opt in ('-t', '--type'):
            t_opt = arg

        if opt in ('-s', '--show'):
            s_opt = arg

    # t_opt must be specified
    if t_opt is None or (t_opt != "nat64" and t_opt != "nat46"):
        usage()
        sys.exit(2)

    # show ...
    if s_opt:
        if s_opt == "rules":
            nat64_show_rules(t_opt, d_opt)
        elif s_opt == "translations":
            nat64_show_translations(t_opt, d_opt)

    # delete ...
    if c_opt and c_opt == "translations":
        nat64_clear_translations(t_opt)


#
# main
#
if __name__ == '__main__':
    nat64_op_main()
