#!/usr/bin/env python3
#
# Copyright (c) 2020, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only
#

import sys
import getopt
import vplaned
from netaddr import IPAddress
from collections import deque
from vyatta.npf.IPProto import num2proto
from vyatta.npf.IPProto import proto2num


# Session json features type
SESSION_FEATURE_NPF = 3

# Session json 'trans_type'
NAT_TRANS_TYPE_DNAT = 1
NAT_TRANS_TYPE_SNAT = 2

# npf session flags
PFIL_IN = 0x0001
PFIL_OUT = 0x0002
SE_ACTIVE = 0x0004
SE_PASS = 0x0008
SE_EXPIRE = 0x0010
SE_GC_PASS_TWO = 0x0020
SE_SECONDARY = 0x0040
SE_LOCAL_ZONE_NAT = 0x0080
SE_IF_DISABLED = 0x0100
SE_NAT_PINHOLE = 0x0200

base_show_cmd = "session-op show dataplane sessions"
base_clear_cmd = "session-op clear dataplane sessions"
base_list_cmd = "session-op list"

#
# Fetch at least this many sessions per batch
#
batch_size = 2000

"""Sessions are fetched and displayed as follows:

1. A list of source addresses is fetched from the sessions.  This is returned
in a simple json array.  (Note that one entry is returned for each session,
*not* per source address.)  uints are turned for IPv4, and strings for IPv6.

2. The list is sorted in Python.

3. From the sorted list of source addresses, we select a start address and an
end address.  These are used to fetch a 'batch' of sessions from the dataplane
where the sessions source addresses are >= start and <= end.

4. The batch of sessions are sorted by source address and source port.

5. The batch of sessions are displayed

6. We repeat steps 3, 4 and 5 until the complete session table is
displayed.

The approx times taken for 1 million sessions in a vRouter in one test are:

  1. Fetch list of source addresses:    0.44 secs
  2. Sort list of source addresses:     3.10 secs
  3. Fetch batch of 2270 sessions:      0.21 secs
  4. Sort batch of 2270 sessions:       0.017 secs

Sessions may be sorted by src address and port, destination address and port,
timeout, or session ID.  For each of these, the relevant items (one from each
session) are fetched in the list in step 1.

"""


#
# This returns the abbreviated state code key for the session table banner.
#
def state_banner():
    """Return the abbreviated state code key for the session table banner"""
    state_desc = "State codes: CL - CLOSED, OP - OPENING, ES - ESTABLISHED, "\
        "CG - CLOSING"
    return state_desc


#
# Session state to string
#
# 'state' is the numeric value in the json returned by the dataplane.
#
def state2str(state, short):
    """Return a long or short string for the given state value.  The long string
    is used in the detailed output.  The sort string is used in the table
    output.

    """

    state_long = ["None", "Closed", "Opening", "Established", "Closing"]
    state_short = ["NO", "CL", "OP", "ES", "CG"]

    if state > len(state_short):
        state = 0

    if short:
        state_str = state_short[state]
    else:
        state_str = state_long[state]

    return state_str


#
# Create context/options dictionary and populate with defaults
#
def create_ctx(show):
    ctx = {}

    #
    # Filters.  These limit both the items list and the sessions returned from
    # the dataplane.  These are all command options.
    #
    ctx['ip'] = False
    ctx['ip6'] = False
    ctx['intf'] = None
    ctx['id'] = None
    ctx['dir'] = None
    ctx['feat'] = None
    ctx['saddr'] = None
    ctx['daddr'] = None
    ctx['sport'] = None
    ctx['dport'] = None
    ctx['proto'] = None
    ctx['taddr'] = None
    ctx['tport'] = None

    #
    # Session display order.
    #
    # order      - None, 'ascending', or 'descending'
    # orderby    - None, 'src_addr', 'dst_addr', 'id' or 'time_to_expire'
    # start-with - Start displaying sessions on or after a specific session
    #              determined by 'orderby'
    # start,end  - Only used within the script for fetching batches of
    #              sessions from the dataplane
    #
    if show:
        ctx['order'] = 'ascending'
        ctx['orderby'] = 'src_addr'
    else:
        ctx['order'] = None
        ctx['orderby'] = None

    ctx['start-with-type'] = None
    ctx['start-with'] = None
    ctx['count'] = 0
    ctx['start'] = None
    ctx['end'] = None

    #
    # Other options
    #
    ctx['detail'] = False
    ctx['brief'] = False
    ctx['summary'] = False

    return ctx


#
# Create a base command string from the options.
#
# Not all options are used here. Specifically, 'ip', 'ip6', and 'count' are
# handled separately:
#
# sess_op_show_unordered will set either 'ip' or 'ip6' in the command string,
# or nothing if both address families are to be returned.
#
# sess_op_show will add either 'ip' or 'ip6' to the command string if sessions
# are to be ordered by src or dest address.  If the user wants both, then the
# command is repeated once for each address family.
#
# If sess_op_show is ordering by session ID or timeout then it will set either
# 'ip' or 'ip6' in the command string, or nothing if both address families are
# to be returned.
#
# The 'count' value entered by the user (and stored in ctx) is never sent to
# the dataplane.  Sessions are fetched in batches, and the 'count' value is
# used to determine when to stop fetching batches.  See sess_op_show_unordered
# and sess_op_show_ordered.
#
def cmd_option_string(ctx):
    cmd = ""

    #
    # Filters
    #
    if ctx['intf']:
        cmd += " intf %s" % (ctx['intf'])

    if ctx['id']:
        cmd += " id %d" % (ctx['id'])

    if ctx['dir']:
        cmd += " dir %s" % (ctx['dir'])

    if ctx['feat']:
        cmd += " feat %s" % (ctx['feat'])

    if ctx['saddr']:
        cmd += " src-addr %s" % (ctx['saddr'])

    if ctx['daddr']:
        cmd += " dst-addr %s" % (ctx['daddr'])

    if ctx['sport']:
        cmd += " src-port %s" % (ctx['sport'])

    if ctx['dport']:
        cmd += " dst-port %s" % (ctx['dport'])

    if ctx['proto']:
        cmd += " proto %s" % (proto2num(ctx['proto']))

    if ctx['taddr']:
        cmd += " trans-addr %s" % (ctx['taddr'])

    if ctx['tport']:
        cmd += " trans-port %s" % (ctx['tport'])

    #
    # Session order.  'order' is not specified to the dataplane.  The dp
    # determines the order (ascending or descending) from the start and end
    # options it receives.  The 'orderby' param is used by the dataplane to
    # give context to the 'start' and 'end' values when fetching batches of
    # sessions.
    #
    if ctx['orderby']:
        cmd += " orderby %s" % (ctx['orderby'])

    #
    # Other
    #
    if ctx['brief']:
        cmd += " brief"

    return cmd


#
# Are we ordering by an address?
#
def orderby_is_addr(orderby):
    return orderby == 'src_addr' or orderby == 'dst_addr' or orderby == 'trans_addr'


#
# Parse the options
#
def sess_op_parse_options(options, ctx):
    """Parse the script options other than the '--show' type options"""

    #
    # Convert list to a 'deque' list from 'collections' module.  This
    # allows us to efficiently pop items of the front of the list.
    #
    options = deque(options)

    # Store some options locally for processing at end of function
    start_with_opt = None
    orderby_opt = None

    while options:
        opt = options.popleft()

        # ip
        if opt == "ip":
            ctx['ip'] = True

        # ip6
        elif opt == "ip6":
            ctx['ip6'] = True

        # interface
        elif opt == "interface":
            ctx['intf'] = options.popleft()

        # session ID
        elif opt == "id":
            ctx['id'] = int(options.popleft())

        # direction
        elif opt == 'direction':
            opt = options.popleft()

            if opt == "in":
                ctx['dir'] = opt

            elif opt == "out":
                ctx['dir'] = opt

        # feature
        elif opt == "feature":
            ctx['feat'] = options.popleft()

        # source address/port
        elif opt == 'source':
            opt = options.popleft()

            if opt == 'address' and options:
                ctx['saddr'] = options.popleft()
            elif opt == 'port' and options:
                ctx['sport'] = int(options.popleft())

        # destination address/port
        elif opt == 'destination':
            opt = options.popleft()

            if opt == 'address' and options:
                ctx['daddr'] = options.popleft()
            elif opt == 'port' and options:
                ctx['dport'] = int(options.popleft())

        # protocol
        elif opt == 'protocol':
            ctx['proto'] = options.popleft()

        # translation address/port
        elif opt == 'translation':
            opt = options.popleft()

            if opt == 'address' and options:
                ctx['taddr'] = options.popleft()
            elif opt == 'port' and options:
                ctx['tport'] = int(options.popleft())

        # unordered
        elif opt == "unordered":
            ctx['order'] = None
            ctx['orderby'] = None

        # order ascending/descending
        elif opt == "ascending" or opt == "descending":
            ctx['order'] = opt

            orderby_opt = options.popleft()

            #
            # The 'orderby' values are the same as is returned in the session
            # json dictionary keys.
            #
            if orderby_opt == 'source-address':
                ctx['orderby'] = 'src_addr'
            elif orderby_opt == 'destination-address':
                ctx['orderby'] = 'dst_addr'
            elif orderby_opt == 'translation-address':
                ctx['orderby'] = 'trans_addr'
            elif orderby_opt == 'id':
                ctx['orderby'] = 'id'
            elif orderby_opt == 'timeout':
                ctx['orderby'] = 'time_to_expire'

        # starting-with
        elif opt == "starting-with":
            start_with_opt = options.popleft()
            ctx['start-with-type'] = start_with_opt

            if start_with_opt == 'address':
                ctx['start-with'] = IPAddress(options.popleft())
            elif start_with_opt == 'id':
                ctx['start-with'] = int(options.popleft())
            elif start_with_opt == 'timeout':
                ctx['start-with'] = int(options.popleft())

        # count
        elif opt == "count":
            ctx['count'] = int(options.popleft())

        # Show detailed output
        elif opt == "detail":
            if not ctx['brief']:
                ctx['detail'] = True

        # Brief is only used to reduce the size of the json returned from the
        # dataplane.  The 'features' sub dictionary is not returned.  This
        # *may* become useful if there are scale issues returning a large
        # session table.
        elif opt == "brief":
            if not ctx['detail']:
                ctx['brief'] = True

        # Summary output displays session and state counts
        elif opt == "summary":
            ctx['summary'] = True

    # Check and finalize options
    error_str = sess_op_finalize_options(ctx)

    return error_str


#
# Check and finalize the options
#
# Return None or an error string
#
def sess_op_finalize_options(ctx):
    """Check and finalize options"""

    # Default to both ip and ip6 if neither was specified
    if not ctx['ip'] and not ctx['ip6']:
        ctx['ip'] = True
        ctx['ip6'] = True

    #
    # 'order' and 'orderby' default to 'ascending' and 'src_addr'.  However if
    # a start-with option of 'id' or 'timeout' is entered then we want to
    # default to those options.
    #
    if ctx['start-with-type'] == 'address':
        if not orderby_is_addr(ctx['orderby']):
            return ("Mismatch between starting-with 'address' "
                    "and order-by '%s'" % (ctx['orderby']))

    elif ctx['start-with-type'] == 'id':
        if ctx['orderby'] != 'id':
            # Was an 'orderby' specifically requested?
            if ctx['order']:
                # Yes.  Return error.
                return ("Mismatch between starting-with 'id' "
                        "and order-by '%s'" % (ctx['orderby']))

            # 'order' not specified. Default to orderby id
            ctx['orderby'] = 'id'
            ctx['order'] = 'ascending'

    elif ctx['start-with-type'] == 'timeout':
        if ctx['orderby'] != 'time_to_expire':
            # Was an 'orderby' specifically requested?
            if ctx['order']:
                # Yes.  Return error.
                return ("Mismatch between starting-with 'timeout' "
                        "and order-by '%s'" % (ctx['orderby']))

            # 'order' not specified. Default to orderby timeout
            ctx['orderby'] = 'time_to_expire'
            ctx['order'] = 'ascending'

    #
    # If filtering or ordering by translation address then feat should be
    # 'snat' or 'dnat'. If no feature is specified then we default to 'snat'
    # since that is the most common type of NAT.
    #
    if ctx['taddr']:
        if not ctx['feat']:
            ctx['feat'] = 'snat'
        elif ctx['feat'] != 'snat' and ctx['feat'] != 'dnat':
            return ("Mismatch between feature '%s' and translation address filter" %
                    (ctx['feat']))

    if ctx['orderby'] == 'trans_addr':
        if not ctx['feat']:
            ctx['feat'] = 'snat'
        elif ctx['feat'] != 'snat' and ctx['feat'] != 'dnat':
            return ("Mismatch between feature '%s' and order-by '%s'" %
                    (ctx['feat'], ctx['orderby']))

    #
    # 'brief' prevents the features being returned in the json, so set 'brief'
    # to False if a feature was specified.  'feat' is prioritised over 'brief'.
    #
    if ctx['brief'] and ctx['feat']:
        ctx['brief'] = False

    return None


#
# Sort the item list returned from the dataplane.  This is a list of src
# addrs, dest addrs, session IDs or timeout values.
#
def sort_item_list(item_list, ctx, af):
    """Sort item_list.  If items are addresses then these are converted to
    IPAddress format in order to allow both IP address format printing *and*
    greater-than and less-than comparisons.

    """

    if not ctx['order'] or not ctx['orderby']:
        # Nothing to do
        return item_list

    # Ascending or descending?
    rev = (ctx['order'] == 'descending')

    if orderby_is_addr(ctx['orderby']):
        if af == 'ip':
            # Sort v4_list while it is in uint format (faster)
            item_list = sorted(item_list, reverse=rev)

            # Convert to IPAddress format (allows < and > comparisons)
            item_list = [IPAddress(i) for i in item_list]

        elif af == 'ip6':
            # ip6 strings need converted to IPAddress format before sorting
            item_list = [IPAddress(i) for i in item_list]
            item_list = sorted(item_list, reverse=rev)

    else:
        # 'id' and 'timeout' are uints
        item_list = sorted(item_list, reverse=rev)

    return item_list


#
# Slice item_list if a 'start-with' option was specified
#
def slice_item_list(item_list, ctx):
    """ Slice item_list if a 'start-with' option was specified """

    if not ctx['order'] or not ctx['start-with']:
        # Nothing to do
        return item_list

    # Ascending or descending?
    rev = (ctx['order'] == 'descending')

    if ctx['start-with']:
        for i in range(0, len(item_list)):
            if not rev and item_list[i] >= ctx['start-with']:
                item_list = item_list[i:]
                break
            elif rev and item_list[i] <= ctx['start-with']:
                item_list = item_list[i:]
                break

    return item_list


#
# Get a sorted list of items from the dataplane sessions.
#
def get_item_list(ctx, af):
    """Get a sorted list of items from the dataplane sessions.  Returns a
    list of units (IP addr, ID, or timeout) or strings (IPv6 addr).

    If we are sorting by source or dest address, then 'af' will be specified
    as either 'ip' or 'ip6'.  Each address family is fetched and processed
    separately.

    If we are sorting by ID or timeout, then 'af' may be either 'ip', 'ip6',
    or 'None'.  'None' will return items from both IP and IPv6 sessions.

    """

    #
    # Lists are *only* used when we are fetching sessions in a particular
    # order
    #
    if not ctx['order'] or not ctx['orderby']:
        return []

    cmd = base_list_cmd

    if af:
        cmd += " %s" % (af)

    cmd += cmd_option_string(ctx)

    item_list = []

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                tmp = dp.json_command(cmd)

                if tmp and '__error' not in tmp and 'list' in tmp:
                    item_list.extend(tmp['list'])

    # Sort and slice list.  Addresses are converted to IPAddress format.
    item_list = sort_item_list(item_list, ctx, af)
    item_list = slice_item_list(item_list, ctx)

    return item_list


#
# Get a batch of sessions from the dataplane
#
def get_sessions(cmd):
    """Get a batch of sessions from the dataplane"""

    sess_list = []

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                tmp = dp.json_command(cmd)

                if tmp and '__error' not in tmp and 'sessions' in tmp:
                    sess_list.extend(tmp['sessions'])

    return sess_list


#
# Determine next 'start' and 'end' values from the given item list.
#
# 'start_index' is the start point.
#
# 'item_list' is a list of source addrs, dest addrs, session IDs or timeout
# values
#
# 'batch_size' is the minimum number of sessions we want when determining the
# start and end values.
#
# Returns the start_index value to be used in the next call of this function,
# and a start and end value that will be used to fetch a batch of sessions
# from the dataplane.
#
def get_start_end_vals(start_index, item_list, batch_size):
    """Get next 'start' and 'end' values from the given item list.  The returned
    'start' and 'end' values are subsequently used to fetch a batch of
    sessions from the dataplane.

    """

    if not item_list:
        return 0, None, None

    list_len = len(item_list)
    if start_index >= list_len:
        return 0, None, None

    count = 0
    start = item_list[start_index]
    end = None

    for i in range(start_index, list_len):
        count += 1

        # Last item in list?
        if i == (list_len - 1):
            end = item_list[i]
            return i+1, start, end

        #
        # We never want the same item to be in two batches otherwise all
        # sessions for that item will appear twice in the show output.  This
        # will mean we may fetch more sessions than 'batch_size'.
        #
        if item_list[i + 1] == item_list[i]:
            continue

        # Full up?
        if count >= batch_size:
            end = item_list[i]
            return i+1, start, end

    # Should never reach here
    return list_len, start, end


#
# NAT addr and port are buried in the session feature array.  In order to keep
# the sorting algorithms simple, we promote them to the session itself.
#
# Defensive coding here means that in theory the new list may be smaller than
# the input list.  In practise that should never happen since sorting by NAT
# address means only NAT sessions should be in sess_list.
#
def orderby_trans_addr_fixup(sess_list):
    new_list = []

    for sess in sess_list:
        feat = sess_feature_npf(sess)
        if not feat:
            continue
        if 'nat' not in feat:
            continue
        nat = feat['nat']
        sess['trans_addr'] = nat['trans_addr']
        sess['trans_port'] = nat['trans_port']
        new_list.append(sess)

    return new_list


#
# Get the npf feature from a dataplane session
#
def sess_feature_npf(sess):
    if 'features' not in sess or len(sess['features']) == 0:
        return None

    for i in range(0, len(sess['features'])):
        if sess['features'][i]['type'] == SESSION_FEATURE_NPF:
            return sess['features'][i]

    return None


#
# Returns True if the session is a firewall session.  Input parameter is the
# npf feature json returned from sess_feature_npf().
#
def sess_is_firewall(feat):
    return feat and 'flags' in feat and (feat['flags'] & SE_PASS) != 0


#
# Firewall detailed output
#
def feat_firewall_detail(fw, col1, col2):
    if 'rule' in fw:
        rule = "%s/%u" % (fw['rule']['name'], fw['rule']['number'])
        print("  %-*s %*s" % (col1, "Firewall rule", col2, rule))
    else:
        print("  %-*s" % (col1, "Firewall"))


#
# NAT feature string
#
def feat_nat_str(nat):
    if nat['trans_type'] == NAT_TRANS_TYPE_DNAT:
        fstr = "dnat:"
    else:
        fstr = "snat:"
    fstr += "%s %d" % (nat['trans_addr'], nat['trans_port'])
    return fstr


#
# NAT detailed output
#
def feat_nat_detail(nat, col1, col2):
    if nat['trans_type'] == NAT_TRANS_TYPE_DNAT:
        trans_type = "DNAT"
    else:
        trans_type = "SNAT"

    if nat['masquerade']:
        trans_type += " (masquerade)"

    print("  %-*s" % (col1, trans_type))
    print("    %-*s %*s" % (col1 - 2, "address", col2, nat['trans_addr']))
    print("    %-*s %*d" % (col1 - 2, "port", col2, nat['trans_port']))

    if 'rule' in nat:
        print("  %-*s %*s" % (col1, "NAT rule", col2, nat['rule']['number']))


#
# NAT64/NAT46 feature string.  nat64 and nat46 are in the same dictionary
# within the npf features dictionary
#
def feat_nat64_str(nat64):
    if nat64['in']:
        dir = "in"
    else:
        dir = "out"

    fstr = "%s:%s peer:%d" % (nat64['type'], dir, nat64['peer_id'])
    return fstr


#
# NAT64 detailed output
#
def feat_nat64_detail(nat64, col1, col2):
    if nat64['in']:
        dir = "in"
    else:
        dir = "out"

    print("  %-*s" % (col1, nat64['type']))
    col1 -= 2

    print("    %-*s %*s" % (col1, "Direction", col2, dir))
    print("    %-*s %*s" % (col1, "Peer", col2, nat64['peer_id']))


#
# ALG
#
# We return a string showing the 'family tree' of ALG session IDs.  The
# session being shown has square brackets around it.
#
# We show up to 2 preceding generations and one succeeding generation (the
# 'children').  The first session shown is always the base parent.  If there
# are more than 2 preceding generations (very unlikely) then there will be
# missing sessions between the first and second ones shown.
#
# alg:sip [27]->30
# alg:sip 27->[30]->31
# alg:sip 27->30->[31]->32/33
# alg:sip 27..31->[32]
# alg:sip 27..31->[33]
#
def feat_alg_tree(alg, my_id):
    parent = None
    base_parent = None

    if 'parent' in alg:
        parent = alg['parent']

    if 'base_parent' in alg:
        base_parent = alg['base_parent']

    #
    # base_parent will only be present if it is different than parent
    #
    if base_parent:
        tree = "%d" % (base_parent)
        if alg['bp_is_gp']:
            tree += "->"
        else:
            tree += ".."
    else:
        tree = ""

    if parent:
        tree += "%d->" % (parent)

    tree += "[%d]->" % (my_id)

    if 'children' in alg and alg['children']:
        for i in range(0, len(alg['children'])):
            if i > 0:
                tree += "/"
            tree += "%d" % (alg['children'][i]['id'])
    elif len(tree) > 2:
        # Remove trailing '->'
        tree = tree[:-2]

    return tree


#
# ALG feature string
#
def feat_alg_str(alg, my_id):
    tree = feat_alg_tree(alg, my_id)
    fstr = "alg:%s %s" % (alg['name'], tree)
    return fstr


#
# Display one SIP media
#
def feat_sip_media(m, indent, col1, col2):
    print("%*s%-*s %*s" % (indent, "", col1, "media", col2, m['proto']))

    col1 -= 2
    indent += 2

    # RTP
    if 'rtp_addr' in m and 'rtp_port' in m:
        rtp = "%s %d" % (m['rtp_addr'], m['rtp_port'])
        print("%*s%-*s %*s" % (indent, "", col1, "RTP,  orig", col2, rtp))

        if 'trtp_addr' in m and 'trtp_port' in m:
            trtp = "%s %d" % (m['trtp_addr'], m['trtp_port'])
            print("%*s%-*s %*s" % (indent, "", col1, "      trans", col2, trtp))

    # RTCP
    if 'rtcp_addr' in m and 'rtcp_port' in m:
        rtcp = "%s %d" % (m['rtcp_addr'], m['rtcp_port'])
        print("%*s%-*s %*s" % (indent, "", col1, "RTCP, orig", col2, rtcp))

        if 'trtcp_addr' in m and 'trtcp_port' in m:
            trtcp = "%s %d" % (m['trtcp_addr'], m['trtcp_port'])
            print("%*s%-*s %*s" % (indent, "", col1, "      trans", col2, trtcp))


#
# Display one SIP call ID
#
def feat_sip_callid(cid, indent, col1, col2):
    print("%*s%-*s %*s" % (indent, "", col1, "Call ID", col2, cid['number']))

    col1 -= 2
    indent += 2

    if 'media' not in cid or not cid['media']:
        return

    for i in range(0, len(cid['media'])):
        feat_sip_media(cid['media'][i], indent, col1, col2)


#
# Display SIP detailed output
#
def feat_sip(sip, indent, col1, col2):
    print("%*s%-*s" % (indent, "", col1, "SIP:"))

    col1 -= 2
    indent += 2

    if 'via_addr' in sip:
        print("%*s%-*s %*s" % (indent, "", col1, "VIA address", col2, sip['via_addr']))
    if 'via_port' in sip:
        print("%*s%-*s %*s" % (indent, "", col1, "VIA port", col2, sip['via_port']))

    if 'callids' not in sip or not sip['callids']:
        return

    for i in range(0, len(sip['callids'])):
        feat_sip_callid(sip['callids'][i], indent, col1, col2)


#
# ALG detailed output
#
def feat_alg_detail(alg, col1, col2):
    # Indent spaces
    indent = 2

    print("%*s%-*s %*s" % (indent, "", col1, "ALG", col2, alg['name']))
    col1 -= 2
    indent += 2

    p = None
    bp = None

    if 'parent' in alg:
        p = alg['parent']
    if 'base_parent' in alg:
        bp = alg['base_parent']

    if bp and bp != p:
        print("%*s%-*s %*s" % (indent, "", col1, "Base parent ID", col2, bp))
    if p:
        print("%*s%-*s %*s" % (indent, "", col1, "Parent ID", col2, p))

    if 'children' in alg and alg['children']:
        children = ""
        for i in range(0, len(alg['children'])):
            if i > 0:
                children += "/"
            children += "%d" % (alg['children'][i]['id'])
        print("%*s%-*s %*s" % (indent, "", col1, "Child IDs", col2, children))

    if 'sip' in alg:
        feat_sip(alg['sip'], indent, col1, col2)


#
# Do we want to display info for this application/dpi name?
#
def dpi_name_ok(name):
    if not name or name == "Unknown" or name == "Unavailable" or name == "None":
        return False

    return True


#
# DPI feature string for use in the table output format
#
# app:Facebook l5-proto:DNS type:SocialNetwork
#
def feat_dpi_str(dpi):
    if 'engines' not in dpi or not dpi['engines']:
        return "app"

    #
    # Use first 'engine' where app-name or proto-name or type is neither
    # Unknown or Unavailable
    #
    for i in range(0, len(dpi['engines'])):
        engine = dpi['engines'][i]

        app_name = engine['app-name']
        proto_name = engine['proto-name']
        app_type = engine['type']

        if dpi_name_ok(app_name) or dpi_name_ok(proto_name) or dpi_name_ok(app_type):
            return "app:%s l5-proto:%s type:%s" % (app_name, proto_name, app_type)

    return "app"


#
# DPI detailed output
#
def feat_dpi_detail(dpi, col1, col2):
    for i in range(0, len(dpi['engines'])):
        engine = dpi['engines'][i]

        app_name = engine['app-name']
        proto_name = engine['proto-name']
        app_type = engine['type']

        if dpi_name_ok(app_name) or dpi_name_ok(proto_name) or dpi_name_ok(app_type):
            print("  %-*s %*s" % (col1, "App", col2, app_name))
            print("    %-*s %*s" % (col1-2, "L5 Proto", col2, proto_name))
            print("    %-*s %*s" % (col1-2, "Type", col2, app_type))


#
# Firewall feature string for use in the table output format
#
def feat_fw_str(fw):
    str = "fw"

    if 'rule' in fw:
        str += ":%s/%u" % (fw['rule']['name'], fw['rule']['number'])

    return str


#
# Return a short features string for the table output format
#
def sess_feat_str(sess):
    """Return a short 'features' string if a features array is present in the
    session, and there is an entry for npf

    """

    feat = sess_feature_npf(sess)
    if not feat:
        return "-"

    fstr = None

    # dpi is the most interesting, so show that first
    if 'dpi' in feat:
        fstr = feat_dpi_str(feat['dpi'])

    if 'firewall' in feat:
        tmp = feat_fw_str(feat['firewall'])

        if not fstr:
            fstr = tmp
        else:
            fstr += "; " + tmp

    if 'nat64' in feat:
        tmp = feat_nat64_str(feat['nat64'])

        if not fstr:
            fstr = tmp
        else:
            fstr += "; " + tmp

    if 'nat' in feat:
        tmp = feat_nat_str(feat['nat'])

        if not fstr:
            fstr = tmp
        else:
            fstr += "; " + tmp

    if 'alg' in feat:
        tmp = feat_alg_str(feat['alg'], sess['id'])

        if not fstr:
            fstr = tmp
        else:
            fstr += "; " + tmp

    # Catch-all for session types we do not handle yet
    if not fstr:
        fstr = "other"

    return fstr


#
# Display feature information for the detailed output format
#
def sess_feat_detail(sess, col1, col2):

    feat = sess_feature_npf(sess)
    if not feat:
        return

    if 'firewall' in feat:
        feat_firewall_detail(feat['firewall'], col1, col2)

    if 'dpi' in feat:
        feat_dpi_detail(feat['dpi'], col1, col2)

    if 'nat64' in feat:
        feat_nat64_detail(feat['nat64'], col1, col2)

    if 'nat' in feat:
        feat_nat_detail(feat['nat'], col1, col2)

    if 'alg' in feat:
        feat_alg_detail(feat['alg'], col1, col2)


#
# npf session flags string.  Return a short flags string for the detailed
# output format
#
def sess_flags_str(sess):
    """Return a short flags string if a features array is present in the
    session, and there is an entry for npf

    """

    feat = sess_feature_npf(sess)
    if not feat:
        return "-"

    flags = feat['flags']

    fstr = ""

    if (flags & SE_ACTIVE) == 0:
        fstr += "!active, "

    if (flags & SE_PASS) != 0:
        fstr += "pass, "

    if (flags & SE_EXPIRE) != 0 or sess['time_to_expire'] < 0:
        fstr += "expired, "

    if (flags & SE_SECONDARY) != 0:
        fstr += "secondary, "

    if (flags & SE_LOCAL_ZONE_NAT) != 0:
        fstr += "local zone nat, "

    if (flags & SE_IF_DISABLED) != 0:
        fstr += "intf disabled, "

    if (flags & SE_NAT_PINHOLE) != 0:
        fstr += "nat pinhole, "

    if fstr == "":
        fstr = "-"
    else:
        # remove trailing ", "
        fstr = fstr[:-2]

    return fstr


#
# Determine direction from the npf session flags and return a string
#
def sess_in_or_out(sess, short):
    feat = sess_feature_npf(sess)
    if not feat:
        return "-"

    flags = feat['flags']

    if (flags & PFIL_IN) != 0:
        if short:
            return "I"
        else:
            return "In"
    elif (flags & PFIL_OUT) != 0:
        if short:
            return "O"
        else:
            return "Out"

    return "-"


#
# Initial column widths
#
def init_col_widths(ctx):
    """Initial column widths for columns in the table format that can increase in
    width.
    """

    # ID
    ctx['idcol'] = 5

    # Src/dst Address
    ctx['addrcol'] = 15

    # Interface
    ctx['ifcol'] = 8

    # Pkts out/in
    ctx['outcol'] = 6
    ctx['incol'] = 6


#
# Increase the column widths?  Returns True if any are increased.
#
def check_col_widths(sess, ctx):
    """Check and possibly increase width of dynamic columns"""

    rv = False

    if len(str(sess['id'])) > ctx['idcol']:
        ctx['idcol'] = len(str(sess['id']))
        rv = True

    if len(sess['interface']) > ctx['ifcol']:
        ctx['ifcol'] = len(sess['interface'])
        rv = True

    if len(sess['src_addr']) > ctx['addrcol']:
        ctx['addrcol'] = len(sess['src_addr'])
        rv = True

    if len(sess['dst_addr']) > ctx['addrcol']:
        ctx['addrcol'] = len(sess['dst_addr'])
        rv = True

    if len(str(sess['counters']['packets_out'])) > ctx['outcol']:
        ctx['outcol'] = len(str(sess['counters']['packets_out']))
        rv = True

    if len(str(sess['counters']['packets_in'])) > ctx['incol']:
        ctx['incol'] = len(str(sess['counters']['packets_in']))
        rv = True

    return rv


#
# Session table header and entry format strings
#
# Strings and addresses are left-justified (eg. %-8s), and numbers are
# right-justified.
#
# Anywhere a left-justified column follows a right-justified column, or where
# a left-justified column follows a variable width column, we leave two
# spaces.  This gives a nice trade-off between efficient spacing and
# readability.
#
def sess_op_fmt():
    hfmt = "%-*s  %-*s %5s  %-*s %5s  %-*s %1s  %-8s %-5s %7s %*s %*s  %-s"
    efmt = "%-*d  %-*s %5d  %-*s %5d  %-*s %1s  %-8s %-5s %7d %*d %*d  %-s"
    return hfmt, efmt


#
# Show banner
#
# Some of the columns have a dynamic width.  i.e. the width will
# increase to accommodate the size of the session element being
# displayed.  The 'ctx' dictionary holds these widths.
#
# Columns are:
#  ID
#  Source Addr
#  Source Port
#  Dest Addr
#  Dest Port
#  Protocol
#  Interface
#  Timeout
#  Feature
#
def sess_op_show_banner(ctx, fmt):
    print(fmt % (ctx['idcol'], "ID",
                 ctx['addrcol'], "Source",
                 "",
                 ctx['addrcol'], "Destination",
                 "",
                 ctx['ifcol'], "Intf",
                 "D",
                 "Proto",
                 "State",
                 "Timeout",
                 ctx['outcol'], "PktOut",
                 ctx['incol'],  "PktIn",
                 "Features"))


#
# Display one session for table output format
#
def sess_op_show_one(sess, ctx, fmt):
    """Display one session for table output format"""

    feat_str = sess_feat_str(sess)

    print(fmt % (ctx['idcol'], sess['id'],
                 ctx['addrcol'], sess['src_addr'],
                 sess['src_port'],
                 ctx['addrcol'], sess['dst_addr'],
                 sess['dst_port'],
                 ctx['ifcol'], sess['interface'],
                 sess_in_or_out(sess, True),
                 num2proto(sess['proto']),
                 state2str(sess['gen_state'], True),
                 sess['time_to_expire'],
                 ctx['outcol'], sess['counters']['packets_out'],
                 ctx['incol'], sess['counters']['packets_in'],
                 feat_str))


#
# Show one session in detail
#
def sess_op_show_one_detail(sess):
    """Display one session for detailed output format"""

    print("Session ID: %d, State: %s, Flags: %s" %
          (sess['id'], state2str(sess['gen_state'], False),
           sess_flags_str(sess)))

    timeout = "%d/%d" % (sess['time_to_expire'],
                         sess['state_expire_window'])
    proto = "%s (%d)" % (num2proto(sess['proto']), sess['proto'])
    counts = sess['counters']

    intf_dir = "%s %s" % (sess['interface'], sess_in_or_out(sess, False))

    col1 = 28
    col2 = 24

    # Column2 width is dynamic
    if len(sess['src_addr']) > col2:
        col2 = len(sess['src_addr'])
    if len(sess['dst_addr']) > col2:
        col2 = len(sess['dst_addr'])
    if len(str(counts['bytes_out'])) > col2:
        col2 = len(str(counts['bytes_out']))
    if len(str(counts['bytes_in'])) > col2:
        col2 = len(str(counts['bytes_in']))

    print("  %-*s %*s" % (col1, "Interface", col2, intf_dir))
    print("  %-*s %*s" % (col1, "Protocol", col2, proto))
    print("  %-*s %*s" % (col1, "Source address", col2, sess['src_addr']))
    print("  %-*s %*s" % (col1, "  port", col2, sess['src_port']))
    print("  %-*s %*s" % (col1, "Destination address", col2, sess['dst_addr']))
    print("  %-*s %*s" % (col1, "  port", col2, sess['src_port']))
    print("  %-*s %*s" % (col1, "Timeout", col2, timeout))

    print("  %-*s %*d" % (col1, "Out, packets", col2, counts['packets_out']))
    print("  %-*s %*d" % (col1, "     bytes", col2, counts['bytes_out']))

    print("  %-*s %*d" % (col1, "In,  packets", col2, counts['packets_in']))
    print("  %-*s %*d" % (col1, "     bytes", col2, counts['bytes_in']))

    # Features
    sess_feat_detail(sess, col1, col2)
    print()


#
# Display unordered sessions
#
# Fetches batches of session from the dataplane using a 'start' and 'count'
# value, where the 'start' value is simply the iteration count through the
# hash table.
#
def sess_op_show_unordered(ctx, batch_size):
    """Display sessions in the order they are returned from the dataplane
    hash table
    """

    base_cmd = base_show_cmd

    #
    # If neither 'ip' or 'ip6' is in the command then the dataplane will
    # return sessions belonging to both address families.  'ip' or 'ip6' is
    # *only* specified when we want sessions for that specific address family.
    #
    if ctx['ip'] and not ctx['ip6']:
        base_cmd += " ip"
    elif ctx['ip6'] and not ctx['ip']:
        base_cmd += " ip6"

    base_cmd += cmd_option_string(ctx)

    if not ctx['detail']:
        # Get format strings
        hfmt, efmt = sess_op_fmt()

        # Set initial column widths
        init_col_widths(ctx)

        # Initial banner output
        print(state_banner())

    reqd_count = ctx['count']
    start = 0
    sess_count = 0

    while not reqd_count or sess_count < reqd_count:

        # Do not fetch more than we need in this batch
        if reqd_count and (reqd_count - sess_count) < batch_size:
            batch_size = reqd_count - sess_count

        cmd = base_cmd + " start %d count %d" % (start, batch_size)

        sess_list = get_sessions(cmd)
        if not sess_list:
            break

        for i in range(0, len(sess_list)):
            #
            # Display banner if column widths change or if this is
            # very first session or every 40 sessions.
            #
            if not ctx['detail']:
                if check_col_widths(sess_list[i], ctx):
                    sess_op_show_banner(ctx, hfmt)
                elif (sess_count % 40) == 0:
                    init_col_widths(ctx)
                    sess_op_show_banner(ctx, hfmt)

            if not ctx['detail']:
                sess_op_show_one(sess_list[i], ctx, efmt)
            else:
                sess_op_show_one_detail(sess_list[i])

            sess_count += 1

        start += len(sess_list)


#
# Sort list of sessions
#
def sort_sessions(sess_list, order, orderby):
    #
    # Format the session field value used in the 'sorted' lambda such that it
    # can be used for a comparison operation.
    #
    # Address strings need converted to IPAddress format to allow them
    # to be compared.
    #
    def item_fmt(item, orderby):
        if orderby_is_addr(orderby):
            return IPAddress(item)
        return item

    # ascending or descending?
    rev = (order == "descending")

    #
    # Get suitable field to do secondary sorting on.  Note, this is
    # unnecessary when primary sort is on session ID since that will be
    # unique.
    #
    # Primary sort      Secondary sort
    # --------------    --------------
    # source address    source port
    # dest address      dest port
    # trans address     trans port
    # timeout           session ID
    # session ID        n/a
    #
    if orderby == 'src_addr':
        orderby2 = 'src_port'
    elif orderby == 'dst_addr':
        orderby2 = 'dst_port'
    elif orderby == 'trans_addr':
        orderby2 = 'trans_port'
    elif orderby == 'timeout':
        orderby2 = 'id'
    elif orderby == 'id':
        orderby2 = None
    else:
        # This should never happen.  But if it does we default to ordering by
        # source address and port.
        orderby = 'src_addr'
        orderby2 = 'src_port'

    tmp = sess_list
    if orderby2:
        sess_list = sorted(tmp, key=lambda d: (item_fmt(d[orderby], orderby),
                                               d[orderby2]), reverse=rev)
    else:
        sess_list = sorted(tmp, key=lambda d: item_fmt(d[orderby], orderby),
                           reverse=rev)

    return sess_list


#
# Fetch and print sessions in batches, using the given ordered list as a guide
#
def sess_op_show_ordered(ctx, item_list, af, batch_size):

    if not item_list or not ctx['order'] or not ctx['orderby']:
        return

    base_cmd = base_show_cmd

    if af:
        base_cmd += " %s" % (af)

    base_cmd += cmd_option_string(ctx)

    if not ctx['detail']:
        # Get format strings
        hfmt, efmt = sess_op_fmt()

        # Set initial column widths
        init_col_widths(ctx)

        # Initial banner output
        print(state_banner())

    if ctx['count'] > 0 and ctx['count'] < len(item_list):
        reqd_count = ctx['count']
    else:
        reqd_count = len(item_list)

    index = 0
    sess_count = 0

    while index < reqd_count and sess_count < reqd_count:

        # Do not fetch more than we need in this batch
        if (reqd_count - sess_count) < batch_size:
            batch_size = reqd_count - sess_count

        # Get start and end values for next batch
        index, start, end = get_start_end_vals(index, item_list, batch_size)
        if not start or not end:
            break

        cmd = base_cmd + " start %s end %s" % (start, end)

        # Get batch of sessions
        sess_list = get_sessions(cmd)

        # Fixup required if ordering by NAT translation address
        if ctx['orderby'] == 'trans_addr':
            sess_list = orderby_trans_addr_fixup(sess_list)

        # Sort sessions
        sess_list = sort_sessions(sess_list, ctx['order'], ctx['orderby'])

        # Display the sessions
        for i in range(0, len(sess_list)):

            if not ctx['detail']:
                #
                # Check column widths before displaying banner.
                # Display banner if column widths change or if this is
                # very first session or every 40 sessions.
                #
                if check_col_widths(sess_list[i], ctx) or (sess_count % 40) == 0:
                    sess_op_show_banner(ctx, hfmt)

            if not ctx['detail']:
                sess_op_show_one(sess_list[i], ctx, efmt)
            else:
                sess_op_show_one_detail(sess_list[i])

            sess_count += 1
            if sess_count >= reqd_count:
                break


#
# Show Summary
#
def sess_op_show_summary():

    cmd = base_show_cmd
    cmd += " summary"

    summary = None

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                summary = dp.json_command(cmd)
                break

    if not summary or 'summary' not in summary:
        return

    sum = summary['summary']

    col1 = 20
    col2 = 14

    print("%-*s" % (col1, "Dataplane sessions"))
    col1 -= 2

    print("  %-*s%*d" % (col1, "Total", col2, sum['total']))

    print("  %-*s" % (col1, "Address family:"))
    print("    %-*s%*d" % (col1-2, "IP", col2, sum['address-family']['ip']))
    print("    %-*s%*d" % (col1-2, "IPv6", col2, sum['address-family']['ip6']))

    print("  %-*s" % (col1, "Direction:"))
    print("    %-*s%*d" % (col1-2, "In", col2, sum['direction']['in']))
    print("    %-*s%*d" % (col1-2, "Out", col2, sum['direction']['out']))

    proto_str = ['TCP', 'UDP', 'Other']
    for name in proto_str:
        proto = sum['protocol'][name.lower()]
        print("  %-*s%*d" % (col1, name, col2, proto['total']))
        print("    %-*s%*d" % (col1-2, "Closed", col2, proto['closed']))
        print("    %-*s%*d" % (col1-2, "Opening", col2, proto['opening']))
        print("    %-*s%*d" % (col1-2, "Established", col2, proto['established']))
        if name == 'TCP':
            print("    %-*s%*d" % (col1-2, "Closing", col2, proto['closing']))

    print("  %-*s" % (col1, "Feature:"))
    print("    %-*s%*d" % (col1-2, "Firewall", col2, sum['feature']['other']))
    print("    %-*s%*d" % (col1-2, "DNAT", col2, sum['feature']['dnat']))
    print("    %-*s%*d" % (col1-2, "SNAT", col2, sum['feature']['snat']))
    print("    %-*s%*d" % (col1-2, "ALG", col2, sum['feature']['alg']))
    print("    %-*s%*d" % (col1-2, "NAT64", col2, sum['feature']['nat64']))
    print("    %-*s%*d" % (col1-2, "NAT46", col2, sum['feature']['nat46']))
    print("    %-*s%*d" % (col1-2, "App", col2, sum['feature']['app']))


#
# Show
#
def sess_op_show(ctx):

    if ctx['summary']:
        sess_op_show_summary()
        return

    if orderby_is_addr(ctx['orderby']):
        #
        # If ordering by address then v4 and v6 are handled separately
        #
        if ctx['ip']:
            item_list = get_item_list(ctx, "ip")

            sess_op_show_ordered(ctx, item_list, "ip", batch_size)

        if ctx['ip6']:
            item_list = get_item_list(ctx, "ip6")

            sess_op_show_ordered(ctx, item_list, "ip6", batch_size)

    elif ctx['orderby']:
        #
        # else ordering by ID or timeout.
        #
        # If neither or both 'ip' or 'ip6' is specified in the command then
        # the dataplane will return list items or sessions for both ip and
        # ip6.
        #
        af = None
        if ctx['ip'] and not ctx['ip6']:
            af = 'ip'
        if ctx['ip6'] and not ctx['ip']:
            af = 'ip6'

        item_list = get_item_list(ctx, af)

        sess_op_show_ordered(ctx, item_list, af, batch_size)

    else:
        # Unordered
        sess_op_show_unordered(ctx, batch_size)


#
# Clear dataplane sessions
#
def sess_op_clear(ctx):
    """Clear dataplane sessions"""

    cmd = base_clear_cmd

    #
    # If an address family has not been specified (or both 'ip' and 'ip6' are
    # specified) then both ip and ip6 sessions will be cleared.  We only
    # specifically add 'ip' or 'ip6' to the command string when we want to
    # clear sessions belonging to just that address family.
    #
    if ctx['ip'] and not ctx['ip6']:
        cmd += " ip"
    if ctx['ip6'] and not ctx['ip']:
        cmd += " ip6"

    cmd += cmd_option_string(ctx)

    with vplaned.Controller() as controller:
        for dp in controller.get_dataplanes():
            with dp:
                try:
                    dp.string_command(cmd)
                except:
                    # Likely a zmq timeout occurred.  However zmq exceptions
                    # are not translated back to vplaned exceptions, so just
                    # return.
                    return


#
# Parse options and call show or clear commands
#
def sess_op_main():
    """Main function"""

    show = False
    clear = False

    #
    # Parse options
    #
    try:
        opts, args = getopt.getopt(sys.argv[1:],
                                   "", ['show', 'clear'])

    except getopt.GetoptError as r:
        print(r, file=sys.stderr)
        sys.exit(2)

    for opt, arg in opts:
        if opt in '--show':
            show = True

        if opt in '--clear':
            clear = True

    ctx = create_ctx(show)

    # Parse remaining options
    error_str = sess_op_parse_options(sys.argv[2:], ctx)
    if error_str:
        print(error_str, file=sys.stderr)
        sys.exit(2)

    if show:
        sess_op_show(ctx)

    elif clear:
        sess_op_clear(ctx)


#
# main
#
if __name__ == '__main__':
    sess_op_main()
