db.py 3.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
try:
    import mysql.connector
except ImportError:
    pass

try:
    import sqlite3
except ImportError:
    pass
10

11
import aurweb.config
12

13
14
engine = None  # See get_engine

15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def get_sqlalchemy_url():
    """
    Build an SQLAlchemy for use with create_engine based on the aurweb configuration.
    """
    import sqlalchemy
    aur_db_backend = aurweb.config.get('database', 'backend')
    if aur_db_backend == 'mysql':
        return sqlalchemy.engine.url.URL(
            'mysql+mysqlconnector',
            username=aurweb.config.get('database', 'user'),
            password=aurweb.config.get('database', 'password'),
            host=aurweb.config.get('database', 'host'),
            database=aurweb.config.get('database', 'name'),
            query={
                'unix_socket': aurweb.config.get('database', 'socket'),
            },
        )
    elif aur_db_backend == 'sqlite':
        return sqlalchemy.engine.url.URL(
            'sqlite',
            database=aurweb.config.get('database', 'name'),
        )
    else:
        raise ValueError('unsupported database backend')


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def get_engine():
    """
    Return the global SQLAlchemy engine.

    The engine is created on the first call to get_engine and then stored in the
    `engine` global variable for the next calls.
    """
    from sqlalchemy import create_engine
    global engine
    if engine is None:
        engine = create_engine(get_sqlalchemy_url(),
                               # check_same_thread is for a SQLite technicality
                               # https://fastapi.tiangolo.com/tutorial/sql-databases/#note
                               connect_args={"check_same_thread": False})
    return engine


def connect():
    """
    Return an SQLAlchemy connection. Connections are usually pooled. See
    <https://docs.sqlalchemy.org/en/13/core/connections.html>.

    Since SQLAlchemy connections are context managers too, you should use it
    with Python’s `with` operator, or with FastAPI’s dependency injection.
    """
    return get_engine().connect()


70
71
class Connection:
    _conn = None
72
    _paramstyle = None
73
74

    def __init__(self):
75
        aur_db_backend = aurweb.config.get('database', 'backend')
76
77

        if aur_db_backend == 'mysql':
78
79
80
81
82
            aur_db_host = aurweb.config.get('database', 'host')
            aur_db_name = aurweb.config.get('database', 'name')
            aur_db_user = aurweb.config.get('database', 'user')
            aur_db_pass = aurweb.config.get('database', 'password')
            aur_db_socket = aurweb.config.get('database', 'socket')
83
84
85
86
87
88
89
90
            self._conn = mysql.connector.connect(host=aur_db_host,
                                                 user=aur_db_user,
                                                 passwd=aur_db_pass,
                                                 db=aur_db_name,
                                                 unix_socket=aur_db_socket,
                                                 buffered=True)
            self._paramstyle = mysql.connector.paramstyle
        elif aur_db_backend == 'sqlite':
91
            aur_db_name = aurweb.config.get('database', 'name')
92
93
94
95
            self._conn = sqlite3.connect(aur_db_name)
            self._paramstyle = sqlite3.paramstyle
        else:
            raise ValueError('unsupported database backend')
96
97

    def execute(self, query, params=()):
98
        if self._paramstyle in ('format', 'pyformat'):
99
100
101
102
103
            query = query.replace('%', '%%').replace('?', '%s')
        elif self._paramstyle == 'qmark':
            pass
        else:
            raise ValueError('unsupported paramstyle')
104
105
106
107
108
109
110
111
112
113
114

        cur = self._conn.cursor()
        cur.execute(query, params)

        return cur

    def commit(self):
        self._conn.commit()

    def close(self):
        self._conn.close()