#!/usr/bin/env python3

# Copyright (c) 2021, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: LGPL-2.1-only

from argparse import ArgumentParser
import socket
from vyatta import configd
import sys

func_arp = "arp"
func_arp_all = "arp-all"

get_arp_type_key = "source"
get_arp_intf_key = "ifname"
get_arp_addr_key = "ip"

arg_parser = ArgumentParser()
arg_parser.add_argument("function_choice", choices=[
                        func_arp, func_arp_all], help="ARP function")
arg_parser.add_argument("--show-intf", required=False, help="Interface name")
arg_parser.add_argument("--addr", required=False, help="IP address")
args = arg_parser.parse_args()

arp_dp = "dataplane"
arp_kernel = "control-plane"


def get_interfaces():
    """
    Returns a list of network interfaces in the system. Excludes hidden
    interfaces starting with '.' and vrf interfaces.
    """

    interfaces = socket.if_nameindex()
    valid_interfaces = [x[1] for x in interfaces if (
        not x[1].startswith("vrf") or not x[1].startswith("."))]
    return valid_interfaces


def show_arp(intf, addr):
    """
    Show dataplane ARP entries.
    """

    output_format = "{:20} {:10} {:17} {}"
    input_arg = {get_arp_type_key: arp_dp}

    if intf is not None:
        if intf in get_interfaces():
            input_arg[get_arp_intf_key] = intf
            if addr is not None:
                input_arg[get_arp_addr_key] = addr
        else:
            print("interface " + intf + " does not exist on system")
            sys.exit(1)

    if (intf is None) or (addr is None):
        print(output_format.format("IP Address", "Flags", "HW address", "Device"))

    cfg = configd.Client()

    try:
        data = cfg.call_rpc_dict("vyatta-arp-v1", "get-arp", input_arg)
        arp_list = data["arp-entry-list"]
    except Exception as e:
        print("Retrieving ARP entries failed: \n {}".format(e), file=sys.stderr)
        sys.exit(1)

    for entry in arp_list:
        if (intf is None) or (addr is None):
            print(output_format.format(
                entry["ip"], entry["flags"], entry["hwaddr"], entry["ifname"]))
        else:
            print("{} {}".format(entry["ip"], entry["ifname"]))
            print("    Flags: {}".format(entry["flags"]))
            print("    HW Address: {}".format(entry["hwaddr"]))
            if "platform_state" in entry:
                print("    Platform state:")
                print(entry["platform_state"])


def show_arp_all(intf, addr):
    """
    Show both kernel and dataplane ARP entries.
    """
    zero_mac = "00:00:00:00:00:00"
    output_format = "{:18} {:17} {:10} {:10} {}"
    input_arg_dp = {get_arp_type_key: arp_dp}
    input_arg_kernel = {get_arp_type_key: arp_kernel}

    if intf is not None:
        input_arg_kernel[get_arp_intf_key] = intf
        if addr is not None:
            input_arg_kernel[get_arp_addr_key] = addr
        if intf in get_interfaces():
            input_arg_dp[get_arp_intf_key] = intf
            if addr is not None:
                input_arg_dp[get_arp_addr_key] = addr
        else:
            print("interface " + intf + " does not exist on system")
            sys.exit(1)

    cfg = configd.Client()

    try:
        data = cfg.call_rpc_dict("vyatta-arp-v1", "get-arp", input_arg_dp)
        arp_list_dp = data["arp-entry-list"]
    except Exception as e:
        print("show arp failed : ", e)
        sys.exit(1)

    try:
        data = cfg.call_rpc_dict(
            "vyatta-arp-v1", "get-arp", input_arg_kernel)
        arp_list_kernel = data["arp-entry-list"]
    except Exception as e:
        print("show arp failed : ", e)
        sys.exit(1)

    arp_dict_kernel = {value["ip"]: value for value in arp_list_kernel}

    print(output_format.format("IP Address", "HW address",
                               "Dataplane", "Controller", "Device"))

    for entry in arp_list_dp:
        kentry = arp_dict_kernel[entry["ip"]
                                 ] if entry["ip"] in arp_dict_kernel.keys() else None
        if kentry and (kentry["hwaddr"] == entry["hwaddr"] or kentry['hwaddr'] == zero_mac):
            print(output_format.format(
                entry["ip"], entry["hwaddr"], entry["flags"], kentry["flags"], entry["ifname"]))
            arp_dict_kernel.pop(entry["ip"])
        else:
            print(output_format.format(
                entry["ip"], entry["hwaddr"], entry["flags"], "", entry["ifname"]))

    for ip in arp_dict_kernel.keys():
        print(output_format.format(
            ip, arp_dict_kernel[ip]["hwaddr"], "", arp_dict_kernel[ip]["flags"],
            arp_dict_kernel[ip]["ifname"]))


if args.function_choice == func_arp_all:
    show_arp_all(args.show_intf, args.addr)
elif args.function_choice == func_arp:
    show_arp(args.show_intf, args.addr)
else:
    print("Incorrect function choice")
