#!/usr/bin/python3
# Copyright (c) 2019-2021 AT&T intellectual property.
# All rights reserved
#
# SPDX-License-Identifier: GPL-2.0-only

"""Create or destroy a sandbox for login sessions.
The sandbox is created per user and this script is
intended and must be called from the systemd service
cli-sandbox@.service.
"""

import sys
import os
from argparse import ArgumentParser
import subprocess
import shutil
import contextlib
import glob
import grp
import pwd
import re
import tempfile
from vyatta.shared_storage import SharedStorage
from vyatta.sandbox import format_group_database_entry,  \
                           format_passwd_database_entry, \
                           parse_group_database

SANDBOX_DEFAULTS = {
    'DISTDIR': '/var/lib/sandbox',
    'RUNDIR': '/run/cli-sandbox',
    'NSPAWN_TEMPLATE': '/etc/cli-sandbox/nspawn-templates/cli_sandbox.nspawn',
    'INIT_HOOKS_DIR': '/etc/cli-sandbox/hooks/sandbox-init.d',
    'POST_CREATE_HOOKS_DIR': '/etc/cli-sandbox/hooks/sandbox-post-create.d',
    'INIT2': '/opt/vyatta/sbin/cli_sandbox_init',
    'SETTINGS_DIR': '/run/systemd/nspawn',
    'ENVFILE': 'cli-sandbox.env',
    'MOVED_MOUNTS': ['/dev/pts', '/run/utmp', '/var/log/wtmp', '/var/log/btmp']
    }

# ENVFILE vars
ENV_NSPAWN_TEMPLATE = 'CLI_SANDBOX_NSPAWN_TEMPLATE'
ENV_ROOT = 'CLI_SANDBOX_ROOT'
ENV_TOP = 'CLI_SANDBOX_TOP'
ENV_UID = 'CLI_SANDBOX_UID'
ENV_USER = 'CLI_SANDBOX_USER'

# Shared Storage configuration file
SHARED_RUN_CONF = '/run/vyatta/shared_storage/shared_storage.conf'

# Rootfs overlay isn't going io contain much data, as most of the cotains is from
# the read-only rootfs under directory. It is going to contain only a few states
# information and few scripts that are pulled it in runtime.
ROOT_TMP_FS_MAX = '1m'

# Used by call to sb_run(cmd, IGNORE) when we want ignore the results of a command
IGNORE = True

# User for password renewal
VYATTAPWDCFG = 'vyattapwdcfg'

# command-not-found handler script variables for sandbox
# CNF script is called:
# /usr/lib/command-not-found -- cmd
#
CNF_FILENAME = '/usr/lib/command-not-found'
CNF_SCRIPT = """#!/bin/bash
cmd="${@: -1}" # Last argument
user_shell="$(basename "${SHELL:-/bin/vbash}")"
echo "${user_shell}-sandbox: ${cmd}: command not found" >&2
"""

def sb_run(cmd, ignore=False, env=None):
    """run a command"""
    try:
        subprocess.run(cmd, check=(not ignore), env=env)
    except subprocess.CalledProcessError as err:
        print('Failed: {}:{}'.format(' '.join(cmd), str(err)), file=sys.stderr)
        sys.exit(1)

def get_current_shares():
    """get currenly mounted shares from conf"""
    ss_conf = SharedStorage()
    ss_conf.load(conf_file=SHARED_RUN_CONF)
    return ss_conf.get_current_mounts()

def run_shell_cmd(cmd):
    try:
        with open("/dev/null", "w") as ignore:
            p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=ignore)
    except Exception as e:
        print("Failed to fetch op commands: {}".format(e))
        return ""

    while True:
        line = p.stdout.readline().decode()
        if not line and p.poll() is not None:
            break
        line = line.strip()
        return line
    return ""

def remove_native_op_cmds(chroot):
    cmd = ['/opt/vyatta/sbin/lu', '-user', 'configd', '/opt/vyatta/bin/opc', '-op', 'complete']
    out = run_shell_cmd(cmd)
    if not out:
        return

    os.environ["PATH"] = '/usr/local/bin:/usr/bin:/bin:/opt/vyatta/bin:/usr/local/sbin:/usr/sbin:/sbin:/opt/vyatta/sbin'
    ops = re.search(r'\((.*?)\)', out).group(1).split(' ')
    for op in ops:
        op = op.strip("'")
        try:
            if not op or op == 'configure':
                continue

            cmd = ['which', op]
            out = run_shell_cmd(cmd)
            if not out:
                continue

            out = chroot + out
            if os.path.isfile(out):
                os.remove(out)
                # link to non-existing file under read only directory
                os.symlink('/var/run/vyatta/.does_not_exist', out)
        except OSError as e:
            print("Exception when removing native op commands: {}".format(e))

def gen_machine_name(uid):
    return 'cli-{}'.format(uid)

def write_command_not_found(chroot):
    """change command-not-found handler in chroot"""
    try:
        fname = chroot + CNF_FILENAME
        with open(fname, 'w') as cnf:
            cnf.write(CNF_SCRIPT)
        os.chmod(fname, 0o755)
    except:
        print("Exception when writing {}".format(CNF_FILENAME))

def generate_passwd_database(chroot, pw_entry, pwdcfg_entry, base_passwd=None):
    """
    Generates a limited /etc/passwd in chroot, overwriting any
    existing file.

    The user represented by pw_entry is referred to as the
    "sandboxed user".

    The generated passwd file is guaranteed to contain an
    entry corresponding to pw_entry.

    base_passwd may optionally be specified as an iterator which
    returns strings representing lines (entries) of a passwd file.

    All lines from base_passwd are included in the generated passwd
    file (except any relating to the sandboxed user).
    """
    tmp_passwd = tempfile.NamedTemporaryFile(
        'w', dir=chroot, delete=False)
    with tmp_passwd:
        if base_passwd is not None:
            for line in base_passwd:
                if line.startswith(pw_entry.pw_name + ':'):
                    continue
                if line.startswith(pwdcfg_entry.pw_name + ':'):
                    continue
                tmp_passwd.write(line)

        print(format_passwd_database_entry(pw_entry), file=tmp_passwd)
        print(format_passwd_database_entry(pwdcfg_entry), file=tmp_passwd)
        os.chmod(tmp_passwd.file.fileno(), 0o0644)

    os.rename(tmp_passwd.name, os.path.join(chroot, 'etc', 'passwd'))

def generate_group_database(chroot, pw_entry, pwdcfg_entry, base_group=None):
    """
    Generates a limited /etc/group in chroot, overwriting any
    existing file.

    The user represented by pw_entry is referred to as the
    "sandboxed user".

    The generated group file is guaranteed to contain entries
    corresponding to the primary and supplementary groups of the
    sandboxed user.

    base_group may optionally be specified as an iterator which
    returns strings representing lines (entries) of a group file.

    All lines from base_groups are included in the generated group
    file as provided (including any member information). If the
    sandboxed user is a member of any of these groups, the entry
    will be updated to indicate this.

    The entry for any of the sandboxed user's groups, which are not
    returned from base_groups, will show the user as the only member.
    """
    user_groups = {g : grp.getgrgid(g) for g in os.getgrouplist(
        pw_entry.pw_name, pw_entry.pw_gid)}
    pcfg_user_groups = {g : grp.getgrgid(g) for g in os.getgrouplist(
        pwdcfg_entry.pw_name, pwdcfg_entry.pw_gid)}

    tmp_group = tempfile.NamedTemporaryFile(
        'w', dir=chroot, delete=False)
    with tmp_group:
        # Write all base groups including any existing
        # members. If the sandboxed user should be a member
        # ensure this is recorded.
        if base_group is not None:
            for group in parse_group_database(base_group):
                try:
                    user_group = user_groups[group.gr_gid]
                    if user_group.gr_name != group.gr_name:
                        # In the unlikely event of a GID being assigned to
                        # multiple groups, prefer the group from the system
                        # database.
                        print("Group ID {} collision - using {} over {}".format(
                            group.gr_gid, user_group.gr_name, group.gr_name))
                        continue

                    if pw_entry.pw_name not in group.gr_mem:
                        user_groups.pop(group.gr_gid)
                        group.gr_mem.append(pw_entry.pw_name)
                except KeyError:
                    pass

                print(format_group_database_entry(group), file=tmp_group)

        # Add any groups the user is a member of which haven't
        # already been processed, ensuring the sandboxed user is
        # the only member.
        for group in user_groups.values():
            group.gr_mem.clear()
            group.gr_mem.append(pw_entry.pw_name)
            print(format_group_database_entry(group), file=tmp_group)

        for group in pcfg_user_groups.values():
            group.gr_mem.clear()
            group.gr_mem.append(pwdcfg_entry.pw_name)
            print(format_group_database_entry(group), file=tmp_group)

        os.chmod(tmp_group.file.fileno(), 0o0644)

    os.rename(tmp_group.name, os.path.join(chroot, 'etc', 'group'))

def generate_user_databases(chroot, pw_entry, pwdcfg_entry,
                            base_passwd=None, base_group=None):
    """
    Generates limited /etc/passwd and /etc/group databases
    in chroot, overwriting any existing files, and also
    removing any backup or shadow databases.

    See generate_passwd_database() and generate_group_database()
    for more details.
    """

    # Discard existing user/group backup and shadow
    # databases as they are not required.
    # Do not remove 'passwd' and 'group' as these may be
    # in use if the sandbox is running.
    for f in ("passwd-", "group-", "shadow*", "gshadow*"):
        for p in glob.glob(os.path.join(chroot, 'etc', f)):
            if os.path.exists(p):
                os.remove(p)

    generate_passwd_database(chroot, pw_entry, pwdcfg_entry, base_passwd)
    generate_group_database(chroot, pw_entry, pwdcfg_entry, base_group)


class UserSandBox:
    """Sandbox class"""
    conf = {k: os.environ.get(k, v) for k, v in SANDBOX_DEFAULTS.items()}
    def __init__(self, name):
        self.name = name

        try:
            self.pw_entry = pwd.getpwnam(name)
        except KeyError:
            self.pw_entry = None
            uid = os.environ[ENV_UID]
        else:
            uid = self.pw_entry.pw_uid
        finally:
            self.machine = gen_machine_name(uid)

        try:
            self.pwdcfg_entry = pwd.getpwnam(VYATTAPWDCFG)
        except KeyError:
            self.pwdcfg_entry = None

        self.topdir = os.path.join(self.RUNDIR, name)
        self.settings = os.path.join(self.SETTINGS_DIR, self.machine + '.nspawn')
        self.root = os.path.join(self.topdir, self.machine)
        self.env = {
            ENV_NSPAWN_TEMPLATE: self.settings,
            ENV_ROOT: self.root,
            ENV_TOP: self.topdir,
            ENV_UID: str(uid),
            ENV_USER: self.name,
        }

    def __getattr__(self, field):
        return self.conf[field]

    def _run_hooks(self, hook_dir):
        if os.path.isdir(hook_dir):
            env = os.environ.copy()
            env.update(self.env)
            sb_run(["run-parts", "--report", hook_dir], env=env)

    def _write_env(self):
        prev_mask = os.umask(0o077)
        with open(os.path.join(self.topdir, self.ENVFILE), 'w') as f:
            for n, v in self.env.items():
                f.write("{}={}\n".format(n, v))
        os.umask(prev_mask)

    def get_moved_mounts(self):
        """Returns mounts to be moved after entering the sandbox namespace.
        Add extra bind mounts for files and directories that can't be directly
        shared mounted before launching the sandbox and needs to be moved after
        enter this sandboxes mount namespaces.
        These are mounted under /.mounts in sandbox root director.
        """
        def bind_str(fname):
            return "{}:/.mounts/{}".format(fname, fname.replace('/', '_'))

        return [bind_str(os.path.abspath(x))
                for x in self.MOVED_MOUNTS if os.path.exists(x)]

    def edit_nspawn_template(self):
        """create nspawn settings file.
        Add extra bind mounts for Home directory and shared-storage mounts
        in a separate [Files] section of nspawn settings file.
        """
        if not os.path.isdir(os.path.dirname(self.settings)):
            os.mkdir(os.path.dirname(self.settings), 0o755)
        shutil.copy2(self.NSPAWN_TEMPLATE, self.settings)
        bindfmt = "Bind={}\n"
        with open(self.settings, 'a') as fout:
            fout.write("\n[Files]\n")
            if os.path.isdir(self.pw_entry.pw_dir):
                fout.write(bindfmt.format(self.pw_entry.pw_dir))
            for shared_dir in get_current_shares():
                fout.write(bindfmt.format(shared_dir))
            for moved_mount in self.get_moved_mounts():
                fout.write(bindfmt.format(moved_mount))

    def _generate_user_databases(self):
        # Use any existing /etc/passwd from the pristine chroot to
        # determine the base system users to setup in the sandbox.
        try:
            base_passwd = open(os.path.join(self.DISTDIR, 'etc', 'passwd'), 'r')
        except OSError:
            base_passwd = contextlib.nullcontext([])

        # Similarly use any existing /etc/group to determine the
        # base system groups to setup.
        try:
            base_groups = open(os.path.join(self.DISTDIR, 'etc', 'group'), 'r')
        except OSError:
            base_groups = contextlib.nullcontext([])

        with base_passwd as bp_iter, base_groups as bg_iter:
            generate_user_databases(self.root, self.pw_entry, self.pwdcfg_entry, bp_iter, bg_iter)

    def create_sandbox(self):
        """create and setup sandbox rootfs"""
        self.edit_nspawn_template()
        os.makedirs(self.topdir)
        mntcmd = ['/bin/mount', '-t', 'tmpfs',
                  '-o', 'size=' + ROOT_TMP_FS_MAX + ',mode=0755',
                  'tmpfs', self.topdir
                 ]
        sb_run(mntcmd)
        self._write_env()
        os.mkdir(self.root)
        upper = os.path.join(self.topdir, 'empty')
        work = os.path.join(self.topdir, 'work')
        os.mkdir(work)
        os.mkdir(upper)
        mntcmd = ['/bin/mount', '-t', 'overlay',
                  '-o', 'lowerdir={},upperdir={},workdir={}'.format(self.DISTDIR, upper, work),
                  'overlayfs',
                  self.root,
                 ]
        sb_run(mntcmd)
        self._generate_user_databases()
        remove_native_op_cmds(self.root)
        write_command_not_found(self.root)
        shutil.copy2(self.INIT2, self.root)
        os.chmod(os.path.join(self.root, os.path.basename(self.INIT2)), 0o755)
        os.mkdir(os.path.join(self.root, '.mounts'), 0o700)
        hname = os.uname()[1]
        with open(os.path.join(self.root, 'hostname.sandbox'), 'w') as outfile:
            print(hname, file=outfile)

        if os.path.isdir(self.INIT_HOOKS_DIR):
            shutil.copytree(self.INIT_HOOKS_DIR, os.path.join(
                self.root, self.INIT_HOOKS_DIR.lstrip(os.path.sep)))

        self._run_hooks(self.POST_CREATE_HOOKS_DIR)

    def update_sandbox(self):
        """
        Updates an existing sandbox:
            * re-generate the sandbox user databases

        The mtime of the sandbox's "ready" file is
        updated on completion of the update.
        """
        ready_file = os.path.join(self.topdir, "ready")
        if os.path.exists(ready_file):
            self._generate_user_databases()
            os.utime(ready_file)

    def destroy_sandbox(self):
        """cleanup sandbox rootfs"""
        if os.path.isdir(self.root) and os.path.ismount(self.root):
            sb_run(['/bin/umount', self.root], IGNORE)
        if os.path.isdir(self.topdir) and os.path.ismount(self.topdir):
            sb_run(['/bin/umount', self.topdir], IGNORE)
            try:
                os.rmdir(self.topdir)
            except OSError as exc:
                print("Failed to remove {}:{}", self.topdir, str(exc))
        if os.path.exists(self.settings):
            try:
                os.remove(self.settings)
            except OSError as exc:
                print("Failed to remove {}:{}", self.settings, str(exc))

    def start_sandbox(self):
        nspawn = '/usr/bin/systemd-nspawn'

        # Typically we are invoked as a systemd service of type notify and
        # systemd-nspawn is responsible for calling sd_notify().
        # By default systemd expects to receive the notify message from the
        # main process of the service.
        # Therefore use an exec() function, ie. without forking, so that
        # systemd-nspawn maintains our PID.

        os.execl(nspawn, nspawn, '--quiet', '--keep-unit', '--private-users=no',
                 '--private-network', '--settings=override',
                 '--directory={}'.format(self.root),
                 '--machine={}'.format(self.machine))

def main():
    """Main cli-sandbox Program"""
    parser = ArgumentParser(description='Vyatta CLI Sanbox Service Rootfs Setup tool')
    action_grp = parser.add_mutually_exclusive_group(required=True)
    action_grp.add_argument('-c',
                            '--create',
                            help='create temporary rootfs for sandbox',
                            action='store_true'
                           )
    action_grp.add_argument('-u',
                            '--update',
                            help='update an existing sandbox',
                            action='store_true'
                           )
    action_grp.add_argument('-d',
                            '--destroy',
                            help='destroy temporary rootfs for sandbox',
                            action='store_true'
                           )
    action_grp.add_argument('-s',
                            '--start',
                            help='start an existing sandbox',
                            action='store_true'
                           )
    parser.add_argument('name', nargs=1, help='sandbox name')

    args = parser.parse_args()

    sandbox = UserSandBox(args.name[0])

    if args.create:
        try:
            sandbox.create_sandbox()
            sys.exit(0)
        except OSError as err:
            print('Failed to create sandbox {}:{}'.format(sandbox.name, str(err)))
            sys.exit(1)

    if args.update:
        try:
            sandbox.update_sandbox()
            sys.exit(0)
        except OSError as err:
            print('Failed to update sandbox {}:{}'.format(sandbox.name, str(err)))
            sys.exit(1)

    if args.destroy:
        try:
            sandbox.destroy_sandbox()
            sys.exit(0)
        except OSError as err:
            print('Failed to cleanup sandbox {}:{}'.format(sandbox.name, str(err)))
            sys.exit(1)

    if args.start:
        try:
            sandbox.start_sandbox()
        except OSError as err:
            print('Failed to start sandbox {}: {}'.format(sandbox.name, err))
            sys.exit(1)

    # unreached #

if __name__ == '__main__':
    main()
