#!/usr/bin/env python3
#
# Copyright (c) 2020, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#

"""Handles top level statistics for npf related features (firewall, nat, and
packet filter).  These statistics comprise of per-interface and per-direction
counts for unmatched, pass, block and drop.  Provides functions for show,
clear and rpc commands.
"""

import sys
import getopt
import vplaned
import json
from vyatta.npf.npf_debug import NpfDebug

# class used for printing debugs
dbg = NpfDebug()

#
# Return code types
#
rct_key = ['ip', 'ip6', 'local', 'l2', 'ip-packet-filter',
           'ip6-packet-filter', 'nat64']

#
# Return code types under 'firewall'.
#
# These relate to the options under 'show/clear dataplane statistics firewall'
# prefix
#
rct_key_fw = ['ip', 'ip6', 'local', 'l2']

dir_key = ['in', 'out']

# Count 'categories'
cat_key = ['pass', 'unmatched', 'block', 'drop']

rct_long = {'ip':                'IPv4 firewall and NAT',
            'ip6':               'IPv6 firewall',
            'local':             'Local and originate firewall',
            'l2':                'Layer 2 firewall',
            'ip-packet-filter':  'IPv4 Packet Filter',
            'ip6-packet-filter': 'IPv6 Packet Filter',
            'nat64':             'NAT64'}

rct_short = {'ip':                'IPv4 firewall',
             'ip6':               'IPv6 firewall',
             'local':             'Local firewall',
             'l2':                'L2 firewall',
             'ip-packet-filter':  'IPv4 pkt filter',
             'ip6-packet-filter': 'IPv6 pkt filter',
             'nat64':             'NAT64'}

rc_desc = {'RC_UNMATCHED':      'unmatched',
           'RC_PASS':           'pass',
           'RC_BLOCK':          'block',
           'RC_L3_SHORT':       'invalid layer 3 header',
           'RC_L4_SHORT':       'invalid layer 4 header',
           'RC_L3_PROTO':       'protocol mismatch',
           'RC_L4_PROTO':       'invalid layer 4 protocol',
           'RC_L3_HDR_VER':     'invalid IP header version field',
           'RC_L3_HDR_LEN':     'invalid IP header length field',
           'RC_NON_IP':         'non-IP packet',
           'RC_ICMP_ECHO':      'unsolicited ICMP echo reply',
           'RC_ENOSTR':         'unknown TCP reset',
           'RC_TCP_SYN':        'missing TCP SYN',
           'RC_TCP_STATE':      'invalid TCP flags',
           'RC_TCP_WIN':        'TCP window error',
           'RC_SESS_ENOMEM':    'no memory to create session',
           'RC_SESS_LIMIT':     'session limiter',
           'RC_SESS_HOOK':      'session hook',
           'RC_DP_SESS_ESTB':   'failed to create dataplane session',
           'RC_MBUF_ENOMEM':    'failed to allocate packet memory',
           'RC_MBUF_ERR':       'failed to prepend or adjust packet buffer',
           'RC_NAT_ENOSPC':     'failed to get NAT port mapping',
           'RC_NAT_ENOMEM':     'no memory to create NAT',
           'RC_NAT_EADDRINUSE': 'fragmented NAT port map',
           'RC_NAT_ERANGE':     'NAT port range too small',
           'RC_NAT_E2BIG':      'unable to fragment packet',
           'RC_ICMP_ERR_NAT':   'failed to translate ICMP error embedded pkt',
           'RC_ALG_EEXIST':     'ALG race condition',
           'RC_ALG_ERR':        'ALG error',
           'RC_NAT64_4T6':      'IPv4 to IPv6',
           'RC_NAT64_6T4':      'IPv6 to IPv4',
           'RC_NAT64_ENOSPC':   'failed to get NAT64 port mapping',
           'RC_NAT64_ENOMEM':   'failed to allocate NAT64 memory',
           'RC_NAT64_6052':     'failed to extract or encode rfc6052 NAT64 addr',
           'RC_INTL':           'internal error'}

#
# We hide several detailed show counts if they are zero.
# RC_ENOSTR is a 'pass' count that is not very interesting.
# RC_INTL is a 'drop' count that should never happen.
#
hidden_if_zero = ['RC_ENOSTR', 'RC_INTL']


#
# npf_dpc_option_string:
#
def npf_dpc_option_string(rct_name, ctx):
    """ Format the various options for the show and clear commands """

    cmd = ""

    if rct_name:
        cmd += " type %s" % (rct_name)

    if ctx['intf']:
        cmd += " interface %s" % (ctx['intf'])

    # 'in' or 'out'
    if ctx['dir']:
        cmd += " dir %s" % (ctx['dir'])

    # 'pass', 'block', 'unmatched', or 'drop'
    if ctx['cat']:
        cmd += " cat %s" % (ctx['cat'])

    if ctx['detail']:
        cmd += " detail true"

    if ctx['brief']:
        cmd += " brief true"

    if ctx['rpc']:
        cmd += " rpc true"

    if ctx['nonzero_only']:
        cmd += " nonzero true"

    return cmd


#
# npf_get_dp_counts
#
def npf_get_dp_counts(rct_name, ctx):
    """
    Get npf return code counts from the dataplane.
    """

    #
    # Format returned from the dataplane is shown below.  This shows one
    # interface, 'dp0p1s1', with one return code type, 'ip'.
    #
    # The 'in' and 'out' dicts will contain 'pass', 'unmatched', block, and
    # 'drop', with detailed counters under each of those if requested (and
    # available).
    #
    # "npf-rc-counts": {
    #     "interfaces": [
    #         "name": "dp0p1s1",
    #         "ip": {
    #             "in": {
    #                 "pass": {
    #                     "count": 6,
    #                     "detail": {
    #                         "RC_PASS": 6,
    #                         "RC_ENOSTR": 0
    #                     }
    #                 }
    #             }
    #             "out": {
    #                 "drop": {
    #                     "count": 0,
    #                     "detail": {
    #                         "RC_L4_SHORT": 0,
    #                         "RC_L3_PROTO": 0,
    #                         "RC_TCP_WIN": 0
    #                         ...
    #                     }
    #                 }
    #             }
    #         }
    #     ]
    # }
    #

    # Prepare dataplane command
    cmd = "npf-op rc show counters"
    cmd += npf_dpc_option_string(rct_name, ctx)

    intf_list = []

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                outer = dp.json_command(cmd)

                if 'npf-rc-counts' not in outer:
                    continue
                if 'interfaces' not in outer['npf-rc-counts']:
                    continue

                intf_list.extend(outer['npf-rc-counts']['interfaces'])

    #
    # Rearrange so we have interface list inside an rc-type dictionary
    #
    # {
    #     "ip": {
    #         "interfaces": [
    #              "name": "dp0p1s1",
    #              "in": {
    #                  "pass": {
    #                      "count": 6,
    #                  },
    #                  "block": {
    #                      "count": 0,
    #                  }
    #                  ...
    #              }
    #              "out": {
    #                  "pass": {
    #                      "count": 0,
    #                  },
    #                  "block": {
    #                      "count": 0,
    #                  }
    #                  ...
    #              }
    #         ]
    #     }
    # }
    #
    #

    rct_dict = dict()

    for i in range(len(intf_list)):
        #
        # Each interface entry contains interface name plus one or more
        # rc-type dictionaries
        #
        for rct_name in intf_list[i]:
            if rct_name not in rct_key:
                continue
            if rct_name not in rct_dict:
                # Create dict
                rct_dict[rct_name] = dict()
                # Create interface list
                rct_dict[rct_name]['interfaces'] = []

            new = dict()
            new['name'] = intf_list[i]['name']

            if 'in' in intf_list[i][rct_name]:
                new['in'] = intf_list[i][rct_name]['in']

            if 'out' in intf_list[i][rct_name]:
                new['out'] = intf_list[i][rct_name]['out']

            rct_dict[rct_name]['interfaces'].append(new)

    #
    # Sort the interfaces list in each rct_dict[rct_name]
    #
    for rct_name in rct_dict:
        new = sorted(rct_dict[rct_name]['interfaces'], key=lambda i: i['name'])
        rct_dict[rct_name]['interfaces'] = new

    # Use this to check the resulting json?
    dbg.pprint(json.dumps(rct_dict, indent=4, sort_keys=False))

    return rct_dict


#
# npf_rc_count
#
# counts   - The 'in' or 'out' dict for a particular type and intf
# cat_name - 'pass', 'unmatched', 'block', or 'drop'
# rc_name  - Counter in the 'detail' dict, e.g. "RC_TCP_WIN"
#
def npf_rc_count(counts, cat_name, rc_name):
    """
    Get counter for one or all categories or just one rc
    """

    tot = 0

    # cat_name should always be None or one off cat_key
    if cat_name and cat_name not in cat_key:
        print("unknown rc category \"%s\"" % (cat_name), file=sys.stderr)
        exit(2)

    for cat in counts:
        if cat_name and cat != cat_name:
            continue

        if rc_name:
            if 'detail' not in counts[cat]:
                return 0
            if rc_name not in counts[cat]['detail']:
                return 0
            return counts[cat]['detail'][rc_name]
        else:
            tot += counts[cat]['count']

    return tot


#
# Handle nat64 'pass' counts differently.  We want to show the v6-to-v4 and
# v4-to-v6 counts at the same level as pass, block etc.
#
def npf_show_long_nat64(counts, ctx, fmt2):
    cat = 'pass'

    dbg.pprint(json.dumps(counts, indent=4, sort_keys=False))

    #
    # The dataplane should also make an exception for nat64 pass, and return
    # a 'detail' dict, but check anyway
    #
    if ('detail' not in counts['in'][cat] or 'detail' not in counts['out'][cat]):
        print(fmt2 % (ctx['col1']-4, cat,
                      counts['in'][cat]['count'],
                      counts['out'][cat]['count']))
        return

    count_in = npf_rc_count(counts['in'], cat, 'RC_NAT64_6T4')
    count_out = npf_rc_count(counts['out'], cat, 'RC_NAT64_6T4')
    print(fmt2 % (ctx['col1']-4, 'ipv6-to-ipv4', count_in, count_out))

    count_in = npf_rc_count(counts['in'], cat, 'RC_NAT64_4T6')
    count_out = npf_rc_count(counts['out'], cat, 'RC_NAT64_4T6')
    print(fmt2 % (ctx['col1']-4, 'ipv4-to-ipv6', count_in, count_out))

    count_in = npf_rc_count(counts['in'], cat, 'RC_PASS')
    count_in += npf_rc_count(counts['in'], cat, 'RC_ENOSTR')
    count_out = npf_rc_count(counts['out'], cat, 'RC_PASS')
    count_out += npf_rc_count(counts['out'], cat, 'RC_ENOSTR')
    print(fmt2 % (ctx['col1']-4, 'pass', count_in, count_out))


#
# npf_show_long_intf
#
# counts - dict containing 'in' and 'out' dicts
#
def npf_show_long_intf(rct_name, intf_name, counts, ctx):
    """
    Show npf dataplane stats for one return code type, e.g. 'ip', on one
    interface.
    """

    if 'in' not in counts or 'out' not in counts:
        print("error in npf_show_long_intf", file=sys.stderr)
        exit(2)

    fmt1 = "  %-*s"
    fmt2 = "    %-*s %14u %14u"
    fmt3 = "      %-*s %14u %14u"

    print(fmt1 % (ctx['col1']-2, intf_name))

    for cat in cat_key:
        if cat not in counts['in'] or cat not in counts['out']:
            continue

        # We handle nat64 pass differently
        if rct_name == 'nat64' and cat == 'pass':
            npf_show_long_nat64(counts, ctx, fmt2)
            continue

        print(fmt2 % (ctx['col1']-4, cat,
                      counts['in'][cat]['count'],
                      counts['out'][cat]['count']))

        # Show detailed return code counts?
        if not ctx['detail'] or 'detail' not in counts['in'][cat]:
            continue

        # Do not display detailed count if there is one or less
        ndetails = 0
        for rc in counts['in'][cat]['detail']:
            if rc in hidden_if_zero:
                if (counts['in'][cat]['detail'][rc] == 0 and counts['out'][cat]['detail'][rc] == 0):
                    continue
            ndetails += 1

        if ndetails <= 1:
            continue

        #
        # Show detailed return code counts.  The counters in 'in' and 'out'
        # detail are always identical
        #
        for rc in counts['in'][cat]['detail']:
            if rc in hidden_if_zero:
                if (counts['in'][cat]['detail'][rc] == 0 and counts['out'][cat]['detail'][rc] == 0):
                    continue

            if rc in rc_desc:
                desc = rc_desc[rc]
            else:
                desc = rc

            print(fmt3 % (ctx['col1']-6, desc,
                          counts['in'][cat]['detail'][rc],
                          counts['out'][cat]['detail'][rc]))


#
# npf_show_long
#
# Show normal and detailed output.
#
def npf_show_long(rct_dict, ctx, first):
    """
    Show npf dataplane stats long form
    """

    for rct_name in rct_dict:
        #
        # This function may be called once for each rc type.  Only display the
        # banner the first time it is called per show command.
        #
        if first:
            print("%-*s %14s %14s" %
                  (ctx['col1'], rct_long[rct_name], "In", "Out"))
            first = False
        else:
            print("%-*s" % (ctx['col1'], rct_long[rct_name]))

        intf_list = rct_dict[rct_name]['interfaces']

        for i in range(len(intf_list)):
            intf_name = intf_list[i]['name']
            npf_show_long_intf(rct_name, intf_name, intf_list[i], ctx)


#
# npf_show_brief
#
# Show brief output.
#
def npf_show_brief(rct_dict, first):
    """
    Show npf dataplane stats short form. Prints one line for each
    rc-type and interface.
    """

    for rct_name in rct_dict:
        #
        # This function may be called once for each rc type.  Only display the
        # banner the first time it is called per show command.
        #
        if first:
            print("  %-16s %-15s %14s %14s" %
                  ("Process", "Interface", "In", "Out"))
            first = False

        desc = rct_short[rct_name]
        first_intf = True
        intf_list = rct_dict[rct_name]['interfaces']

        for i in range(len(intf_list)):
            intf_name = intf_list[i]['name']

            in_counts = intf_list[i]['in']
            out_counts = intf_list[i]['out']

            # Sum all counts in all categories
            in_tot = npf_rc_count(in_counts, None, None)
            out_tot = npf_rc_count(out_counts, None, None)

            fmt = "  %-16s %-15s %14u %14u"
            print(fmt % (desc, intf_name, in_tot, out_tot))

            #
            # First column (e.g. "IPv4 Firewall") only displayed for
            # first interface of that process
            #
            if first_intf:
                first_intf = False
                desc = ""


#
# npf_dp_stats_show
#
def npf_dp_stats_show(rct_name, ctx, first):

    ctx['rct'] = rct_name

    rct_dict = npf_get_dp_counts(rct_name, ctx)
    if not rct_dict:
        return

    if ctx['brief']:
        npf_show_brief(rct_dict, first)
    else:
        npf_show_long(rct_dict, ctx, first)


#
# npf_dp_stats_rpc
#
def npf_dp_stats_rpc():
    try:
        rpc_input = json.load(sys.stdin)
    except ValueError as exc:
        print("Failed to parse input JSON: {}".format(exc), file=sys.stderr)
        return {}, 1

    ctx = {'rct': None,
           'intf': None,
           'dir': None,
           'cat': None,
           'detail': True,
           'brief': False,
           'nonzero_only': False,
           'rpc': True}

    if rpc_input:
        if 'feat' in rpc_input:
            ctx['rct'] = rpc_input['feat']

            if ctx['rct'] not in rct_key:
                return {}, 0

        if 'interface' in rpc_input:
            ctx['intf'] = rpc_input['interface']

        if 'direction' in rpc_input:
            ctx['dir'] = rpc_input['direction']

            if ctx['dir'] not in dir_key:
                return {}, 0

        if 'category' in rpc_input:
            ctx['cat'] = rpc_input['category']

            if ctx['cat'] not in cat_key:
                return {}, 0

    rct_name = ctx['rct']

    rct_dict = npf_get_dp_counts(rct_name, ctx)
    if not rct_dict:
        return {}, 0

    rpc_output = dict()
    rpc_output['dataplane'] = dict()
    rpc_output['dataplane']['statistics'] = dict()

    for rct_name in rct_key:
        if rct_name in rct_dict:
            #
            # The rct_dict has all rc-types at the same level.  For the rpc
            # output, we want some in a 'firewall' dict
            #
            if rct_name in rct_key_fw:
                if 'firewall' not in rpc_output['dataplane']['statistics']:
                    rpc_output['dataplane']['statistics']['firewall'] = dict()

                rpc_output['dataplane']['statistics']['firewall'][rct_name] = rct_dict[rct_name]
            else:
                rpc_output['dataplane']['statistics'][rct_name] = rct_dict[rct_name]

    return rpc_output, 0


#
# npf_dp_stats_main
#
def npf_dp_stats_main():
    """Main function"""

    show_opt = None
    clr_opt = None
    rpc_opt = None

    #
    # Parse options
    #
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "",
                                   ['show=', 'clear=', 'rpc'])

    except getopt.GetoptError as r:
        print(r, file=sys.stderr)
        exit(2)

    #
    # Options and context
    #
    ctx = {'rct': None,
           'intf': None,
           'dir': None,
           'cat': None,
           'detail': False,
           'brief': False,
           'nonzero_only': False,
           'rpc': False,
           'col1': 20}   # Column 1 width for normal show output

    for opt, arg in opts:
        if opt in '--show':
            # show_opt is the rc type name, e.g. 'ip'
            show_opt = arg

        if opt in '--clear':
            clr_opt = arg

        if opt in '--rpc':
            rpc_opt = True

    #
    # Parse the remaining options
    #
    options = sys.argv[3:]
    i = 0
    while (i < len(options)):
        opt = options[i]
        i += 1

        if opt == "interface":
            if i < len(options):
                ctx['intf'] = options[i]
                i += 1

        if opt == "direction":
            if i < len(options):
                ctx['dir'] = options[i]
                i += 1

        if opt == "category":
            if i < len(options):
                ctx['cat'] = options[i]
                i += 1

        if opt == "non-zero":
            ctx['nonzero_only'] = True

        if opt == "detail":
            if not ctx['brief']:
                ctx['detail'] = True
                ctx['col1'] += 32

        if opt == "brief":
            if not ctx['detail']:
                ctx['brief'] = True

    #
    # show counts
    #
    if show_opt:
        first = True

        if show_opt in rct_key:
            npf_dp_stats_show(show_opt, ctx, first)
            first = False

    #
    # clear counts
    #
    elif clr_opt:
        if clr_opt not in rct_key:
            return

        cmd = "npf-op rc clear counters"
        cmd += npf_dpc_option_string(clr_opt, ctx)

        with vplaned.Controller() as controller:
            for dp in controller.get_dataplanes():
                with dp:
                    dp.string_command(cmd)

    #
    # rpc
    #
    elif rpc_opt:
        info, ret = npf_dp_stats_rpc()
        if info:
            print(json.dumps(info))
        exit(ret)


#
# main
#
if __name__ == '__main__':
    npf_dp_stats_main()
