fix: sqlalchemy sessions per request

Best practice for web-apps is to have a session per web request.

Instead of having a per worker-thread, we add a middleware that
generates a unique ID per request, utilizing scoped_sessions
scopefunc (custom function for defining a session scope)
in combination with a ContextVar.
With this we create a new session per request.

Signed-off-by: moson-mo <mo-son@mailbox.org>
This commit is contained in:
moson-mo 2023-07-01 12:16:42 +02:00 committed by moson
parent ec090d7b30
commit f92fc2b035
No known key found for this signature in database
GPG key ID: 4A4760AB4EE15296
8 changed files with 94 additions and 19 deletions

View file

@ -25,7 +25,7 @@ import aurweb.pkgbase.util as pkgbaseutil
from aurweb import aur_logging, prometheus, util from aurweb import aur_logging, prometheus, util
from aurweb.aur_redis import redis_connection from aurweb.aur_redis import redis_connection
from aurweb.auth import BasicAuthBackend from aurweb.auth import BasicAuthBackend
from aurweb.db import get_engine, query from aurweb.db import get_engine, query, set_db_session_context
from aurweb.models import AcceptedTerm, Term from aurweb.models import AcceptedTerm, Term
from aurweb.packages.util import get_pkg_or_base from aurweb.packages.util import get_pkg_or_base
from aurweb.prometheus import instrumentator from aurweb.prometheus import instrumentator
@ -308,3 +308,20 @@ async def id_redirect_middleware(request: Request, call_next: typing.Callable):
# Add application middlewares. # Add application middlewares.
app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend()) app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend())
app.add_middleware(SessionMiddleware, secret_key=session_secret) app.add_middleware(SessionMiddleware, secret_key=session_secret)
# Set context var for database session & remove it after our request
@app.middleware("http")
async def db_session_context(request: Request, call_next: typing.Callable):
# static content won't require a db session
if request.url.path.startswith("/static"):
return await util.error_or_result(call_next, request)
try:
set_db_session_context(hash(request))
response = await util.error_or_result(call_next, request)
finally:
set_db_session_context(None)
return response

View file

@ -1,3 +1,7 @@
from contextvars import ContextVar
from threading import get_ident
from typing import Optional
# Supported database drivers. # Supported database drivers.
DRIVERS = {"mysql": "mysql+mysqldb"} DRIVERS = {"mysql": "mysql+mysqldb"}
@ -13,6 +17,23 @@ class Committer:
self.session.commit() self.session.commit()
db_session_context: ContextVar[Optional[int]] = ContextVar(
"session_id", default=get_ident()
)
def get_db_session_context():
id = db_session_context.get()
return id
def set_db_session_context(session_id: int):
if session_id is None:
get_session().remove()
db_session_context.set(session_id)
def make_random_value(table: str, column: str, length: int): def make_random_value(table: str, column: str, length: int):
"""Generate a unique, random value for a string column in a table. """Generate a unique, random value for a string column in a table.
@ -72,36 +93,39 @@ def name() -> str:
return "db" + sha1 return "db" + sha1
# Module-private global memo used to store SQLAlchemy sessions. # Module-private global memo used to store SQLAlchemy sessions registries.
_sessions = dict() _session_registries = dict()
def get_session(engine=None): def get_session(engine=None):
"""Return aurweb.db's global session.""" """Return aurweb.db's global session."""
dbname = name() dbname = name()
global _sessions global _session_registries
if dbname not in _sessions: if dbname not in _session_registries:
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
if not engine: # pragma: no cover if not engine: # pragma: no cover
engine = get_engine() engine = get_engine()
Session = scoped_session(sessionmaker(autoflush=False, bind=engine)) Session = scoped_session(
_sessions[dbname] = Session() sessionmaker(autoflush=False, bind=engine),
scopefunc=get_db_session_context,
)
_session_registries[dbname] = Session
return _sessions.get(dbname) return _session_registries.get(dbname)
def pop_session(dbname: str) -> None: def pop_session(dbname: str) -> None:
""" """
Pop a Session out of the private _sessions memo. Pop a Session registry out of the private _session_registries memo.
:param dbname: Database name :param dbname: Database name
:raises KeyError: When `dbname` does not exist in the memo :raises KeyError: When `dbname` does not exist in the memo
""" """
global _sessions global _session_registries
_sessions.pop(dbname) _session_registries.pop(dbname)
def refresh(model): def refresh(model):
@ -301,12 +325,14 @@ def get_engine(dbname: str = None, echo: bool = False):
if dbname not in _engines: if dbname not in _engines:
db_backend = aurweb.config.get("database", "backend") db_backend = aurweb.config.get("database", "backend")
connect_args = dict() connect_args = dict()
kwargs = {"echo": echo, "connect_args": connect_args}
is_sqlite = bool(db_backend == "sqlite") is_sqlite = bool(db_backend == "sqlite")
if is_sqlite: # pragma: no cover if is_sqlite: # pragma: no cover
connect_args["check_same_thread"] = False connect_args["check_same_thread"] = False
else:
kwargs["isolation_level"] = "READ_COMMITTED"
kwargs = {"echo": echo, "connect_args": connect_args}
from sqlalchemy import create_engine from sqlalchemy import create_engine
_engines[dbname] = create_engine(get_sqlalchemy_url(), **kwargs) _engines[dbname] = create_engine(get_sqlalchemy_url(), **kwargs)

View file

@ -830,6 +830,7 @@ def test_post_account_edit_type_as_dev(client: TestClient, tu_user: User):
request.cookies = cookies request.cookies = cookies
resp = request.post(endpoint, data=data) resp = request.post(endpoint, data=data)
assert resp.status_code == int(HTTPStatus.OK) assert resp.status_code == int(HTTPStatus.OK)
db.refresh(user2)
assert user2.AccountTypeID == at.DEVELOPER_ID assert user2.AccountTypeID == at.DEVELOPER_ID
@ -850,6 +851,7 @@ def test_post_account_edit_invalid_type_as_tu(client: TestClient, tu_user: User)
request.cookies = cookies request.cookies = cookies
resp = request.post(endpoint, data=data) resp = request.post(endpoint, data=data)
assert resp.status_code == int(HTTPStatus.BAD_REQUEST) assert resp.status_code == int(HTTPStatus.BAD_REQUEST)
db.refresh(user2)
assert user2.AccountTypeID == at.USER_ID assert user2.AccountTypeID == at.USER_ID
errors = get_errors(resp.text) errors = get_errors(resp.text)
@ -1020,6 +1022,7 @@ def test_post_account_edit_inactivity(client: TestClient, user: User):
assert resp.status_code == int(HTTPStatus.OK) assert resp.status_code == int(HTTPStatus.OK)
# Make sure the user record got updated correctly. # Make sure the user record got updated correctly.
db.refresh(user)
assert user.InactivityTS > 0 assert user.InactivityTS > 0
post_data.update({"J": False}) post_data.update({"J": False})
@ -1028,6 +1031,7 @@ def test_post_account_edit_inactivity(client: TestClient, user: User):
resp = request.post(f"/account/{user.Username}/edit", data=post_data) resp = request.post(f"/account/{user.Username}/edit", data=post_data)
assert resp.status_code == int(HTTPStatus.OK) assert resp.status_code == int(HTTPStatus.OK)
db.refresh(user)
assert user.InactivityTS == 0 assert user.InactivityTS == 0
@ -1050,6 +1054,7 @@ def test_post_account_edit_suspended(client: TestClient, user: User):
assert resp.status_code == int(HTTPStatus.OK) assert resp.status_code == int(HTTPStatus.OK)
# Make sure the user record got updated correctly. # Make sure the user record got updated correctly.
db.refresh(user)
assert user.Suspended assert user.Suspended
# Let's make sure the DB got updated properly. # Let's make sure the DB got updated properly.
assert user.session is None assert user.session is None
@ -1207,6 +1212,7 @@ def test_post_account_edit_password(client: TestClient, user: User):
assert response.status_code == int(HTTPStatus.OK) assert response.status_code == int(HTTPStatus.OK)
db.refresh(user)
assert user.valid_password("newPassword") assert user.valid_password("newPassword")
@ -1273,6 +1279,7 @@ def test_post_account_edit_self_type_as_tu(client: TestClient, tu_user: User):
resp = request.post(endpoint, data=data) resp = request.post(endpoint, data=data)
assert resp.status_code == int(HTTPStatus.OK) assert resp.status_code == int(HTTPStatus.OK)
db.refresh(tu_user)
assert tu_user.AccountTypeID == USER_ID assert tu_user.AccountTypeID == USER_ID
@ -1308,6 +1315,7 @@ def test_post_account_edit_other_user_type_as_tu(
assert resp.status_code == int(HTTPStatus.OK) assert resp.status_code == int(HTTPStatus.OK)
# Let's make sure the DB got updated properly. # Let's make sure the DB got updated properly.
db.refresh(user2)
assert user2.AccountTypeID == TRUSTED_USER_ID assert user2.AccountTypeID == TRUSTED_USER_ID
# and also that this got logged out at DEBUG level. # and also that this got logged out at DEBUG level.

View file

@ -1526,6 +1526,7 @@ def test_packages_post_disown_as_maintainer(
errors = get_errors(resp.text) errors = get_errors(resp.text)
expected = "You did not select any packages to disown." expected = "You did not select any packages to disown."
assert errors[0].text.strip() == expected assert errors[0].text.strip() == expected
db.refresh(package)
assert package.PackageBase.Maintainer is not None assert package.PackageBase.Maintainer is not None
# Try to disown `package` without giving the confirm argument. # Try to disown `package` without giving the confirm argument.
@ -1550,6 +1551,7 @@ def test_packages_post_disown_as_maintainer(
data={"action": "disown", "IDs": [package.ID], "confirm": True}, data={"action": "disown", "IDs": [package.ID], "confirm": True},
) )
assert resp.status_code == int(HTTPStatus.BAD_REQUEST) assert resp.status_code == int(HTTPStatus.BAD_REQUEST)
db.refresh(package)
assert package.PackageBase.Maintainer is not None assert package.PackageBase.Maintainer is not None
errors = get_errors(resp.text) errors = get_errors(resp.text)
expected = "You are not allowed to disown one of the packages you selected." expected = "You are not allowed to disown one of the packages you selected."
@ -1563,6 +1565,7 @@ def test_packages_post_disown_as_maintainer(
data={"action": "disown", "IDs": [package.ID], "confirm": True}, data={"action": "disown", "IDs": [package.ID], "confirm": True},
) )
db.get_session().expire_all()
assert package.PackageBase.Maintainer is None assert package.PackageBase.Maintainer is None
successes = get_successes(resp.text) successes = get_successes(resp.text)
expected = "The selected packages have been disowned." expected = "The selected packages have been disowned."
@ -1647,6 +1650,7 @@ def test_packages_post_delete(
# Whoo. Now, let's finally make a valid request as `tu_user` # Whoo. Now, let's finally make a valid request as `tu_user`
# to delete `package`. # to delete `package`.
pkgname = package.PackageBase.Name
with client as request: with client as request:
request.cookies = tu_cookies request.cookies = tu_cookies
resp = request.post( resp = request.post(
@ -1659,7 +1663,7 @@ def test_packages_post_delete(
assert successes[0].text.strip() == expected assert successes[0].text.strip() == expected
# Expect that the package deletion was logged. # Expect that the package deletion was logged.
pkgbases = [package.PackageBase.Name] pkgbases = [pkgname]
expected = ( expected = (
f"Privileged user '{tu_user.Username}' deleted the " f"Privileged user '{tu_user.Username}' deleted the "
f"following package bases: {str(pkgbases)}." f"following package bases: {str(pkgbases)}."

View file

@ -686,6 +686,7 @@ def test_pkgbase_comment_pin_as_co(
assert resp.status_code == int(HTTPStatus.SEE_OTHER) assert resp.status_code == int(HTTPStatus.SEE_OTHER)
# Assert that PinnedTS got set. # Assert that PinnedTS got set.
db.refresh(comment)
assert comment.PinnedTS > 0 assert comment.PinnedTS > 0
# Unpin the comment we just pinned. # Unpin the comment we just pinned.
@ -696,6 +697,7 @@ def test_pkgbase_comment_pin_as_co(
assert resp.status_code == int(HTTPStatus.SEE_OTHER) assert resp.status_code == int(HTTPStatus.SEE_OTHER)
# Let's assert that PinnedTS was unset. # Let's assert that PinnedTS was unset.
db.refresh(comment)
assert comment.PinnedTS == 0 assert comment.PinnedTS == 0
@ -714,6 +716,7 @@ def test_pkgbase_comment_pin(
assert resp.status_code == int(HTTPStatus.SEE_OTHER) assert resp.status_code == int(HTTPStatus.SEE_OTHER)
# Assert that PinnedTS got set. # Assert that PinnedTS got set.
db.refresh(comment)
assert comment.PinnedTS > 0 assert comment.PinnedTS > 0
# Unpin the comment we just pinned. # Unpin the comment we just pinned.
@ -724,6 +727,7 @@ def test_pkgbase_comment_pin(
assert resp.status_code == int(HTTPStatus.SEE_OTHER) assert resp.status_code == int(HTTPStatus.SEE_OTHER)
# Let's assert that PinnedTS was unset. # Let's assert that PinnedTS was unset.
db.refresh(comment)
assert comment.PinnedTS == 0 assert comment.PinnedTS == 0
@ -1038,6 +1042,7 @@ def test_pkgbase_flag(
request.cookies = cookies request.cookies = cookies
resp = request.post(endpoint, data={"comments": "Test"}) resp = request.post(endpoint, data={"comments": "Test"})
assert resp.status_code == int(HTTPStatus.SEE_OTHER) assert resp.status_code == int(HTTPStatus.SEE_OTHER)
db.refresh(pkgbase)
assert pkgbase.Flagger == user assert pkgbase.Flagger == user
assert pkgbase.FlaggerComment == "Test" assert pkgbase.FlaggerComment == "Test"
@ -1075,6 +1080,7 @@ def test_pkgbase_flag(
request.cookies = user2_cookies request.cookies = user2_cookies
resp = request.post(endpoint) resp = request.post(endpoint)
assert resp.status_code == int(HTTPStatus.SEE_OTHER) assert resp.status_code == int(HTTPStatus.SEE_OTHER)
db.refresh(pkgbase)
assert pkgbase.Flagger == user assert pkgbase.Flagger == user
# Now, test that the 'maintainer' user can. # Now, test that the 'maintainer' user can.
@ -1083,6 +1089,7 @@ def test_pkgbase_flag(
request.cookies = maint_cookies request.cookies = maint_cookies
resp = request.post(endpoint) resp = request.post(endpoint)
assert resp.status_code == int(HTTPStatus.SEE_OTHER) assert resp.status_code == int(HTTPStatus.SEE_OTHER)
db.refresh(pkgbase)
assert pkgbase.Flagger is None assert pkgbase.Flagger is None
# Flag it again. # Flag it again.
@ -1096,6 +1103,7 @@ def test_pkgbase_flag(
request.cookies = cookies request.cookies = cookies
resp = request.post(endpoint) resp = request.post(endpoint)
assert resp.status_code == int(HTTPStatus.SEE_OTHER) assert resp.status_code == int(HTTPStatus.SEE_OTHER)
db.refresh(pkgbase)
assert pkgbase.Flagger is None assert pkgbase.Flagger is None
@ -1168,6 +1176,7 @@ def test_pkgbase_vote(client: TestClient, user: User, package: Package):
vote = pkgbase.package_votes.filter(PackageVote.UsersID == user.ID).first() vote = pkgbase.package_votes.filter(PackageVote.UsersID == user.ID).first()
assert vote is not None assert vote is not None
db.refresh(pkgbase)
assert pkgbase.NumVotes == 1 assert pkgbase.NumVotes == 1
# Remove vote. # Remove vote.
@ -1179,6 +1188,7 @@ def test_pkgbase_vote(client: TestClient, user: User, package: Package):
vote = pkgbase.package_votes.filter(PackageVote.UsersID == user.ID).first() vote = pkgbase.package_votes.filter(PackageVote.UsersID == user.ID).first()
assert vote is None assert vote is None
db.refresh(pkgbase)
assert pkgbase.NumVotes == 0 assert pkgbase.NumVotes == 0
@ -1590,9 +1600,9 @@ def test_pkgbase_merge_post(
assert resp.status_code == int(HTTPStatus.SEE_OTHER) assert resp.status_code == int(HTTPStatus.SEE_OTHER)
# Save these relationships for later comparison. # Save these relationships for later comparison.
comments = package.PackageBase.comments.all() comments = [row.__dict__ for row in package.PackageBase.comments.all()]
notifs = package.PackageBase.notifications.all() notifs = [row.__dict__ for row in package.PackageBase.notifications.all()]
votes = package.PackageBase.package_votes.all() votes = [row.__dict__ for row in package.PackageBase.package_votes.all()]
# Merge the package into target. # Merge the package into target.
endpoint = f"/pkgbase/{package.PackageBase.Name}/merge" endpoint = f"/pkgbase/{package.PackageBase.Name}/merge"
@ -1610,9 +1620,13 @@ def test_pkgbase_merge_post(
# Assert that the original comments, notifs and votes we setup # Assert that the original comments, notifs and votes we setup
# got migrated to target as intended. # got migrated to target as intended.
assert comments == target.comments.all() db.get_session().refresh(target)
assert notifs == target.notifications.all() assert len(comments) == target.comments.count()
assert votes == target.package_votes.all() assert comments[0]["PackageBaseID"] != target.ID
assert len(notifs) == target.notifications.count()
assert notifs[0]["PackageBaseID"] != target.ID
assert len(votes) == target.package_votes.count()
assert votes[0]["PackageBaseID"] != target.ID
# ...and that the package got deleted. # ...and that the package got deleted.
package = db.query(Package).filter(Package.Name == pkgname).first() package = db.query(Package).filter(Package.Name == pkgname).first()

View file

@ -649,6 +649,7 @@ def test_orphan_request(
assert resp.headers.get("location") == f"/pkgbase/{pkgbase.Name}" assert resp.headers.get("location") == f"/pkgbase/{pkgbase.Name}"
# We should have unset the maintainer. # We should have unset the maintainer.
db.refresh(pkgbase)
assert pkgbase.Maintainer is None assert pkgbase.Maintainer is None
# We should have removed the comaintainers. # We should have removed the comaintainers.
@ -748,6 +749,7 @@ def test_orphan_as_maintainer(client: TestClient, auser: User, pkgbase: PackageB
# As the pkgbase maintainer, disowning the package just ends up # As the pkgbase maintainer, disowning the package just ends up
# either promoting the lowest priority comaintainer or removing # either promoting the lowest priority comaintainer or removing
# the associated maintainer relationship altogether. # the associated maintainer relationship altogether.
db.refresh(pkgbase)
assert pkgbase.Maintainer is None assert pkgbase.Maintainer is None
@ -1044,6 +1046,7 @@ def test_requests_close_post(client: TestClient, user: User, pkgreq: PackageRequ
resp = request.post(f"/requests/{pkgreq.ID}/close") resp = request.post(f"/requests/{pkgreq.ID}/close")
assert resp.status_code == int(HTTPStatus.SEE_OTHER) assert resp.status_code == int(HTTPStatus.SEE_OTHER)
db.refresh(pkgreq)
assert pkgreq.Status == REJECTED_ID assert pkgreq.Status == REJECTED_ID
assert pkgreq.Closer == user assert pkgreq.Closer == user
assert pkgreq.ClosureComment == str() assert pkgreq.ClosureComment == str()
@ -1060,6 +1063,7 @@ def test_requests_close_post_rejected(
) )
assert resp.status_code == int(HTTPStatus.SEE_OTHER) assert resp.status_code == int(HTTPStatus.SEE_OTHER)
db.refresh(pkgreq)
assert pkgreq.Status == REJECTED_ID assert pkgreq.Status == REJECTED_ID
assert pkgreq.Closer == user assert pkgreq.Closer == user
assert pkgreq.ClosureComment == str() assert pkgreq.ClosureComment == str()

View file

@ -102,6 +102,7 @@ def test_user_language(client: TestClient, user: User):
req.cookies = {"AURSID": sid} req.cookies = {"AURSID": sid}
response = req.post("/language", data=post_data) response = req.post("/language", data=post_data)
assert response.status_code == int(HTTPStatus.SEE_OTHER) assert response.status_code == int(HTTPStatus.SEE_OTHER)
db.refresh(user)
assert user.LangPreference == "de" assert user.LangPreference == "de"

View file

@ -764,6 +764,7 @@ def test_tu_proposal_vote(client, proposal):
assert response.status_code == int(HTTPStatus.OK) assert response.status_code == int(HTTPStatus.OK)
# Check that the proposal record got updated. # Check that the proposal record got updated.
db.refresh(voteinfo)
assert voteinfo.Yes == yes + 1 assert voteinfo.Yes == yes + 1
# Check that the new TUVote exists. # Check that the new TUVote exists.