Verified Commit 4103ab49 authored by Kevin Morris's avatar Kevin Morris
Browse files

housekeep(fastapi): rework aurweb.db session API



Changes:
-------
- Add aurweb.db.get_session()
    - Returns aurweb.db's global `session` instance
    - Provides us a way to change the implementation of the session
      instance without interrupting user code.
- Use aurweb.db.get_session() in session API methods
- Add docstrings to session API methods
- Refactor aurweb.db.delete
    - Normalize aurweb.db.delete to an alias of session.delete
- Refresh instances in places we depend on their non-PK columns
  being up to date.

Signed-off-by: Kevin Morris's avatarKevin Morris <kevr@0cost.org>
parent f8ba2c53
......@@ -13,7 +13,7 @@ from starlette.requests import HTTPConnection
import aurweb.config
from aurweb import l10n, util
from aurweb import db, l10n, util
from aurweb.models import Session, User
from aurweb.models.account_type import ACCOUNT_TYPE_ID
from aurweb.templates import make_variable_context, render_template
......@@ -98,14 +98,12 @@ class AnonymousUser:
class BasicAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection):
from aurweb.db import session
sid = conn.cookies.get("AURSID")
if not sid:
return (None, AnonymousUser())
now_ts = datetime.utcnow().timestamp()
record = session.query(Session).filter(
record = db.query(Session).filter(
and_(Session.SessionID == sid,
Session.LastUpdateTS >= now_ts)).first()
......@@ -116,7 +114,7 @@ class BasicAuthBackend(AuthenticationBackend):
# At this point, we cannot have an invalid user if the record
# exists, due to ForeignKey constraints in the schema upheld
# by mysqlclient.
user = session.query(User).filter(User.ID == record.UsersID).first()
user = db.query(User).filter(User.ID == record.UsersID).first()
user.nonce = util.make_nonce()
user.authenticated = True
......
......@@ -2,10 +2,10 @@ import functools
import math
import re
from typing import Iterable
from typing import Iterable, NewType
from sqlalchemy import event
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import Query, scoped_session
import aurweb.config
import aurweb.util
......@@ -22,6 +22,9 @@ session = None
# Global introspected object memo.
introspected = dict()
# A mocked up type.
Base = NewType("aurweb.models.declarative_base.Base", "Base")
def make_random_value(table: str, column: str):
""" Generate a unique, random value for a string column in a table.
......@@ -58,55 +61,69 @@ def make_random_value(table: str, column: str):
return string
def query(model, *args, **kwargs):
return session.query(model).filter(*args, **kwargs)
def get_session():
""" Return aurweb.db's global session. """
return session
def create(model, *args, **kwargs):
instance = model(*args, **kwargs)
return add(instance)
def refresh(model: Base) -> Base:
""" Refresh the session's knowledge of `model`. """
get_session().refresh(model)
return model
def delete(model, *args, **kwargs):
instance = session.query(model).filter(*args, **kwargs)
for record in instance:
session.delete(record)
def query(Model: Base, *args, **kwargs) -> Query:
"""
Perform an ORM query against the database session.
This method also runs Query.filter on the resulting model
query with *args and **kwargs.
def delete_all(iterable: Iterable):
with begin():
for obj in iterable:
session.delete(obj)
:param Model: Declarative ORM class
"""
return get_session().query(Model).filter(*args, **kwargs)
def rollback():
session.rollback()
def create(Model: Base, *args, **kwargs) -> Base:
"""
Create a record and add() it to the database session.
:param Model: Declarative ORM class
:return: Model instance
"""
instance = Model(*args, **kwargs)
return add(instance)
def add(model):
session.add(model)
return model
def delete(model: Base) -> None:
"""
Delete a set of records found by Query.filter(*args, **kwargs).
def begin():
""" Begin an SQLAlchemy SessionTransaction.
:param Model: Declarative ORM class
"""
get_session().delete(model)
This context is **required** to perform an modifications to the
database.
Example:
def delete_all(iterable: Iterable) -> None:
""" Delete each instance found in `iterable`. """
session_ = get_session()
aurweb.util.apply_all(iterable, session_.delete)
with db.begin():
object = db.create(...)
# On __exit__, db.commit() is run.
with db.begin():
object = db.delete(...)
# On __exit__, db.commit() is run.
def rollback() -> None:
""" Rollback the database session. """
get_session().rollback()
:return: A new SessionTransaction based on session
"""
return session.begin()
def add(model: Base) -> Base:
""" Add `model` to the database session. """
get_session().add(model)
return model
def begin():
""" Begin an SQLAlchemy SessionTransaction. """
return get_session().begin()
def get_sqlalchemy_url():
......
from fastapi import Request
from aurweb import schema
from aurweb import db, schema
from aurweb.models.declarative import Base
......@@ -10,11 +10,10 @@ class Ban(Base):
__mapper_args__ = {"primary_key": [__table__.c.IPAddress]}
def __init__(self, **kwargs):
self.IPAddress = kwargs.get("IPAddress")
self.BanTS = kwargs.get("BanTS")
super().__init__(**kwargs)
def is_banned(request: Request):
from aurweb.db import session
ip = request.client.host
return session.query(Ban).filter(Ban.IPAddress == ip).first() is not None
exists = db.query(Ban).filter(Ban.IPAddress == ip).exists()
return db.query(exists).scalar()
......@@ -146,7 +146,7 @@ class User(Base):
self.authenticated = False
if self.session:
with db.begin():
db.session.delete(self.session)
db.delete(self.session)
def is_trusted_user(self):
return self.AccountType.ID in {
......
......@@ -110,18 +110,26 @@ def get_pkg_or_base(
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
instance = db.query(cls).filter(cls.Name == name).first()
if cls == models.PackageBase and not instance:
if not instance:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return instance
return db.refresh(instance)
def get_pkgbase_comment(
pkgbase: models.PackageBase, id: int) -> models.PackageComment:
def get_pkgbase_comment(pkgbase: models.PackageBase, id: int) \
-> models.PackageComment:
comment = pkgbase.comments.filter(models.PackageComment.ID == id).first()
if not comment:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return comment
return db.refresh(comment)
def get_pkgreq_by_id(id: int):
pkgreq = db.query(models.PackageRequest).filter(
models.PackageRequest.ID == id).first()
if not pkgreq:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return db.refresh(pkgreq)
@register_filter("out_of_date")
......
......@@ -40,8 +40,10 @@ def _update_ratelimit_db(request: Request):
now = int(datetime.utcnow().timestamp())
time_to_delete = now - window_length
records = db.query(ApiRateLimit).filter(
ApiRateLimit.WindowStart < time_to_delete)
with db.begin():
db.delete(ApiRateLimit, ApiRateLimit.WindowStart < time_to_delete)
db.delete_all(records)
host = request.client.host
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
......
......@@ -4,7 +4,7 @@ import typing
from datetime import datetime
from http import HTTPStatus
from fastapi import APIRouter, Form, HTTPException, Request
from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from sqlalchemy import and_, func, or_
......@@ -20,6 +20,7 @@ from aurweb.models.account_type import (DEVELOPER, DEVELOPER_ID, TRUSTED_USER, T
from aurweb.models.ssh_pub_key import get_fingerprint
from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification
from aurweb.templates import make_context, make_variable_context, render_template
from aurweb.users.util import get_user_by_name
router = APIRouter()
logger = logging.get_logger(__name__)
......@@ -49,6 +50,7 @@ async def passreset_post(request: Request,
return render_template(request, "passreset.html", context,
status_code=HTTPStatus.NOT_FOUND)
db.refresh(user)
if resetkey:
context["resetkey"] = resetkey
......@@ -83,7 +85,7 @@ async def passreset_post(request: Request,
with db.begin():
user.ResetKey = str()
if user.session:
db.session.delete(user.session)
db.delete(user.session)
user.update_password(password)
# Render ?step=complete.
......@@ -458,15 +460,15 @@ def cannot_edit(request, user):
@router.get("/account/{username}/edit", response_class=HTMLResponse)
@auth_required(True, redirect="/account/{username}")
async def account_edit(request: Request,
username: str):
async def account_edit(request: Request, username: str):
user = db.query(models.User, models.User.Username == username).first()
response = cannot_edit(request, user)
if response:
return response
context = await make_variable_context(request, "Accounts")
context["user"] = user
context["user"] = db.refresh(user)
context = make_account_form_context(context, request, user, dict())
return render_template(request, "account/edit.html", context)
......@@ -497,16 +499,14 @@ async def account_edit_post(request: Request,
ON: bool = Form(default=False), # Owner Notify
T: int = Form(default=None),
passwd: str = Form(default=str())):
from aurweb.db import session
user = session.query(models.User).filter(
user = db.query(models.User).filter(
models.User.Username == username).first()
response = cannot_edit(request, user)
if response:
return response
context = await make_variable_context(request, "Accounts")
context["user"] = user
context["user"] = db.refresh(user)
args = dict(await request.form())
context = make_account_form_context(context, request, user, args)
......@@ -575,7 +575,7 @@ async def account_edit_post(request: Request,
user.ssh_pub_key.Fingerprint = fingerprint
elif user.ssh_pub_key:
# Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key)
db.delete(user.ssh_pub_key)
if T and T != user.AccountTypeID:
with db.begin():
......@@ -617,27 +617,16 @@ account_template = (
status_code=HTTPStatus.UNAUTHORIZED)
async def account(request: Request, username: str):
_ = l10n.get_translator_for_request(request)
context = await make_variable_context(request,
_("Account") + " " + username)
user = db.query(models.User, models.User.Username == username).first()
if not user:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
context["user"] = user
context = await make_variable_context(
request, _("Account") + " " + username)
context["user"] = get_user_by_name(username)
return render_template(request, "account/show.html", context)
@router.get("/account/{username}/comments")
@auth_required(redirect="/account/{username}/comments")
async def account_comments(request: Request, username: str):
user = db.query(models.User).filter(
models.User.Username == username
).first()
if not user:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
user = get_user_by_name(username)
context = make_context(request, "Accounts")
context["username"] = username
context["comments"] = user.package_comments.order_by(
......@@ -725,7 +714,7 @@ async def accounts_post(request: Request,
# Finally, order and truncate our users for the current page.
users = query.order_by(*order_by).limit(pp).offset(offset)
context["users"] = users
context["users"] = util.apply_all(users, db.refresh)
return render_template(request, "account/index.html", context)
......@@ -751,6 +740,9 @@ async def terms_of_service(request: Request):
unaccepted = db.query(models.Term).filter(
~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all()
for record in (diffs + unaccepted):
db.refresh(record)
# Translate the 'Terms of Service' part of our page title.
_ = l10n.get_translator_for_request(request)
title = f"AUR {_('Terms of Service')}"
......@@ -782,18 +774,21 @@ async def terms_of_service_post(request: Request,
# We already did the database filters here, so let's just use
# them instead of reiterating the process in terms_of_service.
accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(request, context, accept_needed)
return render_terms_of_service(
request, context, util.apply_all(accept_needed, db.refresh))
with db.begin():
# For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision.
for term in diffs:
db.refresh(term)
accepted_term = request.user.accepted_terms.filter(
models.AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it!
for term in unaccepted:
db.refresh(term)
db.create(models.AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision)
......
......@@ -4,7 +4,7 @@ from typing import Any, Dict, List
from fastapi import APIRouter, Form, HTTPException, Query, Request, Response
from fastapi.responses import JSONResponse, RedirectResponse
from sqlalchemy import and_, case
from sqlalchemy import case
import aurweb.filters
import aurweb.packages.util
......@@ -15,9 +15,9 @@ from aurweb.models.package_request import ACCEPTED_ID, PENDING_ID, REJECTED_ID
from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID
from aurweb.models.request_type import DELETION_ID, MERGE, MERGE_ID
from aurweb.packages.search import PackageSearch
from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, query_notified, query_voted
from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, get_pkgreq_by_id, query_notified, query_voted
from aurweb.scripts import notify, popupdate
from aurweb.scripts.rendercomment import update_comment_render
from aurweb.scripts.rendercomment import update_comment_render_fastapi
from aurweb.templates import make_context, make_variable_context, render_raw_template, render_template
logger = logging.get_logger(__name__)
......@@ -92,7 +92,10 @@ async def packages_get(request: Request, context: Dict[str, Any],
# Insert search results into the context.
results = search.results()
context["packages"] = results.limit(per_page).offset(offset)
packages = results.limit(per_page).offset(offset)
util.apply_all(packages, db.refresh)
context["packages"] = packages
context["packages_voted"] = query_voted(
context.get("packages"), request.user)
context["packages_notified"] = query_notified(
......@@ -132,6 +135,7 @@ def create_request_if_missing(requests: List[models.PackageRequest],
ClosedTS=now,
Closer=user)
requests.append(pkgreq)
return pkgreq
def delete_package(deleter: models.User, package: models.Package):
......@@ -147,8 +151,9 @@ def delete_package(deleter: models.User, package: models.Package):
).first()
with db.begin():
create_request_if_missing(
pkgreq = create_request_if_missing(
requests, reqtype, deleter, package)
db.refresh(pkgreq)
bases_to_delete.append(package.PackageBase)
......@@ -171,8 +176,9 @@ def delete_package(deleter: models.User, package: models.Package):
)
# Perform all the deletions.
db.delete_all([package])
db.delete_all(bases_to_delete)
with db.begin():
db.delete(package)
db.delete_all(bases_to_delete)
# Send out all the notifications.
util.apply_all(notifications, lambda n: n.send())
......@@ -221,8 +227,7 @@ async def make_single_context(request: Request,
async def package(request: Request, name: str) -> Response:
# Get the Package.
pkg = get_pkg_or_base(name, models.Package)
pkgbase = (get_pkg_or_base(name, models.PackageBase)
if not pkg else pkg.PackageBase)
pkgbase = pkg.PackageBase
# Add our base information.
context = await make_single_context(request, pkgbase)
......@@ -312,7 +317,7 @@ async def pkgbase_comments_post(
db.create(models.PackageNotification,
User=request.user,
PackageBase=pkgbase)
update_comment_render(comment.ID)
update_comment_render_fastapi(comment)
# Redirect to the pkgbase page.
return RedirectResponse(f"/pkgbase/{pkgbase.Name}#comment-{comment.ID}",
......@@ -374,7 +379,7 @@ async def pkgbase_comment_post(
db.create(models.PackageNotification,
User=request.user,
PackageBase=pkgbase)
update_comment_render(db_comment.ID)
update_comment_render_fastapi(db_comment)
if not next:
next = f"/pkgbase/{pkgbase.Name}"
......@@ -539,7 +544,7 @@ def remove_users(pkgbase, usernames):
conn, comaintainer.User.ID, pkgbase.ID
)
)
db.session.delete(comaintainer)
db.delete(comaintainer)
# Send out notifications if need be.
for notify_ in notifications:
......@@ -679,14 +684,8 @@ async def requests(request: Request,
@router.get("/pkgbase/{name}/request")
@auth_required(True, redirect="/pkgbase/{name}/request")
async def package_request(request: Request, name: str):
pkgbase = get_pkg_or_base(name, models.PackageBase)
context = await make_variable_context(request, "Submit Request")
pkgbase = db.query(models.PackageBase).filter(
models.PackageBase.Name == name).first()
if not pkgbase:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
context["pkgbase"] = pkgbase
return render_template(request, "pkgbase/request.html", context)
......@@ -729,6 +728,7 @@ async def pkgbase_request_post(request: Request, name: str,
]
return render_template(request, "pkgbase/request.html", context)
db.refresh(target)
if target.ID == pkgbase.ID:
# TODO: This error needs to be translated.
context["errors"] = [
......@@ -767,8 +767,7 @@ async def pkgbase_request_post(request: Request, name: str,
@router.get("/requests/{id}/close")
@auth_required(True, redirect="/requests/{id}/close")
async def requests_close(request: Request, id: int):
pkgreq = db.query(models.PackageRequest).filter(
models.PackageRequest.ID == id).first()
pkgreq = get_pkgreq_by_id(id)
if not request.user.is_elevated() and request.user != pkgreq.User:
# Request user doesn't have permission here: redirect to '/'.
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
......@@ -783,8 +782,7 @@ async def requests_close(request: Request, id: int):
async def requests_close_post(request: Request, id: int,
reason: int = Form(default=0),
comments: str = Form(default=str())):
pkgreq = db.query(models.PackageRequest).filter(
models.PackageRequest.ID == id).first()
pkgreq = get_pkgreq_by_id(id)
if not request.user.is_elevated() and request.user != pkgreq.User:
# Request user doesn't have permission here: redirect to '/'.
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
......@@ -823,13 +821,17 @@ async def pkgbase_keywords(request: Request, name: str,
keywords = set(keywords.split(" "))
# Delete all keywords which are not supplied by the user.
other_keywords = pkgbase.keywords.filter(
~models.PackageKeyword.Keyword.in_(keywords))
other_keyword_strings = [kwd.Keyword for kwd in other_keywords]
existing_keywords = set(
kwd.Keyword for kwd in
pkgbase.keywords.filter(
~models.PackageKeyword.Keyword.in_(other_keyword_strings))
)
with db.begin():
db.delete(models.PackageKeyword,
and_(models.PackageKeyword.PackageBaseID == pkgbase.ID,
~models.PackageKeyword.Keyword.in_(keywords)))
existing_keywords = set(kwd.Keyword for kwd in pkgbase.keywords.all())
with db.begin():
db.delete_all(other_keywords)
for keyword in keywords.difference(existing_keywords):
db.create(models.PackageKeyword,
PackageBase=pkgbase,
......@@ -940,7 +942,7 @@ def pkgbase_unnotify_instance(request: Request, pkgbase: models.PackageBase):
has_cred = request.user.has_credential("CRED_PKGBASE_NOTIFY")
if has_cred and notif:
with db.begin():
db.session.delete(notif)
db.delete(notif)
@router.post("/pkgbase/{name}/unnotify")
......@@ -988,7 +990,7 @@ async def pkgbase_unvote(request: Request, name: str):
has_cred = request.user.has_credential("CRED_PKGBASE_VOTE")
if has_cred and vote:
with db.begin():
db.session.delete(vote)
db.delete(vote)
# Update NumVotes/Popularity.
conn = db.ConnectionExecutor(db.get_engine().raw_connection())
......@@ -1015,7 +1017,7 @@ def pkgbase_disown_instance(request: Request, pkgbase: models.PackageBase):
if co:
with db.begin():
pkgbase.Maintainer = co.User
db.session.delete(co)
db.delete(co)
else:
pkgbase.Maintainer = None
......@@ -1463,8 +1465,8 @@ def pkgbase_merge_instance(request: Request, pkgbase: models.PackageBase,
with db.begin():
# Delete pkgbase and its packages now that everything's merged.
for pkg in pkgbase.packages:
db.session.delete(pkg)
db.session.delete(pkgbase)
db.delete(pkg)
db.delete(pkgbase)
# Accept merge requests related to this pkgbase and target.
for pkgreq in requests:
......
from collections import defaultdict
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List, NewType
from sqlalchemy import and_
......@@ -25,6 +25,10 @@ REL_TYPES = {
}
DataGenerator = NewType("DataGenerator",
Callable[[models.Package], Dict[str, Any]])
class RPCError(Exception):
pass
......@@ -188,15 +192,32 @@ class RPC:
self._update_json_relations(package, data)
return data
def _handle_multiinfo_type(self, args: List[str] = [], **kwargs):
def _assemble_json_data(self, packages: List[models.Package],
data_generator: DataGenerator) \
-> List[Dict[str, Any]]:
"""