aurweb/aurweb/db.py
Kevin Morris fa43f6bc3e
change(aurweb): add parallel tests and improve aurweb.db
This change utilizes pytest-xdist to perform a multiproc test
run and reworks aurweb.db's code. We no longer use a global
engine, session or Session, but we now use a memo of engines
and sessions as they are requested, based on the PYTEST_CURRENT_TEST
environment variable, which is available during testing.

Additionally, this change strips several SQLite components
out of the Python code-base.

SQLite is still compatible with PHP and sharness tests, but
not with our FastAPI implementation.

More changes:
------------
- Remove use of aurweb.db.session global in other code.
- Use new aurweb.db.name() dynamic db name function in env.py.
- Added 'addopts' to pytest.ini which utilizes multiprocessing.
    - Highly recommended to leave this be or modify `-n auto` to
      `-n {cpu_threads}` where cpu_threads is at least 2.

Signed-off-by: Kevin Morris <kevr@0cost.org>
2021-11-17 01:34:59 -08:00

412 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import functools
import hashlib
import math
import os
import re
from typing import Iterable, NewType
import sqlalchemy
from sqlalchemy import create_engine, event
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Query, Session, SessionTransaction, scoped_session, sessionmaker
import aurweb.config
import aurweb.util
from aurweb import logging
logger = logging.get_logger(__name__)
DRIVERS = {
"mysql": "mysql+mysqldb"
}
# Global introspected object memo.
introspected = dict()
# Some types we don't get access to in this module.
Base = NewType("Base", "aurweb.models.declarative_base.Base")
def make_random_value(table: str, column: str):
""" Generate a unique, random value for a string column in a table.
This can be used to generate for example, session IDs that
align with the properties of the database column with regards
to size.
Internally, we use SQLAlchemy introspection to look at column
to decide which length to use for random string generation.
:return: A unique string that is not in the database
"""
global introspected
# Make sure column is converted to a string for memo interaction.
scolumn = str(column)
# If the target column is not yet introspected, store its introspection
# object into our global `introspected` memo.
if scolumn not in introspected:
from sqlalchemy import inspect
target_column = scolumn.split('.')[-1]
col = list(filter(lambda c: c.name == target_column,
inspect(table).columns))[0]
introspected[scolumn] = col
col = introspected.get(scolumn)
length = col.type.length
string = aurweb.util.make_random_string(length)
while query(table).filter(column == string).first():
string = aurweb.util.make_random_string(length)
return string
def test_name() -> str:
"""
Return the unhashed database name.
The unhashed database name is determined (lower = higher priority) by:
-------------------------------------------
1. {test_suite} portion of PYTEST_CURRENT_TEST
2. aurweb.config.get("database", "name")
During `pytest` runs, the PYTEST_CURRENT_TEST environment variable
is set to the current test in the format `{test_suite}::{test_func}`.
This allows tests to use a suite-specific database for its runs,
which decouples database state from test suites.
:return: Unhashed database name
"""
db = os.environ.get("PYTEST_CURRENT_TEST",
aurweb.config.get("database", "name"))
return db.split(":")[0]
def name() -> str:
"""
Return sanitized database name that can be used for tests or production.
If test_name() starts with "test/", the database name is SHA-1 hashed,
prefixed with 'db', and returned. Otherwise, test_name() is passed
through and not hashed at all.
:return: SHA1-hashed database name prefixed with 'db'
"""
dbname = test_name()
if not dbname.startswith("test/"):
return dbname
sha1 = hashlib.sha1(dbname.encode()).hexdigest()
return "db" + sha1
# Module-private global memo used to store SQLAlchemy sessions.
_sessions = dict()
def get_session(engine: Engine = None) -> Session:
""" Return aurweb.db's global session. """
dbname = name()
global _sessions
if dbname not in _sessions:
if not engine: # pragma: no cover
engine = get_engine()
Session = scoped_session(
sessionmaker(autocommit=True, autoflush=False, bind=engine))
_sessions[dbname] = Session()
# If this is the first grab of this session, log out the
# database name used.
raw_dbname = test_name()
logger.debug(f"DBName({raw_dbname}): {dbname}")
return _sessions.get(dbname)
def pop_session(dbname: str) -> None:
"""
Pop a Session out of the private _sessions memo.
:param dbname: Database name
:raises KeyError: When `dbname` does not exist in the memo
"""
global _sessions
_sessions.pop(dbname)
def refresh(model: Base) -> Base:
""" Refresh the session's knowledge of `model`. """
get_session().refresh(model)
return model
def query(Model: Base, *args, **kwargs) -> Query:
"""
Perform an ORM query against the database session.
This method also runs Query.filter on the resulting model
query with *args and **kwargs.
:param Model: Declarative ORM class
"""
return get_session().query(Model).filter(*args, **kwargs)
def create(Model: Base, *args, **kwargs) -> Base:
"""
Create a record and add() it to the database session.
:param Model: Declarative ORM class
:return: Model instance
"""
instance = Model(*args, **kwargs)
return add(instance)
def delete(model: Base) -> None:
"""
Delete a set of records found by Query.filter(*args, **kwargs).
:param Model: Declarative ORM class
"""
get_session().delete(model)
def delete_all(iterable: Iterable) -> None:
""" Delete each instance found in `iterable`. """
session_ = get_session()
aurweb.util.apply_all(iterable, session_.delete)
def rollback() -> None:
""" Rollback the database session. """
get_session().rollback()
def add(model: Base) -> Base:
""" Add `model` to the database session. """
get_session().add(model)
return model
def begin() -> SessionTransaction:
""" Begin an SQLAlchemy SessionTransaction. """
return get_session().begin()
def get_sqlalchemy_url() -> URL:
"""
Build an SQLAlchemy URL for use with create_engine.
:return: sqlalchemy.engine.url.URL
"""
constructor = URL
parts = sqlalchemy.__version__.split('.')
major = int(parts[0])
minor = int(parts[1])
if major == 1 and minor >= 4: # pragma: no cover
constructor = URL.create
aur_db_backend = aurweb.config.get('database', 'backend')
if aur_db_backend == 'mysql':
param_query = {}
port = aurweb.config.get_with_fallback("database", "port", None)
if not port:
param_query["unix_socket"] = aurweb.config.get(
"database", "socket")
return constructor(
DRIVERS.get(aur_db_backend),
username=aurweb.config.get('database', 'user'),
password=aurweb.config.get_with_fallback('database', 'password',
fallback=None),
host=aurweb.config.get('database', 'host'),
database=name(),
port=port,
query=param_query
)
elif aur_db_backend == 'sqlite':
return constructor(
'sqlite',
database=aurweb.config.get('database', 'name'),
)
else:
raise ValueError('unsupported database backend')
def sqlite_regexp(regex, item) -> bool: # pragma: no cover
""" Method which mimics SQL's REGEXP for SQLite. """
return bool(re.search(regex, str(item)))
def setup_sqlite(engine: Engine) -> None: # pragma: no cover
""" Perform setup for an SQLite engine. """
@event.listens_for(engine, "connect")
def do_begin(conn, record):
create_deterministic_function = functools.partial(
conn.create_function,
deterministic=True
)
create_deterministic_function("REGEXP", 2, sqlite_regexp)
# Module-private global memo used to store SQLAlchemy engines.
_engines = dict()
def get_engine(dbname: str = None, echo: bool = False) -> Engine:
"""
Return the SQLAlchemy engine for `dbname`.
The engine is created on the first call to get_engine and then stored in the
`engine` global variable for the next calls.
:param dbname: Database name (default: aurweb.db.name())
:param echo: Flag passed through to sqlalchemy.create_engine
:return: SQLAlchemy Engine instance
"""
if not dbname:
dbname = name()
global _engines
if dbname not in _engines:
db_backend = aurweb.config.get("database", "backend")
connect_args = dict()
is_sqlite = bool(db_backend == "sqlite")
if is_sqlite: # pragma: no cover
connect_args["check_same_thread"] = False
kwargs = {
"echo": echo,
"connect_args": connect_args
}
_engines[dbname] = create_engine(get_sqlalchemy_url(), **kwargs)
if is_sqlite: # pragma: no cover
setup_sqlite(_engines.get(dbname))
return _engines.get(dbname)
def pop_engine(dbname: str) -> None:
"""
Pop an Engine out of the private _engines memo.
:param dbname: Database name
:raises KeyError: When `dbname` does not exist in the memo
"""
global _engines
_engines.pop(dbname)
def kill_engine() -> None:
""" Close the current session and dispose of the engine. """
dbname = name()
session = get_session()
session.close()
pop_session(dbname)
engine = get_engine()
engine.dispose()
pop_engine(dbname)
def connect():
"""
Return an SQLAlchemy connection. Connections are usually pooled. See
<https://docs.sqlalchemy.org/en/13/core/connections.html>.
Since SQLAlchemy connections are context managers too, you should use it
with Pythons `with` operator, or with FastAPIs dependency injection.
"""
return get_engine().connect()
class ConnectionExecutor:
_conn = None
_paramstyle = None
def __init__(self, conn, backend=aurweb.config.get("database", "backend")):
self._conn = conn
if backend == "mysql":
self._paramstyle = "format"
elif backend == "sqlite":
import sqlite3
self._paramstyle = sqlite3.paramstyle
def paramstyle(self):
return self._paramstyle
def execute(self, query, params=()): # pragma: no cover
# TODO: SQLite support has been removed in FastAPI. It remains
# here to fund its support for PHP until it is removed.
if self._paramstyle in ('format', 'pyformat'):
query = query.replace('%', '%%').replace('?', '%s')
elif self._paramstyle == 'qmark':
pass
else:
raise ValueError('unsupported paramstyle')
cur = self._conn.cursor()
cur.execute(query, params)
return cur
def commit(self):
self._conn.commit()
def close(self):
self._conn.close()
class Connection:
_executor = None
_conn = None
def __init__(self):
aur_db_backend = aurweb.config.get('database', 'backend')
if aur_db_backend == 'mysql':
import MySQLdb
aur_db_host = aurweb.config.get('database', 'host')
aur_db_name = name()
aur_db_user = aurweb.config.get('database', 'user')
aur_db_pass = aurweb.config.get_with_fallback(
'database', 'password', str())
aur_db_socket = aurweb.config.get('database', 'socket')
self._conn = MySQLdb.connect(host=aur_db_host,
user=aur_db_user,
passwd=aur_db_pass,
db=aur_db_name,
unix_socket=aur_db_socket)
elif aur_db_backend == 'sqlite': # pragma: no cover
# TODO: SQLite support has been removed in FastAPI. It remains
# here to fund its support for PHP until it is removed.
import sqlite3
aur_db_name = aurweb.config.get('database', 'name')
self._conn = sqlite3.connect(aur_db_name)
self._conn.create_function("POWER", 2, math.pow)
else:
raise ValueError('unsupported database backend')
self._conn = ConnectionExecutor(self._conn, aur_db_backend)
def execute(self, query, params=()):
return self._conn.execute(query, params)
def commit(self):
self._conn.commit()
def close(self):
self._conn.close()