Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • wackbyte/aurweb
  • aliu/aurweb
  • morganamilo/aurweb
  • tex/aurweb
  • abitrolly/aurweb
  • muflone/aurweb
  • anthraxx/aurweb
  • jafari/aurweb
  • levitating/aurweb
  • freso/aurweb
  • okabe/aurweb
  • rafaelff/aurweb
  • zoorat/aurweb
  • auerhuhn/aurweb
  • nils/aurweb
  • antiz/aurweb
  • henry-zhr/aurweb
  • segaja/aurweb
  • som015/aurweb
  • gromit/aurweb
  • belongingtome47/aurweb
  • moson/aurweb
  • steppaa23/aurweb
  • bittin/aurweb
  • jkhsjdhjs/aurweb
  • whynothugo/aurweb
  • matt/aurweb
  • fosskers/aurweb
  • awalgarg/aurweb
  • robertoszek/aurweb
  • ainola/aurweb
  • fluix/aurweb
  • hwittenborn/aurweb
  • jocke-l/aurweb
  • eschwartz/aurweb
  • mackilanu/aurweb
  • artafinde/aurweb
  • klausenbusk/aurweb
  • felixonmars/aurweb
  • kevr/aurweb
  • hashworks/aurweb
  • freswa/aurweb
  • lahwaacz/aurweb
  • jelle/aurweb
  • ffy00/aurweb
  • archlinux/aurweb
46 results
Show changes
Showing
with 1810 additions and 95 deletions
# aurweb.archives.spec
from pathlib import Path
from typing import Any, Dict, Iterable, List, Set
class GitInfo:
"""Information about a Git repository."""
""" Path to Git repository. """
path: str
""" Local Git repository configuration. """
config: Dict[str, Any]
def __init__(self, path: str, config: Dict[str, Any] = dict()) -> "GitInfo":
self.path = Path(path)
self.config = config
class SpecOutput:
"""Class used for git_archive.py output details."""
""" Filename relative to the Git repository root. """
filename: Path
""" Git repository information. """
git_info: GitInfo
""" Bytes bound for `SpecOutput.filename`. """
data: bytes
def __init__(self, filename: str, git_info: GitInfo, data: bytes) -> "SpecOutput":
self.filename = filename
self.git_info = git_info
self.data = data
class SpecBase:
"""
Base for Spec classes defined in git_archve.py --spec modules.
All supported --spec modules must contain the following classes:
- Spec(SpecBase)
"""
""" A list of SpecOutputs, each of which contain output file data. """
outputs: List[SpecOutput] = list()
""" A set of repositories to commit changes to. """
repos: Set[str] = set()
def generate(self) -> Iterable[SpecOutput]:
"""
"Pure virtual" output generator.
`SpecBase.outputs` and `SpecBase.repos` should be populated within an
overridden version of this function in SpecBase derivatives.
"""
raise NotImplementedError()
def add_output(self, filename: str, git_info: GitInfo, data: bytes) -> None:
"""
Add a SpecOutput instance to the set of outputs.
:param filename: Filename relative to the git repository root
:param git_info: GitInfo instance
:param data: Binary data bound for `filename`
"""
if git_info.path not in self.repos:
self.repos.add(git_info.path)
self.outputs.append(
SpecOutput(
filename,
git_info,
data,
)
)
from typing import Iterable
import orjson
from aurweb import config, db
from aurweb.models import Package, PackageBase, User
from aurweb.rpc import RPC
from .base import GitInfo, SpecBase, SpecOutput
ORJSON_OPTS = orjson.OPT_SORT_KEYS | orjson.OPT_INDENT_2
class Spec(SpecBase):
def __init__(self) -> "Spec":
self.metadata_repo = GitInfo(
config.get("git-archive", "metadata-repo"),
)
def generate(self) -> Iterable[SpecOutput]:
# Base query used by the RPC.
base_query = (
db.query(Package)
.join(PackageBase)
.join(User, PackageBase.MaintainerUID == User.ID, isouter=True)
)
# Create an instance of RPC, use it to get entities from
# our query and perform a metadata subquery for all packages.
rpc = RPC(version=5, type="info")
print("performing package database query")
packages = rpc.entities(base_query).all()
print("performing package database subqueries")
rpc.subquery({pkg.ID for pkg in packages})
pkgbases, pkgnames = dict(), dict()
for package in packages:
# Produce RPC type=info data for `package`
data = rpc.get_info_json_data(package)
pkgbase_name = data.get("PackageBase")
pkgbase_data = {
"ID": data.pop("PackageBaseID"),
"URLPath": data.pop("URLPath"),
"FirstSubmitted": data.pop("FirstSubmitted"),
"LastModified": data.pop("LastModified"),
"OutOfDate": data.pop("OutOfDate"),
"Maintainer": data.pop("Maintainer"),
"Keywords": data.pop("Keywords"),
"NumVotes": data.pop("NumVotes"),
"Popularity": data.pop("Popularity"),
"PopularityUpdated": package.PopularityUpdated.timestamp(),
}
# Store the data in `pkgbases` dict. We do this so we only
# end up processing a single `pkgbase` if repeated after
# this loop
pkgbases[pkgbase_name] = pkgbase_data
# Remove Popularity and NumVotes from package data.
# These fields change quite often which causes git data
# modification to explode.
# data.pop("NumVotes")
# data.pop("Popularity")
# Remove the ID key from package json.
data.pop("ID")
# Add the `package`.Name to the pkgnames set
name = data.get("Name")
pkgnames[name] = data
# Add metadata outputs
self.add_output(
"pkgname.json",
self.metadata_repo,
orjson.dumps(pkgnames, option=ORJSON_OPTS),
)
self.add_output(
"pkgbase.json",
self.metadata_repo,
orjson.dumps(pkgbases, option=ORJSON_OPTS),
)
return self.outputs
from typing import Iterable
import orjson
from aurweb import config, db
from aurweb.models import PackageBase
from .base import GitInfo, SpecBase, SpecOutput
ORJSON_OPTS = orjson.OPT_SORT_KEYS | orjson.OPT_INDENT_2
class Spec(SpecBase):
def __init__(self) -> "Spec":
self.pkgbases_repo = GitInfo(config.get("git-archive", "pkgbases-repo"))
def generate(self) -> Iterable[SpecOutput]:
query = db.query(PackageBase.Name).order_by(PackageBase.Name.asc()).all()
pkgbases = [pkgbase.Name for pkgbase in query]
self.add_output(
"pkgbase.json",
self.pkgbases_repo,
orjson.dumps(pkgbases, option=ORJSON_OPTS),
)
return self.outputs
from typing import Iterable
import orjson
from aurweb import config, db
from aurweb.models import Package, PackageBase
from .base import GitInfo, SpecBase, SpecOutput
ORJSON_OPTS = orjson.OPT_SORT_KEYS | orjson.OPT_INDENT_2
class Spec(SpecBase):
def __init__(self) -> "Spec":
self.pkgnames_repo = GitInfo(config.get("git-archive", "pkgnames-repo"))
def generate(self) -> Iterable[SpecOutput]:
query = (
db.query(Package.Name)
.join(PackageBase, PackageBase.ID == Package.PackageBaseID)
.order_by(Package.Name.asc())
.all()
)
pkgnames = [pkg.Name for pkg in query]
self.add_output(
"pkgname.json",
self.pkgnames_repo,
orjson.dumps(pkgnames, option=ORJSON_OPTS),
)
return self.outputs
from typing import Iterable
import orjson
from aurweb import config, db
from aurweb.models import User
from .base import GitInfo, SpecBase, SpecOutput
ORJSON_OPTS = orjson.OPT_SORT_KEYS | orjson.OPT_INDENT_2
class Spec(SpecBase):
def __init__(self) -> "Spec":
self.users_repo = GitInfo(config.get("git-archive", "users-repo"))
def generate(self) -> Iterable[SpecOutput]:
query = db.query(User.Username).order_by(User.Username.asc()).all()
users = [user.Username for user in query]
self.add_output(
"users.json",
self.users_repo,
orjson.dumps(users, option=ORJSON_OPTS),
)
return self.outputs
import hashlib
import http
import io
import os
import re
import sys
import traceback
import typing
from contextlib import asynccontextmanager
from urllib.parse import quote_plus
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
import requests
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from jinja2 import TemplateNotFound
from sqlalchemy import and_
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.sessions import SessionMiddleware
import aurweb.captcha # noqa: F401
import aurweb.config
import aurweb.filters # noqa: F401
from aurweb import aur_logging, prometheus, util
from aurweb.aur_redis import redis_connection
from aurweb.auth import BasicAuthBackend
from aurweb.db import get_engine, query
from aurweb.models import AcceptedTerm, Term
from aurweb.packages.util import get_pkg_or_base
from aurweb.prometheus import instrumentator
from aurweb.routers import APP_ROUTES
from aurweb.templates import make_context, render_template
from aurweb.routers import sso
logger = aur_logging.get_logger(__name__)
session_secret = aurweb.config.get("fastapi", "session_secret")
app = FastAPI()
session_secret = aurweb.config.get("fastapi", "session_secret")
if not session_secret:
raise Exception("[fastapi] session_secret must not be empty")
@asynccontextmanager
async def lifespan(app: FastAPI):
await app_startup()
yield
app.add_middleware(SessionMiddleware, secret_key=session_secret)
app.include_router(sso.router)
# Setup the FastAPI app.
app = FastAPI(lifespan=lifespan)
# Instrument routes with the prometheus-fastapi-instrumentator
# library with custom collectors and expose /metrics.
instrumentator().add(prometheus.http_api_requests_total())
instrumentator().add(prometheus.http_requests_total())
instrumentator().instrument(app)
if aurweb.config.get("tracing", "otlp_endpoint"):
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
# Instrument FastAPI for tracing
FastAPIInstrumentor.instrument_app(app)
resource = Resource(attributes={"service.name": "aurweb"})
otlp_endpoint = aurweb.config.get("tracing", "otlp_endpoint")
otlp_exporter = OTLPSpanExporter(endpoint=otlp_endpoint)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.set_tracer_provider(TracerProvider(resource=resource))
trace.get_tracer_provider().add_span_processor(span_processor)
async def app_startup():
# https://stackoverflow.com/questions/67054759/about-the-maximum-recursion-error-in-fastapi
# Test failures have been observed by internal starlette code when
# using starlette.testclient.TestClient. Looking around in regards
# to the recursion error has really not recommended a course of action
# other than increasing the recursion limit. For now, that is how
# we handle the issue: an optional TEST_RECURSION_LIMIT env var
# provided by the user. Docker uses .env's TEST_RECURSION_LIMIT
# when running test suites.
# TODO: Find a proper fix to this issue.
recursion_limit = int(
os.environ.get("TEST_RECURSION_LIMIT", sys.getrecursionlimit() + 1000)
)
sys.setrecursionlimit(recursion_limit)
backend = aurweb.config.get("database", "backend")
if backend not in aurweb.db.DRIVERS:
raise ValueError(
f"The configured database backend ({backend}) is unsupported. "
f"Supported backends: {str(aurweb.db.DRIVERS.keys())}"
)
if not session_secret:
raise Exception("[fastapi] session_secret must not be empty")
if not os.environ.get("PROMETHEUS_MULTIPROC_DIR", None):
logger.warning(
"$PROMETHEUS_MULTIPROC_DIR is not set, the /metrics "
"endpoint is disabled."
)
app.mount("/static", StaticFiles(directory="static"), name="static_files")
# Add application routes.
def add_router(module):
app.include_router(module.router)
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
util.apply_all(APP_ROUTES, add_router)
# Initialize the database engine and ORM.
get_engine()
async def internal_server_error(request: Request, exc: Exception) -> Response:
"""
Dirty HTML error page to replace the default JSON error responses.
In the future this should use a proper Arch-themed HTML template.
Catch all uncaught Exceptions thrown in a route.
:param request: FastAPI Request
:return: Rendered 500.html template with status_code 500
"""
repo = aurweb.config.get("notifications", "gitlab-instance")
project = aurweb.config.get("notifications", "error-project")
token = aurweb.config.get("notifications", "error-token")
context = make_context(request, "Internal Server Error")
# Print out the exception via `traceback` and store the value
# into the `traceback` context variable.
tb_io = io.StringIO()
traceback.print_exc(file=tb_io)
tb = tb_io.getvalue()
context["traceback"] = tb
# Produce a SHA1 hash of the traceback string.
tb_hash = hashlib.sha1(tb.encode()).hexdigest()
tb_id = tb_hash[:7]
redis = redis_connection()
key = f"tb:{tb_hash}"
retval = redis.get(key)
if not retval:
# Expire in one hour; this is just done to make sure we
# don't infinitely store these values, but reduce the number
# of automated reports (notification below). At this time of
# writing, unexpected exceptions are not common, thus this
# will not produce a large memory footprint in redis.
pipe = redis.pipeline()
pipe.set(key, tb)
pipe.expire(key, 86400) # One day.
pipe.execute()
# Send out notification about it.
if "set-me" not in (project, token):
proj = quote_plus(project)
endp = f"{repo}/api/v4/projects/{proj}/issues"
base = f"{request.url.scheme}://{request.url.netloc}"
title = f"Traceback [{tb_id}]: {base}{request.url.path}"
desc = [
"DISCLAIMER",
"----------",
"**This issue is confidential** and should be sanitized "
"before sharing with users or developers. Please ensure "
"you've completed the following tasks:",
"- [ ] I have removed any sensitive data and "
"the description history.",
"",
"Exception Details",
"-----------------",
f"- Route: `{request.url.path}`",
f"- User: `{request.user.Username}`",
f"- Email: `{request.user.Email}`",
]
# Add method-specific information to the description.
if request.method.lower() == "get":
# get
if request.url.query:
desc = desc + [f"- Query: `{request.url.query}`"]
desc += ["", f"```{tb}```"]
else:
# post
form_data = str(dict(request.state.form_data))
desc = desc + [f"- Data: `{form_data}`"] + ["", f"```{tb}```"]
headers = {"Authorization": f"Bearer {token}"}
data = {
"title": title,
"description": "\n".join(desc),
"labels": ["triage"],
"confidential": True,
}
logger.info(endp)
resp = requests.post(endp, json=data, headers=headers)
if resp.status_code != http.HTTPStatus.CREATED:
logger.error(f"Unable to report exception to {repo}: {resp.text}")
else:
logger.warning(
"Unable to report an exception found due to "
"unset notifications.error-{{project,token}}"
)
# Log details about the exception traceback.
logger.error(f"FATAL[{tb_id}]: An unexpected exception has occurred.")
logger.error(tb)
else:
retval = retval.decode()
return render_template(
request,
"errors/500.html",
context,
status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR,
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
"""Handle an HTTPException thrown in a route."""
phrase = http.HTTPStatus(exc.status_code).phrase
return HTMLResponse(f"<h1>{exc.status_code} {phrase}</h1><p>{exc.detail}</p>",
status_code=exc.status_code)
context = make_context(request, phrase)
context["exc"] = exc
context["phrase"] = phrase
# Additional context for some exceptions.
if exc.status_code == http.HTTPStatus.NOT_FOUND:
tokens = request.url.path.split("/")
matches = re.match("^([a-z0-9][a-z0-9.+_-]*?)(\\.git)?$", tokens[1])
if matches and len(tokens) == 2:
try:
pkgbase = get_pkg_or_base(matches.group(1))
context["pkgbase"] = pkgbase
context["git_clone_uri_anon"] = aurweb.config.get(
"options", "git_clone_uri_anon"
)
context["git_clone_uri_priv"] = aurweb.config.get(
"options", "git_clone_uri_priv"
)
except HTTPException:
pass
try:
return render_template(
request, f"errors/{exc.status_code}.html", context, exc.status_code
)
except TemplateNotFound:
return render_template(request, "errors/detail.html", context, exc.status_code)
@app.middleware("http")
async def add_security_headers(request: Request, call_next: typing.Callable):
"""This middleware adds the CSP, XCTO, XFO and RP security
headers to the HTTP response associated with request.
CSP: Content-Security-Policy
XCTO: X-Content-Type-Options
RP: Referrer-Policy
XFO: X-Frame-Options
"""
try:
response = await util.error_or_result(call_next, request)
except Exception as exc:
return await internal_server_error(request, exc)
# Add CSP header.
nonce = request.user.nonce
csp = "default-src 'self'; "
# swagger-ui needs access to cdn.jsdelivr.net javascript
script_hosts = ["cdn.jsdelivr.net"]
csp += f"script-src 'self' 'unsafe-inline' 'nonce-{nonce}' " + " ".join(
script_hosts
)
# swagger-ui needs access to cdn.jsdelivr.net css
css_hosts = ["cdn.jsdelivr.net"]
csp += "; style-src 'self' 'unsafe-inline' " + " ".join(css_hosts)
response.headers["Content-Security-Policy"] = csp
# Add XTCO header.
xcto = "nosniff"
response.headers["X-Content-Type-Options"] = xcto
# Add Referrer Policy header.
rp = "same-origin"
response.headers["Referrer-Policy"] = rp
# Add X-Frame-Options header.
xfo = "SAMEORIGIN"
response.headers["X-Frame-Options"] = xfo
return response
@app.middleware("http")
async def check_terms_of_service(request: Request, call_next: typing.Callable):
"""This middleware function redirects authenticated users if they
have any outstanding Terms to agree to."""
if request.user.is_authenticated() and request.url.path != "/tos":
accepted = (
query(Term)
.join(AcceptedTerm)
.filter(
and_(
AcceptedTerm.UsersID == request.user.ID,
AcceptedTerm.TermsID == Term.ID,
AcceptedTerm.Revision >= Term.Revision,
),
)
)
if query(Term).count() - accepted.count() > 0:
return RedirectResponse("/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
return await util.error_or_result(call_next, request)
@app.middleware("http")
async def id_redirect_middleware(request: Request, call_next: typing.Callable):
id = request.query_params.get("id")
if id is not None:
# Preserve query string.
qs = []
for k, v in request.query_params.items():
if k != "id":
qs.append(f"{k}={quote_plus(str(v))}")
qs = str() if not qs else "?" + "&".join(qs)
path = request.url.path.rstrip("/")
return RedirectResponse(f"{path}/{id}{qs}")
return await util.error_or_result(call_next, request)
# Add application middlewares.
app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend())
app.add_middleware(SessionMiddleware, secret_key=session_secret)
import logging
import logging.config
import os
import aurweb.config
# For testing, users should set LOG_CONFIG=logging.test.conf
# We test against various debug log output.
aurwebdir = aurweb.config.get("options", "aurwebdir")
log_config = os.environ.get("LOG_CONFIG", "logging.conf")
config_path = os.path.join(aurwebdir, log_config)
logging.config.fileConfig(config_path, disable_existing_loggers=False)
logging.getLogger("root").addHandler(logging.NullHandler())
def get_logger(name: str) -> logging.Logger:
"""A logging.getLogger wrapper. Importing this function and
using it to get a module-local logger ensures that logging.conf
initialization is performed wherever loggers are used.
:param name: Logger name; typically `__name__`
:returns: name's logging.Logger
"""
return logging.getLogger(name)
import fakeredis
from redis import ConnectionPool, Redis
import aurweb.config
from aurweb import aur_logging
logger = aur_logging.get_logger(__name__)
pool = None
if aurweb.config.get("tracing", "otlp_endpoint"):
from opentelemetry.instrumentation.redis import RedisInstrumentor
RedisInstrumentor().instrument()
class FakeConnectionPool:
"""A fake ConnectionPool class which holds an internal reference
to a fakeredis handle.
We normally deal with Redis by keeping its ConnectionPool globally
referenced so we can persist connection state through different calls
to redis_connection(), and since FakeRedis does not offer a ConnectionPool,
we craft one up here to hang onto the same handle instance as long as the
same instance is alive; this allows us to use a similar flow from the
redis_connection() user's perspective.
"""
def __init__(self):
self.handle = fakeredis.FakeStrictRedis()
def disconnect(self):
pass
def redis_connection(): # pragma: no cover
global pool
disabled = aurweb.config.get("options", "cache") != "redis"
# If we haven't initialized redis yet, construct a pool.
if disabled:
if pool is None:
logger.debug("Initializing fake Redis instance.")
pool = FakeConnectionPool()
return pool.handle
else:
if pool is None:
logger.debug("Initializing real Redis instance.")
redis_addr = aurweb.config.get("options", "redis_address")
pool = ConnectionPool.from_url(redis_addr)
# Create a connection to the pool.
return Redis(connection_pool=pool)
def kill_redis():
global pool
if pool:
pool.disconnect()
pool = None
import functools
from http import HTTPStatus
from typing import Callable
import fastapi
from fastapi import HTTPException
from fastapi.responses import RedirectResponse
from starlette.authentication import AuthCredentials, AuthenticationBackend
from starlette.requests import HTTPConnection
import aurweb.config
from aurweb import db, filters, l10n, time, util
from aurweb.models import Session, User
from aurweb.models.account_type import ACCOUNT_TYPE_ID
class StubQuery:
"""Acts as a stubbed version of an orm.Query. Typically used
to masquerade fake records for an AnonymousUser."""
def filter(self, *args):
return StubQuery()
def scalar(self):
return 0
class AnonymousUser:
"""A stubbed User class used when an unauthenticated User
makes a request against FastAPI."""
# Stub attributes used to mimic a real user.
ID = 0
Username = "N/A"
Email = "N/A"
class AccountType:
"""A stubbed AccountType static class. In here, we use an ID
and AccountType which do not exist in our constant records.
All records primary keys (AccountType.ID) should be non-zero,
so using a zero here means that we'll never match against a
real AccountType."""
ID = 0
AccountType = "Anonymous"
# AccountTypeID == AccountType.ID; assign a stubbed column.
AccountTypeID = AccountType.ID
LangPreference = aurweb.config.get("options", "default_lang")
Timezone = aurweb.config.get("options", "default_timezone")
Suspended = 0
InactivityTS = 0
# A stub ssh_pub_key relationship.
ssh_pub_key = None
# Add stubbed relationship backrefs.
notifications = StubQuery()
package_votes = StubQuery()
# A nonce attribute, needed for all browser sessions; set in __init__.
nonce = None
def __init__(self):
self.nonce = util.make_nonce()
@staticmethod
def is_authenticated():
return False
@staticmethod
def is_package_maintainer():
return False
@staticmethod
def is_developer():
return False
@staticmethod
def is_elevated():
return False
@staticmethod
def has_credential(credential, **kwargs):
return False
@staticmethod
def voted_for(package):
return False
@staticmethod
def notified(package):
return False
class BasicAuthBackend(AuthenticationBackend):
@db.async_retry_deadlock
async def authenticate(self, conn: HTTPConnection):
unauthenticated = (None, AnonymousUser())
sid = conn.cookies.get("AURSID")
if not sid:
return unauthenticated
timeout = aurweb.config.getint("options", "login_timeout")
remembered = conn.cookies.get("AURREMEMBER") == "True"
if remembered:
timeout = aurweb.config.getint("options", "persistent_cookie_timeout")
# If no session with sid and a LastUpdateTS now or later exists.
now_ts = time.utcnow()
record = db.query(Session).filter(Session.SessionID == sid).first()
if not record:
return unauthenticated
elif record.LastUpdateTS < (now_ts - timeout):
with db.begin():
db.delete_all([record])
return unauthenticated
# At this point, we cannot have an invalid user if the record
# exists, due to ForeignKey constraints in the schema upheld
# by mysqlclient.
user = db.query(User).filter(User.ID == record.UsersID).first()
user.nonce = util.make_nonce()
user.authenticated = True
return AuthCredentials(["authenticated"]), user
def _auth_required(auth_goal: bool = True):
"""
Enforce a user's authentication status, bringing them to the login page
or homepage if their authentication status does not match the goal.
NOTE: This function should not need to be used in downstream code.
See `requires_auth` and `requires_guest` for decorators meant to be
used on routes (they're a bit more implicitly understandable).
:param auth_goal: Whether authentication is required or entirely disallowed
for a user to perform this request.
:return: Return the FastAPI function this decorator wraps.
"""
def decorator(func):
@functools.wraps(func)
async def wrapper(request, *args, **kwargs):
if request.user.is_authenticated() == auth_goal:
return await func(request, *args, **kwargs)
url = "/"
if auth_goal is False:
return RedirectResponse(url, status_code=int(HTTPStatus.SEE_OTHER))
# Use the request path when the user can visit a page directly but
# is not authenticated and use the Referer header if visiting the
# page itself is not directly possible (e.g. submitting a form).
if request.method in ("GET", "HEAD"):
url = request.url.path
elif referer := request.headers.get("Referer"):
aur = aurweb.config.get("options", "aur_location") + "/"
if not referer.startswith(aur):
_ = l10n.get_translator_for_request(request)
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=_("Bad Referer header."),
)
url = referer[len(aur) - 1 :]
url = "/login?" + filters.urlencode({"next": url})
return RedirectResponse(url, status_code=int(HTTPStatus.SEE_OTHER))
return wrapper
return decorator
def requires_auth(func: Callable) -> Callable:
"""Require an authenticated session for a particular route."""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await _auth_required(True)(func)(*args, **kwargs)
return wrapper
def requires_guest(func: Callable) -> Callable:
"""Require a guest (unauthenticated) session for a particular route."""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await _auth_required(False)(func)(*args, **kwargs)
return wrapper
def account_type_required(one_of: set):
"""A decorator that can be used on FastAPI routes to dictate
that a user belongs to one of the types defined in one_of.
This decorator should be run after an @auth_required(True) is
dictated.
- Example code:
@router.get('/some_route')
@auth_required(True)
@account_type_required({"Package Maintainer", "Package Maintainer & Developer"})
async def some_route(request: fastapi.Request):
return Response()
:param one_of: A set consisting of strings to match against AccountType.
:return: Return the FastAPI function this decorator wraps.
"""
# Convert any account type string constants to their integer IDs.
one_of = {ACCOUNT_TYPE_ID[atype] for atype in one_of if isinstance(atype, str)}
def decorator(func):
@functools.wraps(func)
async def wrapper(request: fastapi.Request, *args, **kwargs):
if request.user.AccountTypeID not in one_of:
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))
return await func(request, *args, **kwargs)
return wrapper
return decorator
from aurweb.models.account_type import (
DEVELOPER_ID,
PACKAGE_MAINTAINER_AND_DEV_ID,
PACKAGE_MAINTAINER_ID,
USER_ID,
)
from aurweb.models.user import User
ACCOUNT_CHANGE_TYPE = 1
ACCOUNT_EDIT = 2
ACCOUNT_EDIT_DEV = 3
ACCOUNT_LAST_LOGIN = 4
ACCOUNT_SEARCH = 5
ACCOUNT_LIST_COMMENTS = 28
COMMENT_DELETE = 6
COMMENT_UNDELETE = 27
COMMENT_VIEW_DELETED = 22
COMMENT_EDIT = 25
COMMENT_PIN = 26
PKGBASE_ADOPT = 7
PKGBASE_SET_KEYWORDS = 8
PKGBASE_DELETE = 9
PKGBASE_DISOWN = 10
PKGBASE_EDIT_COMAINTAINERS = 24
PKGBASE_FLAG = 11
PKGBASE_LIST_VOTERS = 12
PKGBASE_NOTIFY = 13
PKGBASE_UNFLAG = 15
PKGBASE_VOTE = 16
PKGREQ_FILE = 23
PKGREQ_CLOSE = 17
PKGREQ_LIST = 18
PM_ADD_VOTE = 19
PM_LIST_VOTES = 20
PM_VOTE = 21
PKGBASE_MERGE = 29
user_developer_or_package_maintainer = set(
[USER_ID, PACKAGE_MAINTAINER_ID, DEVELOPER_ID, PACKAGE_MAINTAINER_AND_DEV_ID]
)
package_maintainer_or_dev = set(
[PACKAGE_MAINTAINER_ID, DEVELOPER_ID, PACKAGE_MAINTAINER_AND_DEV_ID]
)
developer = set([DEVELOPER_ID, PACKAGE_MAINTAINER_AND_DEV_ID])
package_maintainer = set([PACKAGE_MAINTAINER_ID, PACKAGE_MAINTAINER_AND_DEV_ID])
cred_filters = {
PKGBASE_FLAG: user_developer_or_package_maintainer,
PKGBASE_NOTIFY: user_developer_or_package_maintainer,
PKGBASE_VOTE: user_developer_or_package_maintainer,
PKGREQ_FILE: user_developer_or_package_maintainer,
ACCOUNT_CHANGE_TYPE: package_maintainer_or_dev,
ACCOUNT_EDIT: package_maintainer_or_dev,
ACCOUNT_LAST_LOGIN: package_maintainer_or_dev,
ACCOUNT_LIST_COMMENTS: package_maintainer_or_dev,
ACCOUNT_SEARCH: package_maintainer_or_dev,
COMMENT_DELETE: package_maintainer_or_dev,
COMMENT_UNDELETE: package_maintainer_or_dev,
COMMENT_VIEW_DELETED: package_maintainer_or_dev,
COMMENT_EDIT: package_maintainer_or_dev,
COMMENT_PIN: package_maintainer_or_dev,
PKGBASE_ADOPT: package_maintainer_or_dev,
PKGBASE_SET_KEYWORDS: package_maintainer_or_dev,
PKGBASE_DELETE: package_maintainer_or_dev,
PKGBASE_EDIT_COMAINTAINERS: package_maintainer_or_dev,
PKGBASE_DISOWN: package_maintainer_or_dev,
PKGBASE_LIST_VOTERS: package_maintainer_or_dev,
PKGBASE_UNFLAG: package_maintainer_or_dev,
PKGREQ_CLOSE: package_maintainer_or_dev,
PKGREQ_LIST: package_maintainer_or_dev,
PM_ADD_VOTE: package_maintainer,
PM_LIST_VOTES: package_maintainer_or_dev,
PM_VOTE: package_maintainer,
ACCOUNT_EDIT_DEV: developer,
PKGBASE_MERGE: package_maintainer_or_dev,
}
def has_credential(user: User, credential: int, approved: list = tuple()):
if user in approved:
return True
return user.AccountTypeID in cred_filters[credential]
from datetime import UTC, datetime
class Benchmark:
def __init__(self):
self.start()
def _timestamp(self) -> float:
"""Generate a timestamp."""
return float(datetime.now(UTC).timestamp())
def start(self) -> int:
"""Start a benchmark."""
self.current = self._timestamp()
return self.current
def end(self):
"""Return the diff between now - start()."""
n = self._timestamp() - self.current
self.current = float(0)
return n
import pickle
from typing import Any, Callable
from sqlalchemy import orm
from aurweb import config
from aurweb.aur_redis import redis_connection
from aurweb.prometheus import SEARCH_REQUESTS
_redis = redis_connection()
def lambda_cache(key: str, value: Callable[[], Any], expire: int = None) -> list:
"""Store and retrieve lambda results via redis cache.
:param key: Redis key
:param value: Lambda callable returning the value
:param expire: Optional expiration in seconds
:return: result of callable or cache
"""
result = _redis.get(key)
if result is not None:
return pickle.loads(result)
_redis.set(key, (pickle.dumps(result := value())), ex=expire)
return result
def db_count_cache(key: str, query: orm.Query, expire: int = None) -> int:
"""Store and retrieve a query.count() via redis cache.
:param key: Redis key
:param query: SQLAlchemy ORM query
:param expire: Optional expiration in seconds
:return: query.count()
"""
result = _redis.get(key)
if result is None:
_redis.set(key, (result := int(query.count())))
if expire:
_redis.expire(key, expire)
return int(result)
def db_query_cache(key: str, query: orm.Query, expire: int = None) -> list:
"""Store and retrieve query results via redis cache.
:param key: Redis key
:param query: SQLAlchemy ORM query
:param expire: Optional expiration in seconds
:return: query.all()
"""
result = _redis.get(key)
if result is None:
SEARCH_REQUESTS.labels(cache="miss").inc()
if _redis.dbsize() > config.getint("cache", "max_search_entries", 50000):
return query.all()
_redis.set(key, (result := pickle.dumps(query.all())))
if expire:
_redis.expire(key, expire)
else:
SEARCH_REQUESTS.labels(cache="hit").inc()
return pickle.loads(result)
""" This module consists of aurweb's CAPTCHA utility functions and filters. """
import hashlib
from jinja2 import pass_context
from sqlalchemy import func
from aurweb.db import query
from aurweb.models import User
from aurweb.templates import register_filter
def get_captcha_salts():
"""Produce salts based on the current user count."""
count = query(func.count(User.ID)).scalar()
salts = []
for i in range(0, 6):
salts.append(f"aurweb-{count - i}")
return salts
def get_captcha_token(salt):
"""Produce a token for the CAPTCHA salt."""
return hashlib.md5(salt.encode()).hexdigest()[:3]
def get_captcha_challenge(salt):
"""Get a CAPTCHA challenge string (shell command) for a salt."""
token = get_captcha_token(salt)
return f"LC_ALL=C pacman -V|sed -r 's#[0-9]+#{token}#g'|md5sum|cut -c1-6"
def get_captcha_answer(token):
"""Compute the answer via md5 of the real template text, return the
first six digits of the hexadecimal hash."""
text = r"""
.--. Pacman v%s.%s.%s - libalpm v%s.%s.%s
/ _.-' .-. .-. .-. Copyright (C) %s-%s Pacman Development Team
\ '-. '-' '-' '-' Copyright (C) %s-%s Judd Vinet
'--'
This program may be freely redistributed under
the terms of the GNU General Public License.
""" % tuple(
[token] * 10
)
return hashlib.md5((text + "\n").encode()).hexdigest()[:6]
@register_filter("captcha_salt")
@pass_context
def captcha_salt_filter(context):
"""Returns the most recent CAPTCHA salt in the list of salts."""
salts = get_captcha_salts()
return salts[0]
@register_filter("captcha_cmdline")
@pass_context
def captcha_cmdline_filter(context, salt):
"""Returns a CAPTCHA challenge for a given salt."""
return get_captcha_challenge(salt)
import configparser
import os
from typing import Any
import tomlkit
_parser = None
......@@ -8,10 +11,11 @@ def _get_parser():
global _parser
if not _parser:
path = os.environ.get('AUR_CONFIG', '/etc/aurweb/config')
defaults = os.environ.get('AUR_CONFIG_DEFAULTS', path + '.defaults')
path = os.environ.get("AUR_CONFIG", "/etc/aurweb/config")
defaults = os.environ.get("AUR_CONFIG_DEFAULTS", path + ".defaults")
_parser = configparser.RawConfigParser()
_parser.optionxform = lambda option: option
if os.path.isfile(defaults):
with open(defaults) as f:
_parser.read_file(f)
......@@ -20,13 +24,56 @@ def _get_parser():
return _parser
def rehash():
"""Globally rehash the configuration parser."""
global _parser
_parser = None
_get_parser()
def get_with_fallback(section, option, fallback):
return _get_parser().get(section, option, fallback=fallback)
def get(section, option):
return _get_parser().get(section, option)
def _get_project_meta():
with open(os.path.join(get("options", "aurwebdir"), "pyproject.toml")) as pyproject:
file_contents = pyproject.read()
return tomlkit.parse(file_contents)["tool"]["poetry"]
# Publicly visible version of aurweb. This is used to display
# aurweb versioning in the footer and must be maintained.
AURWEB_VERSION = str(_get_project_meta()["version"])
def getboolean(section, option):
return _get_parser().getboolean(section, option)
def getint(section, option):
return _get_parser().getint(section, option)
def getint(section, option, fallback=None):
return _get_parser().getint(section, option, fallback=fallback)
def get_section(section):
if section in _get_parser().sections():
return _get_parser()[section]
def unset_option(section: str, option: str) -> None:
_get_parser().remove_option(section, option)
def set_option(section: str, option: str, value: Any) -> None:
_get_parser().set(section, option, value)
return value
def save() -> None:
aur_config = os.environ.get("AUR_CONFIG", "/etc/aurweb/config")
with open(aur_config, "w") as fp:
_get_parser().write(fp)
def samesite() -> str:
"""Produce cookie SameSite value.
Currently this is hard-coded to return "lax"
:returns "lax"
"""
return "lax"
try:
import mysql.connector
except ImportError:
pass
# Supported database drivers.
DRIVERS = {"mysql": "mysql+mysqldb"}
try:
import sqlite3
except ImportError:
pass
import aurweb.config
def make_random_value(table: str, column: str, length: int):
"""Generate a unique, random value for a string column in a table.
engine = None # See get_engine
:return: A unique string that is not in the database
"""
import aurweb.util
string = aurweb.util.make_random_string(length)
while query(table).filter(column == string).first():
string = aurweb.util.make_random_string(length)
return string
def test_name() -> str:
"""
Return the unhashed database name.
The unhashed database name is determined (lower = higher priority) by:
-------------------------------------------
1. {test_suite} portion of PYTEST_CURRENT_TEST
2. aurweb.config.get("database", "name")
During `pytest` runs, the PYTEST_CURRENT_TEST environment variable
is set to the current test in the format `{test_suite}::{test_func}`.
This allows tests to use a suite-specific database for its runs,
which decouples database state from test suites.
:return: Unhashed database name
"""
import os
import aurweb.config
db = os.environ.get("PYTEST_CURRENT_TEST", aurweb.config.get("database", "name"))
return db.split(":")[0]
def name() -> str:
"""
Return sanitized database name that can be used for tests or production.
If test_name() starts with "test/", the database name is SHA-1 hashed,
prefixed with 'db', and returned. Otherwise, test_name() is passed
through and not hashed at all.
:return: SHA1-hashed database name prefixed with 'db'
"""
dbname = test_name()
if not dbname.startswith("test/"):
return dbname
import hashlib
sha1 = hashlib.sha1(dbname.encode()).hexdigest()
return "db" + sha1
# Module-private global memo used to store SQLAlchemy sessions.
_sessions = dict()
def get_session(engine=None):
"""Return aurweb.db's global session."""
dbname = name()
global _sessions
if dbname not in _sessions:
from sqlalchemy.orm import scoped_session, sessionmaker
if not engine: # pragma: no cover
engine = get_engine()
Session = scoped_session(
sessionmaker(autocommit=True, autoflush=False, bind=engine)
)
_sessions[dbname] = Session()
return _sessions.get(dbname)
def pop_session(dbname: str) -> None:
"""
Pop a Session out of the private _sessions memo.
:param dbname: Database name
:raises KeyError: When `dbname` does not exist in the memo
"""
global _sessions
_sessions.pop(dbname)
def refresh(model):
"""
Refresh the session's knowledge of `model`.
:returns: Passed in `model`
"""
get_session().refresh(model)
return model
def query(Model, *args, **kwargs):
"""
Perform an ORM query against the database session.
This method also runs Query.filter on the resulting model
query with *args and **kwargs.
:param Model: Declarative ORM class
"""
return get_session().query(Model).filter(*args, **kwargs)
def create(Model, *args, **kwargs):
"""
Create a record and add() it to the database session.
:param Model: Declarative ORM class
:return: Model instance
"""
instance = Model(*args, **kwargs)
return add(instance)
def delete(model) -> None:
"""
Delete a set of records found by Query.filter(*args, **kwargs).
:param Model: Declarative ORM class
"""
get_session().delete(model)
def delete_all(iterable) -> None:
"""Delete each instance found in `iterable`."""
import aurweb.util
session_ = get_session()
aurweb.util.apply_all(iterable, session_.delete)
def rollback() -> None:
"""Rollback the database session."""
get_session().rollback()
def add(model):
"""Add `model` to the database session."""
get_session().add(model)
return model
def begin():
"""Begin an SQLAlchemy SessionTransaction."""
return get_session().begin()
def retry_deadlock(func):
from sqlalchemy.exc import OperationalError
def wrapper(*args, _i: int = 0, **kwargs):
# Retry 10 times, then raise the exception
# If we fail before the 10th, recurse into `wrapper`
# If we fail on the 10th, continue to throw the exception
limit = 10
try:
return func(*args, **kwargs)
except OperationalError as exc:
if _i < limit and "Deadlock found" in str(exc):
# Retry on deadlock by recursing into `wrapper`
return wrapper(*args, _i=_i + 1, **kwargs)
# Otherwise, just raise the exception
raise exc
return wrapper
def async_retry_deadlock(func):
from sqlalchemy.exc import OperationalError
async def wrapper(*args, _i: int = 0, **kwargs):
# Retry 10 times, then raise the exception
# If we fail before the 10th, recurse into `wrapper`
# If we fail on the 10th, continue to throw the exception
limit = 10
try:
return await func(*args, **kwargs)
except OperationalError as exc:
if _i < limit and "Deadlock found" in str(exc):
# Retry on deadlock by recursing into `wrapper`
return await wrapper(*args, _i=_i + 1, **kwargs)
# Otherwise, just raise the exception
raise exc
return wrapper
def get_sqlalchemy_url():
"""
Build an SQLAlchemy for use with create_engine based on the aurweb configuration.
Build an SQLAlchemy URL for use with create_engine.
:return: sqlalchemy.engine.url.URL
"""
import sqlalchemy
aur_db_backend = aurweb.config.get('database', 'backend')
if aur_db_backend == 'mysql':
return sqlalchemy.engine.url.URL(
'mysql+mysqlconnector',
username=aurweb.config.get('database', 'user'),
password=aurweb.config.get('database', 'password'),
host=aurweb.config.get('database', 'host'),
database=aurweb.config.get('database', 'name'),
query={
'unix_socket': aurweb.config.get('database', 'socket'),
},
from sqlalchemy.engine.url import URL
import aurweb.config
constructor = URL
parts = sqlalchemy.__version__.split(".")
major = int(parts[0])
minor = int(parts[1])
if major == 1 and minor >= 4: # pragma: no cover
constructor = URL.create
aur_db_backend = aurweb.config.get("database", "backend")
if aur_db_backend == "mysql":
param_query = {}
port = aurweb.config.get_with_fallback("database", "port", None)
if not port:
param_query["unix_socket"] = aurweb.config.get("database", "socket")
return constructor(
DRIVERS.get(aur_db_backend),
username=aurweb.config.get("database", "user"),
password=aurweb.config.get_with_fallback(
"database", "password", fallback=None
),
host=aurweb.config.get("database", "host"),
database=name(),
port=port,
query=param_query,
)
elif aur_db_backend == 'sqlite':
return sqlalchemy.engine.url.URL(
'sqlite',
database=aurweb.config.get('database', 'name'),
elif aur_db_backend == "sqlite":
return constructor(
"sqlite",
database=aurweb.config.get("database", "name"),
)
else:
raise ValueError('unsupported database backend')
raise ValueError("unsupported database backend")
def sqlite_regexp(regex, item) -> bool: # pragma: no cover
"""Method which mimics SQL's REGEXP for SQLite."""
import re
return bool(re.search(regex, str(item)))
def setup_sqlite(engine) -> None: # pragma: no cover
"""Perform setup for an SQLite engine."""
from sqlalchemy import event
@event.listens_for(engine, "connect")
def do_begin(conn, record):
import functools
create_deterministic_function = functools.partial(
conn.create_function, deterministic=True
)
create_deterministic_function("REGEXP", 2, sqlite_regexp)
# Module-private global memo used to store SQLAlchemy engines.
_engines = dict()
def get_engine():
def get_engine(dbname: str = None, echo: bool = False):
"""
Return the global SQLAlchemy engine.
Return the SQLAlchemy engine for `dbname`.
The engine is created on the first call to get_engine and then stored in the
`engine` global variable for the next calls.
:param dbname: Database name (default: aurweb.db.name())
:param echo: Flag passed through to sqlalchemy.create_engine
:return: SQLAlchemy Engine instance
"""
from sqlalchemy import create_engine
global engine
if engine is None:
import aurweb.config
if not dbname:
dbname = name()
global _engines
if dbname not in _engines:
db_backend = aurweb.config.get("database", "backend")
connect_args = dict()
if aurweb.config.get("database", "backend") == "sqlite":
# check_same_thread is for a SQLite technicality
# https://fastapi.tiangolo.com/tutorial/sql-databases/#note
is_sqlite = bool(db_backend == "sqlite")
if is_sqlite: # pragma: no cover
connect_args["check_same_thread"] = False
engine = create_engine(get_sqlalchemy_url(), connect_args=connect_args)
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
return engine
kwargs = {"echo": echo, "connect_args": connect_args}
from sqlalchemy import create_engine
if aurweb.config.get("tracing", "otlp_endpoint"):
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
engine = create_engine(get_sqlalchemy_url(), **kwargs)
SQLAlchemyInstrumentor().instrument(engine=engine)
_engines[dbname] = engine
else:
_engines[dbname] = create_engine(get_sqlalchemy_url(), **kwargs)
if is_sqlite: # pragma: no cover
setup_sqlite(_engines.get(dbname))
return _engines.get(dbname)
def pop_engine(dbname: str) -> None:
"""
Pop an Engine out of the private _engines memo.
:param dbname: Database name
:raises KeyError: When `dbname` does not exist in the memo
"""
global _engines
_engines.pop(dbname)
def kill_engine() -> None:
"""Close the current session and dispose of the engine."""
dbname = name()
session = get_session()
session.close()
pop_session(dbname)
engine = get_engine()
engine.dispose()
pop_engine(dbname)
def connect():
......@@ -72,40 +350,34 @@ def connect():
return get_engine().connect()
class Connection:
class ConnectionExecutor:
_conn = None
_paramstyle = None
def __init__(self):
aur_db_backend = aurweb.config.get('database', 'backend')
if aur_db_backend == 'mysql':
aur_db_host = aurweb.config.get('database', 'host')
aur_db_name = aurweb.config.get('database', 'name')
aur_db_user = aurweb.config.get('database', 'user')
aur_db_pass = aurweb.config.get('database', 'password')
aur_db_socket = aurweb.config.get('database', 'socket')
self._conn = mysql.connector.connect(host=aur_db_host,
user=aur_db_user,
passwd=aur_db_pass,
db=aur_db_name,
unix_socket=aur_db_socket,
buffered=True)
self._paramstyle = mysql.connector.paramstyle
elif aur_db_backend == 'sqlite':
aur_db_name = aurweb.config.get('database', 'name')
self._conn = sqlite3.connect(aur_db_name)
def __init__(self, conn, backend=None):
import aurweb.config
backend = backend or aurweb.config.get("database", "backend")
self._conn = conn
if backend == "mysql":
self._paramstyle = "format"
elif backend == "sqlite":
import sqlite3
self._paramstyle = sqlite3.paramstyle
else:
raise ValueError('unsupported database backend')
def execute(self, query, params=()):
if self._paramstyle in ('format', 'pyformat'):
query = query.replace('%', '%%').replace('?', '%s')
elif self._paramstyle == 'qmark':
def paramstyle(self):
return self._paramstyle
def execute(self, query, params=()): # pragma: no cover
# TODO: SQLite support has been removed in FastAPI. It remains
# here to fund its support for the Sharness testsuite.
if self._paramstyle in ("format", "pyformat"):
query = query.replace("%", "%%").replace("?", "%s")
elif self._paramstyle == "qmark":
pass
else:
raise ValueError('unsupported paramstyle')
raise ValueError("unsupported paramstyle")
cur = self._conn.cursor()
cur.execute(query, params)
......@@ -117,3 +389,51 @@ class Connection:
def close(self):
self._conn.close()
class Connection:
_executor = None
_conn = None
def __init__(self):
import aurweb.config
aur_db_backend = aurweb.config.get("database", "backend")
if aur_db_backend == "mysql":
import MySQLdb
aur_db_host = aurweb.config.get("database", "host")
aur_db_name = name()
aur_db_user = aurweb.config.get("database", "user")
aur_db_pass = aurweb.config.get_with_fallback("database", "password", str())
aur_db_socket = aurweb.config.get("database", "socket")
self._conn = MySQLdb.connect(
host=aur_db_host,
user=aur_db_user,
passwd=aur_db_pass,
db=aur_db_name,
unix_socket=aur_db_socket,
)
elif aur_db_backend == "sqlite": # pragma: no cover
# TODO: SQLite support has been removed in FastAPI. It remains
# here to fund its support for Sharness testsuite.
import math
import sqlite3
aur_db_name = aurweb.config.get("database", "name")
self._conn = sqlite3.connect(aur_db_name)
self._conn.create_function("POWER", 2, math.pow)
else:
raise ValueError("unsupported database backend")
self._conn = ConnectionExecutor(self._conn, aur_db_backend)
def execute(self, query, params=()):
return self._conn.execute(query, params)
def commit(self):
self._conn.commit()
def close(self):
self._conn.close()
""" Constant default values centralized in one place. """
# Default [O]ffset
O = 0
# Default [P]er [P]age
PP = 50
# Default Comments Per Page
COMMENTS_PER_PAGE = 10
# A whitelist of valid PP values
PP_WHITELIST = {50, 100, 250}
# Default `by` parameter for RPC search.
RPC_SEARCH_BY = "name-desc"
def fallback_pp(per_page: int) -> int:
"""If `per_page` is a valid value in PP_WHITELIST, return it.
Otherwise, return defaults.PP."""
if per_page not in PP_WHITELIST:
return PP
return per_page
import functools
from typing import Any, Callable
import fastapi
class AurwebException(Exception):
pass
......@@ -12,64 +18,95 @@ class BannedException(AurwebException):
class PermissionDeniedException(AurwebException):
def __init__(self, user):
msg = 'permission denied: {:s}'.format(user)
msg = "permission denied: {:s}".format(user)
super(PermissionDeniedException, self).__init__(msg)
class BrokenUpdateHookException(AurwebException):
def __init__(self, cmd):
msg = 'broken update hook: {:s}'.format(cmd)
msg = "broken update hook: {:s}".format(cmd)
super(BrokenUpdateHookException, self).__init__(msg)
class InvalidUserException(AurwebException):
def __init__(self, user):
msg = 'unknown user: {:s}'.format(user)
msg = "unknown user: {:s}".format(user)
super(InvalidUserException, self).__init__(msg)
class InvalidPackageBaseException(AurwebException):
def __init__(self, pkgbase):
msg = 'package base not found: {:s}'.format(pkgbase)
msg = "package base not found: {:s}".format(pkgbase)
super(InvalidPackageBaseException, self).__init__(msg)
class InvalidRepositoryNameException(AurwebException):
def __init__(self, pkgbase):
msg = 'invalid repository name: {:s}'.format(pkgbase)
msg = "invalid repository name: {:s}".format(pkgbase)
super(InvalidRepositoryNameException, self).__init__(msg)
class PackageBaseExistsException(AurwebException):
def __init__(self, pkgbase):
msg = 'package base already exists: {:s}'.format(pkgbase)
msg = "package base already exists: {:s}".format(pkgbase)
super(PackageBaseExistsException, self).__init__(msg)
class InvalidReasonException(AurwebException):
def __init__(self, reason):
msg = 'invalid reason: {:s}'.format(reason)
msg = "invalid reason: {:s}".format(reason)
super(InvalidReasonException, self).__init__(msg)
class InvalidCommentException(AurwebException):
def __init__(self, comment):
msg = 'comment is too short: {:s}'.format(comment)
msg = "comment is too short: {:s}".format(comment)
super(InvalidCommentException, self).__init__(msg)
class AlreadyVotedException(AurwebException):
def __init__(self, comment):
msg = 'already voted for package base: {:s}'.format(comment)
msg = "already voted for package base: {:s}".format(comment)
super(AlreadyVotedException, self).__init__(msg)
class NotVotedException(AurwebException):
def __init__(self, comment):
msg = 'missing vote for package base: {:s}'.format(comment)
msg = "missing vote for package base: {:s}".format(comment)
super(NotVotedException, self).__init__(msg)
class InvalidArgumentsException(AurwebException):
def __init__(self, msg):
super(InvalidArgumentsException, self).__init__(msg)
class RPCError(AurwebException):
pass
class ValidationError(AurwebException):
def __init__(self, data: Any, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data = data
class InvariantError(AurwebException):
pass
def handle_form_exceptions(route: Callable) -> fastapi.Response:
"""
A decorator required when fastapi POST routes are defined.
This decorator populates fastapi's `request.state` with a `form_data`
attribute, which is then used to report form data when exceptions
are caught and reported.
"""
@functools.wraps(route)
async def wrapper(request: fastapi.Request, *args, **kwargs):
request.state.form_data = await request.form()
return await route(request, *args, **kwargs)
return wrapper
import copy
import math
from datetime import UTC, datetime
from typing import Any, Union
from urllib.parse import quote_plus, urlencode
from zoneinfo import ZoneInfo
import fastapi
import paginate
from jinja2 import pass_context
from jinja2.filters import do_format
import aurweb.models
from aurweb import config, l10n
from aurweb.templates import register_filter, register_function
@register_filter("pager_nav")
@pass_context
def pager_nav(context: dict[str, Any], page: int, total: int, prefix: str) -> str:
page = int(page) # Make sure this is an int.
pp = context.get("PP", 50)
# Setup a local query string dict, optionally passed by caller.
q = context.get("q", dict())
search_by = context.get("SeB", None)
if search_by:
q["SeB"] = search_by
sort_by = context.get("SB", None)
if sort_by:
q["SB"] = sort_by
def create_url(page: int):
nonlocal q
offset = max(page * pp - pp, 0)
qs = to_qs(extend_query(q, ["O", offset]))
return f"{prefix}?{qs}"
# Use the paginate module to produce our linkage.
pager = paginate.Page(
[], page=page + 1, items_per_page=pp, item_count=total, url_maker=create_url
)
return pager.pager(
link_attr={"class": "page"},
curpage_attr={"class": "page"},
separator="&nbsp",
format="$link_first $link_previous ~5~ $link_next $link_last",
symbol_first="« First",
symbol_previous="‹ Previous",
symbol_next="Next ›",
symbol_last="Last »",
)
@register_function("config_getint")
def config_getint(section: str, key: str) -> int:
return config.getint(section, key)
@register_function("round")
def do_round(f: float) -> int:
return round(f)
@register_filter("tr")
@pass_context
def tr(context: dict[str, Any], value: str):
"""A translation filter; example: {{ "Hello" | tr("de") }}."""
_ = l10n.get_translator_for_request(context.get("request"))
return _(value)
@register_filter("tn")
@pass_context
def tn(context: dict[str, Any], count: int, singular: str, plural: str) -> str:
"""A singular and plural translation filter.
Example:
{{ some_integer | tn("singular %d", "plural %d") }}
:param context: Response context
:param count: The number used to decide singular or plural state
:param singular: The singular translation
:param plural: The plural translation
:return: Translated string
"""
gettext = l10n.get_raw_translator_for_request(context.get("request"))
return gettext.ngettext(singular, plural, count)
@register_filter("dt")
def timestamp_to_datetime(timestamp: int):
return datetime.fromtimestamp(timestamp, UTC)
@register_filter("as_timezone")
def as_timezone(dt: datetime, timezone: str):
return dt.astimezone(tz=ZoneInfo(timezone))
@register_filter("extend_query")
def extend_query(query: dict[str, Any], *additions) -> dict[str, Any]:
"""Add additional key value pairs to query."""
q = copy.copy(query)
for k, v in list(additions):
q[k] = v
return q
@register_filter("urlencode")
def to_qs(query: dict[str, Any]) -> str:
return urlencode(query, doseq=True)
@register_filter("get_vote")
def get_vote(voteinfo, request: fastapi.Request):
from aurweb.models import Vote
return voteinfo.votes.filter(Vote.User == request.user).first()
@register_filter("number_format")
def number_format(value: float, places: int):
"""A converter function similar to PHP's number_format."""
return f"{value:.{places}f}"
@register_filter("account_url")
@pass_context
def account_url(context: dict[str, Any], user: "aurweb.models.user.User") -> str:
base = aurweb.config.get("options", "aur_location")
return f"{base}/account/{user.Username}"
@register_filter("quote_plus")
def _quote_plus(*args, **kwargs) -> str:
return quote_plus(*args, **kwargs)
@register_filter("ceil")
def ceil(*args, **kwargs) -> int:
return math.ceil(*args, **kwargs)
@register_function("date_strftime")
@pass_context
def date_strftime(context: dict[str, Any], dt: Union[int, datetime], fmt: str) -> str:
if isinstance(dt, int):
dt = timestamp_to_datetime(dt)
tz = context.get("timezone")
return as_timezone(dt, tz).strftime(fmt)
@register_function("date_display")
@pass_context
def date_display(context: dict[str, Any], dt: Union[int, datetime]) -> str:
return date_strftime(context, dt, "%Y-%m-%d (%Z)")
@register_function("datetime_display")
@pass_context
def datetime_display(context: dict[str, Any], dt: Union[int, datetime]) -> str:
return date_strftime(context, dt, "%Y-%m-%d %H:%M (%Z)")
@register_filter("format")
def safe_format(value: str, *args: Any, **kwargs: Any) -> str:
"""Wrapper for jinja2 format function to perform additional checks."""
# If we don't have anything to be formatted, just return the value.
# We have some translations that do not contain placeholders for replacement.
# In these cases the jinja2 function is throwing an error:
# "TypeError: not all arguments converted during string formatting"
if "%" not in value:
return value
return do_format(value, *args, **kwargs)