#!/usr/bin/env python3
#
# Copyright (c) 2017-2020 AT&T Intellectual Property.
# All rights reserved.

# SPDX-License-Identifier: GPL-2.0-only

import os
import sys
import socket
import re
import getopt
import pprint

from enum import Enum

import dbus
import vici

from vplaned import Controller
from vyatta import configd

from vyatta.vpn.ipsec.config import setup_dbus

pp = pprint.PrettyPrinter()


vici_socket = '/var/run/charon.vici'

IPSEC_RA_CLIENT_PREFIX = 'ipsec_ra_client-'
IPSEC_RA_SERVER_PREFIX = 'ipsec-remote-access-server-'

DBUS_INTERFACE = 'net.vyatta.eng.security.vpn.ipsec'

class RARole(Enum):
    CLIENT = 1
    SERVER = 2

class Mode(Enum):
    BRIEF = 1
    DETAILED = 2
    STATS = 3

def bswap32(x):
    return (((x >> 24) & 0x000000ff) | ((x >>  8) & 0x0000ff00) |
            ((x <<  8) & 0x00ff0000) | ((x << 24) & 0xff000000))

def conv_proto_port(val):
    m = re.match(r'\[([^\/]*)\/([^\/]*)\]', val)
    return m.group(1), m.group(2)

def conv_ts(val):

    try:
        val = dec(val[0])
    except BaseException:
        return 'n/a', 'n/a', 'n/a'

    proto = 'all'
    port = 'all'

    m = re.match(r'([^\s\[]*)(\[.*\])?', val)
    net = m.group(1)
    if m.group(2):
        proto, port = conv_proto_port(m.group(2))

    return net, proto, port

def conv_blocked(val):
    if val == 1:
        return 'yes'
    elif val == 0:
        return 'no'
    else:
        return 'n/a'

def conv_bytes(kern_bytes, dp_bytes):
    bytes_val = 0

    if kern_bytes:
        bytes_val += kern_bytes

    if dp_bytes:
        bytes_val += dp_bytes

    suffix = ''
    if bytes_val >= 1073741824:
        bytes_val /= 1073741824
        suffix = 'G'
    elif bytes_val >= 1048576:
        bytes_val /= 1048576
        suffix = 'M'
    elif bytes_val >= 1024:
        bytes_val /= 1024
        suffix = 'K'

    bytes_val = '{:.1f}{}'.format(bytes_val, suffix)
    return bytes_val

def conv_enc(cipher, cipher_keysize):
    if not cipher:
        return 'n/a'

    if cipher.startswith('3DES'):
        return '3des'

    if not cipher_keysize:
        return 'n/a'

    cipher_map = {}
    cipher_map['AES_GCM_16'] = 'AES_GCM_128'

    if cipher_map.get(cipher):
        cipher = cipher_map.get(cipher)

    cipher = cipher + '_' + cipher_keysize
    cipher = cipher.split('_')
    if len(cipher) == 4:
        return '{}{}{}{}'.format(
            cipher[0].lower(),
            cipher[3],
            cipher[1].lower(),
            int(cipher[2].lower()))
    elif len(cipher) == 3:
        return '{}{}'.format(
            cipher[0].lower(),
            cipher[2])
    return 'n/a'

def conv_prf(prf):
    if not prf:
        return 'null'

    if prf.startswith('AUTH_'):
        prf = prf.replace('AUTH_', '')
    elif prf.startswith('PRF_'):
        prf = prf.replace('PRF_', '')

    if prf.startswith('HMAC_'):
        prf = prf.replace('HMAC_', '')

    prf = prf.lower()

    # legacy hash tokens for show commands as with strongswan 4.x
    #
    # src/libstrongswan/crypto/proposal/proposal_keywords_static.c
    # static const struct proposal_token wordlist[] =
    hash_map = {}
    hash_map['md5_96'] = 'md5'
    hash_map['sha1_96'] = 'sha1'
    hash_map['sha2_256_128'] = 'sha2_256'
    hash_map['sha2_384_192'] = 'sha2_384'
    hash_map['sha2_512_256'] = 'sha2_512'
    hash_map['aes_xcbc_96'] = 'aesxcbc'

    if hash_map.get(prf):
        return hash_map.get(prf)

    return prf

def conv_dh_group(dhgrp):
    if dhgrp == 'MODP_768':
        dh_group = 1
    elif dhgrp == 'MODP_1024':
        dh_group = 2
    elif dhgrp == 'MODP_1536':
        dh_group = 5
    elif dhgrp == 'MODP_2048':
        dh_group = 14
    elif dhgrp == 'MODP_3072':
        dh_group = 15
    elif dhgrp == 'MODP_4096':
        dh_group = 16
    elif dhgrp == 'MODP_6144':
        dh_group = 17
    elif dhgrp == 'MODP_8192':
        dh_group = 18
    elif dhgrp == 'ECP_256':
        dh_group = 19
    elif dhgrp == 'ECP_384':
        dh_group = 20
    else:
        dh_group = 'n/a'

    return dh_group

def get_dataplane_sad(vrf):
    try:
        with Controller() as controller:
            for dp in controller.get_dataplanes():
                with dp:
                    if vrf:
                        cmd = 'ipsec sad vrf_id {}'.format(vrf['id'])
                    else:
                        cmd = 'ipsec sad'

                    return dp.json_command(cmd)
    except BaseException:
        pass

def get_ipsec_ra_client_tunnel(conn):
    match = re.search(r'-tunnel-(\d+)$', conn)
    if match:
        return match.group(1)

    return None

def configd_client():
    try:
        client = configd.Client()
    except configd.FatalException as e:
        print("can't connect to configd: {}".format(e))
        sys.exit(1)

    return client

def configd_query(cfgpath):
    client = configd_client()

    try:
        return client.tree_get_dict(cfgpath, configd.Client.RUNNING, "json")
    except BaseException:
        return None

def dec(string):
    if not string:
        return None
    return string.decode('utf-8')

def dec_for_print(string):
    ret = dec(string)
    if not ret:
        return 'n/a'
    return ret

def conv_ip(string):
    ret = dec(string)
    if not ret:
        return 'n/a'

    if ret.startswith('@') or ret == '%any':
        ret = '0.0.0.0'

    return ret

def enc(string):
    if not string:
        return None
    return string.encode('utf-8')

def get_state(val):
    if val == 'INSTALLED' or  val == 'ESTABLISHED':
        return 'up'
    else:
        # In the stroke version the else case was undefined
        return 'down'

def dump_ike_sa(profile_config, ike_sa):
    if ike_sa.get('child-sas'):
        del ike_sa['child-sas']

    #pp.pprint(ike_sa)

    local_host = conv_ip(ike_sa.get('local-host'))
    remote_host = conv_ip(ike_sa.get('remote-host'))

    header = """
Peer ID / IP                            Local ID / IP
------------                            -------------
{: <39} {: <39}
"""
    print(header.format(remote_host, local_host))

    if profile_config.get('description'):
        print('Description: {}'.format(profile_config.get('description')))

    header = """
    State    Encrypt       Hash    D-H Grp  A-Time  L-Time IKEv
    -----  ------------  --------  -------  ------  ------ ----
"""
    print(header, end='')

    state = get_state(dec(ike_sa.get('state')))

    ikelifetime = profile_config.get('ike-lifetime')

    ikeatime = dec_for_print(ike_sa.get('established'))

    cipher = conv_enc(dec(ike_sa.get('encr-alg')), dec(ike_sa.get('encr-keysize')))
    prf = conv_prf(dec(ike_sa.get('prf-alg')))
    dhgrp = conv_dh_group(dec(ike_sa.get('dh-group')))
    ike_ver = dec_for_print(ike_sa.get('version'))

    print('    {:<6} {:<13} {:<9} {:<8} {:<7} {:<7} {:<2}'.format(
        state, cipher, prf, dhgrp, ikeatime, ikelifetime, ike_ver))

def dump_child_sa(profile_config, child_sas, ike_sa, mode, vrf):

    local_host = conv_ip(ike_sa.get('local-host'))
    local_id = dec_for_print(ike_sa.get('local-id'))
    remote_host = conv_ip(ike_sa.get('remote-host'))
    remote_id = dec_for_print(ike_sa.get('remote-id'))
    nat_local = conv_ip(ike_sa.get('nat-local'))
    nat_remote = conv_ip(ike_sa.get('nat-remote'))

    # BRIEF / STATS header
    if mode == Mode.BRIEF or mode == Mode.STATS:
        header_brief = """
Peer ID / IP                            Local ID / IP
------------                            -------------
{: <39} {: <39}
"""
        print(header_brief.format(remote_host, local_host))

    # DETAILED header
    elif mode == Mode.DETAILED:
        header_detailed = """
------------------------------------------------------------------
Peer IP:\t\t{peerip}
Peer ID:\t\t{peerid}
Local IP:\t\t{localip}
Local ID:\t\t{localid}
NAT Traversal:\t\t{natt}
NAT Source Port:\t{natsrc}
NAT Dest Port:\t\t{natdst}
""".format(peerip=remote_host, peerid=remote_id, localip=local_host, localid=local_id,
           natt='n/a', natsrc=nat_local, natdst=nat_remote)

        print(header_detailed)

    if profile_config.get('description'):
        print('Description: {}'.format(profile_config.get('description')))

    if mode == Mode.BRIEF:
        header_brief = """
    Tunnel  Id          State  Bytes Out/In   Encrypt       Hash      DH A-Time  L-Time
    ------  ----------  -----  -------------  ------------  --------  -- ------  ------
"""
        print(header_brief, end='')


    elif mode == Mode.STATS:
        header_stats = """
  Tunnel Id         Dir Source Network               Destination Network          Bytes
  ------ ---------- --- --------------               -------------------          -----
"""
        print(header_stats, end='')


    # dataplane SAD
    sad = get_dataplane_sad(vrf)
    dp_sas = {}
    for sa in sad['sas']:
        spi = sa['spi']
        if sys.byteorder == 'little':
            spi = bswap32(int(spi, 16))

        hex_spi = '{:08x}'.format(spi)
        dp_sas[hex_spi] = sa

    for child_sa in child_sas:
        #pp.pprint(child_sas[child_sa])
        sa = child_sas[child_sa]

        spi_in = dec(sa.get('spi-in'))
        spi_out = dec(sa.get('spi-out'))

        # stats from dataplane and kernel (as of today, never really mixed)
        dp_spi_in = {}
        dp_spi_out = {}


        # Dataplane SAD is queried per VRF ID.
        # VICI child_sas is listing all SAs regardless of the VRF ID.
        # Only display SAs which result of the dataplane SAD query,
        # since this limits the result to the queried VRF/routing-instance.
        if dp_sas.get(spi_in):
            dp_spi_in = dp_sas.get(spi_in)
        else:
            continue

        if dp_sas.get(spi_out):
            dp_spi_out = dp_sas.get(spi_out)
        else:
            continue

        if sa.get('name'):
            child_name = dec_for_print(sa.get('name'))
        # Handle strongswan version < 5.5.2
        else:
            child_name = child_sa

        tunnelnum = int(get_ipsec_ra_client_tunnel(child_name))
        sa_id = dec(sa.get('uniqueid'))
        state = get_state(dec(sa.get('state')))

        atime = dec(sa.get('install-time'))

        # total lifetime = install-time + life-time
        life = int(dec(sa.get('life-time'))) + int(atime)

        lsnet, lproto, lport = conv_ts(sa.get('local-ts'))
        rsnet, rproto, rport = conv_ts(sa.get('remote-ts'))


        cipher = conv_enc(dec(sa.get('encr-alg')), dec(sa.get('encr-keysize')))
        integ = conv_prf(dec(sa.get('integ-alg')))

        pfsgrp = conv_dh_group(dec(sa.get('dh-group')))

        # stats from kernel
        bytes_in = int(dec(sa.get('bytes-in')))
        bytes_out = int(dec(sa.get('bytes-out')))

        bytes_in = conv_bytes(bytes_in, dp_spi_in.get('bytes'))
        bytes_out = conv_bytes(bytes_out, dp_spi_out.get('bytes'))
        bytesp = '{}/{}'.format(bytes_out, bytes_in)

        if state == 'up':
            if dp_spi_in.get('blocked') or dp_spi_out.get('blocked'):
                state = 'down'

        # blocked
        inblocked = conv_blocked(dp_spi_in.get('blocked'))
        outblocked = conv_blocked(dp_spi_out.get('blocked'))

        if mode == Mode.BRIEF:
            print('    {:<7} {:<11} {:<6} {:<14} {:<13} {:<9} {:<2} {:<7} {:<7}'.format(
                tunnelnum, sa_id, state, bytesp, cipher, integ, pfsgrp, atime, life))
        elif mode == Mode.DETAILED:
            detailed_tunnel_1 = """
    Tunnel {tunnum}:
        State:\t\t\t{state}
        Id:\t\t\t{saId}
        Inbound SPI:\t\t{inspi}
        Outbound SPI:\t\t{outspi}
        Encryption:\t\t{cipher}
        Hash:\t\t\t{integ}
        DH Group:\t\t{pfsgrp}""".format(tunnum=tunnelnum, state=state, saId=sa_id,
                                        inspi=spi_in, outspi=spi_out, cipher=cipher,
                                        integ=integ, pfsgrp=pfsgrp)

            print(detailed_tunnel_1)

            detailed_tunnel_2 = """
        Local Net:\t\t{srcnet}
        Local Protocol:\t\t{lproto}
        Local Port: \t\t{lport}

        Remote Net:\t\t{dstnet}
        Remote Protocol:\t{rproto}
        Remote Port: \t\t{rport}

        Inbound Bytes:\t\t{inbytes}
        Outbound Bytes:\t\t{outbytes}

        Inbound Blocked:\t{inblocked}
        Outbound Blocked:\t{outblocked}

        Active Time (s):\t{atime}
        Lifetime (s):\t\t{life}""".format(srcnet=lsnet, lproto=lproto, lport=lport,
                                          dstnet=rsnet, rproto=rproto, rport=rport,
                                          inbytes=bytes_in, outbytes=bytes_out,
                                          inblocked=inblocked, outblocked=outblocked,
                                          atime=atime, life=life)
            print(detailed_tunnel_2)

        elif mode == Mode.STATS:
            print('  {:<6} {:<10} {:<3} {:<28} {:<28} {:<8}'.format(
                tunnelnum, sa_id, 'in', rsnet, lsnet, bytes_in))
            print('  {:<6} {:<10} {:<3} {:<28} {:<28} {:<8}'.format(
                tunnelnum, sa_id, 'out', lsnet, rsnet, bytes_out))

    print("\n\n")

def get_ipsec_ra_configs(role):

    ret = {}

    if role == RARole.CLIENT:
        role_name = "remote-access-client"
    elif role == RARole.SERVER:
        role_name = "remote-access-server"
    else:
        raise ValueError

    config = configd_query('security vpn ipsec {} profile'.format(role_name))
    if not config:
        return {}

    for p in config.get('profile'):

        ike_group = {}

        if p.get('ike-group'):
            ike_group = configd_query('security vpn ipsec ike-group {}'.format(p.get('ike-group')))

        profile_config = {
            "description": p.get('description'),
            "profile-name": p.get('profile-name'),
            "ike-lifetime": ike_group.get('lifetime'),
            "serveraddr": [],
            }

        if role == RARole.SERVER:
            connid = IPSEC_RA_SERVER_PREFIX + p.get('profile-name')
            ret[connid] = profile_config
            continue

        for server in p.get('server'):
            serveraddr = server.get('serveraddr')
            profile_config['serveraddr'].append(serveraddr)
            if not server.get('source-interface'):
                connid = "{}{}-{}".format(IPSEC_RA_CLIENT_PREFIX, p.get('profile-name'), serveraddr)
                ret[connid] = profile_config
                continue

            for source_interface in server.get('source-interface'):
                ifname = source_interface.get('ifname')
                connid = "{}{}-{}-{}".format(IPSEC_RA_CLIENT_PREFIX, p.get('profile-name'),
                                             ifname, serveraddr)
                ret[connid] = profile_config

    return ret

def show_ipsec_sa(mode=Mode.BRIEF, vrf=None, show_ike=False, peer=None):

    vs = setup_vici()

    ra_client = get_ipsec_ra_configs(RARole.CLIENT)
    ra_server = get_ipsec_ra_configs(RARole.SERVER)
    configs = {**ra_client, **ra_server}

    for connid in configs:
        profile_config = configs[connid]
        filter_ike = {'ike': connid}
        for list_sa_event in vs.list_sas(filter_ike):
            for ike_config in list_sa_event:
                ike_sa = list_sa_event[ike_config]


                if peer and peer != conv_ip(ike_sa.get('remote-host')):
                    continue

                # IKE_SA
                if show_ike:
                    dump_ike_sa(profile_config, ike_sa)
                    continue

                # CHILD_SA
                dump_child_sa(profile_config, ike_sa.get('child-sas'), ike_sa, mode, vrf)


def show_ike_sa(vrf=None, peer=None):
    show_ipsec_sa(vrf=vrf, show_ike=True, peer=peer)

def get_ra_server_peers_for_cli():
    vs = setup_vici()
    try:
        vici_ike = vs.list_sas()
    except BaseException:
        vici_ike = []

    peers = {}

    for list_conn_event in vici_ike:
        for ike_conn_name in list_conn_event:
            if not ike_conn_name.startswith(IPSEC_RA_SERVER_PREFIX):
                continue

            ike_sa = list_conn_event[ike_conn_name]
            rhost = dec(ike_sa.get('remote-host'))

            if rhost:
                peers[rhost] = True

            # Printing full X.509 seems to be unfeasible,
            # too long string on CLI.
            #rid = dec(ike_sa.get('remote-id'))
            #if rid:
            #    peers[rid] = True

    for peer in peers:
        print(peer)


def get_ra_client_peers_for_cli():
    peers = {}

    configs = get_ipsec_ra_configs(RARole.CLIENT)

    for connid in configs:
        config = configs[connid]
        for n in config.get('serveraddr'):
            peers[n] = True

    for peer in peers:
        print(peer)

def get_ipsec_ra_profiles(role):
    configs = get_ipsec_ra_configs(role)
    for connid in configs:
        print(configs[connid].get('profile-name'))

def reset_ipsec_ra_client_profile(profile):

    conn_id = IPSEC_RA_CLIENT_PREFIX + profile

    try:
        service = setup_dbus()
    except ConnectionRefusedError as e:
        print("Failed to connect to IKE SA daemon\n")
        sys.exit(1)

    try:
        service.reset([conn_id], dbus_interface=DBUS_INTERFACE)
    except dbus.DBusException as e:
        print("Failed to trigger reset of remote-access-client profile \"{}\": {}"
              .format(profile, e))
        sys.exit(1)

def reset_ipsec_ra_server_peer(profile, peer):
    client = configd_client()
    param = {}
    param['profile'] = profile
    param['peer'] = peer
    client.call_rpc_dict("vyatta-security-vpn-ipsec-remote-access-server-v1", "reset-client", param)

def flush_certs():
    client = configd_client()
    param = {}
    client.call_rpc_dict("vyatta-security-vpn-ipsec-v1", "clear-vpn-x509-status", param)

def setup_vici():
    s = socket.socket(socket.AF_UNIX)
    s.connect(vici_socket)
    return vici.Session()

def usage():
    print('\nusage: {} [--show-(ike|ipsec)-sa=(brief|detail|stats)] [--vrf=<vrfname>]'
          .format(os.path.basename(sys.argv[0])))
    print('\t\t[--get-ra-client-peers-for-cli]')
    print('\t\t[--show-(ike|ipsec)-sa-peer=<peer>]')
    print('\t\t[--get-all-ipsec-ra-client-profiles]')
    print('\t\t[--get-all-ipsec-ra-server-profiles]')
    print('\t\t[--reset-ipsec-ra-client-profile=<profile>]')
    print('\t\t[--flush-certs]')

if __name__ == '__main__':
    try:
        opts, args = getopt.getopt(sys.argv[1:], '',
                                   ['show-ike-sa=', 'vrf=', 'profile=',
                                    'show-ipsec-sa=', 'show-ike-sa-peer=',
                                    'show-ipsec-sa-peer=',
                                    'show-ipsec-sa-stats-peer=',
                                    'show-ipsec-sa-peer-detail=',
                                    'get-all-ipsec-ra-client-profiles',
                                    'get-all-ipsec-ra-server-profiles',
                                    'reset-ipsec-ra-client-profile=',
                                    'reset-ipsec-ra-server-peer=',
                                    'get-ra-server-peers-for-cli',
                                    'get-ra-client-peers-for-cli',
                                    'flush-certs'])
    except getopt.GetoptError as err:
        print(err)
        usage()
        sys.exit(2)

    if not os.path.isfile('/var/run/charon.pid'):
        sys.exit(0)


    vrf_arg = None
    profile_arg = None

    for o, a in opts:
        if o in '--profile':
            profile_arg = a
        elif o in '--vrf':
            try:
                vrf_ifindex_path = '/sys/class/net/vrf{}/ifindex'.format(a)
                with open(vrf_ifindex_path, mode='r', encoding='utf-8') as f:
                    vrf_id = int(f.read().strip())
            except EnvironmentError as e:
                print("Could not find routing-instance \"{}\".".format(a))
                sys.exit(1)

            vrf_arg = {'name': a, 'id': vrf_id}

    for o, a in opts:
        if o in '--show-ike-sa':
            if a == 'brief':
                show_ike_sa(vrf=vrf_arg)
            elif a == 'detail':
                print("Detailed IKE SA show command not implemented.")
            elif a == 'stats':
                print("Statistic IKE SA show command not implemented.")
        elif o in '--show-ipsec-sa':
            if a == 'brief':
                show_ipsec_sa(mode=Mode.BRIEF, vrf=vrf_arg)
            elif a == 'detail':
                show_ipsec_sa(mode=Mode.DETAILED, vrf=vrf_arg)
            elif a == 'stats':
                show_ipsec_sa(mode=Mode.STATS, vrf=vrf_arg)
        elif o in '--show-ike-sa-peer':
            show_ike_sa(peer=a)
        elif o in '--show-ipsec-sa-peer':
            show_ipsec_sa(peer=a)
        elif o in '--show-ipsec-sa-stats-peer':
            show_ipsec_sa(peer=a, mode=Mode.STATS)
        elif o in '--show-ipsec-sa-peer-detail':
            show_ipsec_sa(peer=a, mode=Mode.DETAILED)
        elif o in '--get-ra-client-peers-for-cli':
            get_ra_client_peers_for_cli()
        elif o in '--get-ra-server-peers-for-cli':
            get_ra_server_peers_for_cli()
        elif o in '--get-all-ipsec-ra-client-profiles':
            get_ipsec_ra_profiles(RARole.CLIENT)
        elif o in '--get-all-ipsec-ra-server-profiles':
            get_ipsec_ra_profiles(RARole.SERVER)
        elif o in '--reset-ipsec-ra-client-profiles':
            reset_ipsec_ra_client_profile(a)
        elif o in '--reset-ipsec-ra-server-peer':
            reset_ipsec_ra_server_peer(profile_arg, a)
        elif o in '--flush-certs':
            flush_certs()
        elif o in ('--vrf', '--profile'):
            # skip, already processed.
            continue
        else:
            assert False, 'unhandled option'
