aurweb/test/test_db.py
moson 122df968dc
feat: Switch to postgres
Migrate from MariaDB to PostgreSQL.

Signed-off-by: moson <moson@archlinux.org>
2024-12-10 16:13:12 +01:00

238 lines
6.6 KiB
Python

import os
import re
import sqlite3
import tempfile
from unittest import mock
import pytest
from sqlalchemy.exc import OperationalError
import aurweb.config
import aurweb.initdb
from aurweb import db
from aurweb.models.account_type import AccountType
class Args:
"""Stub arguments used for running aurweb.initdb."""
use_alembic = True
verbose = True
class DBCursor:
"""A fake database cursor object used in tests."""
items = []
def execute(self, *args, **kwargs):
self.items = list(args)
return self
def fetchall(self):
return self.items
class DBConnection:
"""A fake database connection object used in tests."""
@staticmethod
def cursor():
return DBCursor()
@staticmethod
def create_function(name, num_args, func):
pass
def make_temp_config(*replacements):
"""Generate a temporary config file with a set of replacements.
:param *replacements: A variable number of tuple regex replacement pairs
:return: A tuple containing (temp directory, temp config file)
"""
aurwebdir = aurweb.config.get("options", "aurwebdir")
config_file = os.path.join(aurwebdir, "conf", "config.dev")
config_defaults = os.path.join(aurwebdir, "conf", "config.defaults")
db_name = aurweb.config.get("database", "name")
db_host = aurweb.config.get_with_fallback("database", "host", "localhost")
db_port = aurweb.config.get_with_fallback("database", "port", "3306")
db_user = aurweb.config.get_with_fallback("database", "user", "root")
db_password = aurweb.config.get_with_fallback("database", "password", None)
# Replacements to perform before *replacements.
# These serve as generic replacements in config.dev
perform = (
(r"name = .+", f"name = {db_name}"),
(r"host = .+", f"host = {db_host}"),
(r";port = .+", f";port = {db_port}"),
(r"user = .+", f"user = {db_user}"),
(r"password = .+", f"password = {db_password}"),
("YOUR_AUR_ROOT", aurwebdir),
)
tmpdir = tempfile.TemporaryDirectory()
tmp = os.path.join(tmpdir.name, "config.tmp")
with open(config_file) as f:
config = f.read()
for repl in tuple(perform + replacements):
config = re.sub(repl[0], repl[1], config)
with open(tmp, "w") as o:
o.write(config)
with open(config_defaults) as i:
with open(f"{tmp}.defaults", "w") as o:
o.write(i.read())
return tmpdir, tmp
def make_temp_sqlite_config():
return make_temp_config(
(r"backend = .*", "backend = sqlite"),
(r"name = .*", "name = /tmp/aurweb.sqlite3"),
)
def make_temp_postgres_config():
return make_temp_config(
(r"backend = .*", "backend = postgres"), (r"name = .*", "name = aurweb_test")
)
@pytest.fixture(autouse=True)
def setup(db_test):
if os.path.exists("/tmp/aurweb.sqlite3"):
os.remove("/tmp/aurweb.sqlite3")
def test_sqlalchemy_sqlite_url():
tmpctx, tmp = make_temp_sqlite_config()
with tmpctx:
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
aurweb.config.rehash()
assert db.get_sqlalchemy_url()
aurweb.config.rehash()
def test_sqlalchemy_postgres_url():
tmpctx, tmp = make_temp_postgres_config()
with tmpctx:
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
aurweb.config.rehash()
assert db.get_sqlalchemy_url()
aurweb.config.rehash()
def test_sqlalchemy_postgres_port_url():
tmpctx, tmp = make_temp_config((r";port = 5432", "port = 5432"))
with tmpctx:
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
aurweb.config.rehash()
assert db.get_sqlalchemy_url()
aurweb.config.rehash()
def test_sqlalchemy_postgres_socket_url():
tmpctx, tmp = make_temp_config()
with tmpctx:
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
aurweb.config.rehash()
assert db.get_sqlalchemy_url()
aurweb.config.rehash()
def test_sqlalchemy_unknown_backend():
tmpctx, tmp = make_temp_config((r"backend = .+", "backend = blah"))
with tmpctx:
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
aurweb.config.rehash()
with pytest.raises(ValueError):
db.get_sqlalchemy_url()
aurweb.config.rehash()
def test_db_connects_without_fail():
"""This only tests the actual config supplied to pytest."""
db.connect()
def test_connection_class_unsupported_backend():
tmpctx, tmp = make_temp_config((r"backend = .+", "backend = blah"))
with tmpctx:
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
aurweb.config.rehash()
with pytest.raises(ValueError):
db.Connection()
aurweb.config.rehash()
def test_create_delete():
with db.begin():
account_type = db.create(AccountType, AccountType="test")
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is not None
with db.begin():
db.delete(account_type)
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is None
def test_add_commit():
# Use db.add and db.commit to add a temporary record.
account_type = AccountType(AccountType="test")
with db.begin():
db.add(account_type)
# Assert it got created in the DB.
assert bool(account_type.ID)
# Query the DB for it and compare the record with our object.
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record == account_type
# Remove the record.
with db.begin():
db.delete(account_type)
def test_connection_executor_postgres_paramstyle():
executor = db.ConnectionExecutor(None, backend="postgres")
assert executor.paramstyle() == "format"
@mock.patch("sqlite3.paramstyle", "pyformat")
def test_connection_executor_sqlite_paramstyle():
executor = db.ConnectionExecutor(None, backend="sqlite")
assert executor.paramstyle() == sqlite3.paramstyle
def test_name_without_pytest_current_test():
with mock.patch.dict("os.environ", {}, clear=True):
dbname = aurweb.db.name()
assert dbname == aurweb.config.get("database", "name")
def test_retry_deadlock():
@db.retry_deadlock
def func():
raise OperationalError("Deadlock found", tuple(), "")
with pytest.raises(OperationalError):
func()
@pytest.mark.asyncio
async def test_async_retry_deadlock():
@db.async_retry_deadlock
async def func():
raise OperationalError("Deadlock found", tuple(), "")
with pytest.raises(OperationalError):
await func()