#!/usr/bin/env python3
import logging
from argparse import ArgumentParser
from asyncio import create_subprocess_exec, gather, run
from asyncio.subprocess import DEVNULL, PIPE
from pathlib import Path
from shlex import quote
from warnings import warn

log = logging.getLogger(__name__)


class DBParser:
    def __init__(self):
        self._state = self._name = self._base = self._version = None
        self._packages = {}

    def _assert_state(self, *wanted):
        assert self._state in wanted, "state is {!r}, not in {!r}".format(
            self._state, wanted
        )

    def _add_package(self):
        name, base, version = self._name, self._base, self._version

        if name is None:
            return
        if base is None:
            base = name

        log.debug("Adding package name=%r base=%r version=%r", name, base, version)
        pkg = self._packages.setdefault(base, {})
        pkg.setdefault("names", set()).add(name)
        if pkg.setdefault("version", version) != version:
            warn(
                "Conflicting versions for pkgbase {!r}: {!r} {!r}".format(
                    base, pkg["version"], version
                )
            )

        self._name = self._base = self._version = None

    def parse_line(self, line):
        line = line.strip()
        state = self._state

        log.debug("Parsing line %r, state %r", line, state)

        if state == "parsing":
            if line == "%NAME%":
                self._state = "name"
                self._add_package()
            elif line == "%BASE%":
                self._state = "base"
            elif line == "%VERSION%":
                self._state = "version"
        elif state == "name":
            self._name = line
            self._state = "parsing"
        elif state == "base":
            self._base = line
            self._state = "parsing"
        elif state == "version":
            self._version = line
            self._state = "parsing"
        else:
            self._assert_state()

    def result(self):
        self._assert_state("parsing")
        self._state = "done"
        self._add_package()
        return self._packages

    def __enter__(self):
        self._assert_state(None)
        self._state = "parsing"
        return self

    def __exit__(self, exc_type, exc, exc_tb):
        if exc is not None:
            return
        if self._state != "done":
            warn("parsing result discarded")
        self._state = "finalized"
        self._add_package()


async def open_db(dbfile):
    args = "/usr/bin/bsdtar", "-xOf", str(dbfile), "*/desc"
    cmdline = " ".join(quote(a) for a in args)
    log.debug("Running %s", cmdline)
    process = await create_subprocess_exec(*args, stdout=PIPE, stderr=DEVNULL, env={})
    process.cmdline = cmdline  # type: ignore
    return process


async def read_db(dbfile):
    packages = {}
    process = await open_db(dbfile)

    with DBParser() as parser:
        while True:
            line = await process.stdout.readline()  # type: ignore
            if not line:
                break
            parser.parse_line(line.decode())
        packages = parser.result()

    await process.communicate()
    if process.returncode != 0:
        raise RuntimeError("{0.cmdline!r} returned {0.returncode}".format(process))

    return packages


async def read_repo(name, archs):
    async def try_read(arch):
        path = Path("/srv/ftp") / name / "os" / arch / (name + ".db")
        db = {}
        if path.exists():
            try:
                db = await read_db(path)
                log.debug("Loaded repo name=%r arch=%r pkgs=%d", name, arch, len(db))
            except Exception as e:
                log.warning("Failed to read repo name=%r arch=%r: %s", name, arch, e)
        return frozenset([arch]), db

    repo = {}
    for arch, db in await gather(*(try_read(a) for a in archs)):
        for base, apkg in db.items():
            rarchs = repo.setdefault(base, {})
            for rarch, rpkg in rarchs.items():
                if rpkg == apkg:
                    rarchs[rarch | arch] = rarchs.pop(rarch)
                    break
            else:
                rarchs[arch] = apkg
    return repo


async def read_repos(names, archs):
    return dict(zip(names, await gather(*(read_repo(n, archs) for n in names))))


def bases(repos):
    return (b for r in repos.values() for b in r.keys())


def packages(repos):
    return (p for r in repos.values() for a in r.values() for p in a.values())


def parse_args():
    parser = ArgumentParser(description="List packages in FTP repositories")
    mode = parser.add_mutually_exclusive_group()
    mode.add_argument(
        "-b",
        "--bases",
        dest="mode",
        action="store_const",
        const="bases",
        help="Only list pkgbases",
    )
    mode.add_argument(
        "-n",
        "--names",
        dest="mode",
        action="store_const",
        const="names",
        help="Only list pkgnames",
    )
    parser.add_argument(
        "-a",
        "--arch",
        metavar="ARCH",
        dest="archs",
        action="append",
        help="Arch to read",
    )
    parser.add_argument("--debug", action="store_true", help="Enable debug output")
    parser.add_argument("repos", metavar="REPO", nargs="+", help="Repository to read")
    return parser.parse_args()


logging.basicConfig()

args = parse_args()
if args.debug:
    logging.root.setLevel(logging.DEBUG)
archs = frozenset(args.archs or ["x86_64"])

repos = run(read_repos(args.repos, archs))

if args.mode == "bases":
    for base in set(bases(repos)):
        print(base)
elif args.mode == "names":
    for name in {n for p in packages(repos) for n in p["names"]}:
        print(name)
else:
    longestbase = longestarch = longestver = 0
    if any(repos.values()):
        longestbase = max(len(b) for b in bases(repos))
        longestarch = len(" ".join(archs))
        longestver = max(len(p["version"]) for p in packages(repos))
    for repo in args.repos:
        pkgs = repos.get(repo, {})
        if not pkgs:
            print("\033[1mRepository \033[31m{}\033[39m is empty\033[0m".format(repo))
            continue

        print("\033[1mPackages in \033[32m{}\033[39m:\033[0m".format(repo))
        for base in sorted(pkgs):
            for arch, pkg in pkgs[base].items():
                print(
                    "  \033[32m{:>{}s}\033[39m"
                    "  \033[{}m{:^{}s}\033[39m"
                    "  \033[36m{:<{}s}\033[39m"
                    "  {}".format(
                        base,
                        longestbase,
                        34 if arch == archs else 31,
                        " ".join(sorted(arch)),
                        longestarch,
                        pkg["version"],
                        longestver,
                        " ".join(sorted(pkg["names"])),
                    )
                )