#!/usr/bin/env python3
#
# Copyright (c) 2019-2020, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#

"""Scripts for NAT Pool op-mode commands"""

import sys
import getopt
import vplaned
from vyatta.npf.npf_addr_group import npf_show_address_group
from vyatta.npf.npf_addr_group import npf_show_address_group_one


# Yes or No string
def yes_or_no(val):
    return "Yes" if val else "No"


#
# usage
#
def np_usage():
    """Show command help"""

    print("usage: {} [--pool <name>]".format(sys.argv[0]),
          file=sys.stderr)


#
# Return ratio of two numbers to at least 3 digit accuracy, e.g.
# 100:1, 10.5:1, 1:2.34 etc.
#
def ratio_str(val1, val2):
    if val1 == 0 or val2 == 0:
        return "n/a"

    if val1 >= val2:
        opt = True
        ratio = val1 / val2
    else:
        opt = False
        ratio = val2 / val1

    if ratio > 100:
        tmp = "%u" % (ratio)
    elif ratio > 10:
        tmp = "%.1f" % (ratio)
    else:
        tmp = "%.2f" % (ratio)

    if opt:
        str = "%s:1" % (tmp)
    else:
        str = "1:%s" % (tmp)

    return str


def np_op_addr_or_none(addr):
    if addr == "0.0.0.0":
        return "none"
    return addr


#
# np_op_show_pool_one
#
def np_op_show_pool_one(pool, d_opt):
    """Show pool one"""

    port_range = "%u-%u" % (pool.get('port_start'), pool.get('port_end'))

    # Contention ratio is the ratio of private addresses to public addresses
    ratio = ratio_str(pool.get('nuser_addrs'), pool.get('naddrs'))

    print("NAT Pool %s" % (pool.get('name')))

    print("  %-32s %18s" % ("Active", yes_or_no(pool.get('active'))))
    print("  %-32s %18s" % ("Type", pool.get('type')))
    if "full" in pool:
        print("  %-32s %18s" % ("Full", yes_or_no(pool.get('full'))))
    print("  %-32s %18u" % ("User count", pool.get('nusers')))
    print("    %-30s %18u" % ("User addresses", pool.get('nuser_addrs')))

    ranges = pool.get('address_ranges')
    print("  %-32s" % ("Addresses:"))
    print("    %-30s %18s" % ("Address pooling", pool.get('addr_pooling')))
    print("    %-30s %18s" % ("Address allocation", pool.get('addr_allocn')))
    print("    %-30s %18u" % ("Address count", pool.get('naddrs')))
    print("    %-30s %18s" % ("Contention ratio", ratio))
    print("    %-30s" % ("Address Ranges:"))
    for r in ranges:
        print("      %s (%s)" % (r.get('name'), r.get('type')))
        print("        %-12s %32s" % ("Range", r.get('range')))
        print("        %-16s %28s" % ("Address count", r.get('naddrs')))
        if r.get('type') == "prefix":
            print("        %-12s %32s" % ("Prefix:", r.get('prefix')))
        if r.get('type') == "subnet":
            print("        %-12s %32s" % ("Subnet:", r.get('subnet')))

    # Is the 'hidden' address-group included in the json?
    if d_opt and "address-group" in pool:
        npf_show_address_group_one(pool.get('address-group'),
                                   "Address-group",
                                   4, 24, 24,
                                   6, 14, 32)

    cur = pool.get('current')
    print("  %-32s" % ("Last Allocated Address:"))
    print("    %-30s %18s" % ("TCP", np_op_addr_or_none(cur.get('tcp'))))
    print("    %-30s %18s" % ("UDP", np_op_addr_or_none(cur.get('udp'))))
    print("    %-30s %18s" % ("Other", np_op_addr_or_none(cur.get('other'))))

    # A CGNAT blacklist should only contain IPv4 addresses
    if "blacklist" in pool:
        npf_show_address_group(pool.get('blacklist'), "ipv4", None,
                               "Blacklist address-group",
                               2, 24, 26,
                               4, 14, 34)

    print("  %-32s" % ("Ports:"))
    print("    %-30s %18s" % ("Port allocation", pool.get('port_allocn')))
    print("    %-30s %18s" % ("Port range", port_range))
    print("    %-30s %18u" % ("Port count", pool.get('nports')))
    print("    %-30s %18u" % ("Port-block size", pool.get('block_sz')))
    print("    %-30s %18u" % ("Max port-blocks per user", pool.get('mbpu')))

    tm = pool.get('map_stats')
    print("  %-32s" % ("Translation Mappings:"))
    print("    %-30s %18u" % ("Active", tm.get('active')))
    print("    %-30s %18u" % ("Total requests", tm.get('reqs')))
    print("      %-28s %18u" % ("Ok", tm.get('reqs') - tm.get('fails')))
    print("      %-28s %18u" % ("Failed", tm.get('fails')))

    pba = pool.get('block_stats')
    print("  %-32s" % ("Port Block Allocation:"))
    print("    %-30s %18u" % ("Active", pba.get('active')))
    print("    %-30s %18u" % ("Total", pba.get('total')))
    print("    %-30s %18u" % ("Total Freed", pba.get('freed')))
    print("    %-30s %18u" % ("Total Failures", pba.get('failures')))
    print("    %-30s %18u" % ("Failures exceeding max",
                              pba.get('subs_limit')))
    print()


#
# np_op_show_pool
#
def np_op_show_pool(p_opt, d_opt):
    """Show pool"""

    pool_list = []

    cmd = "nat-op show pool"
    if p_opt:
        cmd = "%s %s" % (cmd, p_opt)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                np_dict = dp.json_command(cmd)
                tmp_list = np_dict.get('pools')
                if tmp_list:
                    pool_list.extend(tmp_list)

    for pool in pool_list:
        if not p_opt or p_opt == pool.get('name'):
            np_op_show_pool_one(pool, d_opt)


#
# np_op_main
#
def np_op_main():
    """Main function"""

    p_opt = None
    d_opt = False

    #
    # Parse options
    #
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "",
                                   ['pool=', 'detail'])

    except getopt.GetoptError as r:
        print(r, file=sys.stderr)
        np_usage()
        sys.exit(2)

    for opt, arg in opts:
        if opt in '--pool':
            p_opt = arg

        if opt in '--detail':
            d_opt = True

    np_op_show_pool(p_opt, d_opt)


#
# main
#
if __name__ == '__main__':
    np_op_main()
