db.py 3.82 KB
Newer Older
Kevin Morris's avatar
Kevin Morris committed
1
2
import math

3
4
5
6
7
8
9
10
11
try:
    import mysql.connector
except ImportError:
    pass

try:
    import sqlite3
except ImportError:
    pass
12

13
import aurweb.config
14

15
16
engine = None  # See get_engine

17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_sqlalchemy_url():
    """
    Build an SQLAlchemy for use with create_engine based on the aurweb configuration.
    """
    import sqlalchemy
    aur_db_backend = aurweb.config.get('database', 'backend')
    if aur_db_backend == 'mysql':
        return sqlalchemy.engine.url.URL(
            'mysql+mysqlconnector',
            username=aurweb.config.get('database', 'user'),
            password=aurweb.config.get('database', 'password'),
            host=aurweb.config.get('database', 'host'),
            database=aurweb.config.get('database', 'name'),
            query={
                'unix_socket': aurweb.config.get('database', 'socket'),
                'buffered': True,
            },
        )
    elif aur_db_backend == 'sqlite':
        return sqlalchemy.engine.url.URL(
            'sqlite',
            database=aurweb.config.get('database', 'name'),
        )
    else:
        raise ValueError('unsupported database backend')


45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def get_engine():
    """
    Return the global SQLAlchemy engine.

    The engine is created on the first call to get_engine and then stored in the
    `engine` global variable for the next calls.
    """
    from sqlalchemy import create_engine
    global engine
    if engine is None:
        engine = create_engine(get_sqlalchemy_url(),
                               # check_same_thread is for a SQLite technicality
                               # https://fastapi.tiangolo.com/tutorial/sql-databases/#note
                               connect_args={"check_same_thread": False})
    return engine


def connect():
    """
    Return an SQLAlchemy connection. Connections are usually pooled. See
    <https://docs.sqlalchemy.org/en/13/core/connections.html>.

    Since SQLAlchemy connections are context managers too, you should use it
    with Python’s `with` operator, or with FastAPI’s dependency injection.
    """
    return get_engine().connect()


73
74
class Connection:
    _conn = None
75
    _paramstyle = None
76
77

    def __init__(self):
78
        aur_db_backend = aurweb.config.get('database', 'backend')
79
80

        if aur_db_backend == 'mysql':
81
82
83
84
85
            aur_db_host = aurweb.config.get('database', 'host')
            aur_db_name = aurweb.config.get('database', 'name')
            aur_db_user = aurweb.config.get('database', 'user')
            aur_db_pass = aurweb.config.get('database', 'password')
            aur_db_socket = aurweb.config.get('database', 'socket')
86
87
88
89
90
91
92
93
            self._conn = mysql.connector.connect(host=aur_db_host,
                                                 user=aur_db_user,
                                                 passwd=aur_db_pass,
                                                 db=aur_db_name,
                                                 unix_socket=aur_db_socket,
                                                 buffered=True)
            self._paramstyle = mysql.connector.paramstyle
        elif aur_db_backend == 'sqlite':
94
            aur_db_name = aurweb.config.get('database', 'name')
95
            self._conn = sqlite3.connect(aur_db_name)
Kevin Morris's avatar
Kevin Morris committed
96
            self._conn.create_function("POWER", 2, math.pow)
97
98
99
            self._paramstyle = sqlite3.paramstyle
        else:
            raise ValueError('unsupported database backend')
100
101

    def execute(self, query, params=()):
102
        if self._paramstyle in ('format', 'pyformat'):
103
104
105
106
107
            query = query.replace('%', '%%').replace('?', '%s')
        elif self._paramstyle == 'qmark':
            pass
        else:
            raise ValueError('unsupported paramstyle')
108
109
110
111
112
113
114
115
116
117
118

        cur = self._conn.cursor()
        cur.execute(query, params)

        return cur

    def commit(self):
        self._conn.commit()

    def close(self):
        self._conn.close()