#!/usr/bin/python3

# Copyright (c) 2020, AT&T Intellectual Property. All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only

# Script to handle block devices scheduler and discard unused blocks.

import os
import sys
import json
import subprocess
import configparser
import io
import syslog

from datetime import datetime
from vyatta import configd

PROC_PARTITIONS = '/proc/partitions'
BLOCK_PATH = '/sys/block'
SCHEDULER_PATH = 'queue/scheduler'
SYSTEMD_PATH = '/lib/systemd/system/'
SVC_NAME = 'discard-unused-blocks-'
DISCARD_BLOCKS_CMD = '/opt/vyatta/bin/vyatta-block-device clear-unused-blocks '
DISCARD_UNUSED_BLOCKS_SERVICE = os.path.join(SYSTEMD_PATH, SVC_NAME)
DOT_SERVICE = '.service'
DOT_TIMER = '.timer'
FSTRIM_CMD = '/usr/sbin/fstrim -v'
RESTART_DISCARD_UNUSED_BLOCKS_TIMER = 'systemctl restart ' + SVC_NAME
STOP_DISCARD_UNUSED_BLOCKS_TIMER = 'systemctl stop ' + SVC_NAME

DISCARD_UNUSED_BLOCKS_SERVICE_CFG = '''
[Unit]
Description=Discard unused blocks on filesystem.

[Service]
Type=oneshot
ExecStart=none
'''

DISCARD_UNUSED_BLOCKS_TIMER_CFG = '''
[Unit]
Description=Time to discard unused blocks."

[Timer]
OnCalendar=none
Persistent=true

[Install]
WantedBy=timers.target
'''

interval_map = {
    'hourly': '*-*-* *:00:00',
    'daily': '*-*-* ',
    'weekly': 'Sun *-*-* ',
    'monthly': '*-*-01 '
    }

schedulers_map = {
    'deadline': ['deadline', 'mq-deadline'],
    'none': ['none', 'no-op']
    }


def err(msg):
    print(msg, file=sys.stderr)


def parse_block_devices():
    '''
    Get all block devices by parsing /proc/partitions
    output and return the list of block devices.

    Sample output:
        major minor  #blocks  name

        1        0       8192 ram0
        ...
        8        0  125034840 sda

    Skip first 2 lines and add the last field, ignoring
    those are partitions or not a physical block device.
    '''
    devices = []

    try:
        with open(PROC_PARTITIONS, 'r') as r:
            for line in r:
                fields = line.split()
                if fields and not fields[-1][-1].isdigit():
                    if not fields[-1] == "name":
                        devices.append(fields[-1])
    except Exception as e:
        err("Failed to open /proc/partitions: '{}'".format(e))
        return []
    return devices


def list_block_devices(param):
    '''
    Print all potential block devices
    '''
    devices = parse_block_devices()

    if param == "list-all-block-devices":
        devices.append("all")

    print("\n".join(devices))


def get_scheduler_path(name):
    return os.path.join(BLOCK_PATH, name, SCHEDULER_PATH)


def get_schedulers(name):
    '''
    Get list of current schedulers for the block device
    '''
    path = get_scheduler_path(name)
    try:
        with open(path, 'r') as r:
            out = r.read()
    except Exception as e:
        err("Failed to read file output '{}': '{}'".format(path, e))
        return []

    current_scheduler = out[out.find('[')+1: out.find(']')]
    schedulers = out.replace('[', '').replace(']', '').split()
    return current_scheduler, schedulers


def update_scheduler(name, scheduler):
    '''
    Find the correct scheduler for the specified scheduler
    and update it for the block device.
    '''
    current_scheduler, schedulers = get_schedulers(name)

    if not schedulers:
        print("\nWARNING: No schedulers found for the device!")
        return

    # Find the correct scheduler for the block-device
    if scheduler not in schedulers:
        scheduler_list = schedulers_map[scheduler]
        for x in scheduler_list:
            if x in schedulers:
                scheduler = x
                break

    path = get_scheduler_path(name)
    try:
        with open(path, 'w') as s:
            s.write(scheduler)
            syslog.syslog(syslog.LOG_INFO, 'Updated {} block-device '\
                    'scheduler to {}'.format(name, scheduler))
    except OSError as exc:
        if exc.output:
            print("\nWARNING: Invalid Scheduler type will be ignored!")
            print(exc.output.strip())
        else:
            print(str(exc).strip())


def configure_block_device_scheduler():
    '''
    Get the block device info from configured tree and configure scheduler.
    '''
    cfg = None
    try:
        CONFIG_STRING = 'system storage block-device'
        client = configd.Client()
        cfg = client.tree_get_dict(CONFIG_STRING)
    except configd.Exception:
        pass
    except Exception as e:
        print("Failed to get tree on '{}': '{}'".format(CONFIG_STRING, e),
              file=sys.stderr)
        sys.exit(1)

    if not cfg:
        sys.exit(0)

    try:
        for device in cfg['block-device']:
            update_scheduler(device['name'], device['scheduler'])
    except KeyError:
        pass
    except Exception as e:
        print("Failed to configure block device scheduler: '{}'".format(e),
              file=sys.stderr)
        sys.exit(1)


class Block_Device:
    '''
    Block Device Class. Contains a single block device entry
    '''
    _attrs = {
        'name': 'name',
        'current_scheduler': 'current-scheduler',
        'available_schedulers': 'available-schedulers'
    }

    def __init__(self, **kwargs):
        self.name = "unavailable"
        self.current_scheduler = 'unavailable'
        self.available_schedulers = 'unavailable'

        [self.__setattr__(k, kwargs.get(k)) for k in self._attrs.keys()]

    def json_dict(self):
        return {self._attrs[k]: v for (k, v) in self.__dict__.items()}

    def __str__(self):
        return "Block-Devices({})".format(str(self.__dict__))

    def __repr__(self):
        return "Block-Devices({})".format(repr(self.__dict__))


def create_block_device_entry(device):
    '''
    Creates the block device entry with necessary fields.
    '''
    current_scheduler, schedulers = get_schedulers(device)

    fields = {'name': device,
              'current_scheduler': current_scheduler,
              'available_schedulers': schedulers
             }

    s = Block_Device(**fields)
    return s


def get_block_devices():
    lines = []
    lines = parse_block_devices()

    devices = []
    for line in lines:
        device = create_block_device_entry(line)

        if device is not None:
            devices.append(device)

    return devices


def run_cmd(cmd):
    try:
        r = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        return r.returncode, r.stderr.decode(), r.stdout.decode()
    except Exception as e:
        err("Failed running command {}: {}".format(cmd, e))
        sys.exit(1)


def clear_block_device_unused_blocks(mountpoint, device):
    '''
    Clear unused blocks for the specified block-device via fstrim.
    '''
    cmd = FSTRIM_CMD.split()
    cmd.append(mountpoint)
    ret, error, out = run_cmd(cmd)
    if ret != 0:
        output = "Failed clearing device's unused blocks: {}".format(error)
    else:
        output = out

    syslog.openlog(SVC_NAME+device, facility=syslog.LOG_DAEMON)
    syslog.syslog(syslog.LOG_INFO, '{} started: {}'.format(SVC_NAME+device,
                                                           output))

    json_output = {'output': output}
    try:
        print(json.dumps(json_output))
    except ValueError as e:
        err("Failed to decode output JSON: {}".format(e))
        sys.exit(1)


def get_mount_points(block_device):
    '''
    Get all valid mount points for the specified or all block-devices.
    Search in the result filesystems and match entries, where target is /run
    '''
    cmd = ['findmnt', '-J']
    mps = []

    ret, error, out = run_cmd(cmd)
    if ret != 0:
        err("Failed to fetch mount points for block device {}: {}"
                .format(block_device, error))
        return

    result = json.loads(out)

    for id in result["filesystems"][0]["children"]:
        if id["target"] == "/run":
            for index in id["children"]:
                if block_device == "all":
                    devices = parse_block_devices()
                    for dev in devices:
                        if dev in index["source"]:
                            mps.append(index["target"])
                else:
                    if block_device in index["source"]:
                        mps.append(index["target"])
    return mps


def clear_unused_blocks(device):
    '''
    Common routine to clear unused blocks when clear command is issued
    or when discard-unused-blocks timer service is run.
    '''

    if device == "none":
        try:
            json_data = json.load(sys.stdin)
        except ValueError as exc:
            err("Failed to parse input JSON to clear unused blocks: {}"
                    .format(exc))
            sys.exit(1)

        device = json_data['block-device']

    mount_points = get_mount_points(device)

    if not mount_points:
        print("\nWARNING: No mount points found for the block-device!")
        return

    for mp in mount_points:
        clear_block_device_unused_blocks(mp, device)


def run_service_timer(service):
    '''
    Runs the command to restart or stop
    the systemctl timer service.
    '''
    cmd = service.split()
    ret, error, out = run_cmd(cmd)
    if ret != 0:
        err("Failed to execute command {} : {}".format(service, error))
        sys.exit(1)


def update_service_file(path, data):
    '''
    Creates and updates the .service and .timer file.
    ExecStart: script to be executed based on the timer service.
    OnCalendar: schedule of the timer.
    '''
    config = configparser.ConfigParser()
    config.optionxform = str
    if config is None:
        sys.stdout.write('Failed to open configuration\n')
        sys.exit(1)

    if "service" in path:
        config.read_file(io.StringIO(DISCARD_UNUSED_BLOCKS_SERVICE_CFG))
        config['Service']['ExecStart'] = data
    else:
        config.read_file(io.StringIO(DISCARD_UNUSED_BLOCKS_TIMER_CFG))
        config['Timer']['OnCalendar'] = data

    try:
        with open(path, 'w') as cfg:
            config.write(cfg)
    except Exception as e:
        err("Failed to update config file '{}': '{}'".format(path, e))
        sys.exit(1)


def remove_service_file(path):
    if os.path.exists(path):
        try:
            os.remove(path)
        except OSError as exc:
            print("Failed to remove {}:{}", path, str(exc))
            sys.exit(1)


def delete_discard_unused_blocks_schedule(device):
    '''
    Stops the systemctl service to discard-unused-blocks.
    Removes .service and .timer file for the block device being deleted.
    '''
    block_device = device['name']

    run_service_timer(STOP_DISCARD_UNUSED_BLOCKS_TIMER +
                      block_device + DOT_TIMER)

    remove_service_file(DISCARD_UNUSED_BLOCKS_SERVICE +
                        block_device + DOT_SERVICE)
    remove_service_file(DISCARD_UNUSED_BLOCKS_SERVICE +
                        block_device + DOT_TIMER)


def update_discard_unused_blocks_schedule(device):
    '''
    Update the .service and .timer file to discard unused blocks
    and start the timer service.
    For hourly: service is run beginning at the next hour.
    For daily, weekly, monthly: use start time.
    '''
    block_device = device['name']
    repeat_interval = device['discard-unused-blocks']['repeat-interval']
    start_time = device['discard-unused-blocks']['start-time']

    svc_cmd = DISCARD_BLOCKS_CMD + block_device

    if repeat_interval == 'hourly':
        calendar_time = interval_map[repeat_interval]
    else:
        calendar_time = interval_map[repeat_interval] + start_time

    update_service_file(DISCARD_UNUSED_BLOCKS_SERVICE +
                        block_device + DOT_SERVICE, svc_cmd)
    update_service_file(DISCARD_UNUSED_BLOCKS_SERVICE +
                        block_device + DOT_TIMER, calendar_time)

    run_service_timer(RESTART_DISCARD_UNUSED_BLOCKS_TIMER +
                      block_device + DOT_TIMER)


def configure_discard_unused_blocks_schedule():
    '''
    Determine which block-device configuration has been changed,
    added or deleted.
    '''
    cfg_run = None
    cfg_cand = None
    changed = []
    added = []
    deleted = []

    CONFIG_STRING = 'system storage block-device'

    try:
        client = configd.Client()
    except configd.FatalException as e:
        print("Can't connect to configd: '{}'".format(e), file=sys.stderr)
        sys.exit(1)

    if client.node_exists(client.RUNNING, CONFIG_STRING):
        cfg_run = client.tree_get_full_dict(CONFIG_STRING,
                                            client.RUNNING, "json")

    if client.node_exists(client.CANDIDATE, CONFIG_STRING):
        cfg_cand = client.tree_get_full_dict(CONFIG_STRING,
                                             client.CANDIDATE, "json")

    if cfg_run:
        for device in cfg_run.get('block-device'):
            name = device.get('name')
            if name is None:
                continue
            device_path = "{} {}".format(CONFIG_STRING, name)
            status = client.node_get_status(client.CANDIDATE, device_path)
            if status == client.DELETED:
                deleted.append(device)

    if cfg_cand:
        for device in cfg_cand.get('block-device'):
            name = device.get('name')
            if name is None:
                continue
            device_path = "{} {}".format(CONFIG_STRING, name)
            status = client.node_get_status(client.CANDIDATE, device_path)
            if status == client.ADDED:
                added.append(device)
            elif status == client.CHANGED:
                changed.append(device)

    # Handle changed and added block-devices
    for device in added:
        update_discard_unused_blocks_schedule(device)

    for device in changed:
        if 'discard-unused-blocks' in device:
            update_discard_unused_blocks_schedule(device)
        else:
            delete_discard_unused_blocks_schedule(device)

    # Handle deleted block-devices
    for device in deleted:
        delete_discard_unused_blocks_schedule(device)

def validate_discard_unused_blocks_start_time(start_time):
    timeformat = "%H:%M:%S"
    try:
        datetime.strptime(start_time, timeformat)
    except Exception as e:
        err("Invalid time specified: {}".format(e))
        sys.exit(1)

if __name__ == "__main__":
    if sys.argv[1] == "scheduler":
        configure_block_device_scheduler()
    elif sys.argv[1] == "list-block-devices" or \
            sys.argv[1] == "list-all-block-devices":
        list_block_devices(sys.argv[1])
    elif sys.argv[1] == "get-block-devices":
        jout = [v.json_dict() for v in get_block_devices()]
        print(json.dumps({'block-device': (jout)}))
    elif sys.argv[1] == "clear-unused-blocks":
        if len(sys.argv) == 2:
            clear_unused_blocks("none")
        else:
            clear_unused_blocks(sys.argv[2])
    elif sys.argv[1] == "discard-unused-blocks":
        configure_discard_unused_blocks_schedule()
    elif sys.argv[1] == "validate-time":
        validate_discard_unused_blocks_start_time(sys.argv[2])

    exit(0)
