From bc03d8b8f20ac0a1e6a2b03069632c8a064332f0 Mon Sep 17 00:00:00 2001 From: moson Date: Thu, 20 Jul 2023 18:21:05 +0200 Subject: [PATCH] fix: Fix middleware checking for accepted terms The current query is a bit mixed up. The intention was to return the number of unaccepted records. Now it does also count all records that were accepted by some other user though. Let's check the total number of terms vs. the number of accepted records (by our user) instead. Signed-off-by: moson --- aurweb/asgi.py | 19 ++++++++----------- test/test_accounts_routes.py | 6 ++++++ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/aurweb/asgi.py b/aurweb/asgi.py index eb02413b..1be77ff9 100644 --- a/aurweb/asgi.py +++ b/aurweb/asgi.py @@ -14,7 +14,7 @@ from fastapi.responses import RedirectResponse from fastapi.staticfiles import StaticFiles from jinja2 import TemplateNotFound from prometheus_client import multiprocess -from sqlalchemy import and_, or_ +from sqlalchemy import and_ from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.sessions import SessionMiddleware @@ -277,21 +277,18 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable): """This middleware function redirects authenticated users if they have any outstanding Terms to agree to.""" if request.user.is_authenticated() and request.url.path != "/tos": - unaccepted = ( + accepted = ( query(Term) .join(AcceptedTerm) .filter( - or_( - AcceptedTerm.UsersID != request.user.ID, - and_( - AcceptedTerm.UsersID == request.user.ID, - AcceptedTerm.TermsID == Term.ID, - AcceptedTerm.Revision < Term.Revision, - ), - ) + and_( + AcceptedTerm.UsersID == request.user.ID, + AcceptedTerm.TermsID == Term.ID, + AcceptedTerm.Revision >= Term.Revision, + ), ) ) - if query(Term).count() > unaccepted.count(): + if query(Term).count() - accepted.count() > 0: return RedirectResponse("/tos", status_code=int(http.HTTPStatus.SEE_OTHER)) return await util.error_or_result(call_next, request) diff --git a/test/test_accounts_routes.py b/test/test_accounts_routes.py index c9d77c1f..3c481d0a 100644 --- a/test/test_accounts_routes.py +++ b/test/test_accounts_routes.py @@ -1915,6 +1915,12 @@ def test_get_terms_of_service(client: TestClient, user: User): # We accepted the term, there's nothing left to accept. assert response.status_code == int(HTTPStatus.SEE_OTHER) + # Make sure we don't get redirected to /tos when browsing "Home" + with client as request: + request.cookies = cookies + response = request.get("/") + assert response.status_code == int(HTTPStatus.OK) + # Bump the term's revision. with db.begin(): term.Revision = 2