#!/usr/bin/env python3

# Copyright (c) 2021, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: LGPL-2.1-only

from vplaned import Controller
import logging
import subprocess
from re import search
from json import dumps, load
import sys
import logging.handlers
from systemd.journal import JournalHandler

arp_dp = "dataplane"
arp_kernel = "control-plane"

arp_type_key = "source"
arp_intf_key = "ifname"
arp_addr_key = "ip"

PLATFORM_STATE_CMD = "/opt/vyatta/bin/vyatta-platform-util"


def get_arp_dp(intf, addr):
    """
    Get dataplane ARP entries.
    """
    mac_format = "{}:{}:{}:{}:{}:{}"
    intf_cmd = intf if intf else ""
    arp_entries = []

    for dp in controller.get_dataplanes():
        with dp:
            try:
                data = dp.json_command("arp show {}".format(intf_cmd))
            except Exception as e:
                logger.error(
                    "Error with the command 'arp show {}' ".format(intf_cmd))
                logger.error(e)
                continue
            for entry in data["arp"]:
                if (addr is not None) and (addr != entry["ip"]):
                    continue
                arp_entry = {}
                arp_entry["ip"] = entry["ip"]
                arp_entry["hwaddr"] = mac_format.format(
                    *tuple(x.zfill(2) for x in entry["mac"].split(':')))
                arp_entry["ifname"] = entry["ifname"]
                arp_entry["flags"] = entry["flags"]
                if "platform_state" in entry:
                    platform_state_process = subprocess.run(
                        [PLATFORM_STATE_CMD, "--format-platform-state", "ip-neigh"],
                        stdout=subprocess.PIPE, input=dumps(entry),
                        encoding='ascii', text=True)
                    if platform_state_process.returncode == 0:
                        arp_entry["platform_state"] = platform_state_process.stdout
                    else:
                        arp_entry["platform_state"] = platform_state_process.stderr
                arp_entries.append(arp_entry)

    container = {"arp-entry-list": arp_entries}
    return container


def get_arp_kernel(intf, addr):
    """
    Get kernel ARP entries.
    """

    zero_mac = "00:00:00:00:00:00"
    kernel_flags = {
        "REACHABLE": "VALID",
        "STALE": "VALID",
        "DELAY": "VALID",
        "PROBE": "VALID",
        "PERMANENT": "STATIC",
        "NOARP": "STATIC"
    }

    kernel_flags_re = "|".join(list(kernel_flags))

    arp_output = subprocess.run(
        ["ip", "-4", "neigh"], stdout=subprocess.PIPE, text=True)
    if arp_output.returncode != 0:
        logger.error("Getting kernel ARP failed")
        sys.exit(1)
    lines = arp_output.stdout.splitlines()

    arp_entries = []

    for line in lines:
        arp_entry = {}
        match = search(
            "([^ ]+) dev ([^ ]+) lladdr ([^ ]+) (" + kernel_flags_re + ")", line)
        if(match):
            arp_entry["ip"] = match.group(1)
            arp_entry["ifname"] = match.group(2)
            arp_entry["hwaddr"] = match.group(3)
            arp_entry["flags"] = kernel_flags[match.group(4)]

        else:
            match = search("([^ ]+) dev ([^ ]+)  FAILED", line)
            if match:
                arp_entry["ip"] = match.group(1)
                arp_entry["ifname"] = match.group(2)
                arp_entry["hwaddr"] = zero_mac
                arp_entry["flags"] = "FAILED"

            else:
                match = search("([^ ]+) dev ([^ ]+)  INCOMPLETE", line)
                if match:
                    arp_entry["ip"] = match.group(1)
                    arp_entry["ifname"] = match.group(2)
                    arp_entry["hwaddr"] = zero_mac
                    arp_entry["flags"] = "PENDING"

                else:
                    logger.error("Unknown kernel ARP format")
                    sys.exit(1)
        if (intf is not None) and (intf != arp_entry["ifname"]):
            continue
        if (addr is not None) and (addr != arp_entry["ip"]):
            continue
        arp_entries.append(arp_entry)

    container = {"arp-entry-list": arp_entries}
    return container


if __name__ == "__main__":
    logger = logging.getLogger()
    logging.root.addHandler(
        JournalHandler(SYSLOG_IDENTIFIER='vyatta-op-arp'))

    arp_type = None
    intf = None
    addr = None

    try:
        input = load(sys.stdin)
        if arp_type_key in input.keys():
            arp_type = input[arp_type_key]
        if arp_intf_key in input.keys():
            intf = input[arp_intf_key]
        if arp_addr_key in input.keys():
            addr = input[arp_addr_key]
    except (ValueError, Exception) as e:
        logger.error("Error parsing input. \n {}".format(e))
        sys.exit(1)

    with Controller() as controller:
        if arp_type == arp_dp:
            print(dumps(get_arp_dp(intf, addr)))
        elif arp_type == arp_kernel:
            print(dumps(get_arp_kernel(intf, addr)))
        else:
            logger.error("Incorrect function choice {}.".format(
                arp_type))
            sys.exit(1)
