#!/usr/bin/python3
# Copyright (c) 2017-2019 AT&T Intellectual Property.
# All Rights Reserved.

# SPDX-License-Identifier: GPL-2.0-only

import os
import sys
import time
import socket
import getopt
import vici
import pprint
import json

from collections import OrderedDict
from vyatta import configd
from vyatta.vpn.ipsec.config import IPsecRAVPNServer, ClientProfile, IKEGroup, ESPGroup, get_cert_or_key, setup_vici, setup_dbus

DEBUG = False
FORCE = False
PP = None
CONFIG_CANDIDATE=configd.Client.CANDIDATE
CONFIG_RUNNING=configd.Client.RUNNING

def dbg(msg):
    global PP
    if not DEBUG:
        return
    if PP is None:
        PP = pprint.PrettyPrinter(indent=2)
    PP.pprint(msg)

def err(msg):
    print(msg, file=sys.stderr)

CLIENT_NAMEPREFIX = "ipsec_ra_client"

IPSEC_CFG_PATH = 'security vpn ipsec'
INTF_TUNNEL_CFG_PATH = 'interfaces tunnel'

def conn_name_of_profile(profile, tunnel_id=None):
    if tunnel_id:
        return '{}-{}-tunnel-{}'.format(CLIENT_NAMEPREFIX, profile, tunnel_id)

    return '{}-{}'.format(CLIENT_NAMEPREFIX, profile)

# pass config dictionary for the ipsec node and create
# vici connections and secrets.
#

def ra_profile_changed(client, ipsec_cfg_existing, prof):
    profile_conn_name = conn_name_of_profile(prof.name)

    path = 'security vpn ipsec remote-access-client profile ' + prof.name
    path_state = client.node_get_status(CONFIG_CANDIDATE, path)
    if path_state == client.DELETED:
       return ([], [], [profile_conn_name])
    elif path_state == client.ADDED:
       return ([], [profile_conn_name], [])
    elif path_state == client.UNCHANGED:
       return ([], [], [])

    group_path = 'security vpn ipsec ike-group ' + prof.ike.name
    if client.node_get_status(CONFIG_CANDIDATE, group_path) == client.CHANGED:
       return ([profile_conn_name], [], [])

    group_path = 'security vpn ipsec esp-group ' + prof.esp.name
    if client.node_get_status(CONFIG_CANDIDATE, group_path) == client.CHANGED:
       return ([profile_conn_name], [], [])


    # Check for changes we actually care about
    changed_tunnels = []
    added_tunnels = []
    deleted_tunnels = []

    for k in prof.cfg:
        # Ignore:
        # - description: not relevant for IKE backend
        # - profile-name: special key, without status
        if k in ('description', 'profile-name'):
            continue

        key_state = client.node_get_status(CONFIG_CANDIDATE, path + ' ' + k)
        if key_state == client.UNCHANGED:
            continue

        # If tunnel got changed this requires extra work
        # to determine if a CHILD_SA requires a reset.
        # Everything else would be a IKE_SA reset and
        # we do not look any further.
        if k != 'tunnel':
            return ([profile_conn_name], [], [])

        # Determine which tunnels got changed, and return
        # them as list of changed tunnels
        for t in prof.cfg['tunnel']:
            tid = t['tunnel-id']
            tunnel_state = client.node_get_status(CONFIG_CANDIDATE, path + ' tunnel ' + str(tid))
            if tunnel_state == client.CHANGED:
                changed_tunnels.append(conn_name_of_profile(prof.name, tid))
            if tunnel_state == client.ADDED:
                added_tunnels.append(conn_name_of_profile(prof.name, tid))
            # Do not break here, check if any other parameter might require a IKE_SA reset

    # check for deleted tunnels
    try:
        profile_cfg_existing = ipsec_cfg_existing['ipsec']['remote-access-client']['profile'][prof.name]
        for t in profile_cfg_existing['tunnel']:
            tid = t['tunnel-id']
            tunnel_state = client.node_get_status(CONFIG_CANDIDATE, path + ' tunnel ' + str(tid))
            if tunnel_state == client.DELETED:
                deleted_tunnels.append(conn_name_of_profile(prof.name, tid))
    except:
        pass

    # Most fine grain changes are tunnel changes. Which comes last.
    return (changed_tunnels, added_tunnels, deleted_tunnels)


# Returns list of conn_names for IKE SA or CHILD SA which changed or got deleted.
# Or empty lists in case nothing changed.
def ra_config_changed(client, ipsec_cfg, ra_profiles):
    changed = []
    added   = []
    deleted = []

    # Existing profiles - deleted?
    try:
        ipsec_cfg_existing = client.tree_get_dict(IPSEC_CFG_PATH, CONFIG_RUNNING, "json")
        rac = ipsec_cfg_existing['ipsec']['remote-access-client']
        for p in rac['profile']:
            path = 'security vpn ipsec remote-access-client profile ' + p['profile-name']
            if client.node_get_status(CONFIG_CANDIDATE, path) == client.DELETED:
                p_conn_id = conn_name_of_profile(p['profile-name'])
                deleted.append(p_conn_id)
    except (configd.Exception, KeyError):
        pass

    # Candidate profiles - changed/added?
    for p in ra_profiles:
        c, a, d = ra_profile_changed(client, ipsec_cfg, p)
        changed.extend(c)
        added.extend(a)
        deleted.extend(d)

    changed = list(set(changed))
    added = list(set(added))
    deleted = list(set(deleted))
    return (changed, added, deleted)

def read_global_config(top_cfg, feature):
    """read IKE, ESP groups and remote-access-profiles from config tree"""
    cfg = {}

    if top_cfg is None or feature is None:
        return {}, {}, []

    try:
        cfg = top_cfg['ipsec']
        feature_cfgs = cfg[feature]
        ikes = {v['tagnode'] : IKEGroup(v) for v in cfg['ike-group']}
        esps = {v['tagnode'] : ESPGroup(v) for v in cfg['esp-group']}
    except KeyError:
        return {}, {}, []

    return ikes, esps, feature_cfgs

def read_vrf_interfaces_config(top_cfg):
    """read interfaces assigned to a VRF from config tree"""

    routing_instance_intfs = {}

    if top_cfg is None:
        return {}

    try:
        routing_instances = top_cfg['routing-instance']
        for ri in routing_instances:
            ri_name = ri['instance-name']
            if not ri.get('interface'):
                continue
            for ri_intf in ri['interface']:
                ri_intf = ri_intf['name']
                routing_instance_intfs[ri_intf] = ri_name
    except KeyError:
        return {}

    return routing_instance_intfs

def get_connections(ra_profiles):
    conn_sections = OrderedDict()
    for prof in ra_profiles:
        conn_sections.update(prof.section())
    return conn_sections

def get_shared_secrets(ra_profiles):
    return [item for p in ra_profiles for item in p.shared_secrets()]

def get_keys(ra_profiles):
    return [item for p in ra_profiles for item in p.key()]

def stale_connections(vs, prefix, connections):
    all_conns = [c.decode('ascii') for c in vs.get_conns()['conns']]
    stale_list = list(filter(lambda x: x.startswith(prefix) and x not in connections, all_conns))
    return stale_list

def load_secrets(vs, secrets, keys):

    dbg("loading secrets and keys")
    dbg(secrets)
    dbg(keys)
    # clean up secrets before loading
    vs.clear_creds()
    for s in secrets:
        vs.load_shared(s)
    dbg("shared secrets done")
    for k in keys:
        vs.load_key(k)
    dbg("keys done")


def load_vpn_x509(vs, configd_client):

    authorities = []
    crls = []

    # security vpn x509 - hierachy
    dbg("loading certs and crls")
    try:
        top_cfg = configd_client.tree_get_dict("security vpn x509", CONFIG_CANDIDATE, "json")
        cfg = top_cfg['x509']
        authorities = cfg.get('ca-certs')
        crls = cfg.get('crls')
    except configd.Exception:
        return

    if authorities:
        for a in authorities:
            cert = get_cert_or_key('X509', 'CA', a)
            if cert:
                vs.load_cert(cert)

    if crls:
        for c in crls:
            cert = get_cert_or_key('X509_CRL', 'NONE', c)
            if cert:
                vs.load_cert(cert)

    dbg("done loading certs and crls")

def load_ipsec_ra_vpn_server(vs, configd_client, ipsec_cfg):

    # Only perform action on a forced reload. E.g. via "restart vpn"
    if not FORCE:
        return

    ike_groups, esp_groups, feature_cfgs = read_global_config(ipsec_cfg, 'remote-access-server')
    IPsecRAVPNServer(vs, feature_cfgs, ike_groups, esp_groups).sync()

# unload all connections not part of the config.
# load all connections.
def load_ipsec_ra_vpn_client(vs, configd_client, ipsec_cfg):
    try:
        ri_cfg = configd_client.tree_get_dict("routing routing-instance", CONFIG_CANDIDATE, "json")
    except configd.Exception as f:
        ri_cfg = None

    ike_groups, esp_groups, feature_cfgs = read_global_config(ipsec_cfg, 'remote-access-client')
    vrf_interfaces = read_vrf_interfaces_config(ri_cfg)
    ra_profiles = [ClientProfile(v, ike_groups, esp_groups, vrf_interfaces) for v in feature_cfgs['profile']]

    changed, added, deleted = ra_config_changed(configd_client, ipsec_cfg, ra_profiles)
    if len(changed + added + deleted) == 0 and not FORCE:
        dbg("remote-access-client configuration unchanged")
        return None

    connections = get_connections(ra_profiles)
    secrets = get_shared_secrets(ra_profiles)
    keys = get_keys(ra_profiles)
    load_secrets(vs, secrets, keys)

    if len(changed + added + deleted) == 0 and FORCE:
        added = [conn_name_of_profile(p.name) for p in ra_profiles]

    # 1. identify all stale connections
    # 2. perform the VICI conn configuration changes
    # 3. signal reload to vyatta-ike-sa-daemon
    # 4. signal up/down jobs

    dbg(connections)
    dbg(added)

    # 1. identify all stale connections
    stale = stale_connections(vs, CLIENT_NAMEPREFIX, connections)

    # 2. perform the VICI conn configuration changes
    try:
        if stale:
            dbg("unloading stale connections")
            for cn in stale:
                dbg(cn)
                vs.unload_conn(OrderedDict(name=cn))
        if connections:
            dbg("loading connections")
            dbg(connections)
            vs.load_conn(connections)
            dbg("successfully loaded connections")
    except (vici.exception.CommandException, vici.exception.SessionException) as e:
        err("failed to load connections:{}".format(e))


    return (changed, added, deleted)

def trigger_ipsec_ra_vpn_client(updates):

    try:
        ike_sa_daemon = setup_dbus()
    except ConnectionRefusedError as e:
        err("Failed to connect to IKE SA daemon\n")
        sys.exit(1)

    if updates is None:
       return

    changed, added, deleted = updates

    # 4. signal up/down jobs
    if deleted:
        ike_sa_daemon.down(deleted + stale)
    if changed:
        ike_sa_daemon.reset(changed)
    if added:
        ike_sa_daemon.up(added)

def generate_ike_sa_daemon_config(ipsec_cfg, intf_tunnel_cfg):

    if ipsec_cfg is None:
        ipsec_cfg = {}

    if intf_tunnel_cfg is None:
        intf_tunnel_cfg = {}

    # generate IKE SA daemon configuration file
    full_config_dict = ipsec_cfg
    full_config_dict['interfaces'] = intf_tunnel_cfg
    os.makedirs('/var/run/vyatta-ike-sa-daemon', mode=0o700, exist_ok=True)
    old_umask = os.umask(0o177)
    with open('/var/run/vyatta-ike-sa-daemon/config.json', 'w') as fd:
        json.dump(full_config_dict, fd, indent=4, sort_keys=False)
    os.umask(old_umask)

def is_charon_running():
    pidpath = '/var/run/charon.pid'
    return os.path.isfile(pidpath)

def do_ipsec_feature(configd_client, feature_path):

    try:
        path_state = configd_client.node_get_status(CONFIG_CANDIDATE, feature_path)
    except configd.Exception:
        return False

    if path_state == configd_client.DELETED:
        return False

    return True

def do_vpn_x509(configd_client):
    ipsec_features = ['security vpn ipsec remote-access-client', \
                      'security vpn ipsec site-to-site',         \
                      'security vpn ipsec profile',              \
                      'secrutiy vpn l2tp remote-access']

    for feature_path in ipsec_features:
         if do_ipsec_feature(configd_client, feature_path):
             return True

    return False

def do_ipsec_ra_vpn_client(configd_client):
    cfgpath = 'security vpn ipsec remote-access-client'

    return do_ipsec_feature(configd_client, cfgpath)

def do_ipsec_ra_vpn_server(configd_client):
    cfgpath = 'security vpn ipsec remote-access-server'

    return do_ipsec_feature(configd_client, cfgpath)

def process_options():
    global DEBUG, FORCE
    try:
        opts, rest = getopt.getopt(sys.argv[1:], "fd", ['force', 'debug'])
    except getopt.GetoptError as r:
        err(e)
        err("usage: {} [--force] [--debug]".format(sys.argv[0]))
        sys.exit(2)

    for opt, arg in opts:
        if opt in ('-f', '--force'):
            FORCE = True
        elif opt in ('-d', '--debug'):
           DEBUG = True

# Main

def vpn_main():
    process_options()
    vs = setup_vici()

    try:
        configd_client = configd.Client()
    except configd.FatalException as f:
        err("can't connect to configd: {}".format(f))
        sys.exit(1)
    try:
        ipsec_cfg = configd_client.tree_get_dict(IPSEC_CFG_PATH, CONFIG_CANDIDATE, "json")
    except configd.Exception as f:
        # continue there might be a L2TP/IPsec RA server which requires the X.509 configuration
        ipsec_cfg = None

    try:
        intf_tunnel_cfg = configd_client.tree_get_dict(INTF_TUNNEL_CFG_PATH, CONFIG_CANDIDATE, "json")
    except configd.Exception as f:
        intf_tunnel_cfg = None


    generate_ike_sa_daemon_config(ipsec_cfg, intf_tunnel_cfg)

    if do_vpn_x509(configd_client):
        dbg('loading: vpn x509')
        load_vpn_x509(vs, configd_client)

    if do_ipsec_ra_vpn_server(configd_client):
        dbg('loading: vpn ipsec remote-access-server')
        load_ipsec_ra_vpn_server(vs, configd_client, ipsec_cfg)

    if do_ipsec_ra_vpn_client(configd_client):
        dbg('loading: vpn ipsec remote-access-client')
        updates = load_ipsec_ra_vpn_client(vs, configd_client, ipsec_cfg)

    # always reload, to get a fresh copy, even on non IPsec RA VPN client changes
    # to pick up connection-manager logging changes
    if is_charon_running():
        try:
            ike_sa_daemon = setup_dbus()
        except ConnectionRefusedError as e:
            err("Failed to connect to IKE SA daemon\n")
            sys.exit(1)

        # 3. signal reload to vyatta-ike-sa-daemon
        dbg('trigger reload for vyatta-ike-sa-daemon')
        ike_sa_daemon.reload()

    if do_ipsec_ra_vpn_client(configd_client):
        dbg('trigger events for: vpn ipsec remote-access-client')
        trigger_ipsec_ra_vpn_client(updates)


if __name__ == "__main__":
    vpn_main()
    exit(0)
