#!/usr/bin/python3
#
# The Qubes OS Project, http://www.qubes-os.org
#
# Copyright (C) 2022  Marek Marczykowski-Górecki
#                               <marmarek@invisiblethingslab.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

from __future__ import annotations
import dbus
import qubesdb
from typing import List
from ipaddress import IPv4Address
import os

def get_dns_resolv_conf():
    nameservers = []
    try:
        resolv = open("/etc/resolv.conf", "r", encoding="UTF-8")
    except IOError:
        return nameservers
    with resolv:
        for line in resolv:
            tokens = line.split(None, 2)
            if len(tokens) < 2 or tokens[0] != "nameserver":
                continue
            try:
                nameservers.append(IPv4Address(tokens[1]))
            except ValueError:
                pass
    return nameservers

def get_dns_resolved():
    try:
        bus = dbus.SystemBus()
    except dbus.exceptions.DBusException as s:
        if s.get_dbus_name() == 'org.freedesktop.DBus.Error.NoReply':
            return get_dns_resolv_conf()
        raise
    try:
        resolve1 = bus.get_object('org.freedesktop.resolve1',
                                  '/org/freedesktop/resolve1')
        resolve1_proxy = dbus.Interface(resolve1,
                                        dbus_interface='org.freedesktop.resolve1')
        dns = resolve1.Get('org.freedesktop.resolve1.Manager',
                           'DNS',
                           dbus_interface='org.freedesktop.DBus.Properties')
    except dbus.exceptions.DBusException as s:
        error = s.get_dbus_name()
        if error in (
            'org.freedesktop.DBus.Error.ServiceUnknown',
            'org.freedesktop.DBus.Error.NameHasNoOwner',
            'org.freedesktop.DBus.Error.NoSuchUnit',
        ) or error.startswith('org.freedesktop.systemd1.'):
            return get_dns_resolv_conf()
        raise
    # Use global entries first
    dns.sort(key=lambda x: x[0] != 0)
    # Only keep IPv4 entries. systemd-resolved is trusted to return valid
    # addresses.
    return [IPv4Address(bytes(addr)) for _g, family, addr in dns if family == 2]

def install_firewall_rules(dns):
    qdb = qubesdb.QubesDB()
    qubesdb_dns = []
    for i in ('/qubes-netvm-primary-dns', '/qubes-netvm-secondary-dns'):
        ns_maybe = qdb.read(i)
        if ns_maybe is None:
            continue
        try:
            qubesdb_dns.append(IPv4Address(ns_maybe.decode("ascii", "strict")))
        except (UnicodeDecodeError, ValueError):
            pass
    res = [
        'add table ip qubes',
        # Add the chain so that the subsequent delete will work. If the chain already
        # exists this is a harmless no-op.
        'add chain ip qubes dnat-dns',
        # Delete the chain so that if the chain already exists, it will be removed.
        # The removal of the old chain and addition of the new one happen as a single
        # atomic operation, so there is no period where neither chain is present or
        # where both are present.
        'delete chain ip qubes dnat-dns',
        'table ip qubes {',
        'chain dnat-dns {',
        'type nat hook prerouting priority dstnat; policy accept;',
    ]
    for vm_nameserver, dest in zip(qubesdb_dns, get_dns_resolved()):
        dns_ = str(dest)
        res += [
            f"ip daddr {vm_nameserver} udp dport 53 dnat to {dns_}",
            f"ip daddr {vm_nameserver} tcp dport 53 dnat to {dns_}",
        ]
    res += ["}\n}\n"]
    os.execvp("nft", ("nft", "--", "\n".join(res)))

if __name__ == '__main__':
    install_firewall_rules(get_dns_resolved())
