diff --git a/aurweb/db.py b/aurweb/db.py index 4c53730a5c5588f421a6a6f54e5fda90c46275a9..94514d35c2e97d368b5cce409d7843d3887eb3c4 100644 --- a/aurweb/db.py +++ b/aurweb/db.py @@ -1,34 +1,15 @@ -import functools -import hashlib -import math -import os -import re - -from typing import Iterable, NewType - -import sqlalchemy - -from sqlalchemy import create_engine, event -from sqlalchemy.engine.base import Engine -from sqlalchemy.engine.url import URL -from sqlalchemy.orm import Query, Session, SessionTransaction, scoped_session, sessionmaker - -import aurweb.config -import aurweb.util - +# Supported database drivers. DRIVERS = { "mysql": "mysql+mysqldb" } -# Some types we don't get access to in this module. -Base = NewType("Base", "aurweb.models.declarative_base.Base") - def make_random_value(table: str, column: str, length: int): """ Generate a unique, random value for a string column in a table. :return: A unique string that is not in the database """ + import aurweb.util string = aurweb.util.make_random_string(length) while query(table).filter(column == string).first(): string = aurweb.util.make_random_string(length) @@ -52,6 +33,10 @@ def test_name() -> str: :return: Unhashed database name """ + import os + + import aurweb.config + db = os.environ.get("PYTEST_CURRENT_TEST", aurweb.config.get("database", "name")) return db.split(":")[0] @@ -70,7 +55,10 @@ def name() -> str: dbname = test_name() if not dbname.startswith("test/"): return dbname + + import hashlib sha1 = hashlib.sha1(dbname.encode()).hexdigest() + return "db" + sha1 @@ -78,12 +66,13 @@ def name() -> str: _sessions = dict() -def get_session(engine: Engine = None) -> Session: +def get_session(engine=None): """ Return aurweb.db's global session. """ dbname = name() global _sessions if dbname not in _sessions: + from sqlalchemy.orm import scoped_session, sessionmaker if not engine: # pragma: no cover engine = get_engine() @@ -106,13 +95,17 @@ def pop_session(dbname: str) -> None: _sessions.pop(dbname) -def refresh(model: Base) -> Base: - """ Refresh the session's knowledge of `model`. """ +def refresh(model): + """ + Refresh the session's knowledge of `model`. + + :returns: Passed in `model` + """ get_session().refresh(model) return model -def query(Model: Base, *args, **kwargs) -> Query: +def query(Model, *args, **kwargs): """ Perform an ORM query against the database session. @@ -124,7 +117,7 @@ def query(Model: Base, *args, **kwargs) -> Query: return get_session().query(Model).filter(*args, **kwargs) -def create(Model: Base, *args, **kwargs) -> Base: +def create(Model, *args, **kwargs): """ Create a record and add() it to the database session. @@ -135,7 +128,7 @@ def create(Model: Base, *args, **kwargs) -> Base: return add(instance) -def delete(model: Base) -> None: +def delete(model) -> None: """ Delete a set of records found by Query.filter(*args, **kwargs). @@ -144,8 +137,9 @@ def delete(model: Base) -> None: get_session().delete(model) -def delete_all(iterable: Iterable) -> None: +def delete_all(iterable) -> None: """ Delete each instance found in `iterable`. """ + import aurweb.util session_ = get_session() aurweb.util.apply_all(iterable, session_.delete) @@ -155,23 +149,29 @@ def rollback() -> None: get_session().rollback() -def add(model: Base) -> Base: +def add(model): """ Add `model` to the database session. """ get_session().add(model) return model -def begin() -> SessionTransaction: +def begin(): """ Begin an SQLAlchemy SessionTransaction. """ return get_session().begin() -def get_sqlalchemy_url() -> URL: +def get_sqlalchemy_url(): """ Build an SQLAlchemy URL for use with create_engine. :return: sqlalchemy.engine.url.URL """ + import sqlalchemy + + from sqlalchemy.engine.url import URL + + import aurweb.config + constructor = URL parts = sqlalchemy.__version__.split('.') @@ -209,13 +209,17 @@ def get_sqlalchemy_url() -> URL: def sqlite_regexp(regex, item) -> bool: # pragma: no cover """ Method which mimics SQL's REGEXP for SQLite. """ + import re return bool(re.search(regex, str(item))) -def setup_sqlite(engine: Engine) -> None: # pragma: no cover +def setup_sqlite(engine) -> None: # pragma: no cover """ Perform setup for an SQLite engine. """ + from sqlalchemy import event + @event.listens_for(engine, "connect") def do_begin(conn, record): + import functools create_deterministic_function = functools.partial( conn.create_function, deterministic=True @@ -227,7 +231,7 @@ def setup_sqlite(engine: Engine) -> None: # pragma: no cover _engines = dict() -def get_engine(dbname: str = None, echo: bool = False) -> Engine: +def get_engine(dbname: str = None, echo: bool = False): """ Return the SQLAlchemy engine for `dbname`. @@ -238,6 +242,8 @@ def get_engine(dbname: str = None, echo: bool = False) -> Engine: :param echo: Flag passed through to sqlalchemy.create_engine :return: SQLAlchemy Engine instance """ + import aurweb.config + if not dbname: dbname = name() @@ -254,6 +260,7 @@ def get_engine(dbname: str = None, echo: bool = False) -> Engine: "echo": echo, "connect_args": connect_args } + from sqlalchemy import create_engine _engines[dbname] = create_engine(get_sqlalchemy_url(), **kwargs) if is_sqlite: # pragma: no cover @@ -301,7 +308,10 @@ class ConnectionExecutor: _conn = None _paramstyle = None - def __init__(self, conn, backend=aurweb.config.get("database", "backend")): + def __init__(self, conn, backend=None): + import aurweb.config + + backend = backend or aurweb.config.get("database", "backend") self._conn = conn if backend == "mysql": self._paramstyle = "format" @@ -339,6 +349,7 @@ class Connection: _conn = None def __init__(self): + import aurweb.config aur_db_backend = aurweb.config.get('database', 'backend') if aur_db_backend == 'mysql': @@ -357,7 +368,9 @@ class Connection: elif aur_db_backend == 'sqlite': # pragma: no cover # TODO: SQLite support has been removed in FastAPI. It remains # here to fund its support for PHP until it is removed. + import math import sqlite3 + aur_db_name = aurweb.config.get('database', 'name') self._conn = sqlite3.connect(aur_db_name) self._conn.create_function("POWER", 2, math.pow)