#!/usr/bin/python3
#
# Copyright (c) 2019, AT&T Intellectual Property. All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only


"""vyatta_static_host_mapping: commit script to update /etc/hosts or
   /run/dns/<vrf>/hosts file based on current configiuration"""

import os
import os.path
import re
import tempfile
from argparse import ArgumentParser
from vyatta import configd

HOST_RE = re.compile(r'[a-zA-Z0-9][a-zA-Z0-9.-]*[a-zA-Z0-9]\Z')

def is_valid_host(pat):
    """match host pattern"""
    return HOST_RE.match(pat)


def create_hosts_dir(host_file):
    """Create directory for host_file. Ignore ExistsError"""
    dpath = os.path.dirname(host_file)
    try:
        os.makedirs(dpath, 0o755)
    except FileExistsError as exc:
        if not os.path.isdir(dpath):
            raise exc


def get_cfg_mapping(cfg, vrf):
    """get host-name: { ip, aliases } mapping from config"""

    path = 'system static-host-mapping'
    if vrf and vrf != 'default':
        path = 'routing routing-instance {} {}'.format(vrf, path)

    try:
        host_map_cfg = cfg.tree_get_dict(path, cfg.CANDIDATE, 'internal')
        host_map_cfg = host_map_cfg['static-host-mapping']['host-name']
        return host_map_cfg
    except (configd.Exception, KeyError):
        return {}



def process_one_entry(name, entry):
    """create hosts entry fon one host"""

    if not is_valid_host(name):
        print("ignoring invalid name {}".format(name))
        return ''

    inet = entry['inet']
    aliases = entry.get('alias')

    filtered_aliases = []
    if aliases:
        filtered_aliases = [x for x in aliases if is_valid_host(x)]

        if len(aliases) > len(filtered_aliases):
            bad_aliases = [x for x in aliases if not is_valid_host(x)]
            print("Ignoring invalid aliases '{}' for host {}".format(','.join(bad_aliases), name))

    if not filtered_aliases:
        return "{}\t{} #vyatta entry\n".format(inet, name)

    return "{}\t{} {} #vyatta entry\n".format(inet, name, ' '.join(filtered_aliases))


def update_hosts(host_file, static_host_map):
    """Update hosts file from map.
         - remove all vyatta entries
         - add all configured entries
    """

    create_hosts_dir(host_file)
    temp = tempfile.NamedTemporaryFile(dir=os.path.dirname(host_file), delete=False)
    try:
        with open(host_file) as hfile:
            temp.write(re.sub("(.*?)#vyatta entry\n", "", hfile.read()).encode())
    except FileNotFoundError:
        pass

    for host, entry in static_host_map.items():
        line = process_one_entry(host, entry)
        if line:
            temp.write(line.encode())

    temp.close()
    os.chmod(temp.name, 0o644)
    os.rename(temp.name, host_file)


def main():
    """Main"""
    parser = ArgumentParser(description="Vyatta system static_host_mapping end script")
    parser.add_argument('-r', '--vrf', nargs='?', const='default', help='vrfname')
    parser.add_argument('-f', '--dir', nargs='?', default='/', help='vrfname')

    args = parser.parse_args()

    host_dir = '/etc'
    if args.vrf and args.vrf != 'default':
        host_dir = '/run/dns/vrf/' + args.vrf

    host_file = os.path.normpath(args.dir + host_dir + '/' + 'hosts')

    cfg = configd.Client()
    static_host_map = get_cfg_mapping(cfg, args.vrf)
    update_hosts(host_file, static_host_map)
    os.system('/opt/vyatta/sbin/vyatta_update_syslog.pl')


if __name__ == "__main__":
    main()
