diff --git a/aurweb/asgi.py b/aurweb/asgi.py index eb02413b..1be77ff9 100644 --- a/aurweb/asgi.py +++ b/aurweb/asgi.py @@ -14,7 +14,7 @@ from fastapi.responses import RedirectResponse from fastapi.staticfiles import StaticFiles from jinja2 import TemplateNotFound from prometheus_client import multiprocess -from sqlalchemy import and_, or_ +from sqlalchemy import and_ from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.sessions import SessionMiddleware @@ -277,21 +277,18 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable): """This middleware function redirects authenticated users if they have any outstanding Terms to agree to.""" if request.user.is_authenticated() and request.url.path != "/tos": - unaccepted = ( + accepted = ( query(Term) .join(AcceptedTerm) .filter( - or_( - AcceptedTerm.UsersID != request.user.ID, - and_( - AcceptedTerm.UsersID == request.user.ID, - AcceptedTerm.TermsID == Term.ID, - AcceptedTerm.Revision < Term.Revision, - ), - ) + and_( + AcceptedTerm.UsersID == request.user.ID, + AcceptedTerm.TermsID == Term.ID, + AcceptedTerm.Revision >= Term.Revision, + ), ) ) - if query(Term).count() > unaccepted.count(): + if query(Term).count() - accepted.count() > 0: return RedirectResponse("/tos", status_code=int(http.HTTPStatus.SEE_OTHER)) return await util.error_or_result(call_next, request) diff --git a/test/test_accounts_routes.py b/test/test_accounts_routes.py index c9d77c1f..3c481d0a 100644 --- a/test/test_accounts_routes.py +++ b/test/test_accounts_routes.py @@ -1915,6 +1915,12 @@ def test_get_terms_of_service(client: TestClient, user: User): # We accepted the term, there's nothing left to accept. assert response.status_code == int(HTTPStatus.SEE_OTHER) + # Make sure we don't get redirected to /tos when browsing "Home" + with client as request: + request.cookies = cookies + response = request.get("/") + assert response.status_code == int(HTTPStatus.OK) + # Bump the term's revision. with db.begin(): term.Revision = 2