From 3558c6ae5ce64c4f84e6d6713a12e0cb17006d12 Mon Sep 17 00:00:00 2001 From: moson Date: Thu, 30 Nov 2023 14:44:00 +0100 Subject: [PATCH] fix: sqlalchemy sessions per request Best practice for web-apps is to have a session per web request. Instead of having a per worker-thread, we add a middleware that generates a unique ID per request, utilizing scoped_sessions scopefunc (custom function for defining a session scope) in combination with a ContextVar. With this we create a new session per request. Signed-off-by: moson --- aurweb/asgi.py | 19 +++++++++- aurweb/db.py | 48 ++++++++++++++++++++------ test/test_accounts_routes.py | 8 +++++ test/test_package_maintainer_routes.py | 1 + test/test_packages_routes.py | 6 +++- test/test_pkgbase_routes.py | 26 ++++++++++---- test/test_requests.py | 4 +++ test/test_routes.py | 1 + 8 files changed, 94 insertions(+), 19 deletions(-) diff --git a/aurweb/asgi.py b/aurweb/asgi.py index 9b6ffcb3..55a29e2d 100644 --- a/aurweb/asgi.py +++ b/aurweb/asgi.py @@ -25,7 +25,7 @@ import aurweb.pkgbase.util as pkgbaseutil from aurweb import aur_logging, prometheus, util from aurweb.aur_redis import redis_connection from aurweb.auth import BasicAuthBackend -from aurweb.db import get_engine, query +from aurweb.db import get_engine, query, set_db_session_context from aurweb.models import AcceptedTerm, Term from aurweb.packages.util import get_pkg_or_base from aurweb.prometheus import instrumentator @@ -308,3 +308,20 @@ async def id_redirect_middleware(request: Request, call_next: typing.Callable): # Add application middlewares. app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend()) app.add_middleware(SessionMiddleware, secret_key=session_secret) + + +# Set context var for database session & remove it after our request +@app.middleware("http") +async def db_session_context(request: Request, call_next: typing.Callable): + # static content won't require a db session + if request.url.path.startswith("/static"): + return await util.error_or_result(call_next, request) + + try: + set_db_session_context(hash(request)) + response = await util.error_or_result(call_next, request) + + finally: + set_db_session_context(None) + + return response diff --git a/aurweb/db.py b/aurweb/db.py index 6b704f9f..7d70c13c 100644 --- a/aurweb/db.py +++ b/aurweb/db.py @@ -1,3 +1,7 @@ +from contextvars import ContextVar +from threading import get_ident +from typing import Optional + from sqlalchemy.orm import Session # Supported database drivers. @@ -15,6 +19,23 @@ class Committer: self.session.commit() +db_session_context: ContextVar[Optional[int]] = ContextVar( + "session_id", default=get_ident() +) + + +def get_db_session_context(): + id = db_session_context.get() + return id + + +def set_db_session_context(session_id: int): + if session_id is None: + get_session().remove() + + db_session_context.set(session_id) + + def make_random_value(table: str, column: str, length: int): """Generate a unique, random value for a string column in a table. @@ -74,36 +95,39 @@ def name() -> str: return "db" + sha1 -# Module-private global memo used to store SQLAlchemy sessions. -_sessions = dict() +# Module-private global memo used to store SQLAlchemy sessions registries. +_session_registries = dict() def get_session(engine=None) -> Session: """Return aurweb.db's global session.""" dbname = name() - global _sessions - if dbname not in _sessions: + global _session_registries + if dbname not in _session_registries: from sqlalchemy.orm import scoped_session, sessionmaker if not engine: # pragma: no cover engine = get_engine() - Session = scoped_session(sessionmaker(autoflush=False, bind=engine)) - _sessions[dbname] = Session() + Session = scoped_session( + sessionmaker(autoflush=False, bind=engine), + scopefunc=get_db_session_context, + ) + _session_registries[dbname] = Session - return _sessions.get(dbname) + return _session_registries.get(dbname) def pop_session(dbname: str) -> None: """ - Pop a Session out of the private _sessions memo. + Pop a Session registry out of the private _session_registries memo. :param dbname: Database name :raises KeyError: When `dbname` does not exist in the memo """ - global _sessions - _sessions.pop(dbname) + global _session_registries + _session_registries.pop(dbname) def refresh(model): @@ -302,12 +326,14 @@ def get_engine(dbname: str = None, echo: bool = False): if dbname not in _engines: db_backend = aurweb.config.get("database", "backend") connect_args = dict() + kwargs = {"echo": echo, "connect_args": connect_args} is_sqlite = bool(db_backend == "sqlite") if is_sqlite: # pragma: no cover connect_args["check_same_thread"] = False + else: + kwargs["isolation_level"] = "READ_COMMITTED" - kwargs = {"echo": echo, "connect_args": connect_args} from sqlalchemy import create_engine _engines[dbname] = create_engine(get_sqlalchemy_url(), **kwargs) diff --git a/test/test_accounts_routes.py b/test/test_accounts_routes.py index a9cb6f7d..7b5d4cfd 100644 --- a/test/test_accounts_routes.py +++ b/test/test_accounts_routes.py @@ -830,6 +830,7 @@ def test_post_account_edit_type_as_dev(client: TestClient, pm_user: User): request.cookies = cookies resp = request.post(endpoint, data=data) assert resp.status_code == int(HTTPStatus.OK) + db.refresh(user2) assert user2.AccountTypeID == at.DEVELOPER_ID @@ -850,6 +851,7 @@ def test_post_account_edit_invalid_type_as_pm(client: TestClient, pm_user: User) request.cookies = cookies resp = request.post(endpoint, data=data) assert resp.status_code == int(HTTPStatus.BAD_REQUEST) + db.refresh(user2) assert user2.AccountTypeID == at.USER_ID errors = get_errors(resp.text) @@ -1020,6 +1022,7 @@ def test_post_account_edit_inactivity(client: TestClient, user: User): assert resp.status_code == int(HTTPStatus.OK) # Make sure the user record got updated correctly. + db.refresh(user) assert user.InactivityTS > 0 post_data.update({"J": False}) @@ -1028,6 +1031,7 @@ def test_post_account_edit_inactivity(client: TestClient, user: User): resp = request.post(f"/account/{user.Username}/edit", data=post_data) assert resp.status_code == int(HTTPStatus.OK) + db.refresh(user) assert user.InactivityTS == 0 @@ -1050,6 +1054,7 @@ def test_post_account_edit_suspended(client: TestClient, user: User): assert resp.status_code == int(HTTPStatus.OK) # Make sure the user record got updated correctly. + db.refresh(user) assert user.Suspended # Let's make sure the DB got updated properly. assert user.session is None @@ -1207,6 +1212,7 @@ def test_post_account_edit_password(client: TestClient, user: User): assert response.status_code == int(HTTPStatus.OK) + db.refresh(user) assert user.valid_password("newPassword") @@ -1273,6 +1279,7 @@ def test_post_account_edit_self_type_as_pm(client: TestClient, pm_user: User): resp = request.post(endpoint, data=data) assert resp.status_code == int(HTTPStatus.OK) + db.refresh(pm_user) assert pm_user.AccountTypeID == USER_ID @@ -1308,6 +1315,7 @@ def test_post_account_edit_other_user_type_as_pm( assert resp.status_code == int(HTTPStatus.OK) # Let's make sure the DB got updated properly. + db.refresh(user2) assert user2.AccountTypeID == PACKAGE_MAINTAINER_ID # and also that this got logged out at DEBUG level. diff --git a/test/test_package_maintainer_routes.py b/test/test_package_maintainer_routes.py index 6dd1ad88..6761650a 100644 --- a/test/test_package_maintainer_routes.py +++ b/test/test_package_maintainer_routes.py @@ -768,6 +768,7 @@ def test_pm_proposal_vote(client, proposal): assert response.status_code == int(HTTPStatus.OK) # Check that the proposal record got updated. + db.refresh(voteinfo) assert voteinfo.Yes == yes + 1 # Check that the new PMVote exists. diff --git a/test/test_packages_routes.py b/test/test_packages_routes.py index 58b2b1e6..1ed05e4a 100644 --- a/test/test_packages_routes.py +++ b/test/test_packages_routes.py @@ -1531,6 +1531,7 @@ def test_packages_post_disown_as_maintainer( errors = get_errors(resp.text) expected = "You did not select any packages to disown." assert errors[0].text.strip() == expected + db.refresh(package) assert package.PackageBase.Maintainer is not None # Try to disown `package` without giving the confirm argument. @@ -1555,6 +1556,7 @@ def test_packages_post_disown_as_maintainer( data={"action": "disown", "IDs": [package.ID], "confirm": True}, ) assert resp.status_code == int(HTTPStatus.BAD_REQUEST) + db.refresh(package) assert package.PackageBase.Maintainer is not None errors = get_errors(resp.text) expected = "You are not allowed to disown one of the packages you selected." @@ -1568,6 +1570,7 @@ def test_packages_post_disown_as_maintainer( data={"action": "disown", "IDs": [package.ID], "confirm": True}, ) + db.get_session().expire_all() assert package.PackageBase.Maintainer is None successes = get_successes(resp.text) expected = "The selected packages have been disowned." @@ -1652,6 +1655,7 @@ def test_packages_post_delete( # Whoo. Now, let's finally make a valid request as `pm_user` # to delete `package`. + pkgname = package.PackageBase.Name with client as request: request.cookies = pm_cookies resp = request.post( @@ -1664,7 +1668,7 @@ def test_packages_post_delete( assert successes[0].text.strip() == expected # Expect that the package deletion was logged. - pkgbases = [package.PackageBase.Name] + pkgbases = [pkgname] expected = ( f"Privileged user '{pm_user.Username}' deleted the " f"following package bases: {str(pkgbases)}." diff --git a/test/test_pkgbase_routes.py b/test/test_pkgbase_routes.py index b17a371e..8ae91735 100644 --- a/test/test_pkgbase_routes.py +++ b/test/test_pkgbase_routes.py @@ -688,6 +688,7 @@ def test_pkgbase_comment_pin_as_co( assert resp.status_code == int(HTTPStatus.SEE_OTHER) # Assert that PinnedTS got set. + db.refresh(comment) assert comment.PinnedTS > 0 # Unpin the comment we just pinned. @@ -698,6 +699,7 @@ def test_pkgbase_comment_pin_as_co( assert resp.status_code == int(HTTPStatus.SEE_OTHER) # Let's assert that PinnedTS was unset. + db.refresh(comment) assert comment.PinnedTS == 0 @@ -716,6 +718,7 @@ def test_pkgbase_comment_pin( assert resp.status_code == int(HTTPStatus.SEE_OTHER) # Assert that PinnedTS got set. + db.refresh(comment) assert comment.PinnedTS > 0 # Unpin the comment we just pinned. @@ -726,6 +729,7 @@ def test_pkgbase_comment_pin( assert resp.status_code == int(HTTPStatus.SEE_OTHER) # Let's assert that PinnedTS was unset. + db.refresh(comment) assert comment.PinnedTS == 0 @@ -1040,6 +1044,7 @@ def test_pkgbase_flag( request.cookies = cookies resp = request.post(endpoint, data={"comments": "Test"}) assert resp.status_code == int(HTTPStatus.SEE_OTHER) + db.refresh(pkgbase) assert pkgbase.Flagger == user assert pkgbase.FlaggerComment == "Test" @@ -1077,6 +1082,7 @@ def test_pkgbase_flag( request.cookies = user2_cookies resp = request.post(endpoint) assert resp.status_code == int(HTTPStatus.SEE_OTHER) + db.refresh(pkgbase) assert pkgbase.Flagger == user # Now, test that the 'maintainer' user can. @@ -1085,6 +1091,7 @@ def test_pkgbase_flag( request.cookies = maint_cookies resp = request.post(endpoint) assert resp.status_code == int(HTTPStatus.SEE_OTHER) + db.refresh(pkgbase) assert pkgbase.Flagger is None # Flag it again. @@ -1098,6 +1105,7 @@ def test_pkgbase_flag( request.cookies = cookies resp = request.post(endpoint) assert resp.status_code == int(HTTPStatus.SEE_OTHER) + db.refresh(pkgbase) assert pkgbase.Flagger is None @@ -1170,6 +1178,7 @@ def test_pkgbase_vote(client: TestClient, user: User, package: Package): vote = pkgbase.package_votes.filter(PackageVote.UsersID == user.ID).first() assert vote is not None + db.refresh(pkgbase) assert pkgbase.NumVotes == 1 # Remove vote. @@ -1181,6 +1190,7 @@ def test_pkgbase_vote(client: TestClient, user: User, package: Package): vote = pkgbase.package_votes.filter(PackageVote.UsersID == user.ID).first() assert vote is None + db.refresh(pkgbase) assert pkgbase.NumVotes == 0 @@ -1592,9 +1602,9 @@ def test_pkgbase_merge_post( assert resp.status_code == int(HTTPStatus.SEE_OTHER) # Save these relationships for later comparison. - comments = package.PackageBase.comments.all() - notifs = package.PackageBase.notifications.all() - votes = package.PackageBase.package_votes.all() + comments = [row.__dict__ for row in package.PackageBase.comments.all()] + notifs = [row.__dict__ for row in package.PackageBase.notifications.all()] + votes = [row.__dict__ for row in package.PackageBase.package_votes.all()] # Merge the package into target. endpoint = f"/pkgbase/{package.PackageBase.Name}/merge" @@ -1612,9 +1622,13 @@ def test_pkgbase_merge_post( # Assert that the original comments, notifs and votes we setup # got migrated to target as intended. - assert comments == target.comments.all() - assert notifs == target.notifications.all() - assert votes == target.package_votes.all() + db.get_session().refresh(target) + assert len(comments) == target.comments.count() + assert comments[0]["PackageBaseID"] != target.ID + assert len(notifs) == target.notifications.count() + assert notifs[0]["PackageBaseID"] != target.ID + assert len(votes) == target.package_votes.count() + assert votes[0]["PackageBaseID"] != target.ID # ...and that the package got deleted. package = db.query(Package).filter(Package.Name == pkgname).first() diff --git a/test/test_requests.py b/test/test_requests.py index c118ce0b..8e4a12c7 100644 --- a/test/test_requests.py +++ b/test/test_requests.py @@ -649,6 +649,7 @@ def test_orphan_request( assert resp.headers.get("location") == f"/pkgbase/{pkgbase.Name}" # We should have unset the maintainer. + db.refresh(pkgbase) assert pkgbase.Maintainer is None # We should have removed the comaintainers. @@ -748,6 +749,7 @@ def test_orphan_as_maintainer(client: TestClient, auser: User, pkgbase: PackageB # As the pkgbase maintainer, disowning the package just ends up # either promoting the lowest priority comaintainer or removing # the associated maintainer relationship altogether. + db.refresh(pkgbase) assert pkgbase.Maintainer is None @@ -1044,6 +1046,7 @@ def test_requests_close_post(client: TestClient, user: User, pkgreq: PackageRequ resp = request.post(f"/requests/{pkgreq.ID}/close") assert resp.status_code == int(HTTPStatus.SEE_OTHER) + db.refresh(pkgreq) assert pkgreq.Status == REJECTED_ID assert pkgreq.Closer == user assert pkgreq.ClosureComment == str() @@ -1060,6 +1063,7 @@ def test_requests_close_post_rejected( ) assert resp.status_code == int(HTTPStatus.SEE_OTHER) + db.refresh(pkgreq) assert pkgreq.Status == REJECTED_ID assert pkgreq.Closer == user assert pkgreq.ClosureComment == str() diff --git a/test/test_routes.py b/test/test_routes.py index c104211e..aa64ed75 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -102,6 +102,7 @@ def test_user_language(client: TestClient, user: User): req.cookies = {"AURSID": sid} response = req.post("/language", data=post_data) assert response.status_code == int(HTTPStatus.SEE_OTHER) + db.refresh(user) assert user.LangPreference == "de"