mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
In my opinion, this kind of handling of transactions is pretty ugly. The being said, we have issues with running into deadlocks on aur.al, so this commit works against that immediate bug. An ideal solution would be to deal with retrying transactions through the `db.begin()` scope, so we wouldn't have to explicitly annotate functions as "retry functions," which is what this commit does. Closes #376 Signed-off-by: Kevin Morris <kevr@0cost.org>
432 lines
12 KiB
Python
432 lines
12 KiB
Python
# Supported database drivers.
|
||
DRIVERS = {"mysql": "mysql+mysqldb"}
|
||
|
||
|
||
def make_random_value(table: str, column: str, length: int):
|
||
"""Generate a unique, random value for a string column in a table.
|
||
|
||
:return: A unique string that is not in the database
|
||
"""
|
||
import aurweb.util
|
||
|
||
string = aurweb.util.make_random_string(length)
|
||
while query(table).filter(column == string).first():
|
||
string = aurweb.util.make_random_string(length)
|
||
return string
|
||
|
||
|
||
def test_name() -> str:
|
||
"""
|
||
Return the unhashed database name.
|
||
|
||
The unhashed database name is determined (lower = higher priority) by:
|
||
-------------------------------------------
|
||
1. {test_suite} portion of PYTEST_CURRENT_TEST
|
||
2. aurweb.config.get("database", "name")
|
||
|
||
During `pytest` runs, the PYTEST_CURRENT_TEST environment variable
|
||
is set to the current test in the format `{test_suite}::{test_func}`.
|
||
|
||
This allows tests to use a suite-specific database for its runs,
|
||
which decouples database state from test suites.
|
||
|
||
:return: Unhashed database name
|
||
"""
|
||
import os
|
||
|
||
import aurweb.config
|
||
|
||
db = os.environ.get("PYTEST_CURRENT_TEST", aurweb.config.get("database", "name"))
|
||
return db.split(":")[0]
|
||
|
||
|
||
def name() -> str:
|
||
"""
|
||
Return sanitized database name that can be used for tests or production.
|
||
|
||
If test_name() starts with "test/", the database name is SHA-1 hashed,
|
||
prefixed with 'db', and returned. Otherwise, test_name() is passed
|
||
through and not hashed at all.
|
||
|
||
:return: SHA1-hashed database name prefixed with 'db'
|
||
"""
|
||
dbname = test_name()
|
||
if not dbname.startswith("test/"):
|
||
return dbname
|
||
|
||
import hashlib
|
||
|
||
sha1 = hashlib.sha1(dbname.encode()).hexdigest()
|
||
|
||
return "db" + sha1
|
||
|
||
|
||
# Module-private global memo used to store SQLAlchemy sessions.
|
||
_sessions = dict()
|
||
|
||
|
||
def get_session(engine=None):
|
||
"""Return aurweb.db's global session."""
|
||
dbname = name()
|
||
|
||
global _sessions
|
||
if dbname not in _sessions:
|
||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||
|
||
if not engine: # pragma: no cover
|
||
engine = get_engine()
|
||
|
||
Session = scoped_session(
|
||
sessionmaker(autocommit=True, autoflush=False, bind=engine)
|
||
)
|
||
_sessions[dbname] = Session()
|
||
|
||
return _sessions.get(dbname)
|
||
|
||
|
||
def pop_session(dbname: str) -> None:
|
||
"""
|
||
Pop a Session out of the private _sessions memo.
|
||
|
||
:param dbname: Database name
|
||
:raises KeyError: When `dbname` does not exist in the memo
|
||
"""
|
||
global _sessions
|
||
_sessions.pop(dbname)
|
||
|
||
|
||
def refresh(model):
|
||
"""
|
||
Refresh the session's knowledge of `model`.
|
||
|
||
:returns: Passed in `model`
|
||
"""
|
||
get_session().refresh(model)
|
||
return model
|
||
|
||
|
||
def query(Model, *args, **kwargs):
|
||
"""
|
||
Perform an ORM query against the database session.
|
||
|
||
This method also runs Query.filter on the resulting model
|
||
query with *args and **kwargs.
|
||
|
||
:param Model: Declarative ORM class
|
||
"""
|
||
return get_session().query(Model).filter(*args, **kwargs)
|
||
|
||
|
||
def create(Model, *args, **kwargs):
|
||
"""
|
||
Create a record and add() it to the database session.
|
||
|
||
:param Model: Declarative ORM class
|
||
:return: Model instance
|
||
"""
|
||
instance = Model(*args, **kwargs)
|
||
return add(instance)
|
||
|
||
|
||
def delete(model) -> None:
|
||
"""
|
||
Delete a set of records found by Query.filter(*args, **kwargs).
|
||
|
||
:param Model: Declarative ORM class
|
||
"""
|
||
get_session().delete(model)
|
||
|
||
|
||
def delete_all(iterable) -> None:
|
||
"""Delete each instance found in `iterable`."""
|
||
import aurweb.util
|
||
|
||
session_ = get_session()
|
||
aurweb.util.apply_all(iterable, session_.delete)
|
||
|
||
|
||
def rollback() -> None:
|
||
"""Rollback the database session."""
|
||
get_session().rollback()
|
||
|
||
|
||
def add(model):
|
||
"""Add `model` to the database session."""
|
||
get_session().add(model)
|
||
return model
|
||
|
||
|
||
def begin():
|
||
"""Begin an SQLAlchemy SessionTransaction."""
|
||
return get_session().begin()
|
||
|
||
|
||
def retry_deadlock(func):
|
||
from sqlalchemy.exc import OperationalError
|
||
|
||
def wrapper(*args, _i: int = 0, **kwargs):
|
||
# Retry 10 times, then raise the exception
|
||
# If we fail before the 10th, recurse into `wrapper`
|
||
# If we fail on the 10th, continue to throw the exception
|
||
limit = 10
|
||
try:
|
||
return func(*args, **kwargs)
|
||
except OperationalError as exc:
|
||
if _i < limit and "Deadlock found" in str(exc):
|
||
# Retry on deadlock by recursing into `wrapper`
|
||
return wrapper(*args, _i=_i + 1, **kwargs)
|
||
# Otherwise, just raise the exception
|
||
raise exc
|
||
|
||
return wrapper
|
||
|
||
|
||
def async_retry_deadlock(func):
|
||
from sqlalchemy.exc import OperationalError
|
||
|
||
async def wrapper(*args, _i: int = 0, **kwargs):
|
||
# Retry 10 times, then raise the exception
|
||
# If we fail before the 10th, recurse into `wrapper`
|
||
# If we fail on the 10th, continue to throw the exception
|
||
limit = 10
|
||
try:
|
||
return await func(*args, **kwargs)
|
||
except OperationalError as exc:
|
||
if _i < limit and "Deadlock found" in str(exc):
|
||
# Retry on deadlock by recursing into `wrapper`
|
||
return await wrapper(*args, _i=_i + 1, **kwargs)
|
||
# Otherwise, just raise the exception
|
||
raise exc
|
||
|
||
return wrapper
|
||
|
||
|
||
def get_sqlalchemy_url():
|
||
"""
|
||
Build an SQLAlchemy URL for use with create_engine.
|
||
|
||
:return: sqlalchemy.engine.url.URL
|
||
"""
|
||
import sqlalchemy
|
||
from sqlalchemy.engine.url import URL
|
||
|
||
import aurweb.config
|
||
|
||
constructor = URL
|
||
|
||
parts = sqlalchemy.__version__.split(".")
|
||
major = int(parts[0])
|
||
minor = int(parts[1])
|
||
if major == 1 and minor >= 4: # pragma: no cover
|
||
constructor = URL.create
|
||
|
||
aur_db_backend = aurweb.config.get("database", "backend")
|
||
if aur_db_backend == "mysql":
|
||
param_query = {}
|
||
port = aurweb.config.get_with_fallback("database", "port", None)
|
||
if not port:
|
||
param_query["unix_socket"] = aurweb.config.get("database", "socket")
|
||
|
||
return constructor(
|
||
DRIVERS.get(aur_db_backend),
|
||
username=aurweb.config.get("database", "user"),
|
||
password=aurweb.config.get_with_fallback(
|
||
"database", "password", fallback=None
|
||
),
|
||
host=aurweb.config.get("database", "host"),
|
||
database=name(),
|
||
port=port,
|
||
query=param_query,
|
||
)
|
||
elif aur_db_backend == "sqlite":
|
||
return constructor(
|
||
"sqlite",
|
||
database=aurweb.config.get("database", "name"),
|
||
)
|
||
else:
|
||
raise ValueError("unsupported database backend")
|
||
|
||
|
||
def sqlite_regexp(regex, item) -> bool: # pragma: no cover
|
||
"""Method which mimics SQL's REGEXP for SQLite."""
|
||
import re
|
||
|
||
return bool(re.search(regex, str(item)))
|
||
|
||
|
||
def setup_sqlite(engine) -> None: # pragma: no cover
|
||
"""Perform setup for an SQLite engine."""
|
||
from sqlalchemy import event
|
||
|
||
@event.listens_for(engine, "connect")
|
||
def do_begin(conn, record):
|
||
import functools
|
||
|
||
create_deterministic_function = functools.partial(
|
||
conn.create_function, deterministic=True
|
||
)
|
||
create_deterministic_function("REGEXP", 2, sqlite_regexp)
|
||
|
||
|
||
# Module-private global memo used to store SQLAlchemy engines.
|
||
_engines = dict()
|
||
|
||
|
||
def get_engine(dbname: str = None, echo: bool = False):
|
||
"""
|
||
Return the SQLAlchemy engine for `dbname`.
|
||
|
||
The engine is created on the first call to get_engine and then stored in the
|
||
`engine` global variable for the next calls.
|
||
|
||
:param dbname: Database name (default: aurweb.db.name())
|
||
:param echo: Flag passed through to sqlalchemy.create_engine
|
||
:return: SQLAlchemy Engine instance
|
||
"""
|
||
import aurweb.config
|
||
|
||
if not dbname:
|
||
dbname = name()
|
||
|
||
global _engines
|
||
if dbname not in _engines:
|
||
db_backend = aurweb.config.get("database", "backend")
|
||
connect_args = dict()
|
||
|
||
is_sqlite = bool(db_backend == "sqlite")
|
||
if is_sqlite: # pragma: no cover
|
||
connect_args["check_same_thread"] = False
|
||
|
||
kwargs = {"echo": echo, "connect_args": connect_args}
|
||
from sqlalchemy import create_engine
|
||
|
||
_engines[dbname] = create_engine(get_sqlalchemy_url(), **kwargs)
|
||
|
||
if is_sqlite: # pragma: no cover
|
||
setup_sqlite(_engines.get(dbname))
|
||
|
||
return _engines.get(dbname)
|
||
|
||
|
||
def pop_engine(dbname: str) -> None:
|
||
"""
|
||
Pop an Engine out of the private _engines memo.
|
||
|
||
:param dbname: Database name
|
||
:raises KeyError: When `dbname` does not exist in the memo
|
||
"""
|
||
global _engines
|
||
_engines.pop(dbname)
|
||
|
||
|
||
def kill_engine() -> None:
|
||
"""Close the current session and dispose of the engine."""
|
||
dbname = name()
|
||
|
||
session = get_session()
|
||
session.close()
|
||
pop_session(dbname)
|
||
|
||
engine = get_engine()
|
||
engine.dispose()
|
||
pop_engine(dbname)
|
||
|
||
|
||
def connect():
|
||
"""
|
||
Return an SQLAlchemy connection. Connections are usually pooled. See
|
||
<https://docs.sqlalchemy.org/en/13/core/connections.html>.
|
||
|
||
Since SQLAlchemy connections are context managers too, you should use it
|
||
with Python’s `with` operator, or with FastAPI’s dependency injection.
|
||
"""
|
||
return get_engine().connect()
|
||
|
||
|
||
class ConnectionExecutor:
|
||
_conn = None
|
||
_paramstyle = None
|
||
|
||
def __init__(self, conn, backend=None):
|
||
import aurweb.config
|
||
|
||
backend = backend or aurweb.config.get("database", "backend")
|
||
self._conn = conn
|
||
if backend == "mysql":
|
||
self._paramstyle = "format"
|
||
elif backend == "sqlite":
|
||
import sqlite3
|
||
|
||
self._paramstyle = sqlite3.paramstyle
|
||
|
||
def paramstyle(self):
|
||
return self._paramstyle
|
||
|
||
def execute(self, query, params=()): # pragma: no cover
|
||
# TODO: SQLite support has been removed in FastAPI. It remains
|
||
# here to fund its support for PHP until it is removed.
|
||
if self._paramstyle in ("format", "pyformat"):
|
||
query = query.replace("%", "%%").replace("?", "%s")
|
||
elif self._paramstyle == "qmark":
|
||
pass
|
||
else:
|
||
raise ValueError("unsupported paramstyle")
|
||
|
||
cur = self._conn.cursor()
|
||
cur.execute(query, params)
|
||
|
||
return cur
|
||
|
||
def commit(self):
|
||
self._conn.commit()
|
||
|
||
def close(self):
|
||
self._conn.close()
|
||
|
||
|
||
class Connection:
|
||
_executor = None
|
||
_conn = None
|
||
|
||
def __init__(self):
|
||
import aurweb.config
|
||
|
||
aur_db_backend = aurweb.config.get("database", "backend")
|
||
|
||
if aur_db_backend == "mysql":
|
||
import MySQLdb
|
||
|
||
aur_db_host = aurweb.config.get("database", "host")
|
||
aur_db_name = name()
|
||
aur_db_user = aurweb.config.get("database", "user")
|
||
aur_db_pass = aurweb.config.get_with_fallback("database", "password", str())
|
||
aur_db_socket = aurweb.config.get("database", "socket")
|
||
self._conn = MySQLdb.connect(
|
||
host=aur_db_host,
|
||
user=aur_db_user,
|
||
passwd=aur_db_pass,
|
||
db=aur_db_name,
|
||
unix_socket=aur_db_socket,
|
||
)
|
||
elif aur_db_backend == "sqlite": # pragma: no cover
|
||
# TODO: SQLite support has been removed in FastAPI. It remains
|
||
# here to fund its support for PHP until it is removed.
|
||
import math
|
||
import sqlite3
|
||
|
||
aur_db_name = aurweb.config.get("database", "name")
|
||
self._conn = sqlite3.connect(aur_db_name)
|
||
self._conn.create_function("POWER", 2, math.pow)
|
||
else:
|
||
raise ValueError("unsupported database backend")
|
||
|
||
self._conn = ConnectionExecutor(self._conn, aur_db_backend)
|
||
|
||
def execute(self, query, params=()):
|
||
return self._conn.execute(query, params)
|
||
|
||
def commit(self):
|
||
self._conn.commit()
|
||
|
||
def close(self):
|
||
self._conn.close()
|