#!/usr/bin/env python3
#
# Copyright (c) 2019-2020, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#

"""Scripts for CGNAT subscriber op-mode commands"""


import sys
import getopt
import vplaned
import socket
import struct
from time import localtime, strftime


#
# search_dicts
#
def search_dicts(key, value, list_of_dicts):
    """Search a list of dictionaries for all elements with a given key"""

    if not list_of_dicts:
        return []
    return [element for element in list_of_dicts if element[key] == value]


#
# num2str
#
def num2str(count, approx):
    """Convert a count into a string"""
    if approx:
        str = "~%u" % (count)
    else:
        str = "%u" % (count)
    return str


#
# int2ip
#
def int2ip(addr):
    """Convert a uint to an IP address string"""

    return socket.inet_ntoa(struct.pack("!I", addr))


#
# cgn_get_subscriber_list
#
def cgn_get_subscriber_list(prefix):
    """Get a sorted list of active cgnat subscribers.

    Fetches a list of uints from the dataplane, sorts them, then converts to
    IP address strings.  An optional address or prefix/length may be
    specified, in which case the dataplane will only return subscribers
    matching that value.
    """

    subs_list = []
    base_cmd = "cgn-op list subscribers"

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                cmd = base_cmd
                if prefix:
                    cmd = "%s prefix %s" % (cmd, prefix)

                cgn_dict = dp.json_command(cmd)

                # Remove outer object
                tmp_list = cgn_dict.get('subscribers')

                if tmp_list:
                    subs_list.extend(tmp_list)

    # Remove duplicates from list
    subs_list = list(dict.fromkeys(subs_list))

    # Sort list while it is in uint number format
    subs_list.sort()

    # Return list of addresses in IP address string format
    return [int2ip(addr) for addr in subs_list]


#
# cgn_get_subs
#
def cgn_get_subs(addr):
    """Get subscriber entries for a given subscriber address or prefix."""

    subs_list = []

    cmd = "cgn-op show subscriber"
    if addr:
        cmd = "%s address %s" % (cmd, addr)

    cmd = "%s detail" % (cmd)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                cgn_dict = dp.json_command(cmd)
                new = cgn_dict.get('subscribers')

                # Extend subs list
                if new:
                    subs_list.extend(new)

    return subs_list


#
# cgn_subs_show_hdr
#
def cgn_subs_show_hdr():
    """Show CGN subs header"""

    print("%15s  %15s  %18s  %4s  %5s  %10s  %9s  %8s" %
          ("Subscriber", "Paired Addr", "Sessions", "Blks", "Ports",
           "Map Reqs", "Map Fails", "Duration"))


#
# secs2time
#
def secs2time(secs):
    """Convert seconds to hh:mm:ss or dd:hh:mm string.  If brief is specified
    then we abbreviate 'hh:mm:ss' to just 'hh:' etc."""
    one_day = 60*60*24

    mins = secs // 60
    hrs = mins // 60

    if secs >= one_day:
        days = hrs // 24
        time = "%02d:%02d:%02d" % (days, hrs % 24, mins % 60)
        units = "dd:hh:mm"

        return time, units

    time = "%02d:%02d:%02d" % (hrs, mins % 60, secs % 60)
    units = "hh:mm:ss"

    return time, units


#
# cgn_subs_show_one
#
def cgn_subs_show_one(subs):
    """Show one cgnat subscriber"""

    duration = subs.get('duration')
    time, units = secs2time(duration // 1000000)

    sess = subs.get('sess_crtd') - subs.get('sess_dstrd')
    sess2 = subs.get('sess2_crtd') - subs.get('sess2_dstrd')

    if sess2 > 0:
        sess_str = "%u/%u" % (sess, sess2)
    else:
        sess_str = "%u" % (sess)

    print("%15s  %15s  %18s  %4s  %5s  %10s  %9s  %8s (%s)" %
          (subs.get('address'),
           subs.get('paired_addr'),
           sess_str,
           subs.get('block_count'),
           subs.get('tcp_ports_used') + subs.get('udp_ports_used') +
           subs.get('other_ports_used'),
           subs.get('map_reqs'),
           subs.get('map_fails'),
           time, units))


#
# cgn_subs_show_block_detail
#
def cgn_subs_show_block_detail(subs):
    """Show port blocks for one subscriber"""

    print("  %5s   %15s   %11s   %5s   %5s %s" %
          ("Block", "Public Address", "Port Range", "Total",
           "Used", "(tcp/udp/other)"))

    #
    # Get list of blocks for this protocol, sorted by public address and port
    #
    block_list = subs.get('port_blocks')

    # block list might be empty.  If so, just display banner.
    if not block_list:
        return

    tmp = sorted(block_list, key=lambda d: (d['pub_addr'], d['port_start']))
    block_list = tmp

    tot_blocks = 0
    tot_ports = 0
    tot_used = 0

    for block in block_list:
        port_start = block.get('port_start')
        port_end = block.get('port_end')
        port_range = "%u-%u" % (port_start, port_end)
        nports = port_end - port_start + 1

        tcp_uports = block.get('tcp_ports_used')
        udp_uports = block.get('udp_ports_used')
        oth_uports = block.get('other_ports_used')
        tot_uports = tcp_uports + udp_uports + oth_uports

        tot_blocks += 1
        tot_used += tot_uports
        tot_ports += nports

        print("  %5u   %15s   %11s   %5s   "
              "%5u (%u/%u/%u)" % (block.get('block'),
                                  block.get('pub_addr'),
                                  port_range,
                                  port_end - port_start + 1,
                                  tot_uports,
                                  tcp_uports, udp_uports, oth_uports))


#
# cgn_subs_show_detail
#
def cgn_subs_show_detail(subs):
    """Show CGN subs detail"""

    duration = subs.get('duration')
    time, units = secs2time(duration // 1000000)
    map_reqs = subs.get('map_reqs')
    map_fails = subs.get('map_fails')
    map_ok = map_reqs - map_fails

    flags = subs.get('flags')
    exprd = "No"
    if (flags & 0x01) != 0:
        exprd = "Yes"

    tcp_uports = subs.get('tcp_ports_used')
    udp_uports = subs.get('udp_ports_used')
    oth_uports = subs.get('other_ports_used')
    tot_nports = subs.get('port_count')

    sess_crtd = subs.get('sess_crtd')
    sess_dstrd = subs.get('sess_dstrd')
    rate_20s = subs.get('sess_rate_20s')
    rate_1m = subs.get('sess_rate_1m')
    rate_5m = subs.get('sess_rate_5m')
    max_rate = subs.get('sess_rate_max')
    max_rate_1m = subs.get('sess_rate_1m_max')

    # Epoch in seconds
    max_rate_tm = subs.get('sess_rate_max_tm') / 1000000
    max_rate_1m_tm = subs.get('sess_rate_1m_max_tm') / 1000000

    print("Subscriber: %s" % (subs.get('address')))
    print("  %-32s %18s" % ("Paired address", subs.get('paired_addr')))
    print("  %-32s %18s (%s)" % ("Duration", time, units))
    print("  %-32s %18s" % ("Expired", exprd))

    print("  %-32s %18u" % ("Active sessions", sess_crtd - sess_dstrd))
    print("    %-30s %18u" % ("Sessions created", sess_crtd))
    print("    %-30s %18u" % ("Sessions destroyed", sess_dstrd))

    tm_20s = "-"
    tm_1m = "-"

    if max_rate > 0:
        tm_20s = "%s" % (strftime("%F %H:%M:%S +0000", localtime(max_rate_tm)))

    if max_rate_1m > 0:
        tm_1m = "%s" % (strftime("%F %H:%M:%S +0000",
                                 localtime(max_rate_1m_tm)))

    print("  %-19s %15s %15s" % ("Session rate:", "Current", "Maximum"))
    print("    %-17s %15u %15u (%s)" % ("20 secs", rate_20s, max_rate, tm_20s))
    print("    %-17s %15u %15u (%s)" % ("1 minute", rate_1m, max_rate_1m,
                                        tm_1m))
    print("    %-17s %15u" % ("5 minutes", rate_5m))

    sess2_crtd = subs.get('sess2_crtd')
    sess2_dstrd = subs.get('sess2_dstrd')

    print("  %-32s %18u" % ("Active sub-sessions",
                            sess2_crtd - sess2_dstrd))
    print("    %-30s %18u" % ("Sub-sessions created", sess2_crtd))
    print("    %-30s %18u" % ("Sub- sessions destroyed", sess2_dstrd))

    print("  %-32s %18u" % ("Mapping requests", map_reqs))
    print("    %-30s %18u" % ("Ok", map_ok))
    print("    %-30s %18u" % ("Failed", map_fails))

    print("  %-32s %18s" % ("Out, packets",
                            num2str(subs.get('out_pkts'), True)))
    print("  %-32s %18s" % ("     bytes",
                            num2str(subs.get('out_bytes'), True)))
    print("  %-32s %18s" % ("In,  packets",
                            num2str(subs.get('in_pkts'), True)))
    print("  %-32s %18s" % ("     bytes",
                            num2str(subs.get('in_bytes'), True)))
    print("  %-32s %18s" % ("     unknown source",
                            num2str(subs.get('unk_pkts_in'), True)))

    print("  %-32s %18u" % ("Port blocks", subs.get('block_count')))

    tcp_str = "%u/%u" % (tcp_uports, tot_nports)
    udp_str = "%u/%u" % (udp_uports, tot_nports)
    oth_str = "%u/%u" % (oth_uports, tot_nports)

    print("  %-25s   %-23s" % ("Ports in-use:", "Active Block number:"))
    print("    %-11s %11s     %21u" % ("TCP", tcp_str,
                                       subs.get('tcp_active_block')))
    print("    %-11s %11s     %21u" % ("UDP", udp_str,
                                       subs.get('udp_active_block')))
    print("    %-11s %11s     %21u" % ("Other", oth_str,
                                       subs.get('other_active_block')))

    # Show port-block list
    cgn_subs_show_block_detail(subs)
    print()


#
# cgn_op_show_subs
#
def cgn_op_show_subs(sa_opt, d_opt):
    """Show CGN subs"""

    subs_list = []

    # Get list of sorted subscriber addresses
    addr_list = cgn_get_subscriber_list(sa_opt)

    # For each address in addr_list, get the subscriber json
    for addr in addr_list:
        new = cgn_get_subs(addr)
        subs_list.extend(new)

    if not d_opt:
        cgn_subs_show_hdr()

    for subs in subs_list:
        if not d_opt:
            cgn_subs_show_one(subs)
        else:
            cgn_subs_show_detail(subs)


#
# cgn_op_clear_subs_stats
#
def cgn_op_clear_subs_stats(sa_opt):
    """Clear CGN subs"""

    cmd = "cgn-op clear subscriber %s statistics" % (sa_opt)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                dp.string_command(cmd)


#
# cgn_op_update_subs_stats
#
def cgn_op_update_subs_stats(sa_opt):
    """Update CGN subs"""

    cmd = "cgn-op update subscriber %s statistics" % (sa_opt)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                dp.string_command(cmd)


#
# usage
#
def cgn_usage():
    """Show command help"""

    print("usage: {} --show | --clear".format(sys.argv[0]),
          file=sys.stderr)


#
# cgn_op_main
#
def cgn_op_main():
    """Main function"""

    s_opt = False
    d_opt = False
    sa_opt = None
    c_opt = False
    u_opt = False
    stats_opt = False

    #
    # Parse options
    #
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "",
                                   ['show', 'detail',
                                    'subs-addr=',
                                    'clear', 'stats', 'update'])

    except getopt.GetoptError as r:
        print(r, file=sys.stderr)
        cgn_usage()
        sys.exit(2)

    for opt, arg in opts:
        if opt in '--show':
            s_opt = True

        if opt in '--detail':
            d_opt = True

        if opt in '--subs-addr':
            sa_opt = arg

        if opt in '--clear':
            c_opt = True

        if opt in '--update':
            u_opt = True

        if opt in '--stats':
            stats_opt = True

    # show ...
    if s_opt:
        cgn_op_show_subs(sa_opt, d_opt)

    # clear ...
    if c_opt and stats_opt:
        cgn_op_clear_subs_stats(sa_opt)

    # update ...
    if u_opt and stats_opt:
        cgn_op_update_subs_stats(sa_opt)


#
# main
#
if __name__ == '__main__':
    cgn_op_main()
