add aurweb.db.session

+ Added Session class and global session object to aurweb.db,
  these are sessions created by sqlalchemy ORM's sessionmaker
  and will allow us to use declarative/imperative models.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2021-03-29 15:20:23 -07:00
parent 7c65604dad
commit 4238a9fc68
6 changed files with 260 additions and 32 deletions

View file

@ -7,30 +7,36 @@ from starlette.middleware.sessions import SessionMiddleware
import aurweb.config import aurweb.config
from aurweb.db import get_engine
from aurweb.routers import html, sso from aurweb.routers import html, sso
routes = set() routes = set()
# Setup the FastAPI app. # Setup the FastAPI app.
app = FastAPI() app = FastAPI()
app.mount("/static/css",
@app.on_event("startup")
async def app_startup():
session_secret = aurweb.config.get("fastapi", "session_secret")
if not session_secret:
raise Exception("[fastapi] session_secret must not be empty")
app.mount("/static/css",
StaticFiles(directory="web/html/css"), StaticFiles(directory="web/html/css"),
name="static_css") name="static_css")
app.mount("/static/js", app.mount("/static/js",
StaticFiles(directory="web/html/js"), StaticFiles(directory="web/html/js"),
name="static_js") name="static_js")
app.mount("/static/images", app.mount("/static/images",
StaticFiles(directory="web/html/images"), StaticFiles(directory="web/html/images"),
name="static_images") name="static_images")
session_secret = aurweb.config.get("fastapi", "session_secret") app.add_middleware(SessionMiddleware, secret_key=session_secret)
if not session_secret: app.include_router(sso.router)
raise Exception("[fastapi] session_secret must not be empty") app.include_router(html.router)
app.add_middleware(SessionMiddleware, secret_key=session_secret) get_engine()
app.include_router(sso.router)
app.include_router(html.router)
# NOTE: Always keep this dictionary updated with all routes # NOTE: Always keep this dictionary updated with all routes
# that the application contains. We use this to check for # that the application contains. We use this to check for

View file

@ -25,6 +25,13 @@ def _get_parser():
return _parser return _parser
def rehash():
""" Globally rehash the configuration parser. """
global _parser
_parser = None
_get_parser()
def get(section, option): def get(section, option):
return _get_parser().get(section, option) return _get_parser().get(section, option)

View file

@ -1,19 +1,15 @@
import math import math
try:
import mysql.connector
except ImportError:
pass
try:
import sqlite3
except ImportError:
pass
import aurweb.config import aurweb.config
engine = None # See get_engine engine = None # See get_engine
# ORM Session class.
Session = None
# Global ORM Session object.
session = None
def get_sqlalchemy_url(): def get_sqlalchemy_url():
""" """
@ -49,14 +45,15 @@ def get_engine():
`engine` global variable for the next calls. `engine` global variable for the next calls.
""" """
from sqlalchemy import create_engine from sqlalchemy import create_engine
global engine from sqlalchemy.orm import sessionmaker
global engine, session, Session
if engine is None: if engine is None:
connect_args = dict() engine = create_engine(get_sqlalchemy_url(),
if aurweb.config.get("database", "backend") == "sqlite":
# check_same_thread is for a SQLite technicality # check_same_thread is for a SQLite technicality
# https://fastapi.tiangolo.com/tutorial/sql-databases/#note # https://fastapi.tiangolo.com/tutorial/sql-databases/#note
connect_args["check_same_thread"] = False connect_args={"check_same_thread": False})
engine = create_engine(get_sqlalchemy_url(), connect_args=connect_args)
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session() session = Session()
@ -82,6 +79,7 @@ class Connection:
aur_db_backend = aurweb.config.get('database', 'backend') aur_db_backend = aurweb.config.get('database', 'backend')
if aur_db_backend == 'mysql': if aur_db_backend == 'mysql':
import mysql.connector
aur_db_host = aurweb.config.get('database', 'host') aur_db_host = aurweb.config.get('database', 'host')
aur_db_name = aurweb.config.get('database', 'name') aur_db_name = aurweb.config.get('database', 'name')
aur_db_user = aurweb.config.get('database', 'user') aur_db_user = aurweb.config.get('database', 'user')
@ -95,6 +93,7 @@ class Connection:
buffered=True) buffered=True)
self._paramstyle = mysql.connector.paramstyle self._paramstyle = mysql.connector.paramstyle
elif aur_db_backend == 'sqlite': elif aur_db_backend == 'sqlite':
import sqlite3
aur_db_name = aurweb.config.get('database', 'name') aur_db_name = aurweb.config.get('database', 'name')
self._conn = sqlite3.connect(aur_db_name) self._conn = sqlite3.connect(aur_db_name)
self._conn.create_function("POWER", 2, math.pow) self._conn.create_function("POWER", 2, math.pow)

29
test/test_asgi.py Normal file
View file

@ -0,0 +1,29 @@
import http
import os
from unittest import mock
import pytest
from fastapi import HTTPException
import aurweb.asgi
import aurweb.config
@pytest.mark.asyncio
async def test_asgi_startup_exception(monkeypatch):
with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.defaults"}):
aurweb.config.rehash()
with pytest.raises(Exception):
await aurweb.asgi.app_startup()
aurweb.config.rehash()
@pytest.mark.asyncio
async def test_asgi_http_exception_handler():
exc = HTTPException(status_code=422, detail="EXCEPTION!")
phrase = http.HTTPStatus(exc.status_code).phrase
response = await aurweb.asgi.http_exception_handler(None, exc)
assert response.body.decode() == \
f"<h1>{exc.status_code} {phrase}</h1><p>{exc.detail}</p>"

13
test/test_config.py Normal file
View file

@ -0,0 +1,13 @@
from aurweb import config
def test_get():
assert config.get("options", "disable_http_login") == "0"
def test_getboolean():
assert not config.getboolean("options", "disable_http_login")
def test_getint():
assert config.getint("options", "disable_http_login") == 0

174
test/test_db.py Normal file
View file

@ -0,0 +1,174 @@
import os
import re
import sqlite3
import tempfile
from unittest import mock
import mysql.connector
import pytest
import aurweb.config
from aurweb import db
from aurweb.testing import setup_test_db
class DBCursor:
""" A fake database cursor object used in tests. """
items = []
def execute(self, *args, **kwargs):
self.items = list(args)
return self
def fetchall(self):
return self.items
class DBConnection:
""" A fake database connection object used in tests. """
@staticmethod
def cursor():
return DBCursor()
@staticmethod
def create_function(name, num_args, func):
pass
@pytest.fixture(autouse=True)
def setup_db():
setup_test_db()
def test_sqlalchemy_sqlite_url():
with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.dev"}):
aurweb.config.rehash()
assert db.get_sqlalchemy_url()
aurweb.config.rehash()
def test_sqlalchemy_mysql_url():
with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.defaults"}):
aurweb.config.rehash()
assert db.get_sqlalchemy_url()
aurweb.config.rehash()
def make_temp_config(backend):
if not os.path.isdir("/tmp"):
os.mkdir("/tmp")
tmpdir = tempfile.mkdtemp()
tmp = os.path.join(tmpdir, "config.tmp")
with open("conf/config") as f:
config = re.sub(r'backend = sqlite', f'backend = {backend}', f.read())
with open(tmp, "w") as o:
o.write(config)
return (tmpdir, tmp)
def test_sqlalchemy_unknown_backend():
tmpdir, tmp = make_temp_config("blah")
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
aurweb.config.rehash()
with pytest.raises(ValueError):
db.get_sqlalchemy_url()
aurweb.config.rehash()
os.remove(tmp)
os.removedirs(tmpdir)
def test_db_connects_without_fail():
db.connect()
assert db.engine is not None
def test_connection_class_without_fail():
conn = db.Connection()
cur = conn.execute(
"SELECT AccountType FROM AccountTypes WHERE ID = ?", (1,))
account_type = cur.fetchone()[0]
assert account_type == "User"
def test_connection_class_unsupported_backend():
tmpdir, tmp = make_temp_config("blah")
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
aurweb.config.rehash()
with pytest.raises(ValueError):
db.Connection()
aurweb.config.rehash()
os.remove(tmp)
os.removedirs(tmpdir)
@mock.patch("mysql.connector.connect", mock.MagicMock(return_value=True))
@mock.patch.object(mysql.connector, "paramstyle", "qmark")
def test_connection_mysql():
tmpdir, tmp = make_temp_config("mysql")
with mock.patch.dict(os.environ, {
"AUR_CONFIG": tmp,
"AUR_CONFIG_DEFAULTS": "conf/config.defaults"
}):
aurweb.config.rehash()
db.Connection()
aurweb.config.rehash()
os.remove(tmp)
os.removedirs(tmpdir)
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
@mock.patch.object(sqlite3, "paramstyle", "qmark")
def test_connection_sqlite():
db.Connection()
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
@mock.patch.object(sqlite3, "paramstyle", "format")
def test_connection_execute_paramstyle_format():
conn = db.Connection()
# First, test ? to %s format replacement.
account_types = conn\
.execute("SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"])\
.fetchall()
assert account_types == \
["SELECT * FROM AccountTypes WHERE AccountType = %s", ["User"]]
# Test other format replacement.
account_types = conn\
.execute("SELECT * FROM AccountTypes WHERE AccountType = %", ["User"])\
.fetchall()
assert account_types == \
["SELECT * FROM AccountTypes WHERE AccountType = %%", ["User"]]
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
@mock.patch.object(sqlite3, "paramstyle", "qmark")
def test_connection_execute_paramstyle_qmark():
conn = db.Connection()
# We don't modify anything when using qmark, so test equality.
account_types = conn\
.execute("SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"])\
.fetchall()
assert account_types == \
["SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"]]
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
@mock.patch.object(sqlite3, "paramstyle", "unsupported")
def test_connection_execute_paramstyle_unsupported():
conn = db.Connection()
with pytest.raises(ValueError, match="unsupported paramstyle"):
conn.execute(
"SELECT * FROM AccountTypes WHERE AccountType = ?",
["User"]
).fetchall()