Verified Commit 28c4e969 authored by Kevin Morris's avatar Kevin Morris
Browse files

change(fastapi): simplify model imports across code-base

Closes: #133



Signed-off-by: Kevin Morris's avatarKevin Morris <kevr@0cost.org>
parent bfdc85d7
Pipeline #12141 passed with stage
in 8 minutes
......@@ -16,8 +16,7 @@ import aurweb.logging
from aurweb.auth import BasicAuthBackend
from aurweb.db import get_engine, query
from aurweb.models.accepted_term import AcceptedTerm
from aurweb.models.term import Term
from aurweb.models import AcceptedTerm, Term
from aurweb.routers import accounts, auth, errors, html, packages, rpc, rss, sso, trusted_user
# Setup the FastAPI app.
......
......@@ -14,9 +14,8 @@ from starlette.requests import HTTPConnection
import aurweb.config
from aurweb import l10n, util
from aurweb.models import Session, User
from aurweb.models.account_type import ACCOUNT_TYPE_ID
from aurweb.models.session import Session
from aurweb.models.user import User
from aurweb.templates import make_variable_context, render_template
......
......@@ -4,7 +4,7 @@ import hashlib
from jinja2 import pass_context
from aurweb.db import query
from aurweb.models.user import User
from aurweb.models import User
def get_captcha_salts():
......
from sqlalchemy import and_, case, or_, orm
from aurweb import config, db
from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase
from aurweb.models.package_comaintainer import PackageComaintainer
from aurweb.models.package_keyword import PackageKeyword
from aurweb.models.package_notification import PackageNotification
from aurweb.models.package_vote import PackageVote
from aurweb.models.user import User
from aurweb import config, db, models
DEFAULT_MAX_RESULTS = 2500
......@@ -18,22 +11,22 @@ class PackageSearch:
# A constant mapping of short to full name sort orderings.
FULL_SORT_ORDER = {"d": "desc", "a": "asc"}
def __init__(self, user: User):
def __init__(self, user: models.User):
""" Construct an instance of PackageSearch.
This constructors performs several steps during initialization:
1. Setup self.query: an ORM query of Package joined by PackageBase.
"""
self.user = user
self.query = db.query(Package).join(PackageBase).join(
PackageVote,
and_(PackageVote.PackageBaseID == PackageBase.ID,
PackageVote.UsersID == self.user.ID),
self.query = db.query(models.Package).join(models.PackageBase).join(
models.PackageVote,
and_(models.PackageVote.PackageBaseID == models.PackageBase.ID,
models.PackageVote.UsersID == self.user.ID),
isouter=True
).join(
PackageNotification,
and_(PackageNotification.PackageBaseID == PackageBase.ID,
PackageNotification.UserID == self.user.ID),
models.PackageNotification,
and_(models.PackageNotification.PackageBaseID == models.PackageBase.ID,
models.PackageNotification.UserID == self.user.ID),
isouter=True
)
self.ordering = "d"
......@@ -65,59 +58,64 @@ class PackageSearch:
def _search_by_namedesc(self, keywords: str) -> orm.Query:
self.query = self.query.filter(
or_(Package.Name.like(f"%{keywords}%"),
Package.Description.like(f"%{keywords}%"))
or_(models.Package.Name.like(f"%{keywords}%"),
models.Package.Description.like(f"%{keywords}%"))
)
return self
def _search_by_name(self, keywords: str) -> orm.Query:
self.query = self.query.filter(Package.Name.like(f"%{keywords}%"))
self.query = self.query.filter(
models.Package.Name.like(f"%{keywords}%"))
return self
def _search_by_exact_name(self, keywords: str) -> orm.Query:
self.query = self.query.filter(Package.Name == keywords)
self.query = self.query.filter(
models.Package.Name == keywords)
return self
def _search_by_pkgbase(self, keywords: str) -> orm.Query:
self.query = self.query.filter(PackageBase.Name.like(f"%{keywords}%"))
self.query = self.query.filter(
models.PackageBase.Name.like(f"%{keywords}%"))
return self
def _search_by_exact_pkgbase(self, keywords: str) -> orm.Query:
self.query = self.query.filter(PackageBase.Name == keywords)
self.query = self.query.filter(
models.PackageBase.Name == keywords)
return self
def _search_by_keywords(self, keywords: str) -> orm.Query:
self.query = self.query.join(PackageKeyword).filter(
PackageKeyword.Keyword == keywords
self.query = self.query.join(models.PackageKeyword).filter(
models.PackageKeyword.Keyword == keywords
)
return self
def _search_by_maintainer(self, keywords: str) -> orm.Query:
self.query = self.query.join(
User, User.ID == PackageBase.MaintainerUID
).filter(User.Username == keywords)
models.User, models.User.ID == models.PackageBase.MaintainerUID
).filter(models.User.Username == keywords)
return self
def _search_by_comaintainer(self, keywords: str) -> orm.Query:
self.query = self.query.join(PackageComaintainer).join(
User, User.ID == PackageComaintainer.UsersID
).filter(User.Username == keywords)
self.query = self.query.join(models.PackageComaintainer).join(
models.User, models.User.ID == models.PackageComaintainer.UsersID
).filter(models.User.Username == keywords)
return self
def _search_by_co_or_maintainer(self, keywords: str) -> orm.Query:
self.query = self.query.join(
PackageComaintainer,
models.PackageComaintainer,
isouter=True
).join(
User, or_(User.ID == PackageBase.MaintainerUID,
User.ID == PackageComaintainer.UsersID)
).filter(User.Username == keywords)
models.User,
or_(models.User.ID == models.PackageBase.MaintainerUID,
models.User.ID == models.PackageComaintainer.UsersID)
).filter(models.User.Username == keywords)
return self
def _search_by_submitter(self, keywords: str) -> orm.Query:
self.query = self.query.join(
User, User.ID == PackageBase.SubmitterUID
).filter(User.Username == keywords)
models.User, models.User.ID == models.PackageBase.SubmitterUID
).filter(models.User.Username == keywords)
return self
def search_by(self, search_by: str, keywords: str) -> orm.Query:
......@@ -128,17 +126,17 @@ class PackageSearch:
return result
def _sort_by_name(self, order: str):
column = getattr(Package.Name, order)
column = getattr(models.Package.Name, order)
self.query = self.query.order_by(column())
return self
def _sort_by_votes(self, order: str):
column = getattr(PackageBase.NumVotes, order)
column = getattr(models.PackageBase.NumVotes, order)
self.query = self.query.order_by(column())
return self
def _sort_by_popularity(self, order: str):
column = getattr(PackageBase.Popularity, order)
column = getattr(models.PackageBase.Popularity, order)
self.query = self.query.order_by(column())
return self
......@@ -147,10 +145,10 @@ class PackageSearch:
# in terms of performance. We should improve this; there's no
# reason it should take _longer_.
column = getattr(
case([(PackageVote.UsersID == self.user.ID, 1)], else_=0),
case([(models.PackageVote.UsersID == self.user.ID, 1)], else_=0),
order
)
self.query = self.query.order_by(column(), Package.Name.desc())
self.query = self.query.order_by(column(), models.Package.Name.desc())
return self
def _sort_by_notify(self, order: str):
......@@ -158,21 +156,24 @@ class PackageSearch:
# in terms of performance. We should improve this; there's no
# reason it should take _longer_.
column = getattr(
case([(PackageNotification.UserID == self.user.ID, 1)], else_=0),
case([(models.PackageNotification.UserID == self.user.ID, 1)],
else_=0),
order
)
self.query = self.query.order_by(column(), Package.Name.desc())
self.query = self.query.order_by(column(), models.Package.Name.desc())
return self
def _sort_by_maintainer(self, order: str):
column = getattr(User.Username, order)
column = getattr(models.User.Username, order)
self.query = self.query.join(
User, User.ID == PackageBase.MaintainerUID, isouter=True
models.User,
models.User.ID == models.PackageBase.MaintainerUID,
isouter=True
).order_by(column())
return self
def _sort_by_last_modified(self, order: str):
column = getattr(PackageBase.ModifiedTS, order)
column = getattr(models.PackageBase.ModifiedTS, order)
self.query = self.query.order_by(column())
return self
......
......@@ -7,43 +7,35 @@ import orjson
from fastapi import HTTPException
from sqlalchemy import and_, orm
from aurweb import db
from aurweb.models.official_provider import OFFICIAL_BASE, OfficialProvider
from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase
from aurweb.models.package_comment import PackageComment
from aurweb.models.package_dependency import PackageDependency
from aurweb.models.package_notification import PackageNotification
from aurweb.models.package_relation import PackageRelation
from aurweb.models.package_vote import PackageVote
from aurweb.models.relation_type import PROVIDES_ID, RelationType
from aurweb.models.user import User
from aurweb import db, models
from aurweb.models.official_provider import OFFICIAL_BASE
from aurweb.models.relation_type import PROVIDES_ID
from aurweb.redis import redis_connection
from aurweb.templates import register_filter
def dep_depends_extra(dep: PackageDependency) -> str:
def dep_depends_extra(dep: models.PackageDependency) -> str:
""" A function used to produce extra text for dependency display. """
return str()
def dep_makedepends_extra(dep: PackageDependency) -> str:
def dep_makedepends_extra(dep: models.PackageDependency) -> str:
""" A function used to produce extra text for dependency display. """
return "(make)"
def dep_checkdepends_extra(dep: PackageDependency) -> str:
def dep_checkdepends_extra(dep: models.PackageDependency) -> str:
""" A function used to produce extra text for dependency display. """
return "(check)"
def dep_optdepends_extra(dep: PackageDependency) -> str:
def dep_optdepends_extra(dep: models.PackageDependency) -> str:
""" A function used to produce extra text for dependency display. """
return "(optional)"
@register_filter("dep_extra")
def dep_extra(dep: PackageDependency) -> str:
def dep_extra(dep: models.PackageDependency) -> str:
""" Some dependency types have extra text added to their
display. This function provides that output. However, it
**assumes** that the dep passed is bound to a valid one
......@@ -53,7 +45,7 @@ def dep_extra(dep: PackageDependency) -> str:
@register_filter("dep_extra_desc")
def dep_extra_desc(dep: PackageDependency) -> str:
def dep_extra_desc(dep: models.PackageDependency) -> str:
extra = dep_extra(dep)
if not dep.DepDesc:
return extra
......@@ -63,30 +55,30 @@ def dep_extra_desc(dep: PackageDependency) -> str:
@register_filter("pkgname_link")
def pkgname_link(pkgname: str) -> str:
base = "/".join([OFFICIAL_BASE, "packages"])
official = db.query(OfficialProvider).filter(
OfficialProvider.Name == pkgname)
official = db.query(models.OfficialProvider).filter(
models.OfficialProvider.Name == pkgname)
if official.scalar():
return f"{base}/?q={pkgname}"
return f"/packages/{pkgname}"
@register_filter("package_link")
def package_link(package: Package) -> str:
def package_link(package: models.Package) -> str:
base = "/".join([OFFICIAL_BASE, "packages"])
official = db.query(OfficialProvider).filter(
OfficialProvider.Name == package.Name)
official = db.query(models.OfficialProvider).filter(
models.OfficialProvider.Name == package.Name)
if official.scalar():
return f"{base}/?q={package.Name}"
return f"/packages/{package.Name}"
@register_filter("provides_list")
def provides_list(package: Package, depname: str) -> list:
providers = db.query(Package).join(
PackageRelation).join(RelationType).filter(
def provides_list(package: models.Package, depname: str) -> list:
providers = db.query(models.Package).join(
models.PackageRelation).join(models.RelationType).filter(
and_(
PackageRelation.RelName == depname,
RelationType.ID == PROVIDES_ID
models.PackageRelation.RelName == depname,
models.RelationType.ID == PROVIDES_ID
)
)
......@@ -102,7 +94,9 @@ def provides_list(package: Package, depname: str) -> list:
return string
def get_pkg_or_base(name: str, cls: Union[Package, PackageBase] = PackageBase):
def get_pkg_or_base(
name: str,
cls: Union[models.Package, models.PackageBase] = models.PackageBase):
""" Get a PackageBase instance by its name or raise a 404 if
it can't be found in the database.
......@@ -110,20 +104,21 @@ def get_pkg_or_base(name: str, cls: Union[Package, PackageBase] = PackageBase):
:raises HTTPException: With status code 404 if record doesn't exist
:return: {Package,PackageBase} instance
"""
provider = db.query(OfficialProvider).filter(
OfficialProvider.Name == name).first()
provider = db.query(models.OfficialProvider).filter(
models.OfficialProvider.Name == name).first()
if provider:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
instance = db.query(cls).filter(cls.Name == name).first()
if cls == PackageBase and not instance:
if cls == models.PackageBase and not instance:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return instance
def get_pkgbase_comment(pkgbase: PackageBase, id: int) -> PackageComment:
comment = pkgbase.comments.filter(PackageComment.ID == id).first()
def get_pkgbase_comment(
pkgbase: models.PackageBase, id: int) -> models.PackageComment:
comment = pkgbase.comments.filter(models.PackageComment.ID == id).first()
if not comment:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return comment
......@@ -131,10 +126,11 @@ def get_pkgbase_comment(pkgbase: PackageBase, id: int) -> PackageComment:
@register_filter("out_of_date")
def out_of_date(packages: orm.Query) -> orm.Query:
return packages.filter(PackageBase.OutOfDateTS.isnot(None))
return packages.filter(models.PackageBase.OutOfDateTS.isnot(None))
def updated_packages(limit: int = 0, cache_ttl: int = 600) -> List[Package]:
def updated_packages(limit: int = 0,
cache_ttl: int = 600) -> List[models.Package]:
""" Return a list of valid Package objects ordered by their
ModifiedTS column in descending order from cache, after setting
the cache when no key yet exists.
......@@ -149,10 +145,10 @@ def updated_packages(limit: int = 0, cache_ttl: int = 600) -> List[Package]:
# If we already have a cache, deserialize it and return.
return orjson.loads(packages)
query = db.query(Package).join(PackageBase).filter(
PackageBase.PackagerUID.isnot(None)
query = db.query(models.Package).join(models.PackageBase).filter(
models.PackageBase.PackagerUID.isnot(None)
).order_by(
PackageBase.ModifiedTS.desc()
models.PackageBase.ModifiedTS.desc()
)
if limit:
......@@ -178,7 +174,8 @@ def updated_packages(limit: int = 0, cache_ttl: int = 600) -> List[Package]:
return packages
def query_voted(query: List[Package], user: User) -> Dict[int, bool]:
def query_voted(query: List[models.Package],
user: models.User) -> Dict[int, bool]:
""" Produce a dictionary of package base ID keys to boolean values,
which indicate whether or not the package base has a vote record
related to user.
......@@ -189,18 +186,19 @@ def query_voted(query: List[Package], user: User) -> Dict[int, bool]:
"""
output = defaultdict(bool)
query_set = {pkg.PackageBase.ID for pkg in query}
voted = db.query(PackageVote).join(
PackageBase,
PackageBase.ID.in_(query_set)
voted = db.query(models.PackageVote).join(
models.PackageBase,
models.PackageBase.ID.in_(query_set)
).filter(
PackageVote.UsersID == user.ID
models.PackageVote.UsersID == user.ID
)
for vote in voted:
output[vote.PackageBase.ID] = True
return output
def query_notified(query: List[Package], user: User) -> Dict[int, bool]:
def query_notified(query: List[models.Package],
user: models.User) -> Dict[int, bool]:
""" Produce a dictionary of package base ID keys to boolean values,
which indicate whether or not the package base has a notification
record related to user.
......@@ -211,11 +209,11 @@ def query_notified(query: List[Package], user: User) -> Dict[int, bool]:
"""
output = defaultdict(bool)
query_set = {pkg.PackageBase.ID for pkg in query}
notified = db.query(PackageNotification).join(
PackageBase,
PackageBase.ID.in_(query_set)
notified = db.query(models.PackageNotification).join(
models.PackageBase,
models.PackageBase.ID.in_(query_set)
).filter(
PackageNotification.UserID == user.ID
models.PackageNotification.UserID == user.ID
)
for notify in notified:
output[notify.PackageBase.ID] = True
......
......@@ -11,17 +11,12 @@ from sqlalchemy import and_, func, or_
import aurweb.config
from aurweb import db, l10n, time, util
from aurweb import db, l10n, models, time, util
from aurweb.auth import account_type_required, auth_required
from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token
from aurweb.l10n import get_translator_for_request
from aurweb.models.accepted_term import AcceptedTerm
from aurweb.models.account_type import (DEVELOPER, DEVELOPER_ID, TRUSTED_USER, TRUSTED_USER_AND_DEV, TRUSTED_USER_AND_DEV_ID,
TRUSTED_USER_ID, USER_ID, AccountType)
from aurweb.models.ban import Ban
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
from aurweb.models.term import Term
from aurweb.models.user import User
from aurweb.models import account_type
from aurweb.models.ssh_pub_key import get_fingerprint
from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification
from aurweb.templates import make_context, make_variable_context, render_template
......@@ -46,8 +41,8 @@ async def passreset_post(request: Request,
context = await make_variable_context(request, "Password Reset")
# The user parameter being required, we can match against
user = db.query(User, or_(User.Username == user,
User.Email == user)).first()
user = db.query(models.User, or_(models.User.Username == user,
models.User.Email == user)).first()
if not user:
context["errors"] = ["Invalid e-mail."]
return render_template(request, "passreset.html", context,
......@@ -72,13 +67,13 @@ async def passreset_post(request: Request,
return render_template(request, "passreset.html", context,
status_code=HTTPStatus.BAD_REQUEST)
if len(password) < User.minimum_passwd_length():
if len(password) < models.User.minimum_passwd_length():
# Translate the error here, which simplifies error output
# in the jinja2 template.
_ = get_translator_for_request(request)
context["errors"] = [_(
"Your password must be at least %s characters.") % (
str(User.minimum_passwd_length()))]
str(models.User.minimum_passwd_length()))]
return render_template(request, "passreset.html", context,
status_code=HTTPStatus.BAD_REQUEST)
......@@ -95,7 +90,7 @@ async def passreset_post(request: Request,
status_code=HTTPStatus.SEE_OTHER)
# If we got here, we continue with issuing a resetkey for the user.
resetkey = db.make_random_value(User, User.ResetKey)
resetkey = db.make_random_value(models.User, models.User.ResetKey)
with db.begin():
user.ResetKey = resetkey
......@@ -107,7 +102,7 @@ async def passreset_post(request: Request,
status_code=HTTPStatus.SEE_OTHER)
def process_account_form(request: Request, user: User, args: dict):
def process_account_form(request: Request, user: models.User, args: dict):
""" Process an account form. All fields are optional and only checks
requirements in the case they are present.
......@@ -129,11 +124,11 @@ def process_account_form(request: Request, user: User, args: dict):
_ = get_translator_for_request(request)
host = request.client.host
ban = db.query(Ban, Ban.IPAddress == host).first()
ban = db.query(models.Ban, models.Ban.IPAddress == host).first()
if ban:
return False, [
"Account registration has been disabled for your " +
"IP address, probably due to sustained spam attacks. " +
"Account registration has been disabled for your "
"IP address, probably due to sustained spam attacks. "
"Sorry for the inconvenience."
]
......@@ -181,12 +176,12 @@ def process_account_form(request: Request, user: User, args: dict):
timezone = args.get("TZ", None)
def username_exists(username):
return and_(User.ID != user.ID,
func.lower(User.Username) == username.lower())
return and_(models.User.ID != user.ID,
func.lower(models.User.Username) == username.lower())
def email_exists(email):
return and_(User.ID != user.ID,
func.lower(User.Email) == email.lower())
return and_(models.User.ID != user.ID,
func.lower(models.User.Email) == email.lower())
if not util.valid_email(email):
return False, ["The email address is invalid."]
......@@ -203,13 +198,13 @@ def process_account_form(request: Request, user: User, args: dict):
return False, ["Language is not currently supported."]
elif timezone and timezone not in time.SUPPORTED_TIMEZONES:
return False, ["Timezone is not currently supported."]
elif db.query(User, username_exists(username)).first():
elif db.query(models.User, username_exists(username)).first():
# If the username already exists...
return False, [
_("The username, %s%s%s, is already in use.") % (
"<strong>", username, "</strong>")
]
elif db.query(User, email_exists(email)).first():
elif db.query(models.User, email_exists(email)).first():
# If the email already exists...
return False, [
_("The address, %s%s%s, is already in use.") % (
......@@ -217,15 +212,16 @@ def process_account_form(request: Request, user: User, args: dict):
]
def ssh_fingerprint_exists(fingerprint):
return and_(SSHPubKey.UserID != user.ID,
SSHPubKey.Fingerprint == fingerprint)
return and_(models.SSHPubKey.UserID != user.ID,
models.SSHPubKey.Fingerprint == fingerprint)
if ssh_pubkey:
fingerprint = get_fingerprint(ssh_pubkey.strip().rstrip())
if fingerprint is None:
return False, ["The SSH public key is invalid."]
if db.query(SSHPubKey, ssh_fingerprint_exists(fingerprint)).first():
if db.query(models.SSHPubKey,
ssh_fingerprint_exists(fingerprint)).first():
return False, [
_("The SSH public key, %s%s%s, is already in use.") % (
"<strong>", fingerprint, "</strong>")
......@@ -246,7 +242,7 @@ def process_account_form(request: Request, user: User, args: dict):
def make_account_form_context(context: dict,
request: Request,
user: User,
user: models.User,
args: dict):
""" Modify a FastAPI context and add attributes for the account form.
......@@ -382,20 +378,20 @@ async def account_register_post(request: Request,
# Create a user with no password with a resetkey, then send
# an email off about it.
resetkey = db.make_random_value(User, User.ResetKey)
resetkey = db.make_random_value(models.User, models.User.ResetKey)
# By default, we grab the User account type to associate with.
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
atype = db.query(models.AccountType,
models.AccountType.AccountType == "User").first()
# Create a user given all parameters available.
with db.begin():
user = db.create(User, Username=U,
user = db.create(models.User, Username=U,
Email=E, HideEmail=H, BackupEmail=BE,