#!/usr/bin/env python3
#
# Copyright (c) 2019-2020, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#

"""Scripts for CGNAT error, summary, and policy op-mode commands"""

import sys
import getopt
import vplaned
from time import localtime, strftime
from vyatta.npf.npf_addr_group import npf_show_address_group


#
# num2str
#
def num2str(count, approx):
    """Convert a count into a string"""
    if approx:
        str = "~%u" % (count)
    else:
        str = "%u" % (count)
    return str


# Yes or No string
def yes_or_no(val):
    return "Yes" if val else "No"


#
# usage
#
def cgn_usage():
    """Show command help"""

    print("usage: {} --show {{errors | summary | "
          "policy [<name>]}}".format(sys.argv[0]),
          file=sys.stderr)


#
# cgn_op_show_summary
#
def cgn_op_show_summary():
    """Show CGN summary"""

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                cgn_dict = dp.json_command("cgn-op show summary")
                summary = cgn_dict.get('summary')
                if not summary:
                    return

                tbl_full = yes_or_no(summary.get('sess_table_full'))

                #
                # If we ever change to supporting multiple dataplanes then a
                # description of the dataplane should be displayed before each
                # summary.
                #
                print("CGNAT Summary")
                print("  %-32s" % ("Sessions:"))
                print("    %-30s %18u" % ("Active sessions",
                                          summary.get('sess_count')))
                print("      %-28s %18s" % ("Sessions created",
                                            num2str(summary.get('sess_created'), True)))
                print("      %-28s %18s" % ("Sessions destroyed",
                                            num2str(summary.get('sess_destroyed'), True)))
                print("    %-30s %18u" % ("Active sub-sessions",
                                          summary.get('sess2_count')))
                print("      %-28s %18s" % ("Sub-sessions created",
                                            num2str(summary.get('sess2_created'), True)))
                print("      %-28s %18s" % ("Sub-sessions destroyed",
                                            num2str(summary.get('sess2_destroyed'), True)))
                print("    %-30s %18u" % ("Maximum table size",
                                          summary.get('max_sess')))
                print("    %-30s %18s" % ("Table full", tbl_full))
                print("  %-32s" % ("Public address mapping table:"))
                print("    %-30s %18u" % ("Used",
                                          summary.get('apm_table_used')))
                print("  %-32s" % ("Subscriber address table:"))
                print("    %-30s %18u" % ("Used",
                                          summary.get('subs_table_used')))
                print("    %-30s %18u" % ("Max",
                                          summary.get('subs_table_max')))

                # Out
                print("  %-32s" % ("Out:"))
                print("    %-30s %18s" % ("Translated packets",
                                          num2str(summary.get('pkts_out'), True)))
                print("    %-30s %18s" % ("           bytes",
                                          num2str(summary.get('bytes_out'), True)))
                print("    %-30s %18u" % ("Did not match CGNAT policy",
                                          summary.get('nopolicy')))
                if summary.get('bypass'):
                    print("    %-30s %18u" % ("ALG packets",
                                              summary.get('bypass')))
                print("    %-30s %18u" % ("Untranslatable packets",
                                          summary.get('etrans')))
                print("    %-30s %18u" % ("Hairpinned packets",
                                          summary.get('pkts_hairpinned')))

                # In
                print("  %-32s" % ("In:"))
                print("    %-30s %18s" % ("Translated packets",
                                          num2str(summary.get('pkts_in'), True)))
                print("    %-30s %18s" % ("           bytes",
                                          num2str(summary.get('bytes_in'), True)))
                print("    %-30s %18s" % ("Unknown source addr or port",
                                          num2str(summary.get('unk_pkts_in'), True)))
                print("    %-30s %18u" % ("Did not match CGNAT session",
                                          summary.get('nosess')))
                if 'nopool' in summary:
                    print("    %-30s %18u" % ("Did not match CGNAT pool",
                                              summary.get('nopool')))

                # Dest addr/port hash tables
                if 'sess_ht_created' in summary:
                    print("  %-32s" % ("Session hash tables:"))
                    print("    %-30s %18u" % ("Created",
                                              summary.get('sess_ht_created')))
                    print("    %-30s %18u" % ("Destroyed",
                                              summary.get('sess_ht_destroyed')))

                # PCP
                if 'pcp_ok' in summary:
                    print("  %-32s %18u" % ("PCP sessions created",
                                            summary.get('pcp_ok')))
                if 'pcp_err' in summary:
                    print("  %-32s %18u" % ("PCP errors",
                                            summary.get('pcp_err')))

                # Other
                print("  %-32s %18u" % ("Memory allocation failures",
                                        summary.get('enomem')))
                print("  %-32s %18u" % ("Resource limitation failures",
                                        summary.get('enospc')))
                print("  %-32s %18u" % ("Thread contention errors",
                                        summary.get('ethread')))
                print("  %-32s %18u" % ("Packet buffer errors",
                                        summary.get('embuf')))
                if 'icmp_echoreq' in summary:
                    print("  %-32s %18u" % ("ICMP Echo Req for CGNAT addr",
                                            summary.get('icmp_echoreq')))
                print()


#
# Get error counts
#
def cgn_op_get_errors():
    """Get cgnat error counts.  Returns 2 dictionaries"""

    #
    # Create two new dictionaries, keyed by name, e.g. 'PCY_ENOENT'.  Each
    # entry contains a sub-dictionary of 'count' and 'desc'.
    #
    in_d = {}
    out_d = {}

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                cgn_dict = dp.json_command("cgn-op show errors")
                errors = cgn_dict.get('errors')
                if not errors:
                    return [], [], [], []

                in_errors = errors.get('in')
                out_errors = errors.get('out')

                for i in range(0, len(in_errors)):
                    key = in_errors[i].get('name')
                    count = in_errors[i].get('count')
                    desc = in_errors[i].get('desc')

                    if key in in_d:
                        in_d[key][count] += count
                    else:
                        in_d[key] = {'count': count, 'desc': desc}

                for i in range(0, len(out_errors)):
                    key = out_errors[i].get('name')
                    count = out_errors[i].get('count')
                    desc = out_errors[i].get('desc')

                    if key in out_d:
                        out_d[key][count] += count
                    else:
                        out_d[key] = {'count': count, 'desc': desc}

    # Returns 2 dictionaries
    return in_d, out_d


#
# Print one line in error/global count output
#
def cgn_op_show_error_one(key, fmt, in_d, out_d):
    print(fmt % (in_d[key]['desc'], in_d[key]['count'], out_d[key]['count']))


#
# cgn_op_show_errors
#
def cgn_op_show_errors():
    """Show CGN errors"""

    in_d, out_d = cgn_op_get_errors()
    fmt = "    %-54s %12u %12u"

    print("%-58s %12s %12s" % ("CGNAT Global Counts", "In", "Out"))

    print("  Unable to translate packet:")
    cgn_op_show_error_one('PCY_ENOENT', fmt, in_d, out_d)
    cgn_op_show_error_one('SESS_ENOENT', fmt, in_d, out_d)
    cgn_op_show_error_one('POOL_ENOENT', fmt, in_d, out_d)
    cgn_op_show_error_one('PCY_BYPASS', fmt, in_d, out_d)
    cgn_op_show_error_one('BUF_PROTO', fmt, in_d, out_d)
    cgn_op_show_error_one('BUF_ICMP', fmt, in_d, out_d)

    print("  Resource limitations:")
    cgn_op_show_error_one('MBU_ENOSPC', fmt, in_d, out_d)
    cgn_op_show_error_one('BLK_ENOSPC', fmt, in_d, out_d)
    cgn_op_show_error_one('POOL_ENOSPC', fmt, in_d, out_d)
    cgn_op_show_error_one('SRC_ENOSPC', fmt, in_d, out_d)
    cgn_op_show_error_one('APM_ENOSPC', fmt, in_d, out_d)
    cgn_op_show_error_one('S1_ENOSPC', fmt, in_d, out_d)
    cgn_op_show_error_one('S2_ENOSPC', fmt, in_d, out_d)

    print("  Memory allocation failures:")
    cgn_op_show_error_one('S1_ENOMEM', fmt, in_d, out_d)
    cgn_op_show_error_one('S2_ENOMEM', fmt, in_d, out_d)
    cgn_op_show_error_one('PB_ENOMEM', fmt, in_d, out_d)
    cgn_op_show_error_one('APM_ENOMEM', fmt, in_d, out_d)
    cgn_op_show_error_one('SRC_ENOMEM', fmt, in_d, out_d)

    print("  Thread contention errors:")
    cgn_op_show_error_one('S1_EEXIST', fmt, in_d, out_d)
    cgn_op_show_error_one('S2_EEXIST', fmt, in_d, out_d)
    cgn_op_show_error_one('SRC_ENOENT', fmt, in_d, out_d)

    print("  Packet buffer errors:")
    cgn_op_show_error_one('BUF_ENOL3', fmt, in_d, out_d)
    cgn_op_show_error_one('BUF_ENOL4', fmt, in_d, out_d)
    cgn_op_show_error_one('BUF_ENOMEM', fmt, in_d, out_d)

    print("  PCP errors:")
    cgn_op_show_error_one('PCP_EINVAL', fmt, in_d, out_d)
    cgn_op_show_error_one('PCP_ENOSPC', fmt, in_d, out_d)

    print("  Other:")
    cgn_op_show_error_one('ERR_UNKWN', fmt, in_d, out_d)

    print()


#
# cgn_op_show_policy_one
#
def cgn_op_show_policy_one(pol):
    """Show one CGN policy"""

    print("Policy: %s" % (pol.get('name')))

    if "match_group" in pol:
        npf_show_address_group(pol.get('match_group'), "ipv4", None,
                               "Match address-group",
                               2, 22, 28,
                               4, 14, 34)

    print("  %-32s %18s" % ("Interface", pol.get('interface')))
    print("  %-32s %18s" % ("Priority", pol.get('priority')))
    print("  %-32s %18s" % ("Pool", pol.get('pool')))
    print("  %-32s %18s" % ("Log all sessions",
                            yes_or_no(pol.get('log_sess_all'))))

    log_grp = pol.get('log_sess_group')
    if log_grp:
        npf_show_address_group(log_grp, "ipv4", None,
                               "Log select sessions",
                               2, 22, 28,
                               6, 12, 34)

    print("    %-30s %18s" % ("Log session start",
                              yes_or_no(pol.get('log_sess_start'))))
    print("    %-30s %18s" % ("Log session end",
                              yes_or_no(pol.get('log_sess_end'))))
    print("    %-30s %18s" % ("Log session periodically",
                              yes_or_no(pol.get('log_sess_periodic'))))

    sess_crtd = pol.get('sess_created')
    sess_dstrd = pol.get('sess_destroyed')

    print("  %-32s %18u" % ("Active subscribers", pol.get('source_count')))
    print("  %-32s %18s" % ("Active sessions",
                            num2str(sess_crtd - sess_dstrd, True)))
    print("    %-30s %18s" % ("Sessions created", num2str(sess_crtd, True)))
    print("    %-30s %18s" % ("Sessions destroyed", num2str(sess_dstrd, True)))

    sess2_crtd = pol.get('sess2_created')
    sess2_dstrd = pol.get('sess2_destroyed')

    if sess2_crtd > 0:
        print("  %-32s %18s" % ("Active sub-sessions",
                                num2str(sess2_crtd - sess2_dstrd, True)))
        print("    %-30s %18s" % ("Sub-sessions created",
                                  num2str(sess2_crtd, True)))
        print("    %-30s %18s" % ("Sub-sessions destroyed",
                                  num2str(sess2_dstrd, True)))

    print("  %-32s %18s" % ("Out, packets", num2str(pol.get('out_pkts'), True)))
    print("  %-32s %18s" % ("     bytes", num2str(pol.get('out_bytes'), True)))
    print("  %-32s %18s" % ("In,  packets", num2str(pol.get('in_pkts'), True)))
    print("  %-32s %18s" % ("     bytes", num2str(pol.get('in_bytes'), True)))
    print("  %-32s %18s" % ("     unknown source",
                            num2str(pol.get('unk_pkts_in'), True)))

    print("  %s" % ("Max Session Rates:"))
    print("    %-16s %-8s  %-8s" % ("Subscriber", "Max Rate", "Time"))

    # Subscriber max session rate list.  Always returns 5 entries.
    # Stop when we reach the first 'empty' entry.
    #
    sr_list = pol.get('subs_sess_rates')

    for i in range(0, len(sr_list)):
        if sr_list[i].get('max_sess_rate') == 0:
            break
        max_rate_tm = sr_list[i].get('time') / 1000000
        tmp = "%s" % (strftime("%F %H:%M:%S +0000", localtime(max_rate_tm)))

        print("    %-16s %8u  %s" % (sr_list[i].get('subscriber'),
                                     sr_list[i].get('max_sess_rate'),
                                     tmp))

    print()


#
# cgn_op_show_policy
#
def cgn_op_show_policy(n_opt):
    """Show a CGN policy"""

    policy_list = []

    cmd = "cgn-op show policy"
    if n_opt:
        cmd = "%s %s" % (cmd, n_opt)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                cgn_dict = dp.json_command(cmd)
                tmp_list = cgn_dict.get('policies')
                if tmp_list:
                    policy_list.extend(tmp_list)

    #
    # The json returned by the dataplane should already be in order of
    # interface and priority
    #
    for pol in policy_list:
        cgn_op_show_policy_one(pol)


#
# cgn_op_clear_policy_stats
#
def cgn_op_clear_policy_stats(n_opt):
    """Clear a CGN policies statistics"""

    cmd = "cgn-op clear policy %s statistics" % (n_opt)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                dp.string_command(cmd)


#
# cgn_op_clear_errors
#
def cgn_op_clear_errors():
    """Clear CGNAT errors"""

    cmd = "cgn-op clear errors"

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                dp.string_command(cmd)


#
# cgn_op_main
#
def cgn_op_main():
    """Main function"""

    s_opt = None
    n_opt = None
    c_opt = None
    stats_opt = False

    #
    # Parse options
    #
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "",
                                   ['show=', 'name=',
                                    'clear=', 'stats'])

    except getopt.GetoptError as r:
        print(r, file=sys.stderr)
        cgn_usage()
        sys.exit(2)

    for opt, arg in opts:
        if opt in '--show':
            s_opt = arg

        if opt in '--name':
            n_opt = arg

        if opt in '--clear':
            c_opt = arg

        if opt in '--stats':
            stats_opt = True

    # show ...
    if s_opt:
        if s_opt == 'policy':
            cgn_op_show_policy(n_opt)

        if s_opt == 'errors':
            cgn_op_show_errors()

        if s_opt == 'summary':
            cgn_op_show_summary()

    # clear ...
    if c_opt:
        if c_opt == 'policy' and stats_opt:
            cgn_op_clear_policy_stats(n_opt)

        if c_opt == 'errors':
            cgn_op_clear_errors()


#
# main
#
if __name__ == '__main__':
    cgn_op_main()
