#!/usr/bin/env python3

# Copyright (c) 2018-2019 AT&T Intellectual Property.
# All rights reserved.

# SPDX-License-Identifier: GPL-2.0-only

import time
import sys
import os

import vci

import vici
import pprint
import json

import syslog

from collections import OrderedDict

from vyatta.vpn.ipsec.config import IKEGroup, ESPGroup, IPsecRAVPNServer, setup_vici 

IPSEC_RA_SERVER_PREFIX = 'ipsec-remote-access-server-'
NS_IPSEC_RA_SERVER_PREFIX = 'vyatta-security-vpn-ipsec-remote-access-server-v1:'


def dec(string):
    if not string:
        return None
    return string.decode('utf-8')


def log(msg):
    syslog.syslog(syslog.LOG_DEBUG, '{}'.format(msg))

def err(msg):
    syslog.syslog(syslog.LOG_ERR, '{}'.format(msg))

class CharonConfigSync():
    def __init__(self, vci_cfg):
        try:
            self.cfg = vci_cfg['vyatta-security-v1:security']['vyatta-security-vpn-ipsec-v1:vpn']['ipsec']
        except KeyError:
            self.cfg = None
        self.vs = setup_vici()

    def sync(self):
        ike_group = None
        esp_group = None

        if self.cfg:
            ike_group = self.cfg.get('ike-group')
            esp_group = self.cfg.get('esp-group')
            cfg_ipsec_ra_vpn_server_v1 = self.cfg.get(NS_IPSEC_RA_SERVER_PREFIX + 'remote-access-server')

        if ike_group and esp_group:
            ike_groups = {v['tagnode'] : IKEGroup(v) for v in ike_group}
            esp_groups = {v['tagnode'] : ESPGroup(v) for v in esp_group}
        else:
            ike_groups = None
            esp_groups = None
            cfg_ipsec_ra_vpn_server_v1 = None

        IPsecRAVPNServer(self.vs, cfg_ipsec_ra_vpn_server_v1, ike_groups, esp_groups).sync()


class Config(vci.Config):
    conf = {}

    def set(self, input):
        self.cfg = input
        CharonConfigSync(self.cfg).sync()
        return

    def get(self):
        return self.cfg

    def check(self, input):
        # TODO: is there a charon VICI check function?
        # Or noop parameter? Or do we have to write our own VICI checker?
        return

def ipsec_ra_server_reset_client(input):
    params = []

    profile = input.get(NS_IPSEC_RA_SERVER_PREFIX + 'profile')
    peer = input.get(NS_IPSEC_RA_SERVER_PREFIX + 'peer')

    vs = setup_vici()

    for sa in vs.list_sas():
        # There might be more then one IKE SA for a certain remote-host or remote-id
        for conn_id in sa:

            if not conn_id.startswith(IPSEC_RA_SERVER_PREFIX + profile):
                continue

            ike_sa = sa[conn_id]

            if dec(ike_sa.get('remote-host')) != peer and dec(ike_sa.get('remote-id')) != peer:
                continue

            param = OrderedDict()
            param['ike-id'] = ike_sa.get('uniqueid')
            # Don't wait for client reposnes.
            param['force'] = True
            params.append(param)

    if len(params):
        for param in params:
            for log in vs.terminate(param):
                pass

    return {}


def clear_vpn_x509_status(input):

    vs = setup_vici()

    for cert_type in ['X509_CRL', 'OCSP_RESPONSE']:
        param = OrderedDict()
        param['type'] = cert_type
        vs.flush_certs(param)

    return {}


if __name__ == "__main__":
    PP = pprint.PrettyPrinter(indent=2)
    syslog.openlog('vci-security-vpn-ipsec', facility=syslog.LOG_DAEMON)

    os.system("systemctl start strongswan.service")

    if os.getenv("TESTFILE"):
        with open(os.getenv("TESTFILE")) as fd:
            top_cfg = json.load(fd)
        c = Config()
        c.set(top_cfg)
        sys.exit(0)

    (vci.Component("net.vyatta.vci.ipsec")
            .model(vci.Model("net.vyatta.vci.ipsec.v1")
                   .config(Config())
                   .rpc("vyatta-security-vpn-ipsec-v1", "clear-vpn-x509-status", clear_vpn_x509_status)
                   .rpc("vyatta-security-vpn-ipsec-remote-access-server-v1", "reset-client", ipsec_ra_server_reset_client))
            .run()
            .wait())
