#!/usr/bin/env python3
#
# Copyright (c) 2019-2020, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#

"""Scripts for CGNAT session op-mode commands

Shows all sessions in table format, ordered by subscriber address and port.

It does this by first fetching a list of IP addresses in uint format for
active subscribers and sorting that list numerically.

It then fetches the sessions for each subscriber in the ordered list in
turn.  It sorts (or filters) each sub-list by subscriber port, before
displaying and moving onto the next subscriber.

An 'unordered' option allows the session table to be fetched and displayed
in batches of 1000.  Each batch resumes where the previous batch finished
such that the session table is only iterated over once.

"""

import sys
import getopt
import socket
import vplaned
import operator
import struct
from datetime import datetime


# Session state abbreviations
CGN_SESS_STATE_HDR = "State codes: CL - CLOSED, OP - OPENING, " \
    "ES - ESTABLISHED, TR - TRANSITORY, CG - CLOSING\n"


#
# num2str
#
def num2str(count, approx):
    """Convert a count into a string"""
    if approx:
        str = "~%u" % (count)
    else:
        str = "%u" % (count)
    return str


#
# int2ip
#
def int2ip(addr):
    """Convert a uint to an IP address string"""

    return socket.inet_ntoa(struct.pack("!I", addr))


#
# class ProtocolTable
#
class ProtocolTable:
    """Create a table from /etc/protocols

    Table is indexed with a protocol number.
    """

    def __init__(self):
        self._protocol_names = {}
        with open('/etc/protocols', 'r') as f:
            for line in f:
                proto_data, _, _ = line.partition('#')
                if not proto_data:
                    continue
                fields = proto_data.split()
                if len(fields) > 2:
                    self._protocol_names[int(fields[1])] = fields[0]

    def __getitem__(self, protocol_number):
        name = self._protocol_names.get(protocol_number)
        if name is None:
            name = "Unassigned"
        return name


# Global proto_table
proto_table = []


#
# npf_num2proto
#
def npf_num2proto(pnum):
    """Protocol number to name"""

    # Look for the common ones first
    if pnum == 6:
        return "tcp"
    elif pnum == 17:
        return "udp"
    elif pnum == 1:
        return "icmp"
    elif pnum == 58:
        # Use the short form of icmp-ipv6 when appropriate
        return "icmpv6"

    # Get cached proto table, else create new one
    global proto_table
    if not bool(proto_table):
        proto_table = ProtocolTable()

    pname = proto_table[pnum]

    # If not found, return the number as a string
    if not pname:
        return str(pnum)

    return pname


#
# npf_proto2num
#
# Protocol name or number to number.  (This is similar to the C function
# getprotoent used by perl in FWHelper.pm)
#
def npf_proto2num(name):
    # Ignore 'ip' and 'ipv6'
    if name in ('ip', 'ipv6'):
        return 0

    # Some likely protocol values
    proto2num = {
        'icmp': 1,
        'tcp': 6,
        'udp': 17,
        'ipv6-icmp': 58
    }

    if name.isdigit():
        # Already a number?
        return int(name)

    if name in proto2num:
        # Common protocol?
        return proto2num[name]

    # else uncommon protocol
    try:
        return socket.getprotobyname(name)
    except socket.error:
        return 0


#
# secs2time
#
def secs2time(secs):
    """Convert seconds to hh:mm:ss or dd:hh:mm string"""
    one_day = 60*60*24

    mins = secs // 60
    hrs = mins // 60

    if secs >= one_day:
        days = hrs // 24
        time = "%02d:%02d:%02d" % (days, hrs % 24, mins % 60)
        units = "dd:hh:mm"
    else:
        time = "%02d:%02d:%02d" % (hrs, mins % 60, secs % 60)
        units = "hh:mm:ss"

    return time, units


#
# cgn_sess_state_str
#
def cgn_sess_state_str(outer, inner, long_or_short):
    """Session state to string.

    Use the inner 2-tuple session if it exists, else use the outer 3-tuple
    session.
    """

    state2short = ["NO", "CL", "OP", "ES", "TR", "CG", "CG", "CG"]
    state2long = ["None", "Closed", "Opening", "Established", "Transitory",
                  "Client-FIN-Rcv", "Server-FIN-Rcv", "CS-FIN-Rcv"]

    proto = outer.get('proto')

    if inner:
        state = inner.get('state')
    else:
        state = outer.get('state')

    if state is None:
        return "?"

    if long_or_short:
        state_str = "%s" % (state2long[state])
    else:
        state_str = "%s" % (state2short[state])

    if proto == 6 and inner:
        sep = ""
        if long_or_short:
            sep = " "
        state_str = "%s%s[%u/%02X]" % (state_str, sep, state,
                                       inner.get('hist'))

    return state_str


#
# cgn_get_subscriber_list
#
def cgn_get_subscriber_list(prefix):
    """Get a sorted list of active cgnat subscribers.

    Fetches a list of uints from the dataplane, sorts them, then converts to
    IP address strings.  An optional address or prefix/length may be
    specified, in which case the dataplane will only return subscribers
    matching that value.
    """

    subs_list = []

    cmd = "cgn-op list subscribers"
    if prefix:
        cmd = "%s prefix %s" % (cmd, prefix)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                cgn_dict = dp.json_command(cmd)

                # Remove outer object
                tmp_list = cgn_dict.get('subscribers')

                if tmp_list:
                    subs_list.extend(tmp_list)

    # Remove duplicates from list
    subs_list = list(dict.fromkeys(subs_list))

    # Sort list while it is in uint number format
    subs_list.sort()

    # Return list of addresses in IP address string format
    return [int2ip(addr) for addr in subs_list]


#
# cgn_get_sessions_subs
#
def cgn_get_sessions_subs(addr, fltrs):
    """Get sessions for a given subscriber address.

    For simplicity, we just fetch all sessions for the given subscriber
    rather than batches of 1000.
    """

    sess_list = []

    cmd = "cgn-op show session subs-addr %s" % (addr)
    if fltrs:
        cmd = "%s %s" % (cmd, fltrs)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                cgn_dict = dp.json_command(cmd)

                if cgn_dict and "__error" not in cgn_dict:
                    new = cgn_dict.get('sessions')

                    if new and "__error" not in new:
                        # Extend session list
                        sess_list.extend(new)

    # Sort list by subscriber port
    if sess_list:
        sess_list.sort(key=operator.itemgetter('subs_port'))

    return sess_list


#
# col1_width
#
# Column 1 is the two-part sessions ID column and can vary in length quite a
# bit.  As such, we make it dynamic.  It starts at 6 chars wide and grows as
# the size of sessions IDs grow.
#
def col1_width(tmp):
    """Column 1 width static variable."""

    # 6 is minimum
    if tmp < 6:
        tmp = 6

    if 'cnt' not in col1_width.__dict__:
        col1_width.cnt = 0
    # Column width only ever increases
    if tmp > col1_width.cnt:
        col1_width.cnt = tmp
    return col1_width.cnt


#
# cgn_sess_show_hdr
#
def cgn_sess_show_hdr():
    """CGN Session Table header"""

    print(CGN_SESS_STATE_HDR)

    col1 = col1_width(6)

    print("%-*s %21s %21s %10s %21s" %
          (col1 + 6 + 9, "", "Subscriber", "Public", "", "Destination"))

    print("%-*s %-5s %-8s %15s %5s %15s %5s %10s %15s %5s %7s %6s %6s" %
          (col1, "ID", "Proto", "State", "Address", "Port", "Address", "Port",
           "Intf", "Address", "Port", "Timeout", "PktOut", "PktIn"))


#
# cgn_count_str
#
def cgn_count_str(c):
    """Number to string.  Adds k, M, G etc. suffix"""

    # Less than 100K?
    if c < 100000:
        return "%4u" % (c)

    # Change to units of 1K
    c = c / 1000

    # Less than 1M?
    if c < 1000:
        return "%4uK" % (c)

    # Change to units of 100K
    c = c / 100

    # Less than 10M?
    if c < 100:
        return "%2u.%1uM" % (c / 10, c % 10)

    # Change to units of 1M
    c = c / 10

    # Less than 1G?
    if c < 1000:
        return "%4uM" % (c)

    # Change to units of 100M
    c = c / 100

    # Less than 10G?
    if c < 100:
        return "%2u.%1uG" % (c / 10, c % 10)

    # Change to units of 1G
    c = c / 10

    return "%4uG" % (c)


#
# cgn_sess_show_one
#
def cgn_sess_show_one(outer, inner, show_hdr):
    """Show one session"""

    state = cgn_sess_state_str(outer, inner, False)

    # Is there a 2-tuple inner session?
    if inner:
        dst_addr = "%s" % (inner.get('dst_addr'))
        dst_port = "%s" % (inner.get('dst_port'))
        sid = "%u.%u" % (outer.get('id'), inner.get('id'))
        etime = inner.get('cur_to')
        pkts_out = inner.get('out_pkts')
        pkts_in = inner.get('in_pkts')
    else:
        dst_addr = "-"
        dst_port = "-"
        if outer.get('init_dst_port'):
            dst_port = "%u" % (outer.get('init_dst_port'))
        sid = "%u.0" % (outer.get('id'))
        etime = outer.get('cur_to')
        pkts_out = outer.get('out_pkts')
        pkts_in = outer.get('in_pkts')

    if etime is None:
        etime = "-"

    col1 = col1_width(len(sid))

    # Do not display header until col1 width has been determined
    if show_hdr:
        cgn_sess_show_hdr()

    print("%-*s %-5s %-8s %15s %5s %15s %5s %10s %15s %5s %7s %6s %6s" %
          (col1, sid, npf_num2proto(outer.get('proto')), state,
           outer.get('subs_addr'), outer.get('subs_port'),
           outer.get('pub_addr'), outer.get('pub_port'),
           outer.get('intf'),
           dst_addr, dst_port, etime,
           cgn_count_str(pkts_out), cgn_count_str(pkts_in)))


#
# cgn_sess_show_outer_brief
#
def cgn_sess_show_outer_brief(outer, show_hdr):
    """Show 3-tuple outer session"""

    dst_list = []
    dst_dict = outer.get('destinations')

    #
    # If outer session has nested inner sessions, get list of inner sessions
    # sorted by address and port
    #
    if dst_dict:
        tmp = dst_dict.get('sessions')
        dst_list = sorted(tmp, key=lambda d: (d['dst_addr'], d['dst_port']))

    if dst_list:
        for dst in dst_list:
            cgn_sess_show_one(outer, dst, show_hdr)
            show_hdr = False
    else:
        cgn_sess_show_one(outer, None, show_hdr)


#
# cgn_sess_show_one_detail
#
def cgn_sess_show_one_detail(outer, inner):
    """Show one session in detail"""

    state = cgn_sess_state_str(outer, inner, True)
    proto = npf_num2proto(outer.get('proto'))

    # Is there a 2-tuple inner session?
    if inner:
        dst_info = "%s/%s" % (inner.get('dst_addr'), inner.get('dst_port'))
        sid = "%u.%u" % (outer.get('id'), inner.get('id'))
        cur_timeout = inner.get('cur_to')
        max_timeout = inner.get('max_to')
        start_time = inner.get('start_time')
        duration = inner.get('duration')
        pkts_out = inner.get('out_pkts')
        pkts_in = inner.get('in_pkts')
        bytes_out = inner.get('out_bytes')
        bytes_in = inner.get('in_bytes')
        if inner.get('exprd'):
            inner_exprd = "Yes"
        else:
            inner_exprd = "No"
    else:
        dst_info = "any"
        sid = "%u.0" % (outer.get('id'))
        cur_timeout = outer.get('cur_to')
        max_timeout = outer.get('max_to')
        start_time = outer.get('start_time')
        duration = outer.get('duration')
        pkts_out = outer.get('out_pkts')
        pkts_in = outer.get('in_pkts')
        bytes_out = outer.get('out_bytes')
        bytes_in = outer.get('in_bytes')
        inner_exprd = "-"

    unk_pkts_in = outer.get('unk_pkts_in')

    time, units = secs2time(duration // 1000000)

    # Convert start_time in microsecs to a datetime string
    start_time = start_time // 1000000
    start_time = \
        datetime.fromtimestamp(start_time).strftime('%Y-%m-%d %H:%M:%S')

    if outer.get('exprd'):
        outer_exprd = "Yes"
    else:
        outer_exprd = "No"

    exprd = "%s/%s" % (outer_exprd, inner_exprd)

    print("Session ID: %s, State: %s, Expired: %s" % (sid, state, exprd))
    print("  Interface:         %20s" % (outer.get('intf')))
    print("  Policy name:       %20s" % (outer.get('policy')))
    print("  NAT pool:          %20s" % (outer.get('pool')))
    if cur_timeout is not None and max_timeout is not None:
        print("  Timeout:           %20u (max %u)" % (cur_timeout, max_timeout))
    print("  Duration:          %20s (%s)" % (time, units))
    print("  Start time:        %20s" % (start_time))
    if proto == "tcp" and inner:
        print("  External RTT: %u,   Internal RTT: %u" %
              (inner.get('rtt_ext'), inner.get('rtt_int')))
    if outer.get('init_dst_port'):
        print("  Initial dest port: %20u" % (outer.get('init_dst_port')))
    print("  Mapping:   %s/%s | %s/%s" % (outer.get('subs_addr'),
                                          outer.get('subs_port'),
                                          outer.get('pub_addr'),
                                          outer.get('pub_port')))
    print("  Out:       %s/%s --> %s, Proto: %s" %
          (outer.get('subs_addr'), outer.get('subs_port'),
           dst_info, proto))

    #
    # If there are 2-tuple sessions then the pkt counts will
    # be delayed.  We indicate this with a tilde, '~'.
    #
    print("    packets: %s, bytes: %s" % (num2str(pkts_out, inner),
                                          num2str(bytes_out, inner)))

    print("  In:        %s --> %s/%s, Proto: %s" %
          (dst_info, outer.get('pub_addr'), outer.get('pub_port'),
           proto))
    print("    packets: %s, bytes: %s" % (num2str(pkts_in, inner),
                                          num2str(bytes_in, inner)))

    # Unknown source count only applies to 3-tuple session
    if not inner:
        print("    unknown source: %u" % (unk_pkts_in))

    print()


def cgn_sess_show_outer_detail(outer, sort_key):
    """Show 3-tuple outer session"""

    dst_list = []
    dst_dict = outer.get('destinations')

    #
    # If outer session has nested inner sessions, get list of inner sessions
    # sorted by address and port
    #
    if dst_dict:
        tmp = dst_dict.get('sessions')
        if sort_key:
            dst_list = sorted(tmp, key=lambda d: (d[sort_key], d['dst_port']))
        else:
            dst_list = tmp

    if dst_list:
        for dst in dst_list:
            cgn_sess_show_one_detail(outer, dst)
    else:
        cgn_sess_show_one_detail(outer, None)


#
# cgn_op_show_sess
#
def cgn_op_show_sess(fltrs, count_opt, d_opt, sa_opt):
    """Fetch and display sessions from the dataplane

    Fetches in batches of 1000 or more.  Will not split fetching nested
    sessions over multiple commands, so may return many more sessions than
    1000 per command.
    """

    # Return at least this many sessions
    if count_opt:
        count = count_opt
    else:
        count = 1000

    if sa_opt:
        if fltrs:
            fltrs = "%s subs-addr %s" % (fltrs, sa_opt)
        else:
            fltrs = "subs-addr %s" % (sa_opt)

    base_cmd = "cgn-op show session count %u" % (count)
    if fltrs:
        base_cmd = "%s %s" % (base_cmd, fltrs)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                cmd = base_cmd

                while True:
                    cgn_dict = dp.json_command(cmd)
                    if not cgn_dict:
                        break

                    if "__error" in cgn_dict:
                        print(cgn_dict["__error"])
                        return

                    sess_list = cgn_dict.get('sessions')

                    # Exit when no sessions are returned
                    if not sess_list:
                        break

                    show_hdr = True

                    for sess in sess_list:
                        if "__error" in sess:
                            print(sess["__error"])
                        elif d_opt:
                            cgn_sess_show_outer_detail(sess, None)
                        else:
                            cgn_sess_show_outer_brief(sess, show_hdr)
                            show_hdr = False

                    # If a count was specified then assume user only wants
                    # that number.  Also exit if number of sessions returned
                    # is less than requested.
                    if count_opt or len(sess_list) < count:
                        break

                    # Target session is last session from previous batch
                    cmd = "%s tgt-addr %s tgt-port %u tgt-proto %u " \
                        "tgt-intf %s" % (base_cmd,
                                         sess.get('subs_addr'),
                                         sess.get('subs_port'),
                                         sess.get('proto'), sess.get('intf'))


#
# cgn_op_show_sess_ordered
#
def cgn_op_show_sess_ordered(fltrs, count_opt, d_opt, sa_opt):
    """Fetch and display sessions from the dataplane.

    """

    count = 0
    show_hdr = True

    # Get sorted list of subscriber address strings
    subs_list = cgn_get_subscriber_list(sa_opt)

    # For each subscriber address ...
    for subs_addr in subs_list:

        # Get session list, sorted by subscriber port
        sess_list = cgn_get_sessions_subs(subs_addr, fltrs)

        for sess in sess_list:
            if d_opt:
                cgn_sess_show_outer_detail(sess, None)
            else:
                cgn_sess_show_outer_brief(sess, show_hdr)
                show_hdr = False

            # If a count option was specified then assume the user only wants
            # that many sessions
            count = count + 1
            if count_opt and count >= count_opt:
                return


#
# Clear sessions
#
def cgn_op_clear_sess(clr_opts):
    """Clear CGN sess"""

    cmd = "cgn-op clear session %s" % (clr_opts)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                try:
                    dp.string_command(cmd)
                except:
                    # Likely a zmq timeout occurred.  However zmq exceptions
                    # are not translated back to vplaned exceptions, so just
                    # return.  This can occur when clearing a full (32 million
                    # entries) session table.
                    return


#
# Clear session statistics
#
def cgn_op_clear_sess_stats(clr_opts):
    """Clear CGN session statistics"""

    cmd = "cgn-op clear session %s statistics" % (clr_opts)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                try:
                    dp.string_command(cmd)
                except:
                    # Likely a zmq timeout occurred.  However zmq exceptions
                    # are not translated back to vplaned exceptions, so just
                    # return.  This can occur when clearing a full (32 million
                    # entries) session table.
                    return


#
# Update session statistics
#
def cgn_op_update_sess_stats(upd_opts):
    """Update CGN session statistics"""

    if upd_opts:
        cmd = "cgn-op update session %s statistics" % (upd_opts)
    else:
        cmd = "cgn-op update session statistics"

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                try:
                    dp.string_command(cmd)
                except:
                    # Likely a zmq timeout occurred.  However zmq exceptions
                    # are not translated back to vplaned exceptions, so just
                    # return.  This can occur when clearing a full (32 million
                    # entries) session table.
                    return


#
# usage
#
def cgn_usage():
    """Show command help"""

    print("usage: {} --show | --clear".format(sys.argv[0]),
          file=sys.stderr)


#
# cgn_op_main
#
def cgn_op_main():
    """Main function"""

    s_opt = False
    c_opt = False
    u_opt = False
    d_opt = False
    stats_opt = False
    intf_opt = None
    pool_opt = None
    pa_opt = None
    pp_opt = None
    sa_opt = None
    sp_opt = None
    da_opt = None
    dp_opt = None
    id_opt = None
    proto_opt = None
    count = None
    unordered = False

    #
    # Parse options
    #
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "",
                                   ['show', 'clear', 'detail',
                                    'update', 'exclude-inner',
                                    'interface=', 'pool=',
                                    'id1=', 'id2=',
                                    'pub-addr=', 'pub-port=',
                                    'subs-addr=', 'subs-port=',
                                    'dst-addr=', 'dst-port=',
                                    'proto=', 'count=', 'unordered',
                                    'stats'])

    except getopt.GetoptError as r:
        print(r, file=sys.stderr)
        cgn_usage()
        sys.exit(2)

    fltrs = ""

    for opt, arg in opts:
        if opt in '--show':
            s_opt = True

        if opt in '--clear':
            c_opt = True

        if opt in '--update':
            u_opt = True

        if opt in '--detail':
            d_opt = True
            fltrs = "%s detail" % (fltrs)

        if opt in '--exclude-inner':
            fltrs = "%s outer" % (fltrs)

        if opt in '--interface':
            intf_opt = arg
            fltrs = "%s intf %s" % (fltrs, intf_opt)

        if opt in '--pool':
            pool_opt = arg
            fltrs = "%s pool %s" % (fltrs, pool_opt)

        if opt in '--id1':
            id_opt = arg
            fltrs = "%s id1 %s" % (fltrs, id_opt)

        if opt in '--id2':
            id_opt = arg
            fltrs = "%s id2 %s" % (fltrs, id_opt)

        if opt in '--pub-addr':
            pa_opt = arg
            fltrs = "%s pub-addr %s" % (fltrs, pa_opt)

        if opt in '--pub-port':
            pp_opt = int(arg)
            fltrs = "%s pub-port %u" % (fltrs, pp_opt)

        if opt in '--subs-addr':
            # sa_opt is passed to the show functions instead of
            # being added to the filter
            sa_opt = arg

        if opt in '--subs-port':
            sp_opt = int(arg)
            fltrs = "%s subs-port %u" % (fltrs, sp_opt)

        if opt in '--dst-addr':
            da_opt = arg
            fltrs = "%s dst-addr %s" % (fltrs, da_opt)

        if opt in '--dst-port':
            dp_opt = int(arg)
            fltrs = "%s dst-port %u" % (fltrs, dp_opt)

        if opt in '--proto':
            proto_opt = npf_proto2num(arg)
            fltrs = "%s proto %u" % (fltrs, proto_opt)

        if opt in '--count':
            count = int(arg)

        if opt in '--unordered':
            unordered = True

        if opt in '--stats':
            stats_opt = True

    #
    # sa_opt is passed to the show scripts outwith the fltr string, so only
    # add it to the fltr string for the update and clear commands
    #
    if ((c_opt or u_opt) and sa_opt):
        fltrs = "%s subs-addr %s" % (fltrs, sa_opt)

    #
    # clearing sessions is all done in the dataplane
    #
    if c_opt:
        if not fltrs:
            fltrs = "all"

        if stats_opt:
            cgn_op_clear_sess_stats(fltrs)
        else:
            cgn_op_clear_sess(fltrs)
        return

    if u_opt and stats_opt:
        cgn_op_update_sess_stats(fltrs)

    if not s_opt:
        return

    #
    # show sessions.
    #

    #
    # Show sessions unordered.  All filtering is done in dataplane.
    #
    if unordered:
        cgn_op_show_sess(fltrs, count, d_opt, sa_opt)
    else:
        cgn_op_show_sess_ordered(fltrs, count, d_opt, sa_opt)


#
# main
#
if __name__ == '__main__':
    cgn_op_main()
