aurweb/aurweb/db.py
moson-mo 8ca63075e9
housekeep: remove PHP implementation
removal of the PHP codebase

Signed-off-by: moson-mo <mo-son@mailbox.org>
2023-04-28 16:10:32 +02:00

432 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Supported database drivers.
DRIVERS = {"mysql": "mysql+mysqldb"}
def make_random_value(table: str, column: str, length: int):
"""Generate a unique, random value for a string column in a table.
:return: A unique string that is not in the database
"""
import aurweb.util
string = aurweb.util.make_random_string(length)
while query(table).filter(column == string).first():
string = aurweb.util.make_random_string(length)
return string
def test_name() -> str:
"""
Return the unhashed database name.
The unhashed database name is determined (lower = higher priority) by:
-------------------------------------------
1. {test_suite} portion of PYTEST_CURRENT_TEST
2. aurweb.config.get("database", "name")
During `pytest` runs, the PYTEST_CURRENT_TEST environment variable
is set to the current test in the format `{test_suite}::{test_func}`.
This allows tests to use a suite-specific database for its runs,
which decouples database state from test suites.
:return: Unhashed database name
"""
import os
import aurweb.config
db = os.environ.get("PYTEST_CURRENT_TEST", aurweb.config.get("database", "name"))
return db.split(":")[0]
def name() -> str:
"""
Return sanitized database name that can be used for tests or production.
If test_name() starts with "test/", the database name is SHA-1 hashed,
prefixed with 'db', and returned. Otherwise, test_name() is passed
through and not hashed at all.
:return: SHA1-hashed database name prefixed with 'db'
"""
dbname = test_name()
if not dbname.startswith("test/"):
return dbname
import hashlib
sha1 = hashlib.sha1(dbname.encode()).hexdigest()
return "db" + sha1
# Module-private global memo used to store SQLAlchemy sessions.
_sessions = dict()
def get_session(engine=None):
"""Return aurweb.db's global session."""
dbname = name()
global _sessions
if dbname not in _sessions:
from sqlalchemy.orm import scoped_session, sessionmaker
if not engine: # pragma: no cover
engine = get_engine()
Session = scoped_session(
sessionmaker(autocommit=True, autoflush=False, bind=engine)
)
_sessions[dbname] = Session()
return _sessions.get(dbname)
def pop_session(dbname: str) -> None:
"""
Pop a Session out of the private _sessions memo.
:param dbname: Database name
:raises KeyError: When `dbname` does not exist in the memo
"""
global _sessions
_sessions.pop(dbname)
def refresh(model):
"""
Refresh the session's knowledge of `model`.
:returns: Passed in `model`
"""
get_session().refresh(model)
return model
def query(Model, *args, **kwargs):
"""
Perform an ORM query against the database session.
This method also runs Query.filter on the resulting model
query with *args and **kwargs.
:param Model: Declarative ORM class
"""
return get_session().query(Model).filter(*args, **kwargs)
def create(Model, *args, **kwargs):
"""
Create a record and add() it to the database session.
:param Model: Declarative ORM class
:return: Model instance
"""
instance = Model(*args, **kwargs)
return add(instance)
def delete(model) -> None:
"""
Delete a set of records found by Query.filter(*args, **kwargs).
:param Model: Declarative ORM class
"""
get_session().delete(model)
def delete_all(iterable) -> None:
"""Delete each instance found in `iterable`."""
import aurweb.util
session_ = get_session()
aurweb.util.apply_all(iterable, session_.delete)
def rollback() -> None:
"""Rollback the database session."""
get_session().rollback()
def add(model):
"""Add `model` to the database session."""
get_session().add(model)
return model
def begin():
"""Begin an SQLAlchemy SessionTransaction."""
return get_session().begin()
def retry_deadlock(func):
from sqlalchemy.exc import OperationalError
def wrapper(*args, _i: int = 0, **kwargs):
# Retry 10 times, then raise the exception
# If we fail before the 10th, recurse into `wrapper`
# If we fail on the 10th, continue to throw the exception
limit = 10
try:
return func(*args, **kwargs)
except OperationalError as exc:
if _i < limit and "Deadlock found" in str(exc):
# Retry on deadlock by recursing into `wrapper`
return wrapper(*args, _i=_i + 1, **kwargs)
# Otherwise, just raise the exception
raise exc
return wrapper
def async_retry_deadlock(func):
from sqlalchemy.exc import OperationalError
async def wrapper(*args, _i: int = 0, **kwargs):
# Retry 10 times, then raise the exception
# If we fail before the 10th, recurse into `wrapper`
# If we fail on the 10th, continue to throw the exception
limit = 10
try:
return await func(*args, **kwargs)
except OperationalError as exc:
if _i < limit and "Deadlock found" in str(exc):
# Retry on deadlock by recursing into `wrapper`
return await wrapper(*args, _i=_i + 1, **kwargs)
# Otherwise, just raise the exception
raise exc
return wrapper
def get_sqlalchemy_url():
"""
Build an SQLAlchemy URL for use with create_engine.
:return: sqlalchemy.engine.url.URL
"""
import sqlalchemy
from sqlalchemy.engine.url import URL
import aurweb.config
constructor = URL
parts = sqlalchemy.__version__.split(".")
major = int(parts[0])
minor = int(parts[1])
if major == 1 and minor >= 4: # pragma: no cover
constructor = URL.create
aur_db_backend = aurweb.config.get("database", "backend")
if aur_db_backend == "mysql":
param_query = {}
port = aurweb.config.get_with_fallback("database", "port", None)
if not port:
param_query["unix_socket"] = aurweb.config.get("database", "socket")
return constructor(
DRIVERS.get(aur_db_backend),
username=aurweb.config.get("database", "user"),
password=aurweb.config.get_with_fallback(
"database", "password", fallback=None
),
host=aurweb.config.get("database", "host"),
database=name(),
port=port,
query=param_query,
)
elif aur_db_backend == "sqlite":
return constructor(
"sqlite",
database=aurweb.config.get("database", "name"),
)
else:
raise ValueError("unsupported database backend")
def sqlite_regexp(regex, item) -> bool: # pragma: no cover
"""Method which mimics SQL's REGEXP for SQLite."""
import re
return bool(re.search(regex, str(item)))
def setup_sqlite(engine) -> None: # pragma: no cover
"""Perform setup for an SQLite engine."""
from sqlalchemy import event
@event.listens_for(engine, "connect")
def do_begin(conn, record):
import functools
create_deterministic_function = functools.partial(
conn.create_function, deterministic=True
)
create_deterministic_function("REGEXP", 2, sqlite_regexp)
# Module-private global memo used to store SQLAlchemy engines.
_engines = dict()
def get_engine(dbname: str = None, echo: bool = False):
"""
Return the SQLAlchemy engine for `dbname`.
The engine is created on the first call to get_engine and then stored in the
`engine` global variable for the next calls.
:param dbname: Database name (default: aurweb.db.name())
:param echo: Flag passed through to sqlalchemy.create_engine
:return: SQLAlchemy Engine instance
"""
import aurweb.config
if not dbname:
dbname = name()
global _engines
if dbname not in _engines:
db_backend = aurweb.config.get("database", "backend")
connect_args = dict()
is_sqlite = bool(db_backend == "sqlite")
if is_sqlite: # pragma: no cover
connect_args["check_same_thread"] = False
kwargs = {"echo": echo, "connect_args": connect_args}
from sqlalchemy import create_engine
_engines[dbname] = create_engine(get_sqlalchemy_url(), **kwargs)
if is_sqlite: # pragma: no cover
setup_sqlite(_engines.get(dbname))
return _engines.get(dbname)
def pop_engine(dbname: str) -> None:
"""
Pop an Engine out of the private _engines memo.
:param dbname: Database name
:raises KeyError: When `dbname` does not exist in the memo
"""
global _engines
_engines.pop(dbname)
def kill_engine() -> None:
"""Close the current session and dispose of the engine."""
dbname = name()
session = get_session()
session.close()
pop_session(dbname)
engine = get_engine()
engine.dispose()
pop_engine(dbname)
def connect():
"""
Return an SQLAlchemy connection. Connections are usually pooled. See
<https://docs.sqlalchemy.org/en/13/core/connections.html>.
Since SQLAlchemy connections are context managers too, you should use it
with Pythons `with` operator, or with FastAPIs dependency injection.
"""
return get_engine().connect()
class ConnectionExecutor:
_conn = None
_paramstyle = None
def __init__(self, conn, backend=None):
import aurweb.config
backend = backend or aurweb.config.get("database", "backend")
self._conn = conn
if backend == "mysql":
self._paramstyle = "format"
elif backend == "sqlite":
import sqlite3
self._paramstyle = sqlite3.paramstyle
def paramstyle(self):
return self._paramstyle
def execute(self, query, params=()): # pragma: no cover
# TODO: SQLite support has been removed in FastAPI. It remains
# here to fund its support for the Sharness testsuite.
if self._paramstyle in ("format", "pyformat"):
query = query.replace("%", "%%").replace("?", "%s")
elif self._paramstyle == "qmark":
pass
else:
raise ValueError("unsupported paramstyle")
cur = self._conn.cursor()
cur.execute(query, params)
return cur
def commit(self):
self._conn.commit()
def close(self):
self._conn.close()
class Connection:
_executor = None
_conn = None
def __init__(self):
import aurweb.config
aur_db_backend = aurweb.config.get("database", "backend")
if aur_db_backend == "mysql":
import MySQLdb
aur_db_host = aurweb.config.get("database", "host")
aur_db_name = name()
aur_db_user = aurweb.config.get("database", "user")
aur_db_pass = aurweb.config.get_with_fallback("database", "password", str())
aur_db_socket = aurweb.config.get("database", "socket")
self._conn = MySQLdb.connect(
host=aur_db_host,
user=aur_db_user,
passwd=aur_db_pass,
db=aur_db_name,
unix_socket=aur_db_socket,
)
elif aur_db_backend == "sqlite": # pragma: no cover
# TODO: SQLite support has been removed in FastAPI. It remains
# here to fund its support for Sharness testsuite.
import math
import sqlite3
aur_db_name = aurweb.config.get("database", "name")
self._conn = sqlite3.connect(aur_db_name)
self._conn.create_function("POWER", 2, math.pow)
else:
raise ValueError("unsupported database backend")
self._conn = ConnectionExecutor(self._conn, aur_db_backend)
def execute(self, query, params=()):
return self._conn.execute(query, params)
def commit(self):
self._conn.commit()
def close(self):
self._conn.close()