#!/usr/bin/env python3
#
# Copyright (c) 2017-2019, AT&T Intellectual Property.  All rights reserved.
# Copyright (c) 2015-2017 Brocade Communications Systems, Inc.
# All Rights Reserved.

# SPDX-License-Identifier: GPL-2.0-only

import configparser
import re
import subprocess
import sys
import vplaned
import zmq

conf_file = '/etc/vyatta/controller.conf'

def bswap32(x):
    return (((x >> 24) & 0x000000ff) | ((x >>  8) & 0x0000ff00) |
            ((x <<  8) & 0x00ff0000) | ((x << 24) & 0xff000000))

def get_controller_ip():
    config = configparser.ConfigParser()
    if config is None:
        sys.stdout.write('failed to open configuration\n')
        sys.exit(0)
    config.read(conf_file)
    return config['Controller']['ip']

def get_unique_reqid(reqid, spi):
    stroke = subprocess.Popen(['ipsec', 'stroke', 'statusall'], stdout=subprocess.PIPE)
    out, err = stroke.communicate()
    out = out.decode("ascii")
    sa_section = False

    if sys.byteorder == 'little':
       spi = bswap32(spi)

    hex_spi = str(hex(spi))

    # strip 0x
    hex_spi = hex_spi[2:]
    for line in out.splitlines():
        if sa_section:
           m = re.search(r"{(\d+)}:  \w+, \w+, reqid (\d+), ESP.* SPIs: ([0-9a-f]+)_i ([0-9a-f]+)_o", line)
           if m:
              uniq_reqid = int(m.group(1))
              matched_reqid = int(m.group(2))
              spi_in = m.group(3)
              spi_out = m.group(4)
              if matched_reqid != reqid:
                 continue
              if spi_in == hex_spi or spi_out == hex_spi:
                 return uniq_reqid
        elif line.startswith("Security Associations"):
             sa_section = True

    sys.stdout.write('unique reqid not found for reqid: {}\n'.format(reqid))
    return None

def rekey_and_replace(operands):
    orig_reqid = operands['reqid']
    spi = operands['SPI']

    reqid = get_unique_reqid(orig_reqid, spi)

    if reqid is not None:
        sys.stdout.write('rekeying reqid {}\n'.format(reqid))
        subprocess.call(['ipsec', 'stroke', 'rekey', '{{{}}}'.format(reqid)])
        subprocess.call(['ipsec', 'stroke', 'down', '{{{}}}'.format(reqid)])
    sys.stdout.flush()


def main():
    controller_ip = get_controller_ip()

    context = zmq.Context()
    receiver = context.socket(zmq.PULL)

    partial_url = 'tcp://{}'.format(controller_ip)
    port = receiver.bind_to_random_port(partial_url)
    full_url = '{}:{}'.format(partial_url, port)

    with vplaned.Controller() as ctrl:
        ctrl.store('ipsec listener',
                   'ipsec listener {}'.format(full_url),
                   action='SET')

    sys.stdout.write('listening on {}\n'.format(full_url))
    sys.stdout.flush()

    while True:
        msg = receiver.recv_json()
        for rator, rands in msg.items():
            if rator == 'REKEY':
                rekey_and_replace(rands)
            else:
                sys.stdout.write('unknown operator {}\n'.rator)

if __name__ == '__main__':
        main()
