#!/usr/bin/python3
#
# Copyright (c) 2019 AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#

""" This is run to get the CGNAT session information from the dataplane
and provide it as YANG RPC information. """


import os
import sys
import getopt
import vplaned
import json
from vyatta.npf.npf_debug import NpfDebug


PROGNAME = os.path.basename(__file__)
DATAPLANE_CMD = 'cgn-op show session detail'

# class used for printing debugs
dbg = NpfDebug()


def err(msg):
    print(msg, file=sys.stderr)


def process_options():
    try:
        opts, args = getopt.getopt(sys.argv[1:], "d", ['debug'])

    except getopt.GetoptError as r:
        err(r)
        err("usage: {} [-d|--debug] ".format(sys.argv[0]))
        sys.exit(2)

    for opt, arg in opts:
        if opt in ('-d', '--debug'):
            dbg.enable()


state_mappings = {
    1: 'closed',
    2: 'opening',
    3: 'established',
    4: 'transitory',
    5: 'closing',
    6: 'closing',
    7: 'closing',
}


def get_base_sess_info(entry, rpc_entry, unique_id):
    # This ID is needed due to limitations in YANG implementation of not
    # allowing multiple keys.
    rpc_entry['id'] = unique_id
    rpc_entry['session-id'] = entry['id']
    rpc_entry['ip-protocol'] = entry['proto']
    rpc_entry['state'] = state_mappings.get(entry['state'], 'closed')
    rpc_entry['subscriber-ip-address'] = entry['subs_addr']
    rpc_entry['subscriber-port'] = entry['subs_port']
    rpc_entry['public-ip-address'] = entry['pub_addr']
    rpc_entry['public-port'] = entry['pub_port']
    rpc_entry['interface'] = entry['intf']
    rpc_entry['pool-name'] = entry['pool']


def get_other_sess_info(xentry, rpc_entry):
    rpc_entry['timeout'] = xentry['cur_to']
    rpc_entry['max-timeout'] = xentry['max_to']
    rpc_entry['start-time'] = xentry['start_time']
    rpc_entry['duration'] = xentry['duration']
    rpc_entry['packets-out'] = xentry['out_pkts']
    rpc_entry['packets-in'] = xentry['in_pkts']
    rpc_entry['bytes-out'] = xentry['out_bytes']
    rpc_entry['bytes-in'] = xentry['in_bytes']
    if 'unk_pkts_in' in xentry:
        rpc_entry['unknown-source'] = xentry['unk_pkts_in']
    rpc_entry['expired'] = xentry['exprd']


def sess_rpc(entry, list):
    if 'destinations' in entry:
        for dentry in entry['destinations']['sessions']:
            rpc_entry = {}
            get_base_sess_info(entry, rpc_entry, len(list))
            rpc_entry['sub-session-id'] = dentry['id']
            rpc_entry['destination-ip-address'] = dentry['dst_addr']
            rpc_entry['destination-port'] = dentry['dst_port']
            if 'rtt_int' in dentry:
                rpc_entry['rtt-internal'] = dentry['rtt_int']
            if 'rtt_ext' in dentry:
                rpc_entry['rtt-external'] = dentry['rtt_ext']
            if 'hist' in dentry:
                rpc_entry['state-history'] = dentry['hist']
            get_other_sess_info(dentry, rpc_entry)

            list.append(rpc_entry)
            dbg.pprint("dest rpc entry: {}".format(rpc_entry))
    else:
        rpc_entry = {}
        get_base_sess_info(entry, rpc_entry, len(list))
        if entry['init_dst_port']:
            rpc_entry['destination-port'] = entry['init_dst_port']
        else:
            rpc_entry['destination-port'] = 0
        rpc_entry['sub-session-id'] = 0
        get_other_sess_info(entry, rpc_entry)

        list.append(rpc_entry)
        dbg.pprint("non-dest rpc entry: {}".format(rpc_entry))


param_mappings = {
    'subscriber-ip-address-prefix': 'subs-addr',
    'subscriber-port': 'subs-port',
    'interface': 'intf',
    'public-ip-address-prefix': 'pub-addr',
    'public-port': 'pub-port',
    'destination-ip-address-prefix': 'dst-addr',
    'destination-port': 'dst-port',
    'session-id': 'id1',
    'sub-session-id': 'id2',
    'pool-name': 'pool',
    'req-entries': 'count',
    'target': 'target',
}


def get_cgnat_session_info():
    dbg.pprint("get_cgnat_session_info()")

    try:
        rpc_input = json.load(sys.stdin)
    except ValueError as exc:
        err("Failed to parse input JSON: {}".format(exc))
        return None, 1

    args = DATAPLANE_CMD
    count_opt = None
    count = None
    target = None

    for param, value in rpc_input.items():
        dp_param = param_mappings.get(param)
        if dp_param is None:
            err("{}: unknown rpc input option: {}".format(PROGNAME, param))
            return None, 2

        if dp_param == 'count':
            count_opt = value
        elif dp_param == 'target':
            target = value
        else:
            args += " {} {}".format(dp_param, value)

    # Fetch session in batches of 1000 if a count has not been specified
    if count_opt:
        count = count_opt
    else:
        count = 1000

    args += " count {}".format(count)

    sess_info_list = []
    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:

                # Target variables change for each request to the dataplane
                if target:
                    cmd = "%s tgt-addr %s tgt-port %u tgt-proto %u " \
                        "tgt-intf %s" % (args, target['address'],
                                         target['port'], target['protocol'],
                                         target['interface'])
                else:
                    cmd = args

                dbg.pprint("dp command: {}".format(cmd))

                while True:
                    dp_dict = dp.json_command(cmd)
                    if not dp_dict:
                        break

                    sess_list = dp_dict.get('sessions')

                    # Exit when no sessions are returned
                    if not sess_list:
                        break

                    for sess in sess_list:
                        sess_rpc(sess, sess_info_list)

                    # If a count was specified then assume user only wants
                    # that number
                    if count_opt:
                        break

                    # Target session is last session from previous batch
                    cmd = "%s tgt-addr %s tgt-port %u tgt-proto %u " \
                        "tgt-intf %s" % (args, sess.get('subs_addr'),
                                         sess.get('subs_port'),
                                         sess.get('proto'),
                                         sess.get('intf'))

    if sess_info_list:
        return {'sessions': sess_info_list}, 0

    return None, 0


if __name__ == "__main__":
    process_options()
    info, ret = get_cgnat_session_info()
    if info:
        print(json.dumps(info))
    exit(ret)
