diff --git a/aurweb/db.py b/aurweb/db.py index 500cf95a..590712e0 100644 --- a/aurweb/db.py +++ b/aurweb/db.py @@ -98,9 +98,11 @@ def get_sqlalchemy_url(): param_query = None else: port = None - param_query = {'unix_socket': aurweb.config.get('database', 'socket')} + param_query = { + 'unix_socket': aurweb.config.get('database', 'socket') + } return constructor( - 'mysql+mysqlconnector', + 'mysql+mysqldb', username=aurweb.config.get('database', 'user'), password=aurweb.config.get('database', 'password'), host=aurweb.config.get('database', 'host'), @@ -117,7 +119,7 @@ def get_sqlalchemy_url(): raise ValueError('unsupported database backend') -def get_engine(): +def get_engine(echo: bool = False): """ Return the global SQLAlchemy engine. @@ -135,13 +137,24 @@ def get_engine(): # check_same_thread is for a SQLite technicality # https://fastapi.tiangolo.com/tutorial/sql-databases/#note connect_args["check_same_thread"] = False - engine = create_engine(get_sqlalchemy_url(), connect_args=connect_args) + + engine = create_engine(get_sqlalchemy_url(), + connect_args=connect_args, + echo=echo) Session = sessionmaker(autocommit=False, autoflush=True, bind=engine) session = Session() return engine +def kill_engine(): + global engine, Session, session + if engine: + session.close() + engine.dispose() + engine = Session = session = None + + def connect(): """ Return an SQLAlchemy connection. Connections are usually pooled. See @@ -160,8 +173,7 @@ class ConnectionExecutor: def __init__(self, conn, backend=aurweb.config.get("database", "backend")): self._conn = conn if backend == "mysql": - import mysql.connector - self._paramstyle = mysql.connector.paramstyle + self._paramstyle = "format" elif backend == "sqlite": import sqlite3 self._paramstyle = sqlite3.paramstyle @@ -197,18 +209,17 @@ class Connection: aur_db_backend = aurweb.config.get('database', 'backend') if aur_db_backend == 'mysql': - import mysql.connector + import MySQLdb aur_db_host = aurweb.config.get('database', 'host') aur_db_name = aurweb.config.get('database', 'name') aur_db_user = aurweb.config.get('database', 'user') aur_db_pass = aurweb.config.get('database', 'password') aur_db_socket = aurweb.config.get('database', 'socket') - self._conn = mysql.connector.connect(host=aur_db_host, - user=aur_db_user, - passwd=aur_db_pass, - db=aur_db_name, - unix_socket=aur_db_socket, - buffered=True) + self._conn = MySQLdb.connect(host=aur_db_host, + user=aur_db_user, + passwd=aur_db_pass, + db=aur_db_name, + unix_socket=aur_db_socket) elif aur_db_backend == 'sqlite': import sqlite3 aur_db_name = aurweb.config.get('database', 'name') @@ -217,7 +228,7 @@ class Connection: else: raise ValueError('unsupported database backend') - self._conn = ConnectionExecutor(self._conn) + self._conn = ConnectionExecutor(self._conn, aur_db_backend) def execute(self, query, params=()): return self._conn.execute(query, params) diff --git a/aurweb/initdb.py b/aurweb/initdb.py index 5f55bfc9..46f079c0 100644 --- a/aurweb/initdb.py +++ b/aurweb/initdb.py @@ -2,7 +2,6 @@ import argparse import alembic.command import alembic.config -import sqlalchemy import aurweb.db import aurweb.schema @@ -34,6 +33,8 @@ def feed_initial_data(conn): def run(args): + aurweb.config.rehash() + # Ensure Alembic is fine before we do the real work, in order not to fail at # the last step and leave the database in an inconsistent state. The # configuration is loaded lazily, so we query it to force its loading. @@ -42,8 +43,7 @@ def run(args): alembic_config.get_main_option('script_location') alembic_config.attributes["configure_logger"] = False - engine = sqlalchemy.create_engine(aurweb.db.get_sqlalchemy_url(), - echo=(args.verbose >= 1)) + engine = aurweb.db.get_engine(echo=(args.verbose >= 1)) aurweb.schema.metadata.create_all(engine) feed_initial_data(engine.connect()) diff --git a/aurweb/models/accepted_term.py b/aurweb/models/accepted_term.py index 6e8ffe99..483109f1 100644 --- a/aurweb/models/accepted_term.py +++ b/aurweb/models/accepted_term.py @@ -1,3 +1,4 @@ +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import mapper from aurweb.db import make_relationship @@ -11,7 +12,19 @@ class AcceptedTerm: User: User = None, Term: Term = None, Revision: int = None): self.User = User + if not self.User: + raise IntegrityError( + statement="Foreign key UserID cannot be null.", + orig="AcceptedTerms.UserID", + params=("NULL")) + self.Term = Term + if not self.Term: + raise IntegrityError( + statement="Foreign key TermID cannot be null.", + orig="AcceptedTerms.TermID", + params=("NULL")) + self.Revision = Revision diff --git a/aurweb/models/api_rate_limit.py b/aurweb/models/api_rate_limit.py index 44e7a463..8b945b6a 100644 --- a/aurweb/models/api_rate_limit.py +++ b/aurweb/models/api_rate_limit.py @@ -1,3 +1,4 @@ +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import mapper from aurweb.schema import ApiRateLimit as _ApiRateLimit @@ -8,8 +9,20 @@ class ApiRateLimit: Requests: int = None, WindowStart: int = None): self.IP = IP + self.Requests = Requests + if self.Requests is None: + raise IntegrityError( + statement="Column Requests cannot be null.", + orig="ApiRateLimit.Requests", + params=("NULL")) + self.WindowStart = WindowStart + if self.WindowStart is None: + raise IntegrityError( + statement="Column WindowStart cannot be null.", + orig="ApiRateLimit.WindowStart", + params=("NULL")) mapper(ApiRateLimit, _ApiRateLimit, primary_key=[_ApiRateLimit.c.IP]) diff --git a/aurweb/models/group.py b/aurweb/models/group.py index 5d4f3834..c5583eb4 100644 --- a/aurweb/models/group.py +++ b/aurweb/models/group.py @@ -1,3 +1,4 @@ +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import mapper from aurweb.schema import Groups @@ -6,6 +7,11 @@ from aurweb.schema import Groups class Group: def __init__(self, Name: str = None): self.Name = Name + if not self.Name: + raise IntegrityError( + statement="Column Name cannot be null.", + orig="Groups.Name", + params=("NULL")) mapper(Group, Groups) diff --git a/aurweb/models/license.py b/aurweb/models/license.py index 1c174925..bcc02713 100644 --- a/aurweb/models/license.py +++ b/aurweb/models/license.py @@ -1,3 +1,4 @@ +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import mapper from aurweb.schema import Licenses @@ -6,6 +7,11 @@ from aurweb.schema import Licenses class License: def __init__(self, Name: str = None): self.Name = Name + if not self.Name: + raise IntegrityError( + statement="Column Name cannot be null.", + orig="Licenses.Name", + params=("NULL")) mapper(License, Licenses) diff --git a/aurweb/models/package.py b/aurweb/models/package.py index fa82bb74..28a13791 100644 --- a/aurweb/models/package.py +++ b/aurweb/models/package.py @@ -1,3 +1,4 @@ +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import mapper from aurweb.db import make_relationship @@ -11,7 +12,19 @@ class Package: Name: str = None, Version: str = None, Description: str = None, URL: str = None): self.PackageBase = PackageBase + if not self.PackageBase: + raise IntegrityError( + statement="Foreign key UserID cannot be null.", + orig="Packages.PackageBaseID", + params=("NULL")) + self.Name = Name + if not self.Name: + raise IntegrityError( + statement="Column Name cannot be null.", + orig="Packages.Name", + params=("NULL")) + self.Version = Version self.Description = Description self.URL = URL diff --git a/aurweb/models/package_base.py b/aurweb/models/package_base.py index 57e5a46b..699559d5 100644 --- a/aurweb/models/package_base.py +++ b/aurweb/models/package_base.py @@ -1,5 +1,6 @@ from datetime import datetime +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import mapper from aurweb.db import make_relationship @@ -12,6 +13,12 @@ class PackageBase: Maintainer: User = None, Submitter: User = None, Packager: User = None, **kwargs): self.Name = Name + if not self.Name: + raise IntegrityError( + statement="Column Name cannot be null.", + orig="PackageBases.Name", + params=("NULL")) + self.Flagger = Flagger self.Maintainer = Maintainer self.Submitter = Submitter diff --git a/aurweb/models/package_dependency.py b/aurweb/models/package_dependency.py index ae6ae62a..21801802 100644 --- a/aurweb/models/package_dependency.py +++ b/aurweb/models/package_dependency.py @@ -1,3 +1,4 @@ +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import mapper from aurweb.db import make_relationship @@ -12,8 +13,26 @@ class PackageDependency: DepName: str = None, DepDesc: str = None, DepCondition: str = None, DepArch: str = None): self.Package = Package + if not self.Package: + raise IntegrityError( + statement="Foreign key PackageID cannot be null.", + orig="PackageDependencies.PackageID", + params=("NULL")) + self.DependencyType = DependencyType - self.DepName = DepName # nullable=False + if not self.DependencyType: + raise IntegrityError( + statement="Foreign key DepTypeID cannot be null.", + orig="PackageDependencies.DepTypeID", + params=("NULL")) + + self.DepName = DepName + if not self.DepName: + raise IntegrityError( + statement="Column DepName cannot be null.", + orig="PackageDependencies.DepName", + params=("NULL")) + self.DepDesc = DepDesc self.DepCondition = DepCondition self.DepArch = DepArch diff --git a/aurweb/models/package_group.py b/aurweb/models/package_group.py index c155fe00..19a11c80 100644 --- a/aurweb/models/package_group.py +++ b/aurweb/models/package_group.py @@ -1,5 +1,5 @@ -from sqlalchemy.orm import mapper from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import mapper from aurweb.db import make_relationship from aurweb.models.group import Group diff --git a/aurweb/models/package_keyword.py b/aurweb/models/package_keyword.py index 4a66f38e..2bae223c 100644 --- a/aurweb/models/package_keyword.py +++ b/aurweb/models/package_keyword.py @@ -1,5 +1,5 @@ -from sqlalchemy.orm import mapper from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import mapper from aurweb.db import make_relationship from aurweb.models.package_base import PackageBase diff --git a/aurweb/models/package_license.py b/aurweb/models/package_license.py index 6f23f84a..491874a4 100644 --- a/aurweb/models/package_license.py +++ b/aurweb/models/package_license.py @@ -1,5 +1,5 @@ -from sqlalchemy.orm import mapper from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import mapper from aurweb.db import make_relationship from aurweb.models.license import License diff --git a/aurweb/models/package_relation.py b/aurweb/models/package_relation.py index 196f1dee..d9ade727 100644 --- a/aurweb/models/package_relation.py +++ b/aurweb/models/package_relation.py @@ -1,3 +1,4 @@ +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import mapper from aurweb.db import make_relationship @@ -12,8 +13,26 @@ class PackageRelation: RelName: str = None, RelCondition: str = None, RelArch: str = None): self.Package = Package + if not self.Package: + raise IntegrityError( + statement="Foreign key PackageID cannot be null.", + orig="PackageRelations.PackageID", + params=("NULL")) + self.RelationType = RelationType + if not self.RelationType: + raise IntegrityError( + statement="Foreign key RelTypeID cannot be null.", + orig="PackageRelations.RelTypeID", + params=("NULL")) + self.RelName = RelName # nullable=False + if not self.RelName: + raise IntegrityError( + statement="Column RelName cannot be null.", + orig="PackageRelations.RelName", + params=("NULL")) + self.RelCondition = RelCondition self.RelArch = RelArch diff --git a/aurweb/models/session.py b/aurweb/models/session.py index 60749303..f1e0fff5 100644 --- a/aurweb/models/session.py +++ b/aurweb/models/session.py @@ -1,16 +1,20 @@ -from sqlalchemy import Column, Integer +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import backref, mapper, relationship -from aurweb.db import make_random_value +from aurweb.db import make_random_value, query from aurweb.models.user import User from aurweb.schema import Sessions class Session: - UsersID = Column(Integer, nullable=True) - def __init__(self, **kwargs): self.UsersID = kwargs.get("UsersID") + if not query(User, User.ID == self.UsersID).first(): + raise IntegrityError( + statement="Foreign key UsersID cannot be null.", + orig="Sessions.UsersID", + params=("NULL")) + self.SessionID = kwargs.get("SessionID") self.LastUpdateTS = kwargs.get("LastUpdateTS") diff --git a/aurweb/models/term.py b/aurweb/models/term.py index 1b4902f7..1a0780df 100644 --- a/aurweb/models/term.py +++ b/aurweb/models/term.py @@ -1,3 +1,4 @@ +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import mapper from aurweb.schema import Terms @@ -8,7 +9,19 @@ class Term: Description: str = None, URL: str = None, Revision: int = None): self.Description = Description + if not self.Description: + raise IntegrityError( + statement="Column Description cannot be null.", + orig="Terms.Description", + params=("NULL")) + self.URL = URL + if not self.URL: + raise IntegrityError( + statement="Column URL cannot be null.", + orig="Terms.URL", + params=("NULL")) + self.Revision = Revision diff --git a/conf/config.dev b/conf/config.dev index 94775a92..45d940e6 100644 --- a/conf/config.dev +++ b/conf/config.dev @@ -6,17 +6,19 @@ ; development-specific options too. [database] -backend = sqlite -name = YOUR_AUR_ROOT/aurweb.sqlite3 +; Options: mysql, sqlite. +backend = mysql -; Alternative MySQL configuration (Use either port of socket, if both defined port takes priority) -;backend = mysql -;name = aurweb -;user = aur -;password = aur -;host = localhost +; If using sqlite, set name to the database file path. +name = aurweb + +; MySQL database information. User defaults to root for containerized +; testing with mysqldb. This should be set to a non-root user. +user = root +;password = non-root-user-password +host = localhost ;port = 3306 -;socket = /var/run/mysqld/mysqld.sock +socket = /var/run/mysqld/mysqld.sock [options] aurwebdir = YOUR_AUR_ROOT diff --git a/test/Makefile b/test/Makefile index 060e57c2..920c7113 100644 --- a/test/Makefile +++ b/test/Makefile @@ -8,7 +8,7 @@ MAKEFLAGS = -j1 check: sh pytest pytest: - cd .. && AUR_CONFIG=conf/config coverage run --append /usr/bin/pytest test + cd .. && coverage run --append /usr/bin/pytest test ifdef PROVE sh: diff --git a/test/test_accounts_routes.py b/test/test_accounts_routes.py index 0f813823..3080a505 100644 --- a/test/test_accounts_routes.py +++ b/test/test_accounts_routes.py @@ -802,18 +802,40 @@ def test_post_account_edit_ssh_pub_key(): assert response.status_code == int(HTTPStatus.OK) # Now let's update what's already there to gain coverage over that path. - pk = str() - with tempfile.TemporaryDirectory() as tmpdir: - with open("/dev/null", "w") as null: - proc = Popen(["ssh-keygen", "-f", f"{tmpdir}/test.ssh", "-N", ""], - stdout=null, stderr=null) - proc.wait() - assert proc.returncode == 0 + post_data["PK"] = make_ssh_pubkey() - # Read in the public key, then delete the temp dir we made. - pk = open(f"{tmpdir}/test.ssh.pub").read().rstrip() + with client as request: + response = request.post("/account/test/edit", cookies={ + "AURSID": sid + }, data=post_data, allow_redirects=False) - post_data["PK"] = pk + assert response.status_code == int(HTTPStatus.OK) + + +def test_post_account_edit_missing_ssh_pubkey(): + request = Request() + sid = user.login(request, "testPassword") + + post_data = { + "U": user.Username, + "E": user.Email, + "PK": make_ssh_pubkey(), + "passwd": "testPassword" + } + + with client as request: + response = request.post("/account/test/edit", cookies={ + "AURSID": sid + }, data=post_data, allow_redirects=False) + + assert response.status_code == int(HTTPStatus.OK) + + post_data = { + "U": user.Username, + "E": user.Email, + "PK": str(), # Pass an empty string now to walk the delete path. + "passwd": "testPassword" + } with client as request: response = request.post("/account/test/edit", cookies={ diff --git a/test/test_api_rate_limit.py b/test/test_api_rate_limit.py index c599ddcf..536e3841 100644 --- a/test/test_api_rate_limit.py +++ b/test/test_api_rate_limit.py @@ -34,5 +34,5 @@ def test_api_rate_key_null_requests_raises_exception(): def test_api_rate_key_null_window_start_raises_exception(): from aurweb.db import session with pytest.raises(IntegrityError): - create(ApiRateLimit, IP="127.0.0.1", WindowStart=1) + create(ApiRateLimit, IP="127.0.0.1", Requests=1) session.rollback() diff --git a/test/test_auth.py b/test/test_auth.py index 7837e7f7..42eac040 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -4,6 +4,8 @@ import pytest from starlette.authentication import AuthenticationError +import aurweb.config + from aurweb.auth import BasicAuthBackend, has_credential from aurweb.db import create, query from aurweb.models.account_type import AccountType @@ -53,13 +55,12 @@ async def test_auth_backend_invalid_sid(): async def test_auth_backend_invalid_user_id(): # Create a new session with a fake user id. now_ts = datetime.utcnow().timestamp() - create(Session, UsersID=666, SessionID="realSession", - LastUpdateTS=now_ts + 5) + db_backend = aurweb.config.get("database", "backend") + with pytest.raises(IntegrityError): + create(Session, UsersID=666, SessionID="realSession", + LastUpdateTS=now_ts + 5) - # Here, we specify a real SID; but it's user is not there. - request.cookies["AURSID"] = "realSession" - with pytest.raises(AuthenticationError, match="Invalid User ID: 666"): - await backend.authenticate(request) + session.rollback() @pytest.mark.asyncio diff --git a/test/test_ban.py b/test/test_ban.py index a4fa5a28..b728644b 100644 --- a/test/test_ban.py +++ b/test/test_ban.py @@ -33,8 +33,7 @@ def test_ban(): def test_invalid_ban(): from aurweb.db import session - with pytest.raises(sa_exc.IntegrityError, - match="NOT NULL constraint failed: Bans.IPAddress"): + with pytest.raises(sa_exc.IntegrityError): bad_ban = Ban(BanTS=datetime.utcnow()) session.add(bad_ban) diff --git a/test/test_db.py b/test/test_db.py index e0946ed5..3911134f 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -5,16 +5,22 @@ import tempfile from unittest import mock -import mysql.connector import pytest import aurweb.config +import aurweb.initdb from aurweb import db from aurweb.models.account_type import AccountType from aurweb.testing import setup_test_db +class Args: + """ Stub arguments used for running aurweb.initdb. """ + use_alembic = True + verbose = True + + class DBCursor: """ A fake database cursor object used in tests. """ items = [] @@ -38,27 +44,73 @@ class DBConnection: pass +def make_temp_config(config_file, *replacements): + """ Generate a temporary config file with a set of replacements. + + :param *replacements: A variable number of tuple regex replacement pairs + :return: A tuple containing (temp directory, temp config file) + """ + tmpdir = tempfile.TemporaryDirectory() + tmp = os.path.join(tmpdir.name, "config.tmp") + with open(config_file) as f: + config = f.read() + for repl in list(replacements): + config = re.sub(repl[0], repl[1], config) + with open(tmp, "w") as o: + o.write(config) + aurwebdir = aurweb.config.get("options", "aurwebdir") + defaults = os.path.join(aurwebdir, "conf/config.defaults") + with open(defaults) as i: + with open(f"{tmp}.defaults", "w") as o: + o.write(i.read()) + return tmpdir, tmp + + +def make_temp_sqlite_config(config_file): + return make_temp_config(config_file, + (r"backend = .*", "backend = sqlite"), + (r"name = .*", "name = /tmp/aurweb.sqlite3")) + + +def make_temp_mysql_config(config_file): + return make_temp_config(config_file, + (r"backend = .*", "backend = mysql"), + (r"name = .*", "name = aurweb")) + + @pytest.fixture(autouse=True) def setup_db(): - setup_test_db("Bans") + if os.path.exists("/tmp/aurweb.sqlite3"): + os.remove("/tmp/aurweb.sqlite3") + + # In various places in this test, we reinitialize the engine. + # Make sure we kill the previous engine before initializing + # it via setup_test_db(). + aurweb.db.kill_engine() + setup_test_db() def test_sqlalchemy_sqlite_url(): - with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.dev"}): - aurweb.config.rehash() - assert db.get_sqlalchemy_url() + tmpctx, tmp = make_temp_sqlite_config("conf/config") + with tmpctx: + with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): + aurweb.config.rehash() + assert db.get_sqlalchemy_url() aurweb.config.rehash() def test_sqlalchemy_mysql_url(): - with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.defaults"}): - aurweb.config.rehash() - assert db.get_sqlalchemy_url() + tmpctx, tmp = make_temp_mysql_config("conf/config") + with tmpctx: + with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): + aurweb.config.rehash() + assert db.get_sqlalchemy_url() aurweb.config.rehash() def test_sqlalchemy_mysql_port_url(): - tmpctx, tmp = make_temp_config("conf/config.defaults", ";port = 3306", "port = 3306") + tmpctx, tmp = make_temp_config("conf/config", + (r";port = 3306", "port = 3306")) with tmpctx: with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): @@ -67,18 +119,9 @@ def test_sqlalchemy_mysql_port_url(): aurweb.config.rehash() -def make_temp_config(config_file, src_str, replace_with): - tmpdir = tempfile.TemporaryDirectory() - tmp = os.path.join(tmpdir.name, "config.tmp") - with open(config_file) as f: - config = re.sub(src_str, f'{replace_with}', f.read()) - with open(tmp, "w") as o: - o.write(config) - return tmpdir, tmp - - def test_sqlalchemy_unknown_backend(): - tmpctx, tmp = make_temp_config("conf/config", "backend = sqlite", "backend = blah") + tmpctx, tmp = make_temp_config("conf/config", + (r"backend = mysql", "backend = blah")) with tmpctx: with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): @@ -89,22 +132,31 @@ def test_sqlalchemy_unknown_backend(): def test_db_connects_without_fail(): + """ This only tests the actual config supplied to pytest. """ db.connect() assert db.engine is not None -def test_connection_class_without_fail(): - conn = db.Connection() +def test_connection_class_sqlite_without_fail(): + tmpctx, tmp = make_temp_sqlite_config("conf/config") + with tmpctx: + with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): + aurweb.config.rehash() - cur = conn.execute( - "SELECT AccountType FROM AccountTypes WHERE ID = ?", (1,)) - account_type = cur.fetchone()[0] + aurweb.db.kill_engine() + aurweb.initdb.run(Args()) - assert account_type == "User" + conn = db.Connection() + cur = conn.execute( + "SELECT AccountType FROM AccountTypes WHERE ID = ?", (1,)) + account_type = cur.fetchone()[0] + assert account_type == "User" + aurweb.config.rehash() def test_connection_class_unsupported_backend(): - tmpctx, tmp = make_temp_config("conf/config", "backend = sqlite", "backend = blah") + tmpctx, tmp = make_temp_config("conf/config", + (r"backend = mysql", "backend = blah")) with tmpctx: with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): @@ -114,10 +166,9 @@ def test_connection_class_unsupported_backend(): aurweb.config.rehash() -@mock.patch("mysql.connector.connect", mock.MagicMock(return_value=True)) -@mock.patch.object(mysql.connector, "paramstyle", "qmark") +@mock.patch("MySQLdb.connect", mock.MagicMock(return_value=True)) def test_connection_mysql(): - tmpctx, tmp = make_temp_config("conf/config", "backend = sqlite", "backend = mysql") + tmpctx, tmp = make_temp_mysql_config("conf/config") with tmpctx: with mock.patch.dict(os.environ, { "AUR_CONFIG": tmp, @@ -137,44 +188,78 @@ def test_connection_sqlite(): @mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection())) @mock.patch.object(sqlite3, "paramstyle", "format") def test_connection_execute_paramstyle_format(): - conn = db.Connection() + tmpctx, tmp = make_temp_sqlite_config("conf/config") - # First, test ? to %s format replacement. - account_types = conn\ - .execute("SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"])\ - .fetchall() - assert account_types == \ - ["SELECT * FROM AccountTypes WHERE AccountType = %s", ["User"]] + with tmpctx: + with mock.patch.dict(os.environ, { + "AUR_CONFIG": tmp, + "AUR_CONFIG_DEFAULTS": "conf/config.defaults" + }): + aurweb.config.rehash() - # Test other format replacement. - account_types = conn\ - .execute("SELECT * FROM AccountTypes WHERE AccountType = %", ["User"])\ - .fetchall() - assert account_types == \ - ["SELECT * FROM AccountTypes WHERE AccountType = %%", ["User"]] + aurweb.db.kill_engine() + aurweb.initdb.run(Args()) + + conn = db.Connection() + + # First, test ? to %s format replacement. + account_types = conn\ + .execute("SELECT * FROM AccountTypes WHERE AccountType = ?", + ["User"]).fetchall() + assert account_types == \ + ["SELECT * FROM AccountTypes WHERE AccountType = %s", ["User"]] + + # Test other format replacement. + account_types = conn\ + .execute("SELECT * FROM AccountTypes WHERE AccountType = %", + ["User"]).fetchall() + assert account_types == \ + ["SELECT * FROM AccountTypes WHERE AccountType = %%", ["User"]] + aurweb.config.rehash() @mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection())) @mock.patch.object(sqlite3, "paramstyle", "qmark") def test_connection_execute_paramstyle_qmark(): - conn = db.Connection() - # We don't modify anything when using qmark, so test equality. - account_types = conn\ - .execute("SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"])\ - .fetchall() - assert account_types == \ - ["SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"]] + tmpctx, tmp = make_temp_sqlite_config("conf/config") + + with tmpctx: + with mock.patch.dict(os.environ, { + "AUR_CONFIG": tmp, + "AUR_CONFIG_DEFAULTS": "conf/config.defaults" + }): + aurweb.config.rehash() + + aurweb.db.kill_engine() + aurweb.initdb.run(Args()) + + conn = db.Connection() + # We don't modify anything when using qmark, so test equality. + account_types = conn\ + .execute("SELECT * FROM AccountTypes WHERE AccountType = ?", + ["User"]).fetchall() + assert account_types == \ + ["SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"]] + aurweb.config.rehash() @mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection())) @mock.patch.object(sqlite3, "paramstyle", "unsupported") def test_connection_execute_paramstyle_unsupported(): - conn = db.Connection() - with pytest.raises(ValueError, match="unsupported paramstyle"): - conn.execute( - "SELECT * FROM AccountTypes WHERE AccountType = ?", - ["User"] - ).fetchall() + tmpctx, tmp = make_temp_sqlite_config("conf/config") + with tmpctx: + with mock.patch.dict(os.environ, { + "AUR_CONFIG": tmp, + "AUR_CONFIG_DEFAULTS": "conf/config.defaults" + }): + aurweb.config.rehash() + conn = db.Connection() + with pytest.raises(ValueError, match="unsupported paramstyle"): + conn.execute( + "SELECT * FROM AccountTypes WHERE AccountType = ?", + ["User"] + ).fetchall() + aurweb.config.rehash() def test_create_delete(): @@ -186,13 +271,12 @@ def test_create_delete(): assert record is None -@mock.patch("mysql.connector.paramstyle", "qmark") def test_connection_executor_mysql_paramstyle(): executor = db.ConnectionExecutor(None, backend="mysql") - assert executor.paramstyle() == "qmark" + assert executor.paramstyle() == "format" @mock.patch("sqlite3.paramstyle", "pyformat") def test_connection_executor_sqlite_paramstyle(): executor = db.ConnectionExecutor(None, backend="sqlite") - assert executor.paramstyle() == "pyformat" + assert executor.paramstyle() == sqlite3.paramstyle diff --git a/test/test_initdb.py b/test/test_initdb.py index eae33007..c7d29ee2 100644 --- a/test/test_initdb.py +++ b/test/test_initdb.py @@ -1,27 +1,19 @@ -import pytest - import aurweb.config import aurweb.db import aurweb.initdb from aurweb.models.account_type import AccountType -from aurweb.schema import metadata -from aurweb.testing import setup_test_db -@pytest.fixture(autouse=True) -def setup(): - setup_test_db() - - tables = metadata.tables.keys() - for table in tables: - aurweb.db.session.execute(f"DROP TABLE IF EXISTS {table}") +class Args: + use_alembic = True + verbose = True def test_run(): - class Args: - use_alembic = True - verbose = False + from aurweb.schema import metadata + aurweb.db.kill_engine() + metadata.drop_all(aurweb.db.get_engine()) aurweb.initdb.run(Args()) record = aurweb.db.query(AccountType, AccountType.AccountType == "User").first() diff --git a/test/test_package_relation.py b/test/test_package_relation.py index dd0455cd..96932f40 100644 --- a/test/test_package_relation.py +++ b/test/test_package_relation.py @@ -1,6 +1,6 @@ import pytest -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, OperationalError from aurweb.db import create, query from aurweb.models.account_type import AccountType @@ -36,7 +36,7 @@ def setup(): URL="https://test.package") -def test_package_dependencies(): +def test_package_relation(): conflicts = query(RelationType, RelationType.Name == "conflicts").first() pkgrel = create(PackageRelation, Package=package, RelationType=conflicts, @@ -68,10 +68,12 @@ def test_package_dependencies(): assert pkgrel in package.package_relations -def test_package_dependencies_null_package_raises_exception(): +def test_package_relation_null_package_raises_exception(): from aurweb.db import session conflicts = query(RelationType, RelationType.Name == "conflicts").first() + assert conflicts is not None + with pytest.raises(IntegrityError): create(PackageRelation, RelationType=conflicts, @@ -79,7 +81,7 @@ def test_package_dependencies_null_package_raises_exception(): session.rollback() -def test_package_dependencies_null_dependency_type_raises_exception(): +def test_package_relation_null_relation_type_raises_exception(): from aurweb.db import session with pytest.raises(IntegrityError): @@ -89,11 +91,13 @@ def test_package_dependencies_null_dependency_type_raises_exception(): session.rollback() -def test_package_dependencies_null_depname_raises_exception(): +def test_package_relation_null_relname_raises_exception(): from aurweb.db import session - depends = query(RelationType, RelationType.Name == "depends").first() - with pytest.raises(IntegrityError): + depends = query(RelationType, RelationType.Name == "conflicts").first() + assert depends is not None + + with pytest.raises((OperationalError, IntegrityError)): create(PackageRelation, Package=package, RelationType=depends) diff --git a/test/test_session.py b/test/test_session.py index 2877ea7f..c324a739 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -25,7 +25,7 @@ def setup(): ResetKey="testReset", Passwd="testPassword", AccountType=account_type) session = create(Session, UsersID=user.ID, SessionID="testSession", - LastUpdateTS=datetime.utcnow()) + LastUpdateTS=datetime.utcnow().timestamp()) def test_session():