Verified Commit a5943bf2 authored by Kevin Morris's avatar Kevin Morris
Browse files

[FastAPI] Refactor db modifications



For SQLAlchemy to automatically understand updates from the
external world, it must use an `autocommit=True` in its session.

This change breaks how we were using commit previously, as
`autocommit=True` causes SQLAlchemy to commit when a
SessionTransaction context hits __exit__.

So, a refactoring was required of our tests: All usage of
any `db.{create,delete}` must be called **within** a
SessionTransaction context, created via new `db.begin()`.

From this point forward, we're going to require:

```
with db.begin():
    db.create(...)
    db.delete(...)
    db.session.delete(object)
```

With this, we now get external DB modifications automatically
without reloading or restarting the FastAPI server, which we
absolutely need for production.

Signed-off-by: Kevin Morris's avatarKevin Morris <kevr@0cost.org>
parent b52059d4
Pipeline #10864 passed with stage
in 4 minutes and 33 seconds
......@@ -59,20 +59,15 @@ def query(model, *args, **kwargs):
return session.query(model).filter(*args, **kwargs)
def create(model, autocommit: bool = True, *args, **kwargs):
def create(model, *args, **kwargs):
instance = model(*args, **kwargs)
add(instance)
if autocommit is True:
commit()
return instance
return add(instance)
def delete(model, *args, autocommit: bool = True, **kwargs):
def delete(model, *args, **kwargs):
instance = session.query(model).filter(*args, **kwargs)
for record in instance:
session.delete(record)
if autocommit is True:
commit()
def rollback():
......@@ -84,8 +79,25 @@ def add(model):
return model
def commit():
session.commit()
def begin():
""" Begin an SQLAlchemy SessionTransaction.
This context is **required** to perform an modifications to the
database.
Example:
with db.begin():
object = db.create(...)
# On __exit__, db.commit() is run.
with db.begin():
object = db.delete(...)
# On __exit__, db.commit() is run.
:return: A new SessionTransaction based on session
"""
return session.begin()
def get_sqlalchemy_url():
......@@ -155,23 +167,23 @@ def get_engine(echo: bool = False):
connect_args=connect_args,
echo=echo)
Session = sessionmaker(autocommit=True, autoflush=False, bind=engine)
session = Session()
if db_backend == "sqlite":
# For SQLite, we need to add some custom functions as
# they are used in the reference graph method.
def regexp(regex, item):
return bool(re.search(regex, str(item)))
@event.listens_for(engine, "begin")
def do_begin(conn):
@event.listens_for(engine, "connect")
def do_begin(conn, record):
create_deterministic_function = functools.partial(
conn.connection.create_function,
conn.create_function,
deterministic=True
)
create_deterministic_function("REGEXP", 2, regexp)
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
return engine
......
......@@ -102,7 +102,7 @@ class User(Base):
def login(self, request: Request, password: str, session_time=0):
""" Login and authenticate a request. """
from aurweb.db import session
from aurweb import db
from aurweb.models.session import Session, generate_unique_sid
if not self._login_approved(request):
......@@ -112,10 +112,7 @@ class User(Base):
if not self.authenticated:
return None
self.LastLogin = now_ts = datetime.utcnow().timestamp()
self.LastLoginIPAddress = request.client.host
session.commit()
now_ts = datetime.utcnow().timestamp()
session_ts = now_ts + (
session_time if session_time
else aurweb.config.getint("options", "login_timeout")
......@@ -123,22 +120,23 @@ class User(Base):
sid = None
if not self.session:
sid = generate_unique_sid()
self.session = Session(UsersID=self.ID, SessionID=sid,
LastUpdateTS=session_ts)
session.add(self.session)
else:
last_updated = self.session.LastUpdateTS
if last_updated and last_updated < now_ts:
self.session.SessionID = sid = generate_unique_sid()
with db.begin():
self.LastLogin = now_ts
self.LastLoginIPAddress = request.client.host
if not self.session:
sid = generate_unique_sid()
self.session = Session(UsersID=self.ID, SessionID=sid,
LastUpdateTS=session_ts)
db.add(self.session)
else:
# Session is still valid; retrieve the current SID.
sid = self.session.SessionID
self.session.LastUpdateTS = session_ts
last_updated = self.session.LastUpdateTS
if last_updated and last_updated < now_ts:
self.session.SessionID = sid = generate_unique_sid()
else:
# Session is still valid; retrieve the current SID.
sid = self.session.SessionID
session.commit()
self.session.LastUpdateTS = session_ts
request.cookies["AURSID"] = self.session.SessionID
return self.session.SessionID
......@@ -149,13 +147,11 @@ class User(Base):
return aurweb.auth.has_credential(self, cred, approved)
def logout(self, request):
from aurweb.db import session
del request.cookies["AURSID"]
self.authenticated = False
if self.session:
session.delete(self.session)
session.commit()
with db.begin():
db.session.delete(self.session)
def is_trusted_user(self):
return self.AccountType.ID in {
......
......@@ -43,8 +43,6 @@ async def passreset_post(request: Request,
resetkey: str = Form(default=None),
password: str = Form(default=None),
confirm: str = Form(default=None)):
from aurweb.db import session
context = await make_variable_context(request, "Password Reset")
# The user parameter being required, we can match against
......@@ -86,12 +84,11 @@ async def passreset_post(request: Request,
# We got to this point; everything matched up. Update the password
# and remove the ResetKey.
user.ResetKey = str()
user.update_password(password)
if user.session:
session.delete(user.session)
session.commit()
with db.begin():
user.ResetKey = str()
if user.session:
db.session.delete(user.session)
user.update_password(password)
# Render ?step=complete.
return RedirectResponse(url="/passreset?step=complete",
......@@ -99,8 +96,8 @@ async def passreset_post(request: Request,
# If we got here, we continue with issuing a resetkey for the user.
resetkey = db.make_random_value(User, User.ResetKey)
user.ResetKey = resetkey
session.commit()
with db.begin():
user.ResetKey = resetkey
executor = db.ConnectionExecutor(db.get_engine().raw_connection())
ResetKeyNotification(executor, user.ID).send()
......@@ -364,8 +361,6 @@ async def account_register_post(request: Request,
ON: bool = Form(default=False),
captcha: str = Form(default=None),
captcha_salt: str = Form(...)):
from aurweb.db import session
context = await make_variable_context(request, "Register")
args = dict(await request.form())
......@@ -394,11 +389,13 @@ async def account_register_post(request: Request,
AccountType.AccountType == "User").first()
# Create a user given all parameters available.
user = db.create(User, Username=U, Email=E, HideEmail=H, BackupEmail=BE,
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
LangPreference=L, Timezone=TZ, CommentNotify=CN,
UpdateNotify=UN, OwnershipNotify=ON, ResetKey=resetkey,
AccountType=account_type)
with db.begin():
user = db.create(User, Username=U,
Email=E, HideEmail=H, BackupEmail=BE,
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
LangPreference=L, Timezone=TZ, CommentNotify=CN,
UpdateNotify=UN, OwnershipNotify=ON,
ResetKey=resetkey, AccountType=account_type)
# If a PK was given and either one does not exist or the given
# PK mismatches the existing user's SSHPubKey.PubKey.
......@@ -410,10 +407,10 @@ async def account_register_post(request: Request,
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
session.commit()
with db.begin():
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
# Send a reset key notification to the new user.
executor = db.ConnectionExecutor(db.get_engine().raw_connection())
......@@ -499,63 +496,67 @@ async def account_edit_post(request: Request,
status_code=int(HTTPStatus.BAD_REQUEST))
# Set all updated fields as needed.
user.Username = U or user.Username
user.Email = E or user.Email
user.HideEmail = bool(H)
user.BackupEmail = BE or user.BackupEmail
user.RealName = R or user.RealName
user.Homepage = HP or user.Homepage
user.IRCNick = I or user.IRCNick
user.PGPKey = K or user.PGPKey
user.InactivityTS = datetime.utcnow().timestamp() if J else 0
with db.begin():
user.Username = U or user.Username
user.Email = E or user.Email
user.HideEmail = bool(H)
user.BackupEmail = BE or user.BackupEmail
user.RealName = R or user.RealName
user.Homepage = HP or user.Homepage
user.IRCNick = I or user.IRCNick
user.PGPKey = K or user.PGPKey
user.InactivityTS = datetime.utcnow().timestamp() if J else 0
# If we update the language, update the cookie as well.
if L and L != user.LangPreference:
request.cookies["AURLANG"] = L
user.LangPreference = L
with db.begin():
user.LangPreference = L
context["language"] = L
# If we update the timezone, also update the cookie.
if TZ and TZ != user.Timezone:
user.Timezone = TZ
with db.begin():
user.Timezone = TZ
request.cookies["AURTZ"] = TZ
context["timezone"] = TZ
user.CommentNotify = bool(CN)
user.UpdateNotify = bool(UN)
user.OwnershipNotify = bool(ON)
with db.begin():
user.CommentNotify = bool(CN)
user.UpdateNotify = bool(UN)
user.OwnershipNotify = bool(ON)
# If a PK is given, compare it against the target user's PK.
if PK:
# Get the second token in the public key, which is the actual key.
pubkey = PK.strip().rstrip()
parts = pubkey.split(" ")
if len(parts) == 3:
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
if not user.ssh_pub_key:
# No public key exists, create one.
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
elif user.ssh_pub_key.PubKey != pubkey:
# A public key already exists, update it.
user.ssh_pub_key.PubKey = pubkey
user.ssh_pub_key.Fingerprint = fingerprint
elif user.ssh_pub_key:
# Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key)
# Commit changes, if any.
session.commit()
with db.begin():
if PK:
# Get the second token in the public key, which is the actual key.
pubkey = PK.strip().rstrip()
parts = pubkey.split(" ")
if len(parts) == 3:
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
if not user.ssh_pub_key:
# No public key exists, create one.
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
elif user.ssh_pub_key.PubKey != pubkey:
# A public key already exists, update it.
user.ssh_pub_key.PubKey = pubkey
user.ssh_pub_key.Fingerprint = fingerprint
elif user.ssh_pub_key:
# Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key)
if P and not user.valid_password(P):
# Remove the fields we consumed for passwords.
context["P"] = context["C"] = str()
# If a password was given and it doesn't match the user's, update it.
user.update_password(P)
with db.begin():
user.update_password(P)
if user == request.user:
# If the target user is the request user, login with
# the updated password and update AURSID.
......@@ -731,21 +732,17 @@ async def terms_of_service_post(request: Request,
accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(request, context, accept_needed)
# For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision.
for term in diffs:
accepted_term = request.user.accepted_terms.filter(
AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it!
for term in unaccepted:
db.create(AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision,
autocommit=False)
if diffs or unaccepted:
# If we had any terms to update, commit the changes.
db.commit()
with db.begin():
# For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision.
for term in diffs:
accepted_term = request.user.accepted_terms.filter(
AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it!
for term in unaccepted:
db.create(AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision)
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))
......@@ -44,8 +44,6 @@ async def language(request: Request,
setting the language on any page, we want to preserve query
parameters across the redirect.
"""
from aurweb.db import session
if next[0] != '/':
return HTMLResponse(b"Invalid 'next' parameter.", status_code=400)
......@@ -53,8 +51,8 @@ async def language(request: Request,
# If the user is authenticated, update the user's LangPreference.
if request.user.is_authenticated():
request.user.LangPreference = set_lang
session.commit()
with db.begin():
request.user.LangPreference = set_lang
# In any case, set the response's AURLANG cookie that never expires.
response = RedirectResponse(url=f"{next}{query_string}",
......
......@@ -214,10 +214,9 @@ async def trusted_user_proposal_post(request: Request,
return Response("Invalid 'decision' value.",
status_code=int(HTTPStatus.BAD_REQUEST))
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo,
autocommit=False)
voteinfo.ActiveTUs += 1
db.commit()
with db.begin():
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo)
voteinfo.ActiveTUs += 1
context["error"] = "You've already voted for this proposal."
return render_proposal(request, context, proposal, voteinfo, voters, vote)
......@@ -294,12 +293,13 @@ async def trusted_user_addvote_post(request: Request,
agenda = re.sub(r'<[/]?style.*>', '', agenda)
# Create a new TUVoteInfo (proposal)!
voteinfo = db.create(TUVoteInfo,
User=user,
Agenda=agenda,
Submitted=timestamp, End=timestamp + duration,
Quorum=quorum,
Submitter=request.user)
with db.begin():
voteinfo = db.create(TUVoteInfo,
User=user,
Agenda=agenda,
Submitted=timestamp, End=timestamp + duration,
Quorum=quorum,
Submitter=request.user)
# Redirect to the new proposal.
return RedirectResponse(f"/tu/{voteinfo.ID}",
......
import pytest
from aurweb.db import create, delete, query
from aurweb.db import begin, create, delete, query
from aurweb.models.account_type import AccountType
from aurweb.models.user import User
from aurweb.testing import setup_test_db
......@@ -14,11 +14,13 @@ def setup():
global account_type
account_type = create(AccountType, AccountType="TestUser")
with begin():
account_type = create(AccountType, AccountType="TestUser")
yield account_type
delete(AccountType, AccountType.ID == account_type.ID)
with begin():
delete(AccountType, AccountType.ID == account_type.ID)
def test_account_type():
......@@ -38,12 +40,14 @@ def test_account_type():
def test_user_account_type_relationship():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
with begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
assert user.AccountType == account_type
# This must be deleted here to avoid foreign key issues when
# deleting the temporary AccountType in the fixture.
delete(User, User.ID == user.ID)
with begin():
delete(User, User.ID == user.ID)
......@@ -11,9 +11,9 @@ import pytest
from fastapi.testclient import TestClient
from aurweb import captcha
from aurweb import captcha, db
from aurweb.asgi import app
from aurweb.db import commit, create, query
from aurweb.db import create, query
from aurweb.models.accepted_term import AcceptedTerm
from aurweb.models.account_type import DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, AccountType
from aurweb.models.ban import Ban
......@@ -57,9 +57,11 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test UserZ", Passwd="testPassword",
IRCNick="testZ", AccountType=account_type)
with db.begin():
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test UserZ", Passwd="testPassword",
IRCNick="testZ", AccountType=account_type)
yield user
......@@ -70,9 +72,10 @@ def setup():
@pytest.fixture
def tu_user():
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_AND_DEV_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
yield user
......@@ -149,11 +152,9 @@ def test_post_passreset_user():
def test_post_passreset_resetkey():
from aurweb.db import session
user.session = Session(UsersID=user.ID, SessionID="blah",
LastUpdateTS=datetime.utcnow().timestamp())
session.commit()
with db.begin():
user.session = Session(UsersID=user.ID, SessionID="blah",
LastUpdateTS=datetime.utcnow().timestamp())
# Prepare a password reset.
with client as request:
......@@ -357,7 +358,8 @@ def test_post_register_error_invalid_captcha():
def test_post_register_error_ip_banned():
# 'testclient' is used as request.client.host via FastAPI TestClient.
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
with db.begin():
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
with client as request:
response = post_register(request)
......@@ -576,7 +578,8 @@ def test_post_register_error_ssh_pubkey_taken():
# Take the sha256 fingerprint of the ssh public key, create it.
fp = get_fingerprint(pk)
create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp)
with db.begin():
create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp)
with client as request:
response = post_register(request, PK=pk)
......@@ -660,13 +663,11 @@ def test_post_account_edit():
def test_post_account_edit_dev():
from aurweb.db import session
# Modify our user to be a "Trusted User & Developer"
name = "Trusted User & Developer"
tu_or_dev = query(AccountType, AccountType.AccountType == name).first()
user.AccountType = tu_or_dev
session.commit()
with db.begin():
user.AccountType = tu_or_dev
request = Request()
sid = user.login(request, "testPassword")
......@@ -1001,21 +1002,19 @@ def get_rows(html):
def test_post_accounts(tu_user):
# Set a PGPKey.
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
with db.begin():
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
# Create a few more users.
users = [user]
for i in range(10):
_user = create(User, Username=f"test_{i}",
Email=f"test_{i}@example.org",
RealName=f"Test #{i}",
Passwd="testPassword",
IRCNick=f"test_#{i}",
autocommit=False)
users.append(_user)
# Commit everything to the database.
commit()
with db.begin():
for i in range(10):
_user = create(User, Username=f"test_{i}",
Email=f"test_{i}@example.org",
RealName=f"Test #{i}",
Passwd="testPassword",
IRCNick=f"test_#{i}")
users.append(_user)
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
......@@ -1085,11 +1084,12 @@ def test_post_accounts_account_type(tu_user):
# test the `u` parameter.
account_type = query(AccountType,
AccountType.AccountType == "User").first()
create(User, Username="test_2",
Email="test_2@example.org",
RealName="Test User 2",
Passwd="testPassword",
AccountType=account_type)
with db.begin():
create(User, Username="test_2",
Email="test_2@examp