#!/usr/bin/python3 -u
# -*- coding: utf-8 -*-
# Copyright (c) 2019-2021, AT&T Intellectual Property.  All rights reserved.
#
# SPDX-License-Identifier: LGPL-2.1-only

import zmq
import select
import subprocess
import sys
from vyatta.platform.sfpmgr import SfpStateManager
from vyatta.phy.phy import PhyBus
from vyatta.phy.basephy import PhyException
from vyatta.platform.detect import PlatformError, detect
from vyatta.platform.basesfphelper import ModuleNotPresentException
import configparser
from collections import defaultdict
import os

sfpd_presence_file = '/var/run/vyatta/sfpd-presence'

#
# EEPROM offsets per type
#
eeprom_fields = defaultdict(lambda: defaultdict(dict))
eeprom_fields['SFP'] = {'start':20, 'length':40}
eeprom_fields['SFP']['v_name'] = {'start':0,  'end':16, 'format':'utf-8' }
eeprom_fields['SFP']['v_oui']  = {'start':17, 'end':20, 'format':'hex' }
eeprom_fields['SFP']['v_part'] = {'start':20, 'end':36, 'format':'utf-8' }
eeprom_fields['SFP']['v_rev']  = {'start':36, 'end':40, 'format':'utf-8'}

eeprom_fields['QSFP'] = {'start':148, 'length':38}
eeprom_fields['QSFP']['v_name'] = {'start':0,  'end':16, 'format':'utf-8' }
eeprom_fields['QSFP']['v_oui']  = {'start':17, 'end':20, 'format':'hex' }
eeprom_fields['QSFP']['v_part'] = {'start':20, 'end':36, 'format':'utf-8' }
eeprom_fields['QSFP']['v_rev']  = {'start':36, 'end':38, 'format':'utf-8' }

class SfpDaemon(object):
    def __init__(self, pub_endpoint, rep_endpoint, req_endpoint, monitor_endpoint, helper_module):
        self._ctx = zmq.Context.instance()
        self._ctx.IPV6 = 1
        self._ctx.LINGER = 0
        self.restarted = False
        self.mismatch = False
        self.epoch = 0
        self.boot_scan_end_time = 0;
        self.boot_walk_complete_notified = False
        self.pub_endpoint = pub_endpoint
        self.rep_endpoint = rep_endpoint
        self.req_endpoint = req_endpoint
        self.monitor_endpoint = monitor_endpoint
        self.sfphelper = helper_module.new_helper(self)
        self.sfpmgr = SfpStateManager(self.pub_endpoint, self.rep_endpoint,
                                      self.req_endpoint, self.sfphelper)
        self.sfp_presence = defaultdict(lambda: defaultdict(dict))
        self.monitor_socket = self._ctx.socket(zmq.PUB)
        self.monitor_socket.bind(monitor_endpoint)
        if monitor_endpoint.startswith("ipc://"):
            # Make it user/group readable/writable so it's possible
            # for clients not running as the same user to use it
            os.chmod(monitor_endpoint[6:], 0o770)

    def get_vendor_field(self, data, porttype, fieldname):
        start  = eeprom_fields[porttype][fieldname]['start']
        end    = eeprom_fields[porttype][fieldname]['end']
        format = eeprom_fields[porttype][fieldname]['format']

        if format == 'utf-8':
            return data[start:end].decode('utf-8').rstrip()
        elif format == 'hex':
            return data[start:end].hex()
        else:
            return None

    def get_vendor_data(self, porttype, port):
        start = eeprom_fields[porttype]['start']
        length = eeprom_fields[porttype]['length']
        try:
            data = self.sfphelper.read_eeprom(porttype, port, offset=start, length=length)
        except ModuleNotPresentException:
            print('Failed to read vendor data from EEPROM: {} {}\n'.format(porttype, port))
            return 'Unknown', '000000', 'Unknown', 'Unknown'

        return self.get_vendor_field(data, porttype, 'v_name'), \
            self.get_vendor_field(data, porttype, 'v_oui'), \
            self.get_vendor_field(data, porttype, 'v_part'), \
            self.get_vendor_field(data, porttype, 'v_rev')

    def write_presence_file(self):
        presence_file = configparser.ConfigParser()
        presence_file.add_section('epoch')
        presence_file['epoch']['value'] = str(self.epoch)
        presence_file.add_section('boot_scan_end_time')
        presence_file['boot_scan_end_time']['value'] = str(self.boot_scan_end_time)
        for typekey, typeval in self.sfp_presence.items():
            for portkey, portval in typeval.items():
                section_int = portval['port']
                if portval['port_type'] == 'QSFP':
                    section_int = section_int + 128
                section = str(section_int)
                presence_file.add_section(section)

                presence_file[section]['port_name']      = portval['port_name']
                presence_file[section]['vendor_name']    = portval['vendor_name']
                presence_file[section]['vendor_oui']     = portval['vendor_oui']
                presence_file[section]['part_id']        = portval['vendor_part_id']
                presence_file[section]['vendor_rev']     = portval['vendor_rev']
                presence_file[section]['detection_time'] = portval['time']

        with open(sfpd_presence_file, 'w') as f:
            os.chmod(sfpd_presence_file, 0o644)
            presence_file.write(f)

        print('Notifying dataplane of SFP presence update\n')
        self.monitor_socket.send_string('SFP_PRESENCE_NOTIFY')

    def read_presence_file(self):
        '''
        After daemon restart, re-read the presence file
        '''
        print('Reading SFP presence file\n')

        existing = configparser.ConfigParser()
        existing.read(sfpd_presence_file)
        for section in existing.sections():
            if section == 'epoch':
                previous_epoch = int(existing['epoch']['value'])
                self.epoch = previous_epoch + 1
                continue
            if section == 'boot_scan_end_time':
                self.boot_scan_end_time = int(
                    existing['boot_scan_end_time']['value'])
                continue

            port = int(section)
            if port < 128:
                porttype = 'SFP'
            else:
                porttype = 'QSFP'
                port -= 128

            pinfo = self.sfp_presence[porttype][port]
            pinfo['port'] = port
            pinfo['port_type'] = porttype
            pinfo['port_name'] = existing[section]['port_name']
            pinfo['vendor_name'] = existing[section]['vendor_name']
            pinfo['vendor_oui'] = existing[section]['vendor_oui']
            pinfo['vendor_part_id'] = existing[section]['part_id']
            pinfo['vendor_rev'] = existing[section]['vendor_rev']
            pinfo['time'] = existing[section]['detection_time']
            pinfo['epoch'] = previous_epoch

    def sweep_stale(self, stype, port):
        print('Sweeping stale entry {} {}\n'.format(stype, port))
        del self.sfp_presence[stype][port]
        self.swept = True

    def new_epoch_sweep(self):
        '''
        After sfpd restart and boot walk is complete, purge stale entries
        from presence dict.
        '''
        self.swept = False

        # dict comprehension to generate stale dict from presence dict
        stale = {
            typekey: {
                 portkey:portval for portkey, portval in typeval.items() if portval['epoch'] < self.epoch
            }
            for typekey,typeval in self.sfp_presence.items()
        }

        # dict comprehension to delete stale entries from presence dict
        temp = {
            typekey: {
                portkey: self.sweep_stale(typekey, portkey) for portkey, portval in typeval.items()
            }
            for typekey,typeval in stale.items()
        }

        return self.swept

    def boot_walk_complete(self):
        self.boot_walk_complete_notified = True
        save = True
        if self.restarted:
            self.restarted = False
            save = self.new_epoch_sweep() or self.mismatch
        else:
            with open('/proc/uptime') as f:
                for line in f:
                    self.boot_scan_end_time = line.split('.')[0]

        if save:
            self.write_presence_file()

        self.mismatch = False

    def record_presence_change(self, portname, porttype, port, presence):
        print('Record presence {} for port {}\n'.format(presence, portname))
        if port in self.sfp_presence[porttype].keys():
            exists = True
            if presence:
                old_pinfo = self.sfp_presence[porttype][port].copy()
        else:
            exists = False

        if presence:
            with open('/proc/uptime') as f:
                for line in f:
                    seconds_since_boot = line.split('.')[0]

            vname, oui, part, rev = self.get_vendor_data(porttype, port)

            pinfo = self.sfp_presence[porttype][port]
            pinfo['port'] = port
            pinfo['port_type'] = porttype
            pinfo['port_name'] = 'dp0' +  portname
            pinfo['vendor_name'] = vname
            pinfo['vendor_oui'] = '{}.{}.{}'.format(oui[0:2], oui[2:4], oui[4:6])
            pinfo['vendor_part_id'] = part
            pinfo['vendor_rev'] = rev
            pinfo['epoch'] = self.epoch
        else:
            if exists:
                del self.sfp_presence[porttype][port]

        #
        # If this is a restart and the dictionary entry already exists,
        # there is a mismatch if the details (apart from epoch) are not
        # the same when entry inserted or existing entry removed.
        # Save and notify if mismatch.
        #
        if presence:
            if self.restarted and exists:
                old_pinfo['epoch'] = pinfo['epoch']
                if pinfo != old_pinfo:
                    self.mismatch = True
                    print('Port {} mismatch since daemon restart\n'.format(portname))
                    pinfo['time'] = seconds_since_boot
            else:
                pinfo['time'] = seconds_since_boot
        else:
            if self.restarted and exists:
                self.mismatch = True

        if self.boot_walk_complete_notified:
            self.write_presence_file()

    def on_sfp_presence_change(self, portname, porttype, port, presence,
                               extra_state=None):
        '''
        Called when sfphelper detects that the presence of a port has
        changed
        '''
        _, _, part, _ = self.get_vendor_data(porttype, port);

        print("%s: %s %s has been %s" % ("dp0" + portname, porttype, part, "inserted" if presence else "removed"), flush=True)
        self.sfpmgr.on_sfp_presence_change(portname, porttype, port, presence, extra_state)
        self.record_presence_change(portname, porttype, port, presence)

    def on_file_event(self, file, event):
        '''
        Called when a file event is triggered by the sfphelper's
        main_loop function
        '''
        if file == self.sfpmgr.get_rep_socket_fd():
            self.sfpmgr.process_rep_socket()
        else:
            raise Exception("unexpected event for file {}".format(file))

    def main(self):
        #
        # If our presence file is already there, this must be a restart.
        #
        self.restarted = os.path.exists(sfpd_presence_file)
        if self.restarted:
            self.read_presence_file()
        self.sfphelper.main_loop([(self.sfpmgr.get_rep_socket_fd(), select.POLLIN)])

def main():
    if len(sys.argv) != 5:
        print("\nUsage: " + sys.argv[0] + " <pub_endpoint> <rep_endpoint> <req_endpoint> <monitor_endpoint>")
        sys.exit(1)

    helper_module = None
    try:
        platform = detect()
        helper_module = platform.get_sfp_helper_module()
    except (AttributeError, PlatformError) as e:
        pass
    if helper_module is None:
        sys.exit(0)

    sfpd = SfpDaemon(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], helper_module)
    sfpd.main()

if __name__ == "__main__":
    main()
