mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
fix: Fix middleware checking for accepted terms
The current query is a bit mixed up. The intention was to return the number of unaccepted records. Now it does also count all records that were accepted by some other user though. Let's check the total number of terms vs. the number of accepted records (by our user) instead. Signed-off-by: moson <moson@archlinux.org>
This commit is contained in:
parent
5729d6787f
commit
bc03d8b8f2
2 changed files with 14 additions and 11 deletions
|
@ -14,7 +14,7 @@ from fastapi.responses import RedirectResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from jinja2 import TemplateNotFound
|
from jinja2 import TemplateNotFound
|
||||||
from prometheus_client import multiprocess
|
from prometheus_client import multiprocess
|
||||||
from sqlalchemy import and_, or_
|
from sqlalchemy import and_
|
||||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
@ -277,21 +277,18 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable):
|
||||||
"""This middleware function redirects authenticated users if they
|
"""This middleware function redirects authenticated users if they
|
||||||
have any outstanding Terms to agree to."""
|
have any outstanding Terms to agree to."""
|
||||||
if request.user.is_authenticated() and request.url.path != "/tos":
|
if request.user.is_authenticated() and request.url.path != "/tos":
|
||||||
unaccepted = (
|
accepted = (
|
||||||
query(Term)
|
query(Term)
|
||||||
.join(AcceptedTerm)
|
.join(AcceptedTerm)
|
||||||
.filter(
|
.filter(
|
||||||
or_(
|
|
||||||
AcceptedTerm.UsersID != request.user.ID,
|
|
||||||
and_(
|
and_(
|
||||||
AcceptedTerm.UsersID == request.user.ID,
|
AcceptedTerm.UsersID == request.user.ID,
|
||||||
AcceptedTerm.TermsID == Term.ID,
|
AcceptedTerm.TermsID == Term.ID,
|
||||||
AcceptedTerm.Revision < Term.Revision,
|
AcceptedTerm.Revision >= Term.Revision,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
if query(Term).count() - accepted.count() > 0:
|
||||||
if query(Term).count() > unaccepted.count():
|
|
||||||
return RedirectResponse("/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
|
return RedirectResponse("/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
|
||||||
|
|
||||||
return await util.error_or_result(call_next, request)
|
return await util.error_or_result(call_next, request)
|
||||||
|
|
|
@ -1915,6 +1915,12 @@ def test_get_terms_of_service(client: TestClient, user: User):
|
||||||
# We accepted the term, there's nothing left to accept.
|
# We accepted the term, there's nothing left to accept.
|
||||||
assert response.status_code == int(HTTPStatus.SEE_OTHER)
|
assert response.status_code == int(HTTPStatus.SEE_OTHER)
|
||||||
|
|
||||||
|
# Make sure we don't get redirected to /tos when browsing "Home"
|
||||||
|
with client as request:
|
||||||
|
request.cookies = cookies
|
||||||
|
response = request.get("/")
|
||||||
|
assert response.status_code == int(HTTPStatus.OK)
|
||||||
|
|
||||||
# Bump the term's revision.
|
# Bump the term's revision.
|
||||||
with db.begin():
|
with db.begin():
|
||||||
term.Revision = 2
|
term.Revision = 2
|
||||||
|
|
Loading…
Add table
Reference in a new issue