mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
add aurweb.db.session
+ Added Session class and global session object to aurweb.db, these are sessions created by sqlalchemy ORM's sessionmaker and will allow us to use declarative/imperative models. Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
parent
7c65604dad
commit
4238a9fc68
6 changed files with 260 additions and 32 deletions
|
@ -7,30 +7,36 @@ from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
|
||||||
import aurweb.config
|
import aurweb.config
|
||||||
|
|
||||||
|
from aurweb.db import get_engine
|
||||||
from aurweb.routers import html, sso
|
from aurweb.routers import html, sso
|
||||||
|
|
||||||
routes = set()
|
routes = set()
|
||||||
|
|
||||||
# Setup the FastAPI app.
|
# Setup the FastAPI app.
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.mount("/static/css",
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def app_startup():
|
||||||
|
session_secret = aurweb.config.get("fastapi", "session_secret")
|
||||||
|
if not session_secret:
|
||||||
|
raise Exception("[fastapi] session_secret must not be empty")
|
||||||
|
|
||||||
|
app.mount("/static/css",
|
||||||
StaticFiles(directory="web/html/css"),
|
StaticFiles(directory="web/html/css"),
|
||||||
name="static_css")
|
name="static_css")
|
||||||
app.mount("/static/js",
|
app.mount("/static/js",
|
||||||
StaticFiles(directory="web/html/js"),
|
StaticFiles(directory="web/html/js"),
|
||||||
name="static_js")
|
name="static_js")
|
||||||
app.mount("/static/images",
|
app.mount("/static/images",
|
||||||
StaticFiles(directory="web/html/images"),
|
StaticFiles(directory="web/html/images"),
|
||||||
name="static_images")
|
name="static_images")
|
||||||
|
|
||||||
session_secret = aurweb.config.get("fastapi", "session_secret")
|
app.add_middleware(SessionMiddleware, secret_key=session_secret)
|
||||||
if not session_secret:
|
app.include_router(sso.router)
|
||||||
raise Exception("[fastapi] session_secret must not be empty")
|
app.include_router(html.router)
|
||||||
|
|
||||||
app.add_middleware(SessionMiddleware, secret_key=session_secret)
|
get_engine()
|
||||||
|
|
||||||
app.include_router(sso.router)
|
|
||||||
app.include_router(html.router)
|
|
||||||
|
|
||||||
# NOTE: Always keep this dictionary updated with all routes
|
# NOTE: Always keep this dictionary updated with all routes
|
||||||
# that the application contains. We use this to check for
|
# that the application contains. We use this to check for
|
||||||
|
|
|
@ -25,6 +25,13 @@ def _get_parser():
|
||||||
return _parser
|
return _parser
|
||||||
|
|
||||||
|
|
||||||
|
def rehash():
|
||||||
|
""" Globally rehash the configuration parser. """
|
||||||
|
global _parser
|
||||||
|
_parser = None
|
||||||
|
_get_parser()
|
||||||
|
|
||||||
|
|
||||||
def get(section, option):
|
def get(section, option):
|
||||||
return _get_parser().get(section, option)
|
return _get_parser().get(section, option)
|
||||||
|
|
||||||
|
|
29
aurweb/db.py
29
aurweb/db.py
|
@ -1,19 +1,15 @@
|
||||||
import math
|
import math
|
||||||
|
|
||||||
try:
|
|
||||||
import mysql.connector
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
import sqlite3
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
import aurweb.config
|
import aurweb.config
|
||||||
|
|
||||||
engine = None # See get_engine
|
engine = None # See get_engine
|
||||||
|
|
||||||
|
# ORM Session class.
|
||||||
|
Session = None
|
||||||
|
|
||||||
|
# Global ORM Session object.
|
||||||
|
session = None
|
||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_url():
|
def get_sqlalchemy_url():
|
||||||
"""
|
"""
|
||||||
|
@ -49,14 +45,15 @@ def get_engine():
|
||||||
`engine` global variable for the next calls.
|
`engine` global variable for the next calls.
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
global engine
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
global engine, session, Session
|
||||||
|
|
||||||
if engine is None:
|
if engine is None:
|
||||||
connect_args = dict()
|
engine = create_engine(get_sqlalchemy_url(),
|
||||||
if aurweb.config.get("database", "backend") == "sqlite":
|
|
||||||
# check_same_thread is for a SQLite technicality
|
# check_same_thread is for a SQLite technicality
|
||||||
# https://fastapi.tiangolo.com/tutorial/sql-databases/#note
|
# https://fastapi.tiangolo.com/tutorial/sql-databases/#note
|
||||||
connect_args["check_same_thread"] = False
|
connect_args={"check_same_thread": False})
|
||||||
engine = create_engine(get_sqlalchemy_url(), connect_args=connect_args)
|
|
||||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
session = Session()
|
session = Session()
|
||||||
|
|
||||||
|
@ -82,6 +79,7 @@ class Connection:
|
||||||
aur_db_backend = aurweb.config.get('database', 'backend')
|
aur_db_backend = aurweb.config.get('database', 'backend')
|
||||||
|
|
||||||
if aur_db_backend == 'mysql':
|
if aur_db_backend == 'mysql':
|
||||||
|
import mysql.connector
|
||||||
aur_db_host = aurweb.config.get('database', 'host')
|
aur_db_host = aurweb.config.get('database', 'host')
|
||||||
aur_db_name = aurweb.config.get('database', 'name')
|
aur_db_name = aurweb.config.get('database', 'name')
|
||||||
aur_db_user = aurweb.config.get('database', 'user')
|
aur_db_user = aurweb.config.get('database', 'user')
|
||||||
|
@ -95,6 +93,7 @@ class Connection:
|
||||||
buffered=True)
|
buffered=True)
|
||||||
self._paramstyle = mysql.connector.paramstyle
|
self._paramstyle = mysql.connector.paramstyle
|
||||||
elif aur_db_backend == 'sqlite':
|
elif aur_db_backend == 'sqlite':
|
||||||
|
import sqlite3
|
||||||
aur_db_name = aurweb.config.get('database', 'name')
|
aur_db_name = aurweb.config.get('database', 'name')
|
||||||
self._conn = sqlite3.connect(aur_db_name)
|
self._conn = sqlite3.connect(aur_db_name)
|
||||||
self._conn.create_function("POWER", 2, math.pow)
|
self._conn.create_function("POWER", 2, math.pow)
|
||||||
|
|
29
test/test_asgi.py
Normal file
29
test/test_asgi.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
import http
|
||||||
|
import os
|
||||||
|
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
import aurweb.asgi
|
||||||
|
import aurweb.config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asgi_startup_exception(monkeypatch):
|
||||||
|
with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.defaults"}):
|
||||||
|
aurweb.config.rehash()
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await aurweb.asgi.app_startup()
|
||||||
|
aurweb.config.rehash()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asgi_http_exception_handler():
|
||||||
|
exc = HTTPException(status_code=422, detail="EXCEPTION!")
|
||||||
|
phrase = http.HTTPStatus(exc.status_code).phrase
|
||||||
|
response = await aurweb.asgi.http_exception_handler(None, exc)
|
||||||
|
assert response.body.decode() == \
|
||||||
|
f"<h1>{exc.status_code} {phrase}</h1><p>{exc.detail}</p>"
|
13
test/test_config.py
Normal file
13
test/test_config.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
from aurweb import config
|
||||||
|
|
||||||
|
|
||||||
|
def test_get():
|
||||||
|
assert config.get("options", "disable_http_login") == "0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_getboolean():
|
||||||
|
assert not config.getboolean("options", "disable_http_login")
|
||||||
|
|
||||||
|
|
||||||
|
def test_getint():
|
||||||
|
assert config.getint("options", "disable_http_login") == 0
|
174
test/test_db.py
Normal file
174
test/test_db.py
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import mysql.connector
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import aurweb.config
|
||||||
|
|
||||||
|
from aurweb import db
|
||||||
|
from aurweb.testing import setup_test_db
|
||||||
|
|
||||||
|
|
||||||
|
class DBCursor:
|
||||||
|
""" A fake database cursor object used in tests. """
|
||||||
|
items = []
|
||||||
|
|
||||||
|
def execute(self, *args, **kwargs):
|
||||||
|
self.items = list(args)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def fetchall(self):
|
||||||
|
return self.items
|
||||||
|
|
||||||
|
|
||||||
|
class DBConnection:
|
||||||
|
""" A fake database connection object used in tests. """
|
||||||
|
@staticmethod
|
||||||
|
def cursor():
|
||||||
|
return DBCursor()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_function(name, num_args, func):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_db():
|
||||||
|
setup_test_db()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlalchemy_sqlite_url():
|
||||||
|
with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.dev"}):
|
||||||
|
aurweb.config.rehash()
|
||||||
|
assert db.get_sqlalchemy_url()
|
||||||
|
aurweb.config.rehash()
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlalchemy_mysql_url():
|
||||||
|
with mock.patch.dict(os.environ, {"AUR_CONFIG": "conf/config.defaults"}):
|
||||||
|
aurweb.config.rehash()
|
||||||
|
assert db.get_sqlalchemy_url()
|
||||||
|
aurweb.config.rehash()
|
||||||
|
|
||||||
|
|
||||||
|
def make_temp_config(backend):
|
||||||
|
if not os.path.isdir("/tmp"):
|
||||||
|
os.mkdir("/tmp")
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
tmp = os.path.join(tmpdir, "config.tmp")
|
||||||
|
with open("conf/config") as f:
|
||||||
|
config = re.sub(r'backend = sqlite', f'backend = {backend}', f.read())
|
||||||
|
with open(tmp, "w") as o:
|
||||||
|
o.write(config)
|
||||||
|
return (tmpdir, tmp)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlalchemy_unknown_backend():
|
||||||
|
tmpdir, tmp = make_temp_config("blah")
|
||||||
|
|
||||||
|
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
|
||||||
|
aurweb.config.rehash()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
db.get_sqlalchemy_url()
|
||||||
|
aurweb.config.rehash()
|
||||||
|
|
||||||
|
os.remove(tmp)
|
||||||
|
os.removedirs(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_db_connects_without_fail():
|
||||||
|
db.connect()
|
||||||
|
assert db.engine is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_connection_class_without_fail():
|
||||||
|
conn = db.Connection()
|
||||||
|
|
||||||
|
cur = conn.execute(
|
||||||
|
"SELECT AccountType FROM AccountTypes WHERE ID = ?", (1,))
|
||||||
|
account_type = cur.fetchone()[0]
|
||||||
|
|
||||||
|
assert account_type == "User"
|
||||||
|
|
||||||
|
|
||||||
|
def test_connection_class_unsupported_backend():
|
||||||
|
tmpdir, tmp = make_temp_config("blah")
|
||||||
|
|
||||||
|
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
|
||||||
|
aurweb.config.rehash()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
db.Connection()
|
||||||
|
aurweb.config.rehash()
|
||||||
|
|
||||||
|
os.remove(tmp)
|
||||||
|
os.removedirs(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("mysql.connector.connect", mock.MagicMock(return_value=True))
|
||||||
|
@mock.patch.object(mysql.connector, "paramstyle", "qmark")
|
||||||
|
def test_connection_mysql():
|
||||||
|
tmpdir, tmp = make_temp_config("mysql")
|
||||||
|
with mock.patch.dict(os.environ, {
|
||||||
|
"AUR_CONFIG": tmp,
|
||||||
|
"AUR_CONFIG_DEFAULTS": "conf/config.defaults"
|
||||||
|
}):
|
||||||
|
aurweb.config.rehash()
|
||||||
|
db.Connection()
|
||||||
|
aurweb.config.rehash()
|
||||||
|
|
||||||
|
os.remove(tmp)
|
||||||
|
os.removedirs(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
|
||||||
|
@mock.patch.object(sqlite3, "paramstyle", "qmark")
|
||||||
|
def test_connection_sqlite():
|
||||||
|
db.Connection()
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
|
||||||
|
@mock.patch.object(sqlite3, "paramstyle", "format")
|
||||||
|
def test_connection_execute_paramstyle_format():
|
||||||
|
conn = db.Connection()
|
||||||
|
|
||||||
|
# First, test ? to %s format replacement.
|
||||||
|
account_types = conn\
|
||||||
|
.execute("SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"])\
|
||||||
|
.fetchall()
|
||||||
|
assert account_types == \
|
||||||
|
["SELECT * FROM AccountTypes WHERE AccountType = %s", ["User"]]
|
||||||
|
|
||||||
|
# Test other format replacement.
|
||||||
|
account_types = conn\
|
||||||
|
.execute("SELECT * FROM AccountTypes WHERE AccountType = %", ["User"])\
|
||||||
|
.fetchall()
|
||||||
|
assert account_types == \
|
||||||
|
["SELECT * FROM AccountTypes WHERE AccountType = %%", ["User"]]
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
|
||||||
|
@mock.patch.object(sqlite3, "paramstyle", "qmark")
|
||||||
|
def test_connection_execute_paramstyle_qmark():
|
||||||
|
conn = db.Connection()
|
||||||
|
# We don't modify anything when using qmark, so test equality.
|
||||||
|
account_types = conn\
|
||||||
|
.execute("SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"])\
|
||||||
|
.fetchall()
|
||||||
|
assert account_types == \
|
||||||
|
["SELECT * FROM AccountTypes WHERE AccountType = ?", ["User"]]
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection()))
|
||||||
|
@mock.patch.object(sqlite3, "paramstyle", "unsupported")
|
||||||
|
def test_connection_execute_paramstyle_unsupported():
|
||||||
|
conn = db.Connection()
|
||||||
|
with pytest.raises(ValueError, match="unsupported paramstyle"):
|
||||||
|
conn.execute(
|
||||||
|
"SELECT * FROM AccountTypes WHERE AccountType = ?",
|
||||||
|
["User"]
|
||||||
|
).fetchall()
|
Loading…
Add table
Reference in a new issue