fix: Fix middleware checking for accepted terms

The current query is a bit mixed up. The intention was to return the
number of unaccepted records. Now it does also count all records
that were accepted by some other user though.

Let's check the total number of terms vs. the number of accepted
records (by our user) instead.

Signed-off-by: moson <moson@archlinux.org>
This commit is contained in:
moson 2023-07-20 18:21:05 +02:00
parent 5729d6787f
commit bc03d8b8f2
No known key found for this signature in database
GPG key ID: 4A4760AB4EE15296
2 changed files with 14 additions and 11 deletions

View file

@ -14,7 +14,7 @@ from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from jinja2 import TemplateNotFound
from prometheus_client import multiprocess
from sqlalchemy import and_, or_
from sqlalchemy import and_
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.sessions import SessionMiddleware
@ -277,21 +277,18 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable):
"""This middleware function redirects authenticated users if they
have any outstanding Terms to agree to."""
if request.user.is_authenticated() and request.url.path != "/tos":
unaccepted = (
accepted = (
query(Term)
.join(AcceptedTerm)
.filter(
or_(
AcceptedTerm.UsersID != request.user.ID,
and_(
AcceptedTerm.UsersID == request.user.ID,
AcceptedTerm.TermsID == Term.ID,
AcceptedTerm.Revision < Term.Revision,
),
)
and_(
AcceptedTerm.UsersID == request.user.ID,
AcceptedTerm.TermsID == Term.ID,
AcceptedTerm.Revision >= Term.Revision,
),
)
)
if query(Term).count() > unaccepted.count():
if query(Term).count() - accepted.count() > 0:
return RedirectResponse("/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
return await util.error_or_result(call_next, request)

View file

@ -1915,6 +1915,12 @@ def test_get_terms_of_service(client: TestClient, user: User):
# We accepted the term, there's nothing left to accept.
assert response.status_code == int(HTTPStatus.SEE_OTHER)
# Make sure we don't get redirected to /tos when browsing "Home"
with client as request:
request.cookies = cookies
response = request.get("/")
assert response.status_code == int(HTTPStatus.OK)
# Bump the term's revision.
with db.begin():
term.Revision = 2