mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
add aurweb.db.session
+ Added Session class and global session object to aurweb.db, these are sessions created by sqlalchemy ORM's sessionmaker and will allow us to use declarative/imperative models. Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
parent
7c65604dad
commit
4238a9fc68
6 changed files with 260 additions and 32 deletions
|
@ -7,12 +7,21 @@ from starlette.middleware.sessions import SessionMiddleware
|
|||
|
||||
import aurweb.config
|
||||
|
||||
from aurweb.db import get_engine
|
||||
from aurweb.routers import html, sso
|
||||
|
||||
routes = set()
|
||||
|
||||
# Setup the FastAPI app.
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def app_startup():
|
||||
session_secret = aurweb.config.get("fastapi", "session_secret")
|
||||
if not session_secret:
|
||||
raise Exception("[fastapi] session_secret must not be empty")
|
||||
|
||||
app.mount("/static/css",
|
||||
StaticFiles(directory="web/html/css"),
|
||||
name="static_css")
|
||||
|
@ -23,15 +32,12 @@ app.mount("/static/images",
|
|||
StaticFiles(directory="web/html/images"),
|
||||
name="static_images")
|
||||
|
||||
session_secret = aurweb.config.get("fastapi", "session_secret")
|
||||
if not session_secret:
|
||||
raise Exception("[fastapi] session_secret must not be empty")
|
||||
|
||||
app.add_middleware(SessionMiddleware, secret_key=session_secret)
|
||||
|
||||
app.include_router(sso.router)
|
||||
app.include_router(html.router)
|
||||
|
||||
get_engine()
|
||||
|
||||
# NOTE: Always keep this dictionary updated with all routes
|
||||
# that the application contains. We use this to check for
|
||||
# parameter value verification.
|
||||
|
|
|
@ -25,6 +25,13 @@ def _get_parser():
|
|||
return _parser
|
||||
|
||||
|
||||
def rehash():
|
||||
""" Globally rehash the configuration parser. """
|
||||
global _parser
|
||||
_parser = None
|
||||
_get_parser()
|
||||
|
||||
|
||||
def get(section, option):
|
||||
return _get_parser().get(section, option)
|
||||
|
||||
|
|
29
aurweb/db.py
29
aurweb/db.py
|
@ -1,19 +1,15 @@
|
|||
import math
|
||||
|
||||
try:
|
||||
import mysql.connector
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import sqlite3
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import aurweb.config
|
||||
|
||||
engine = None # See get_engine
|
||||
|
||||
# ORM Session class.
|
||||
Session = None
|
||||
|
||||
# Global ORM Session object.
|
||||
session = None
|
||||
|
||||
|
||||
def get_sqlalchemy_url():
|
||||
"""
|
||||
|
@ -49,14 +45,15 @@ def get_engine():
|
|||
`engine` global variable for the next calls.
|
||||
"""
|
||||
from sqlalchemy import create_engine
|
||||
global engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
global engine, session, Session
|
||||
|
||||
if engine is None:
|
||||
connect_args = dict()
|
||||
if aurweb.config.get("database", "backend") == "sqlite":
|
||||
engine = create_engine(get_sqlalchemy_url(),
|
||||
# check_same_thread is for a SQLite technicality
|
||||
# https://fastapi.tiangolo.com/tutorial/sql-databases/#note
|
||||
connect_args["check_same_thread"] = False
|
||||
engine = create_engine(get_sqlalchemy_url(), connect_args=connect_args)
|
||||
connect_args={"check_same_thread": False})
|
||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
session = Session()
|
||||
|
||||
|
@ -82,6 +79,7 @@ class Connection:
|
|||
aur_db_backend = aurweb.config.get('database', 'backend')
|
||||
|
||||
if aur_db_backend == 'mysql':
|
||||
import mysql.connector
|
||||
aur_db_host = aurweb.config.get('database', 'host')
|
||||
aur_db_name = aurweb.config.get('database', 'name')
|
||||
aur_db_user = aurweb.config.get('database', 'user')
|
||||
|
@ -95,6 +93,7 @@ class Connection:
|
|||
buffered=True)
|
||||
self._paramstyle = mysql.connector.paramstyle
|
||||
elif aur_db_backend == 'sqlite':
|
||||
import sqlite3
|
||||
aur_db_name = aurweb.config.get('database', 'name')
|
||||
self._conn = sqlite3.connect(aur_db_name)
|
||||
self._conn.create_function("POWER", 2, math.pow)
|
||||
|
|
29
test/test_asgi.py
Normal file
29
test/test_asgi.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import http
|
||||
import os
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import aurweb.asgi
|
||||
import aurweb.config
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asgi_startup_exception(monkeypatch):
|
||||
with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.defaults"}):
|
||||
aurweb.config.rehash()
|
||||
with pytest.raises(Exception):
|
||||
await aurweb.asgi.app_startup()
|
||||
aurweb.config.rehash()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asgi_http_exception_handler():
|
||||
exc = HTTPException(status_code=422, detail="EXCEPTION!")
|
||||
phrase = http.HTTPStatus(exc.status_code).phrase
|
||||
response = await aurweb.asgi.http_exception_handler(None, exc)
|
||||
assert response.body.decode() == \
|
||||
f"<h1>{exc.status_code} {phrase}</h1><p>{exc.detail}</p>"
|
13
test/test_config.py
Normal file
13
test/test_config.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from aurweb import config
|
||||
|
||||
|
||||
def test_get():
|
||||
assert config.get("options", "disable_http_login") == "0"
|
||||
|
||||
|
||||
def test_getboolean():
|
||||
assert not config.getboolean("options", "disable_http_login")
|
||||
|
||||
|
||||
def test_getint():
|
||||
assert config.getint("options", "disable_http_login") == 0
|
174
test/test_db.py
Normal file
174
test/test_db.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import tempfile
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import mysql.connector
|
||||
import pytest
|
||||
|
||||
import aurweb.config
|
||||
|
||||
from aurweb import db
|
||||
from aurweb.testing import setup_test_db
|
||||
|
||||
|
||||
class DBCursor:
|
||||
""" A fake database cursor object used in tests. """
|
||||
items = []
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
self.items = list(args)
|
||||
return self
|
||||
|
||||
def fetchall(self):
|
||||
return self.items
|
||||
|
||||
|
||||
class DBConnection:
|
||||
""" A fake database connection object used in tests. """
|
||||
@staticmethod
|
||||
def cursor():
|
||||
return DBCursor()
|
||||
|
||||
@staticmethod
|
||||
def create_function(name, num_args, func):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_db():
|
||||
setup_test_db()
|
||||
|
||||
|
||||
def test_sqlalchemy_sqlite_url():
|
||||
with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.dev"}):
|
||||
aurweb.config.rehash()
|
||||
assert db.get_sqlalchemy_url()
|
||||
aurweb.config.rehash()
|
||||
|
||||
|
||||
def test_sqlalchemy_mysql_url():
|
||||
with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.defaults"}):
|
||||
aurweb.config.rehash()
|
||||
assert db.get_sqlalchemy_url()
|
||||
aurweb.config.rehash()
|
||||
|
||||
|
||||
def make_temp_config(backend):
|
||||
if not os.path.isdir("/tmp"):
|
||||
os.mkdir("/tmp")
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
tmp = os.path.join(tmpdir, "config.tmp")
|
||||
with open("conf/config") as f:
|
||||
config = re.sub(r'backend = sqlite', f'backend = {backend}', f.read())
|
||||
with open(tmp, "w") as o:
|
||||
o.write(config)
|
||||
return (tmpdir, tmp)
|
||||
|
||||
|
||||
def test_sqlalchemy_unknown_backend():
|
||||
tmpdir, tmp = make_temp_config("blah")
|
||||
|
||||
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
|
||||
aurweb.config.rehash()
|
||||
with pytest.raises(ValueError):
|
||||
db.get_sqlalchemy_url()
|
||||
aurweb.config.rehash()
|
||||
|
||||
os.remove(tmp)
|
||||
os.removedirs(tmpdir)
|
||||
|
||||
|
||||
def test_db_connects_without_fail():
|
||||
db.connect()
|
||||
assert db.engine is not None
|
||||
|
||||
|
||||
def test_connection_class_without_fail():
|
||||
conn = db.Connection()
|
||||
|
||||
cur = conn.execute(
|
||||
"SELECT AccountType FROM AccountTypes WHERE ID = ?", (1,))
|
||||
account_type = cur.fetchone()[0]
|
||||
|
||||
assert account_type == "User"
|
||||
|
||||
|
||||
def test_connection_class_unsupported_backend():
|
||||
tmpdir, tmp = make_temp_config("blah")
|
||||
|
||||
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
|
||||
aurweb.config.rehash()
|
||||
with pytest.raises(ValueError):
|
||||
db.Connection()
|
||||
aurweb.config.rehash()
|
||||
|
||||
os.remove(tmp)
|
||||
os.removedirs(tmpdir)
|
||||
|
||||
|
||||
@mock.patch("mysql.connector.connect", mock.MagicMock(return_value=True))
|
||||
@mock.patch.object(mysql.connector, "paramstyle", "qmark")
|
||||
def test_connection_mysql():
|
||||
tmpdir, tmp = make_temp_config("mysql")
|
||||
with mock.patch.dict(os.environ, {
|
||||
"AUR_CONFIG": tmp,
|
||||
"AUR_CONFIG_DEFAULTS": "conf/config.defaults"
|
||||
}):
|
||||
aurweb.config.rehash()
|
||||
db.Connection()
|
||||
aurweb.config.rehash()
|
||||
|
||||
os.remove(tmp)
|
||||
os.removedirs(tmpdir)
|
||||
|
||||
|
||||
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
|
||||
@mock.patch.object(sqlite3, "paramstyle", "qmark")
|
||||
def test_connection_sqlite():
|
||||
db.Connection()
|
||||
|
||||
|
||||
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
|
||||
@mock.patch.object(sqlite3, "paramstyle", "format")
|
||||
def test_connection_execute_paramstyle_format():
|
||||
conn = db.Connection()
|
||||
|
||||
# First, test ? to %s format replacement.
|
||||
account_types = conn\
|
||||
.execute("SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"])\
|
||||
.fetchall()
|
||||
assert account_types == \
|
||||
["SELECT * FROM AccountTypes WHERE AccountType = %s", ["User"]]
|
||||
|
||||
# Test other format replacement.
|
||||
account_types = conn\
|
||||
.execute("SELECT * FROM AccountTypes WHERE AccountType = %", ["User"])\
|
||||
.fetchall()
|
||||
assert account_types == \
|
||||
["SELECT * FROM AccountTypes WHERE AccountType = %%", ["User"]]
|
||||
|
||||
|
||||
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
|
||||
@mock.patch.object(sqlite3, "paramstyle", "qmark")
|
||||
def test_connection_execute_paramstyle_qmark():
|
||||
conn = db.Connection()
|
||||
# We don't modify anything when using qmark, so test equality.
|
||||
account_types = conn\
|
||||
.execute("SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"])\
|
||||
.fetchall()
|
||||
assert account_types == \
|
||||
["SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"]]
|
||||
|
||||
|
||||
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
|
||||
@mock.patch.object(sqlite3, "paramstyle", "unsupported")
|
||||
def test_connection_execute_paramstyle_unsupported():
|
||||
conn = db.Connection()
|
||||
with pytest.raises(ValueError, match="unsupported paramstyle"):
|
||||
conn.execute(
|
||||
"SELECT * FROM AccountTypes WHERE AccountType = ?",
|
||||
["User"]
|
||||
).fetchall()
|
Loading…
Add table
Reference in a new issue