#!/usr/bin/python3
#
# use avahi to find a _apt_proxy._tcp provider and return
# a http proxy string suitable for apt

import asyncio
import functools
import socket
import sys
import time

from subprocess import PIPE, Popen
from syslog import syslog, LOG_INFO, LOG_USER

DEFAULT_CONNECT_TIMEOUT_SEC = 2


def DEBUG(msg):
    if "--debug" in sys.argv:
        sys.stderr.write(msg + "\n")


def get_avahi_discover_timeout():
    APT_AVAHI_TIMEOUT_VAR = "APT::Avahi-Discover::Timeout"
    p = Popen(
        ["/usr/bin/apt-config", "shell", "TIMEOUT", APT_AVAHI_TIMEOUT_VAR],
        stdout=PIPE,
        universal_newlines=True,
    )
    stdout, stderr = p.communicate()
    if not stdout:
        DEBUG(f"no timeout set, using default '{DEFAULT_CONNECT_TIMEOUT_SEC}'")
        return DEFAULT_CONNECT_TIMEOUT_SEC
    if not stdout.startswith("TIMEOUT="):
        raise ValueError(f"got unexpected apt-config output: '{stdout}'")
    varname, sep, value = stdout.strip().partition("=")
    timeout = int(value.strip("'"))
    DEBUG(f"using timeout: '{timeout}'")
    return timeout


@functools.total_ordering
class AptAvahiClient:
    def __init__(self, addr):
        self.time_to_connect = sys.maxsize
        self.address = addr

    async def time_connect(self):
        try:
            _time_init = time.time()
            reader, writer = await asyncio.open_connection(
                self.address[0], self.address[1]
            )
            self.time_to_connect = time.time() - _time_init
            writer.close()
        except Exception:
            pass

    def __eq__(self, other):
        return self.time_to_connect == other.time_to_connect

    def __lt__(self, other):
        return self.time_to_connect < other.time_to_connect

    def __repr__(self):
        return "<{}> {}: {}".format(
            self.__class__.__name__, self.address, self.time_to_connect
        )

    def log(self, message):
        syslog((LOG_INFO | LOG_USER), "%s\n" % str(message))

    def log_info(self, message, type="info"):
        if type not in self.ignore_log_types:
            self.log("{}: {}".format(type, message))


def is_ipv6(a):
    return ":" in a


def is_linklocal(addr):
    # Link-local should start with fe80 and six null bytes
    return addr.startswith("fe80::")


def get_proxy_host_port_from_avahi():
    service = "_apt_proxy._tcp"

    # Obtain all of the services addresses from avahi, pulling the IPv6
    # addresses to the top.
    addr4 = []
    addr6 = []
    p = Popen(["avahi-browse", "-kprtf", service], stdout=PIPE, universal_newlines=True)
    DEBUG("avahi-browse output:")
    for line in p.stdout:
        DEBUG(f" '{line}'")
        if line.startswith("="):
            tokens = line.split(";")
            addr = tokens[7]
            port = int(tokens[8])
            if is_ipv6(addr):
                # We need to skip ipv6 link-local addresses since
                # APT can't use them
                if not is_linklocal(addr):
                    addr6.append((addr, port))
            else:
                addr4.append((addr, port))

    # Run through the offered addresses and see if we we have a bound local
    # address for it.
    addrs = []
    for ip, port in addr6 + addr4:
        try:
            res = socket.getaddrinfo(ip, port, 0, 0, 0, socket.AI_ADDRCONFIG)
            if res:
                addrs.append((ip, port))
        except socket.gaierror:
            pass
    if not addrs:
        return None

    # sort by answering speed
    hosts = []
    for addr in addrs:
        hosts.append(AptAvahiClient(addr))
    # 2s timeout, arbitrary
    timeout = get_avahi_discover_timeout()

    async def gather_hosts():
        await asyncio.gather(*[h.time_connect() for h in hosts])

    try:
        asyncio.run(asyncio.wait_for(gather_hosts(), timeout=timeout))
    except TimeoutError:
        pass

    DEBUG(f"sorted hosts: '{sorted(hosts)}'")

    # No host wanted to connect
    if all(h.time_to_connect == sys.maxsize for h in hosts):
        return None

    fastest_host = sorted(hosts)[0]
    fastest_address = fastest_host.address
    return fastest_address


if __name__ == "__main__":
    # Dump the approved address out in an appropriate format.
    address = get_proxy_host_port_from_avahi()
    if address:
        (ip, port) = address
        if is_ipv6(ip):
            proxyurl = f"http://[{ip}]:{port}/"
        else:
            proxyurl = f"http://{ip}:{port}/"
        print(proxyurl, end="")
