mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
238 lines
6.6 KiB
Python
238 lines
6.6 KiB
Python
import os
|
|
import re
|
|
import sqlite3
|
|
import tempfile
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
from sqlalchemy.exc import OperationalError
|
|
|
|
import aurweb.config
|
|
import aurweb.initdb
|
|
from aurweb import db
|
|
from aurweb.models.account_type import AccountType
|
|
|
|
|
|
class Args:
|
|
"""Stub arguments used for running aurweb.initdb."""
|
|
|
|
use_alembic = True
|
|
verbose = True
|
|
|
|
|
|
class DBCursor:
|
|
"""A fake database cursor object used in tests."""
|
|
|
|
items = []
|
|
|
|
def execute(self, *args, **kwargs):
|
|
self.items = list(args)
|
|
return self
|
|
|
|
def fetchall(self):
|
|
return self.items
|
|
|
|
|
|
class DBConnection:
|
|
"""A fake database connection object used in tests."""
|
|
|
|
@staticmethod
|
|
def cursor():
|
|
return DBCursor()
|
|
|
|
@staticmethod
|
|
def create_function(name, num_args, func):
|
|
pass
|
|
|
|
|
|
def make_temp_config(*replacements):
|
|
"""Generate a temporary config file with a set of replacements.
|
|
|
|
:param *replacements: A variable number of tuple regex replacement pairs
|
|
:return: A tuple containing (temp directory, temp config file)
|
|
"""
|
|
aurwebdir = aurweb.config.get("options", "aurwebdir")
|
|
config_file = os.path.join(aurwebdir, "conf", "config.dev")
|
|
config_defaults = os.path.join(aurwebdir, "conf", "config.defaults")
|
|
|
|
db_name = aurweb.config.get("database", "name")
|
|
db_host = aurweb.config.get_with_fallback("database", "host", "localhost")
|
|
db_port = aurweb.config.get_with_fallback("database", "port", "3306")
|
|
db_user = aurweb.config.get_with_fallback("database", "user", "root")
|
|
db_password = aurweb.config.get_with_fallback("database", "password", None)
|
|
|
|
# Replacements to perform before *replacements.
|
|
# These serve as generic replacements in config.dev
|
|
perform = (
|
|
(r"name = .+", f"name = {db_name}"),
|
|
(r"host = .+", f"host = {db_host}"),
|
|
(r";port = .+", f";port = {db_port}"),
|
|
(r"user = .+", f"user = {db_user}"),
|
|
(r"password = .+", f"password = {db_password}"),
|
|
("YOUR_AUR_ROOT", aurwebdir),
|
|
)
|
|
|
|
tmpdir = tempfile.TemporaryDirectory()
|
|
tmp = os.path.join(tmpdir.name, "config.tmp")
|
|
with open(config_file) as f:
|
|
config = f.read()
|
|
for repl in tuple(perform + replacements):
|
|
config = re.sub(repl[0], repl[1], config)
|
|
with open(tmp, "w") as o:
|
|
o.write(config)
|
|
with open(config_defaults) as i:
|
|
with open(f"{tmp}.defaults", "w") as o:
|
|
o.write(i.read())
|
|
return tmpdir, tmp
|
|
|
|
|
|
def make_temp_sqlite_config():
|
|
return make_temp_config(
|
|
(r"backend = .*", "backend = sqlite"),
|
|
(r"name = .*", "name = /tmp/aurweb.sqlite3"),
|
|
)
|
|
|
|
|
|
def make_temp_postgres_config():
|
|
return make_temp_config(
|
|
(r"backend = .*", "backend = postgres"), (r"name = .*", "name = aurweb_test")
|
|
)
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def setup(db_test):
|
|
if os.path.exists("/tmp/aurweb.sqlite3"):
|
|
os.remove("/tmp/aurweb.sqlite3")
|
|
|
|
|
|
def test_sqlalchemy_sqlite_url():
|
|
tmpctx, tmp = make_temp_sqlite_config()
|
|
with tmpctx:
|
|
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
|
|
aurweb.config.rehash()
|
|
assert db.get_sqlalchemy_url()
|
|
aurweb.config.rehash()
|
|
|
|
|
|
def test_sqlalchemy_postgres_url():
|
|
tmpctx, tmp = make_temp_postgres_config()
|
|
with tmpctx:
|
|
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
|
|
aurweb.config.rehash()
|
|
assert db.get_sqlalchemy_url()
|
|
aurweb.config.rehash()
|
|
|
|
|
|
def test_sqlalchemy_postgres_port_url():
|
|
tmpctx, tmp = make_temp_config((r";port = 5432", "port = 5432"))
|
|
|
|
with tmpctx:
|
|
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
|
|
aurweb.config.rehash()
|
|
assert db.get_sqlalchemy_url()
|
|
aurweb.config.rehash()
|
|
|
|
|
|
def test_sqlalchemy_postgres_socket_url():
|
|
tmpctx, tmp = make_temp_config()
|
|
|
|
with tmpctx:
|
|
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
|
|
aurweb.config.rehash()
|
|
assert db.get_sqlalchemy_url()
|
|
aurweb.config.rehash()
|
|
|
|
|
|
def test_sqlalchemy_unknown_backend():
|
|
tmpctx, tmp = make_temp_config((r"backend = .+", "backend = blah"))
|
|
|
|
with tmpctx:
|
|
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
|
|
aurweb.config.rehash()
|
|
with pytest.raises(ValueError):
|
|
db.get_sqlalchemy_url()
|
|
aurweb.config.rehash()
|
|
|
|
|
|
def test_db_connects_without_fail():
|
|
"""This only tests the actual config supplied to pytest."""
|
|
db.connect()
|
|
|
|
|
|
def test_connection_class_unsupported_backend():
|
|
tmpctx, tmp = make_temp_config((r"backend = .+", "backend = blah"))
|
|
|
|
with tmpctx:
|
|
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
|
|
aurweb.config.rehash()
|
|
with pytest.raises(ValueError):
|
|
db.Connection()
|
|
aurweb.config.rehash()
|
|
|
|
|
|
def test_create_delete():
|
|
with db.begin():
|
|
account_type = db.create(AccountType, AccountType="test")
|
|
|
|
record = db.query(AccountType, AccountType.AccountType == "test").first()
|
|
assert record is not None
|
|
|
|
with db.begin():
|
|
db.delete(account_type)
|
|
|
|
record = db.query(AccountType, AccountType.AccountType == "test").first()
|
|
assert record is None
|
|
|
|
|
|
def test_add_commit():
|
|
# Use db.add and db.commit to add a temporary record.
|
|
account_type = AccountType(AccountType="test")
|
|
with db.begin():
|
|
db.add(account_type)
|
|
|
|
# Assert it got created in the DB.
|
|
assert bool(account_type.ID)
|
|
|
|
# Query the DB for it and compare the record with our object.
|
|
record = db.query(AccountType, AccountType.AccountType == "test").first()
|
|
assert record == account_type
|
|
|
|
# Remove the record.
|
|
with db.begin():
|
|
db.delete(account_type)
|
|
|
|
|
|
def test_connection_executor_postgres_paramstyle():
|
|
executor = db.ConnectionExecutor(None, backend="postgres")
|
|
assert executor.paramstyle() == "format"
|
|
|
|
|
|
@mock.patch("sqlite3.paramstyle", "pyformat")
|
|
def test_connection_executor_sqlite_paramstyle():
|
|
executor = db.ConnectionExecutor(None, backend="sqlite")
|
|
assert executor.paramstyle() == sqlite3.paramstyle
|
|
|
|
|
|
def test_name_without_pytest_current_test():
|
|
with mock.patch.dict("os.environ", {}, clear=True):
|
|
dbname = aurweb.db.name()
|
|
assert dbname == aurweb.config.get("database", "name")
|
|
|
|
|
|
def test_retry_deadlock():
|
|
@db.retry_deadlock
|
|
def func():
|
|
raise OperationalError("Deadlock found", tuple(), "")
|
|
|
|
with pytest.raises(OperationalError):
|
|
func()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_retry_deadlock():
|
|
@db.async_retry_deadlock
|
|
async def func():
|
|
raise OperationalError("Deadlock found", tuple(), "")
|
|
|
|
with pytest.raises(OperationalError):
|
|
await func()
|