Skip to content
Snippets Groups Projects
Commit d18cfad6 authored by Kevin Morris's avatar Kevin Morris
Browse files

use djangos method of wiping sqlite3 tables


Django uses a reference graph to determine the order
in table deletions that occur. Do the same here.

This commit also adds in the `REGEXP` sqlite function,
exactly how Django uses it in its reference graphing.

Signed-off-by: Kevin Morris's avatarKevin Morris <kevr@0cost.org>
parent 5de7ff64
No related branches found
No related tags found
1 merge request!72use djangos method of wiping sqlite3 tables
Pipeline #8347 passed
import functools
import math
import re
from sqlalchemy import event
import aurweb.config
import aurweb.util
......@@ -129,13 +133,31 @@ def get_engine(echo: bool = False):
if engine is None:
connect_args = dict()
if aurweb.config.get("database", "backend") == "sqlite":
db_backend = aurweb.config.get("database", "backend")
if db_backend == "sqlite":
# check_same_thread is for a SQLite technicality
# https://fastapi.tiangolo.com/tutorial/sql-databases/#note
connect_args["check_same_thread"] = False
engine = create_engine(get_sqlalchemy_url(),
connect_args=connect_args,
echo=echo)
if db_backend == "sqlite":
# For SQLite, we need to add some custom functions as
# they are used in the reference graph method.
def regexp(regex, item):
return bool(re.search(regex, str(item)))
@event.listens_for(engine, "begin")
def do_begin(conn):
create_deterministic_function = functools.partial(
conn.connection.create_function,
deterministic=True
)
create_deterministic_function("REGEXP", 2, regexp)
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
......
from itertools import chain
import aurweb.db
def references_graph(table):
""" Taken from Django's sqlite3/operations.py. """
query = """
WITH tables AS (
SELECT :table name
UNION
SELECT sqlite_master.name
FROM sqlite_master
JOIN tables ON (sql REGEXP :regexp_1 || tables.name || :regexp_2)
) SELECT name FROM tables;
"""
params = {
"table": table,
"regexp_1": r'(?i)\s+references\s+("|\')?',
"regexp_2": r'("|\')?\s*\(',
}
cursor = aurweb.db.session.execute(query, params=params)
return [row[0] for row in cursor.fetchall()]
def setup_test_db(*args):
""" This function is to be used to setup a test database before
using it. It takes a variable number of table strings, and for
......@@ -25,8 +47,22 @@ def setup_test_db(*args):
aurweb.db.get_engine()
tables = list(args)
db_backend = aurweb.config.get("database", "backend")
if db_backend != "sqlite":
aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 0")
else:
# We're using sqlite, setup tables to be deleted without violating
# foreign key constraints by graphing references.
tables = set(chain.from_iterable(
references_graph(table) for table in tables))
for table in tables:
aurweb.db.session.execute(f"DELETE FROM {table}")
if db_backend != "sqlite":
aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 1")
# Expunge all objects from SQLAlchemy's IdentityMap.
aurweb.db.session.expunge_all()
......@@ -200,6 +200,9 @@ def test_connection_execute_paramstyle_format():
aurweb.db.kill_engine()
aurweb.initdb.run(Args())
# Test SQLite route of clearing tables.
setup_test_db("Users", "Bans")
conn = db.Connection()
# First, test ? to %s format replacement.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment