diff --git a/aurweb/asgi.py b/aurweb/asgi.py index 00d7c595..b6e15582 100644 --- a/aurweb/asgi.py +++ b/aurweb/asgi.py @@ -7,30 +7,36 @@ from starlette.middleware.sessions import SessionMiddleware import aurweb.config +from aurweb.db import get_engine from aurweb.routers import html, sso routes = set() # Setup the FastAPI app. app = FastAPI() -app.mount("/static/css", - StaticFiles(directory="web/html/css"), - name="static_css") -app.mount("/static/js", - StaticFiles(directory="web/html/js"), - name="static_js") -app.mount("/static/images", - StaticFiles(directory="web/html/images"), - name="static_images") -session_secret = aurweb.config.get("fastapi", "session_secret") -if not session_secret: - raise Exception("[fastapi] session_secret must not be empty") -app.add_middleware(SessionMiddleware, secret_key=session_secret) +@app.on_event("startup") +async def app_startup(): + session_secret = aurweb.config.get("fastapi", "session_secret") + if not session_secret: + raise Exception("[fastapi] session_secret must not be empty") -app.include_router(sso.router) -app.include_router(html.router) + app.mount("/static/css", + StaticFiles(directory="web/html/css"), + name="static_css") + app.mount("/static/js", + StaticFiles(directory="web/html/js"), + name="static_js") + app.mount("/static/images", + StaticFiles(directory="web/html/images"), + name="static_images") + + app.add_middleware(SessionMiddleware, secret_key=session_secret) + app.include_router(sso.router) + app.include_router(html.router) + + get_engine() # NOTE: Always keep this dictionary updated with all routes # that the application contains. We use this to check for diff --git a/aurweb/config.py b/aurweb/config.py index 020c3b80..49a2765a 100644 --- a/aurweb/config.py +++ b/aurweb/config.py @@ -25,6 +25,13 @@ def _get_parser(): return _parser +def rehash(): + """ Globally rehash the configuration parser. """ + global _parser + _parser = None + _get_parser() + + def get(section, option): return _get_parser().get(section, option) diff --git a/aurweb/db.py b/aurweb/db.py index 04b40f43..7993dfdb 100644 --- a/aurweb/db.py +++ b/aurweb/db.py @@ -1,19 +1,15 @@ import math -try: - import mysql.connector -except ImportError: - pass - -try: - import sqlite3 -except ImportError: - pass - import aurweb.config engine = None # See get_engine +# ORM Session class. +Session = None + +# Global ORM Session object. +session = None + def get_sqlalchemy_url(): """ @@ -49,14 +45,15 @@ def get_engine(): `engine` global variable for the next calls. """ from sqlalchemy import create_engine - global engine + from sqlalchemy.orm import sessionmaker + + global engine, session, Session + if engine is None: - connect_args = dict() - if aurweb.config.get("database", "backend") == "sqlite": - # check_same_thread is for a SQLite technicality - # https://fastapi.tiangolo.com/tutorial/sql-databases/#note - connect_args["check_same_thread"] = False - engine = create_engine(get_sqlalchemy_url(), connect_args=connect_args) + engine = create_engine(get_sqlalchemy_url(), + # check_same_thread is for a SQLite technicality + # https://fastapi.tiangolo.com/tutorial/sql-databases/#note + connect_args={"check_same_thread": False}) Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) session = Session() @@ -82,6 +79,7 @@ class Connection: aur_db_backend = aurweb.config.get('database', 'backend') if aur_db_backend == 'mysql': + import mysql.connector aur_db_host = aurweb.config.get('database', 'host') aur_db_name = aurweb.config.get('database', 'name') aur_db_user = aurweb.config.get('database', 'user') @@ -95,6 +93,7 @@ class Connection: buffered=True) self._paramstyle = mysql.connector.paramstyle elif aur_db_backend == 'sqlite': + import sqlite3 aur_db_name = aurweb.config.get('database', 'name') self._conn = sqlite3.connect(aur_db_name) self._conn.create_function("POWER", 2, math.pow) diff --git a/test/test_asgi.py b/test/test_asgi.py new file mode 100644 index 00000000..79b34daf --- /dev/null +++ b/test/test_asgi.py @@ -0,0 +1,29 @@ +import http +import os + +from unittest import mock + +import pytest + +from fastapi import HTTPException + +import aurweb.asgi +import aurweb.config + + +@pytest.mark.asyncio +async def test_asgi_startup_exception(monkeypatch): + with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.defaults"}): + aurweb.config.rehash() + with pytest.raises(Exception): + await aurweb.asgi.app_startup() + aurweb.config.rehash() + + +@pytest.mark.asyncio +async def test_asgi_http_exception_handler(): + exc = HTTPException(status_code=422, detail="EXCEPTION!") + phrase = http.HTTPStatus(exc.status_code).phrase + response = await aurweb.asgi.http_exception_handler(None, exc) + assert response.body.decode() == \ + f"
{exc.detail}
" diff --git a/test/test_config.py b/test/test_config.py new file mode 100644 index 00000000..4f10b60d --- /dev/null +++ b/test/test_config.py @@ -0,0 +1,13 @@ +from aurweb import config + + +def test_get(): + assert config.get("options", "disable_http_login") == "0" + + +def test_getboolean(): + assert not config.getboolean("options", "disable_http_login") + + +def test_getint(): + assert config.getint("options", "disable_http_login") == 0 diff --git a/test/test_db.py b/test/test_db.py new file mode 100644 index 00000000..0a134541 --- /dev/null +++ b/test/test_db.py @@ -0,0 +1,174 @@ +import os +import re +import sqlite3 +import tempfile + +from unittest import mock + +import mysql.connector +import pytest + +import aurweb.config + +from aurweb import db +from aurweb.testing import setup_test_db + + +class DBCursor: + """ A fake database cursor object used in tests. """ + items = [] + + def execute(self, *args, **kwargs): + self.items = list(args) + return self + + def fetchall(self): + return self.items + + +class DBConnection: + """ A fake database connection object used in tests. """ + @staticmethod + def cursor(): + return DBCursor() + + @staticmethod + def create_function(name, num_args, func): + pass + + +@pytest.fixture(autouse=True) +def setup_db(): + setup_test_db() + + +def test_sqlalchemy_sqlite_url(): + with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.dev"}): + aurweb.config.rehash() + assert db.get_sqlalchemy_url() + aurweb.config.rehash() + + +def test_sqlalchemy_mysql_url(): + with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.defaults"}): + aurweb.config.rehash() + assert db.get_sqlalchemy_url() + aurweb.config.rehash() + + +def make_temp_config(backend): + if not os.path.isdir("/tmp"): + os.mkdir("/tmp") + tmpdir = tempfile.mkdtemp() + tmp = os.path.join(tmpdir, "config.tmp") + with open("conf/config") as f: + config = re.sub(r'backend = sqlite', f'backend = {backend}', f.read()) + with open(tmp, "w") as o: + o.write(config) + return (tmpdir, tmp) + + +def test_sqlalchemy_unknown_backend(): + tmpdir, tmp = make_temp_config("blah") + + with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): + aurweb.config.rehash() + with pytest.raises(ValueError): + db.get_sqlalchemy_url() + aurweb.config.rehash() + + os.remove(tmp) + os.removedirs(tmpdir) + + +def test_db_connects_without_fail(): + db.connect() + assert db.engine is not None + + +def test_connection_class_without_fail(): + conn = db.Connection() + + cur = conn.execute( + "SELECT AccountType FROM AccountTypes WHERE ID = ?", (1,)) + account_type = cur.fetchone()[0] + + assert account_type == "User" + + +def test_connection_class_unsupported_backend(): + tmpdir, tmp = make_temp_config("blah") + + with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): + aurweb.config.rehash() + with pytest.raises(ValueError): + db.Connection() + aurweb.config.rehash() + + os.remove(tmp) + os.removedirs(tmpdir) + + +@mock.patch("mysql.connector.connect", mock.MagicMock(return_value=True)) +@mock.patch.object(mysql.connector, "paramstyle", "qmark") +def test_connection_mysql(): + tmpdir, tmp = make_temp_config("mysql") + with mock.patch.dict(os.environ, { + "AUR_CONFIG": tmp, + "AUR_CONFIG_DEFAULTS": "conf/config.defaults" + }): + aurweb.config.rehash() + db.Connection() + aurweb.config.rehash() + + os.remove(tmp) + os.removedirs(tmpdir) + + +@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection())) +@mock.patch.object(sqlite3, "paramstyle", "qmark") +def test_connection_sqlite(): + db.Connection() + + +@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection())) +@mock.patch.object(sqlite3, "paramstyle", "format") +def test_connection_execute_paramstyle_format(): + conn = db.Connection() + + # First, test ? to %s format replacement. + account_types = conn\ + .execute("SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"])\ + .fetchall() + assert account_types == \ + ["SELECT * FROM AccountTypes WHERE AccountType = %s", ["User"]] + + # Test other format replacement. + account_types = conn\ + .execute("SELECT * FROM AccountTypes WHERE AccountType = %", ["User"])\ + .fetchall() + assert account_types == \ + ["SELECT * FROM AccountTypes WHERE AccountType = %%", ["User"]] + + +@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection())) +@mock.patch.object(sqlite3, "paramstyle", "qmark") +def test_connection_execute_paramstyle_qmark(): + conn = db.Connection() + # We don't modify anything when using qmark, so test equality. + account_types = conn\ + .execute("SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"])\ + .fetchall() + assert account_types == \ + ["SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"]] + + +@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection())) +@mock.patch.object(sqlite3, "paramstyle", "unsupported") +def test_connection_execute_paramstyle_unsupported(): + conn = db.Connection() + with pytest.raises(ValueError, match="unsupported paramstyle"): + conn.execute( + "SELECT * FROM AccountTypes WHERE AccountType = ?", + ["User"] + ).fetchall()