Verified Commit 51b60f42 authored by Kevin Morris's avatar Kevin Morris
Browse files

feat(auth): add requires_{auth,guest} decorators



These new decorators are meant to be used without any arguments
and provide aliases to auth_required:
- `auth_required(True) -> requires_auth`
- `auth_required(False) -> requires_guest`

These decorators should be used without arguments, e.g.:

    @router.get("/")
    @requires_guest
    async def my_route(request: Request):
        return HTMLResponse()

Signed-off-by: Kevin Morris's avatarKevin Morris <kevr@0cost.org>
parent 3e048e96
Pipeline #14288 waiting for manual action with stages
in 4 minutes
......@@ -2,6 +2,7 @@ import functools
from datetime import datetime
from http import HTTPStatus
from typing import Callable
import fastapi
......@@ -129,10 +130,15 @@ class BasicAuthBackend(AuthenticationBackend):
return (AuthCredentials(["authenticated"]), user)
def auth_required(auth_goal: bool = True):
""" Enforce a user's authentication status, bringing them to the login page
def _auth_required(auth_goal: bool = True):
"""
Enforce a user's authentication status, bringing them to the login page
or homepage if their authentication status does not match the goal.
NOTE: This function should not need to be used in downstream code.
See `requires_auth` and `requires_guest` for decorators meant to be
used on routes (they're a bit more implicitly understandable).
:param auth_goal: Whether authentication is required or entirely disallowed
for a user to perform this request.
:return: Return the FastAPI function this decorator wraps.
......@@ -167,6 +173,24 @@ def auth_required(auth_goal: bool = True):
return decorator
def requires_auth(func: Callable) -> Callable:
""" Require an authenticated session for a particular route. """
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await _auth_required(True)(func)(*args, **kwargs)
return wrapper
def requires_guest(func: Callable) -> Callable:
""" Require a guest (unauthenticated) session for a particular route. """
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await _auth_required(False)(func)(*args, **kwargs)
return wrapper
def account_type_required(one_of: set):
""" A decorator that can be used on FastAPI routes to dictate
that a user belongs to one of the types defined in one_of.
......
......@@ -10,7 +10,7 @@ from sqlalchemy import and_, or_
import aurweb.config
from aurweb import cookies, db, l10n, logging, models, util
from aurweb.auth import account_type_required, auth_required
from aurweb.auth import account_type_required, requires_auth, requires_guest
from aurweb.captcha import get_captcha_salts
from aurweb.exceptions import ValidationError
from aurweb.l10n import get_translator_for_request
......@@ -27,14 +27,14 @@ logger = logging.get_logger(__name__)
@router.get("/passreset", response_class=HTMLResponse)
@auth_required(False)
@requires_guest
async def passreset(request: Request):
context = await make_variable_context(request, "Password Reset")
return render_template(request, "passreset.html", context)
@router.post("/passreset", response_class=HTMLResponse)
@auth_required(False)
@requires_guest
async def passreset_post(request: Request,
user: str = Form(...),
resetkey: str = Form(default=None),
......@@ -224,7 +224,7 @@ def make_account_form_context(context: dict,
@router.get("/register", response_class=HTMLResponse)
@auth_required(False)
@requires_guest
async def account_register(request: Request,
U: str = Form(default=str()), # Username
E: str = Form(default=str()), # Email
......@@ -250,7 +250,7 @@ async def account_register(request: Request,
@router.post("/register", response_class=HTMLResponse)
@auth_required(False)
@requires_guest
async def account_register_post(request: Request,
U: str = Form(default=str()), # Username
E: str = Form(default=str()), # Email
......@@ -348,7 +348,7 @@ def cannot_edit(request: Request, user: models.User) \
@router.get("/account/{username}/edit", response_class=HTMLResponse)
@auth_required()
@requires_auth
async def account_edit(request: Request, username: str):
user = db.query(models.User, models.User.Username == username).first()
......@@ -364,7 +364,7 @@ async def account_edit(request: Request, username: str):
@router.post("/account/{username}/edit", response_class=HTMLResponse)
@auth_required()
@requires_auth
async def account_edit_post(request: Request,
username: str,
U: str = Form(default=str()), # Username
......@@ -461,7 +461,7 @@ async def account(request: Request, username: str):
@router.get("/account/{username}/comments")
@auth_required()
@requires_auth
async def account_comments(request: Request, username: str):
user = get_user_by_name(username)
context = make_context(request, "Accounts")
......@@ -472,7 +472,7 @@ async def account_comments(request: Request, username: str):
@router.get("/accounts")
@auth_required()
@requires_auth
@account_type_required({at.TRUSTED_USER,
at.DEVELOPER,
at.TRUSTED_USER_AND_DEV})
......@@ -482,7 +482,7 @@ async def accounts(request: Request):
@router.post("/accounts")
@auth_required()
@requires_auth
@account_type_required({at.TRUSTED_USER,
at.DEVELOPER,
at.TRUSTED_USER_AND_DEV})
......@@ -567,7 +567,7 @@ def render_terms_of_service(request: Request,
@router.get("/tos")
@auth_required()
@requires_auth
async def terms_of_service(request: Request):
# Query the database for terms that were previously accepted,
# but now have a bumped Revision that needs to be accepted.
......@@ -591,7 +591,7 @@ async def terms_of_service(request: Request):
@router.post("/tos")
@auth_required()
@requires_auth
async def terms_of_service_post(request: Request,
accept: bool = Form(default=False)):
# Query the database for terms that were previously accepted,
......
......@@ -7,7 +7,7 @@ from fastapi.responses import HTMLResponse, RedirectResponse
import aurweb.config
from aurweb import cookies, db
from aurweb.auth import auth_required
from aurweb.auth import requires_auth, requires_guest
from aurweb.l10n import get_translator_for_request
from aurweb.models import User
from aurweb.templates import make_variable_context, render_template
......@@ -29,7 +29,7 @@ async def login_get(request: Request, next: str = "/"):
@router.post("/login", response_class=HTMLResponse)
@auth_required(False)
@requires_guest
async def login_post(request: Request,
next: str = Form(...),
user: str = Form(default=str()),
......@@ -81,7 +81,7 @@ async def login_post(request: Request,
@router.post("/logout")
@auth_required()
@requires_auth
async def logout(request: Request, next: str = Form(default="/")):
if request.user.is_authenticated():
request.user.logout(request)
......
......@@ -7,7 +7,7 @@ from fastapi import APIRouter, Form, Request, Response
import aurweb.filters # noqa: F401
from aurweb import config, db, defaults, logging, models, util
from aurweb.auth import auth_required, creds
from aurweb.auth import creds, requires_auth
from aurweb.exceptions import InvariantError
from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID
from aurweb.packages import util as pkgutil
......@@ -406,7 +406,7 @@ PACKAGE_ACTIONS = {
@router.post("/packages")
@auth_required()
@requires_auth
async def packages_post(request: Request,
IDs: List[int] = Form(default=[]),
action: str = Form(default=str()),
......
......@@ -6,7 +6,7 @@ from fastapi.responses import JSONResponse, RedirectResponse
from sqlalchemy import and_
from aurweb import config, db, l10n, logging, templates, util
from aurweb.auth import auth_required, creds
from aurweb.auth import creds, requires_auth
from aurweb.exceptions import InvariantError, ValidationError
from aurweb.models import PackageBase
from aurweb.models.package_comment import PackageComment
......@@ -116,7 +116,7 @@ async def pkgbase_keywords(request: Request, name: str,
@router.get("/pkgbase/{name}/flag")
@auth_required()
@requires_auth
async def pkgbase_flag_get(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
......@@ -131,7 +131,7 @@ async def pkgbase_flag_get(request: Request, name: str):
@router.post("/pkgbase/{name}/flag")
@auth_required()
@requires_auth
async def pkgbase_flag_post(request: Request, name: str,
comments: str = Form(default=str())):
pkgbase = get_pkg_or_base(name, PackageBase)
......@@ -157,7 +157,7 @@ async def pkgbase_flag_post(request: Request, name: str,
@router.post("/pkgbase/{name}/comments")
@auth_required()
@requires_auth
async def pkgbase_comments_post(
request: Request, name: str,
comment: str = Form(default=str()),
......@@ -189,7 +189,7 @@ async def pkgbase_comments_post(
@router.get("/pkgbase/{name}/comments/{id}/form")
@auth_required()
@requires_auth
async def pkgbase_comment_form(request: Request, name: str, id: int,
next: str = Query(default=None)):
"""
......@@ -229,7 +229,7 @@ async def pkgbase_comment_form(request: Request, name: str, id: int,
@router.get("/pkgbase/{name}/comments/{id}/edit")
@auth_required()
@requires_auth
async def pkgbase_comment_edit(request: Request, name: str, id: int,
next: str = Form(default=None)):
"""
......@@ -253,7 +253,7 @@ async def pkgbase_comment_edit(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}")
@auth_required()
@requires_auth
async def pkgbase_comment_post(
request: Request, name: str, id: int,
comment: str = Form(default=str()),
......@@ -293,7 +293,7 @@ async def pkgbase_comment_post(
@router.post("/pkgbase/{name}/comments/{id}/pin")
@auth_required()
@requires_auth
async def pkgbase_comment_pin(request: Request, name: str, id: int,
next: str = Form(default=None)):
"""
......@@ -327,7 +327,7 @@ async def pkgbase_comment_pin(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/unpin")
@auth_required()
@requires_auth
async def pkgbase_comment_unpin(request: Request, name: str, id: int,
next: str = Form(default=None)):
"""
......@@ -360,7 +360,7 @@ async def pkgbase_comment_unpin(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/delete")
@auth_required()
@requires_auth
async def pkgbase_comment_delete(request: Request, name: str, id: int,
next: str = Form(default=None)):
"""
......@@ -399,7 +399,7 @@ async def pkgbase_comment_delete(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/undelete")
@auth_required()
@requires_auth
async def pkgbase_comment_undelete(request: Request, name: str, id: int,
next: str = Form(default=None)):
"""
......@@ -437,7 +437,7 @@ async def pkgbase_comment_undelete(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/vote")
@auth_required()
@requires_auth
async def pkgbase_vote(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
......@@ -461,7 +461,7 @@ async def pkgbase_vote(request: Request, name: str):
@router.post("/pkgbase/{name}/unvote")
@auth_required()
@requires_auth
async def pkgbase_unvote(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
......@@ -481,7 +481,7 @@ async def pkgbase_unvote(request: Request, name: str):
@router.post("/pkgbase/{name}/notify")
@auth_required()
@requires_auth
async def pkgbase_notify(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_notify_instance(request, pkgbase)
......@@ -490,7 +490,7 @@ async def pkgbase_notify(request: Request, name: str):
@router.post("/pkgbase/{name}/unnotify")
@auth_required()
@requires_auth
async def pkgbase_unnotify(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_unnotify_instance(request, pkgbase)
......@@ -499,7 +499,7 @@ async def pkgbase_unnotify(request: Request, name: str):
@router.post("/pkgbase/{name}/unflag")
@auth_required()
@requires_auth
async def pkgbase_unflag(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_unflag_instance(request, pkgbase)
......@@ -508,7 +508,7 @@ async def pkgbase_unflag(request: Request, name: str):
@router.get("/pkgbase/{name}/disown")
@auth_required()
@requires_auth
async def pkgbase_disown_get(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
......@@ -524,7 +524,7 @@ async def pkgbase_disown_get(request: Request, name: str):
@router.post("/pkgbase/{name}/disown")
@auth_required()
@requires_auth
async def pkgbase_disown_post(request: Request, name: str,
comments: str = Form(default=str()),
confirm: bool = Form(default=False)):
......@@ -559,7 +559,7 @@ async def pkgbase_disown_post(request: Request, name: str,
@router.post("/pkgbase/{name}/adopt")
@auth_required()
@requires_auth
async def pkgbase_adopt_post(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
......@@ -575,7 +575,7 @@ async def pkgbase_adopt_post(request: Request, name: str):
@router.get("/pkgbase/{name}/comaintainers")
@auth_required()
@requires_auth
async def pkgbase_comaintainers(request: Request, name: str) -> Response:
# Get the PackageBase.
pkgbase = get_pkg_or_base(name, PackageBase)
......@@ -601,7 +601,7 @@ async def pkgbase_comaintainers(request: Request, name: str) -> Response:
@router.post("/pkgbase/{name}/comaintainers")
@auth_required()
@requires_auth
async def pkgbase_comaintainers_post(request: Request, name: str,
users: str = Form(default=str())) \
-> Response:
......@@ -643,7 +643,7 @@ async def pkgbase_comaintainers_post(request: Request, name: str,
@router.get("/pkgbase/{name}/request")
@auth_required()
@requires_auth
async def pkgbase_request(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
context = await make_variable_context(request, "Submit Request")
......@@ -652,7 +652,7 @@ async def pkgbase_request(request: Request, name: str):
@router.post("/pkgbase/{name}/request")
@auth_required()
@requires_auth
async def pkgbase_request_post(request: Request, name: str,
type: str = Form(...),
merge_into: str = Form(default=None),
......@@ -732,7 +732,7 @@ async def pkgbase_request_post(request: Request, name: str,
@router.get("/pkgbase/{name}/delete")
@auth_required()
@requires_auth
async def pkgbase_delete_get(request: Request, name: str):
if not request.user.has_credential(creds.PKGBASE_DELETE):
return RedirectResponse(f"/pkgbase/{name}",
......@@ -744,7 +744,7 @@ async def pkgbase_delete_get(request: Request, name: str):
@router.post("/pkgbase/{name}/delete")
@auth_required()
@requires_auth
async def pkgbase_delete_post(request: Request, name: str,
confirm: bool = Form(default=False),
comments: str = Form(default=str())):
......@@ -779,7 +779,7 @@ async def pkgbase_delete_post(request: Request, name: str,
@router.get("/pkgbase/{name}/merge")
@auth_required()
@requires_auth
async def pkgbase_merge_get(request: Request, name: str,
into: str = Query(default=str()),
next: str = Query(default=str())):
......@@ -810,7 +810,7 @@ async def pkgbase_merge_get(request: Request, name: str,
@router.post("/pkgbase/{name}/merge")
@auth_required()
@requires_auth
async def pkgbase_merge_post(request: Request, name: str,
into: str = Form(default=str()),
comments: str = Form(default=str()),
......
......@@ -6,7 +6,7 @@ from fastapi.responses import RedirectResponse
from sqlalchemy import case
from aurweb import db, defaults, util
from aurweb.auth import auth_required, creds
from aurweb.auth import creds, requires_auth
from aurweb.models import PackageRequest, User
from aurweb.models.package_request import PENDING_ID, REJECTED_ID
from aurweb.requests.util import get_pkgreq_by_id
......@@ -17,7 +17,7 @@ router = APIRouter()
@router.get("/requests")
@auth_required()
@requires_auth
async def requests(request: Request,
O: int = Query(default=defaults.O),
PP: int = Query(default=defaults.PP)):
......@@ -50,7 +50,7 @@ async def requests(request: Request,
@router.get("/requests/{id}/close")
@auth_required()
@requires_auth
async def request_close(request: Request, id: int):
pkgreq = get_pkgreq_by_id(id)
......@@ -64,7 +64,7 @@ async def request_close(request: Request, id: int):
@router.post("/requests/{id}/close")
@auth_required()
@requires_auth
async def request_close_post(request: Request, id: int,
comments: str = Form(default=str())):
pkgreq = get_pkgreq_by_id(id)
......
......@@ -10,7 +10,7 @@ from fastapi.responses import RedirectResponse, Response
from sqlalchemy import and_, or_
from aurweb import db, l10n, logging, models
from aurweb.auth import account_type_required, auth_required
from aurweb.auth import account_type_required, requires_auth
from aurweb.models.account_type import DEVELOPER, TRUSTED_USER, TRUSTED_USER_AND_DEV
from aurweb.templates import make_context, make_variable_context, render_template
......@@ -41,7 +41,7 @@ ADDVOTE_SPECIFICS = {
@router.get("/tu")
@auth_required()
@requires_auth
@account_type_required(REQUIRED_TYPES)
async def trusted_user(request: Request,
coff: int = 0, # current offset
......@@ -147,7 +147,7 @@ def render_proposal(request: Request,
@router.get("/tu/{proposal}")
@auth_required()
@requires_auth
@account_type_required(REQUIRED_TYPES)
async def trusted_user_proposal(request: Request, proposal: int):
context = await make_variable_context(request, "Trusted User")
......@@ -176,7 +176,7 @@ async def trusted_user_proposal(request: Request, proposal: int):
@router.post("/tu/{proposal}")
@auth_required()
@requires_auth
@account_type_required(REQUIRED_TYPES)
async def trusted_user_proposal_post(request: Request,
proposal: int,
......@@ -227,7 +227,7 @@ async def trusted_user_proposal_post(request: Request,
@router.get("/addvote")
@auth_required()
@requires_auth
@account_type_required({TRUSTED_USER, TRUSTED_USER_AND_DEV})
async def trusted_user_addvote(request: Request,
user: str = str(),
......@@ -247,7 +247,7 @@ async def trusted_user_addvote(request: Request,
@router.post("/addvote")
@auth_required()
@requires_auth
@account_type_required({TRUSTED_USER, TRUSTED_USER_AND_DEV})
async def trusted_user_addvote_post(request: Request,
user: str = Form(default=str()),
......
......@@ -7,7 +7,7 @@ from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from aurweb import config, db
from aurweb.auth import AnonymousUser, BasicAuthBackend, account_type_required, auth_required
from aurweb.auth import AnonymousUser, BasicAuthBackend, _auth_required, account_type_required
from aurweb.models.account_type import USER, USER_ID
from aurweb.models.session import Session
from aurweb.models.user import User
......@@ -105,7 +105,7 @@ async def test_auth_required_redirection_bad_referrer():
pass
# Get down to the nitty gritty internal wrapper.
bad_referrer_route = auth_required()(bad_referrer_route)
bad_referrer_route = _auth_required()(bad_referrer_route)
# Execute the route with a "./blahblahblah" Referer, which does not
# match aur_location; `./` has been used as a prefix to attempt to
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment