fix: Fix middleware checking for accepted terms

The current query is a bit mixed up. The intention was to return the
number of unaccepted records. Now it does also count all records
that were accepted by some other user though.

Let's check the total number of terms vs. the number of accepted
records (by our user) instead.

Signed-off-by: moson <moson@archlinux.org>
This commit is contained in:
moson 2023-07-20 18:21:05 +02:00
parent 5729d6787f
commit bc03d8b8f2
No known key found for this signature in database
GPG key ID: 4A4760AB4EE15296
2 changed files with 14 additions and 11 deletions

View file

@ -14,7 +14,7 @@ from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
from prometheus_client import multiprocess from prometheus_client import multiprocess
from sqlalchemy import and_, or_ from sqlalchemy import and_
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
@ -277,21 +277,18 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable):
"""This middleware function redirects authenticated users if they """This middleware function redirects authenticated users if they
have any outstanding Terms to agree to.""" have any outstanding Terms to agree to."""
if request.user.is_authenticated() and request.url.path != "/tos": if request.user.is_authenticated() and request.url.path != "/tos":
unaccepted = ( accepted = (
query(Term) query(Term)
.join(AcceptedTerm) .join(AcceptedTerm)
.filter( .filter(
or_( and_(
AcceptedTerm.UsersID != request.user.ID, AcceptedTerm.UsersID == request.user.ID,
and_( AcceptedTerm.TermsID == Term.ID,
AcceptedTerm.UsersID == request.user.ID, AcceptedTerm.Revision >= Term.Revision,
AcceptedTerm.TermsID == Term.ID, ),
AcceptedTerm.Revision < Term.Revision,
),
)
) )
) )
if query(Term).count() > unaccepted.count(): if query(Term).count() - accepted.count() > 0:
return RedirectResponse("/tos", status_code=int(http.HTTPStatus.SEE_OTHER)) return RedirectResponse("/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
return await util.error_or_result(call_next, request) return await util.error_or_result(call_next, request)

View file

@ -1915,6 +1915,12 @@ def test_get_terms_of_service(client: TestClient, user: User):
# We accepted the term, there's nothing left to accept. # We accepted the term, there's nothing left to accept.
assert response.status_code == int(HTTPStatus.SEE_OTHER) assert response.status_code == int(HTTPStatus.SEE_OTHER)
# Make sure we don't get redirected to /tos when browsing "Home"
with client as request:
request.cookies = cookies
response = request.get("/")
assert response.status_code == int(HTTPStatus.OK)
# Bump the term's revision. # Bump the term's revision.
with db.begin(): with db.begin():
term.Revision = 2 term.Revision = 2