diff --git a/aurweb/auth/__init__.py b/aurweb/auth/__init__.py index 82192cc2..7aa4b526 100644 --- a/aurweb/auth/__init__.py +++ b/aurweb/auth/__init__.py @@ -1,5 +1,4 @@ import functools -import re from datetime import datetime from http import HTTPStatus @@ -122,17 +121,12 @@ class BasicAuthBackend(AuthenticationBackend): def auth_required(is_required: bool = True, - login: bool = True, - redirect: str = "/", template: tuple = None, status_code: HTTPStatus = HTTPStatus.UNAUTHORIZED): """ Authentication route decorator. - If redirect is given, the user will be redirected if the auth state - does not match is_required. - If template is given, it will be rendered with Unauthorized if - is_required does not match and take priority over redirect. + is_required does not match. A precondition of this function is that, if template is provided, it **must** match the following format: @@ -152,8 +146,6 @@ def auth_required(is_required: bool = True, applying any format operations. :param is_required: A boolean indicating whether the function requires auth - :param login: Redirect to `/login`, passing `next=` - :param redirect: Path to redirect to if is_required isn't True :param template: A three-element template tuple: (path, title_iterable, variable_iterable) :param status_code: An optional status_code for template render. @@ -166,14 +158,17 @@ def auth_required(is_required: bool = True, if request.user.is_authenticated() != is_required: url = "/" - if redirect: - path_params_expr = re.compile(r'\{(\w+)\}') - match = re.findall(path_params_expr, redirect) - args = {k: request.path_params.get(k) for k in match} - url = redirect.format(**args) + if is_required: + if request.method == "GET": + url = request.url.path + elif request.method == "POST" and (referer := request.headers.get("Referer")): + aur = aurweb.config.get("options", "aur_location") + "/" + if not referer.startswith(aur): + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, + detail=_("Bad Referer header.")) + url = referer[len(aur) - 1:] - if login: - url = "/login?" + util.urlencode({"next": url}) + url = "/login?" + util.urlencode({"next": url}) if template: # template=("template.html", diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py index 360857e8..dade92bb 100644 --- a/aurweb/routers/accounts.py +++ b/aurweb/routers/accounts.py @@ -27,14 +27,14 @@ logger = logging.get_logger(__name__) @router.get("/passreset", response_class=HTMLResponse) -@auth_required(False, login=False) +@auth_required(False) async def passreset(request: Request): context = await make_variable_context(request, "Password Reset") return render_template(request, "passreset.html", context) @router.post("/passreset", response_class=HTMLResponse) -@auth_required(False, login=False) +@auth_required(False) async def passreset_post(request: Request, user: str = Form(...), resetkey: str = Form(default=None), @@ -226,7 +226,7 @@ def make_account_form_context(context: dict, @router.get("/register", response_class=HTMLResponse) -@auth_required(False, login=False) +@auth_required(False) async def account_register(request: Request, U: str = Form(default=str()), # Username E: str = Form(default=str()), # Email @@ -252,7 +252,7 @@ async def account_register(request: Request, @router.post("/register", response_class=HTMLResponse) -@auth_required(False, login=False) +@auth_required(False) async def account_register_post(request: Request, U: str = Form(default=str()), # Username E: str = Form(default=str()), # Email @@ -340,7 +340,7 @@ def cannot_edit(request, user): @router.get("/account/{username}/edit", response_class=HTMLResponse) -@auth_required(True, redirect="/account/{username}") +@auth_required(True) async def account_edit(request: Request, username: str): user = db.query(models.User, models.User.Username == username).first() @@ -356,7 +356,7 @@ async def account_edit(request: Request, username: str): @router.post("/account/{username}/edit", response_class=HTMLResponse) -@auth_required(True, redirect="/account/{username}") +@auth_required(True) async def account_edit_post(request: Request, username: str, U: str = Form(default=str()), # Username @@ -443,7 +443,7 @@ async def account(request: Request, username: str): @router.get("/account/{username}/comments") -@auth_required(redirect="/account/{username}/comments") +@auth_required() async def account_comments(request: Request, username: str): user = get_user_by_name(username) context = make_context(request, "Accounts") @@ -454,7 +454,7 @@ async def account_comments(request: Request, username: str): @router.get("/accounts") -@auth_required(True, redirect="/accounts") +@auth_required(True) @account_type_required({at.TRUSTED_USER, at.DEVELOPER, at.TRUSTED_USER_AND_DEV}) @@ -464,7 +464,7 @@ async def accounts(request: Request): @router.post("/accounts") -@auth_required(True, redirect="/accounts") +@auth_required(True) @account_type_required({at.TRUSTED_USER, at.DEVELOPER, at.TRUSTED_USER_AND_DEV}) @@ -548,7 +548,7 @@ def render_terms_of_service(request: Request, @router.get("/tos") -@auth_required(True, redirect="/tos") +@auth_required(True) async def terms_of_service(request: Request): # Query the database for terms that were previously accepted, # but now have a bumped Revision that needs to be accepted. @@ -572,7 +572,7 @@ async def terms_of_service(request: Request): @router.post("/tos") -@auth_required(True, redirect="/tos") +@auth_required(True) async def terms_of_service_post(request: Request, accept: bool = Form(default=False)): # Query the database for terms that were previously accepted, diff --git a/aurweb/routers/auth.py b/aurweb/routers/auth.py index 1e0b026a..74763667 100644 --- a/aurweb/routers/auth.py +++ b/aurweb/routers/auth.py @@ -29,7 +29,7 @@ async def login_get(request: Request, next: str = "/"): @router.post("/login", response_class=HTMLResponse) -@auth_required(False, login=False) +@auth_required(False) async def login_post(request: Request, next: str = Form(...), user: str = Form(default=str()), diff --git a/aurweb/routers/packages.py b/aurweb/routers/packages.py index 2bf04949..4a2cdce3 100644 --- a/aurweb/routers/packages.py +++ b/aurweb/routers/packages.py @@ -295,7 +295,7 @@ async def package_base_voters(request: Request, name: str) -> Response: @router.post("/pkgbase/{name}/comments") -@auth_required(True, redirect="/pkgbase/{name}/comments") +@auth_required(True) async def pkgbase_comments_post( request: Request, name: str, comment: str = Form(default=str()), @@ -327,7 +327,7 @@ async def pkgbase_comments_post( @router.get("/pkgbase/{name}/comments/{id}/form") -@auth_required(True, login=False) +@auth_required(True) async def pkgbase_comment_form(request: Request, name: str, id: int, next: str = Query(default=None)): """ Produce a comment form for comment {id}. """ @@ -353,7 +353,7 @@ async def pkgbase_comment_form(request: Request, name: str, id: int, @router.post("/pkgbase/{name}/comments/{id}") -@auth_required(True, redirect="/pkgbase/{name}/comments/{id}") +@auth_required(True) async def pkgbase_comment_post( request: Request, name: str, id: int, comment: str = Form(default=str()), @@ -392,7 +392,7 @@ async def pkgbase_comment_post( @router.get("/pkgbase/{name}/comments/{id}/edit") -@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/edit") +@auth_required(True) async def pkgbase_comment_edit(request: Request, name: str, id: int, next: str = Form(default=None)): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -407,7 +407,7 @@ async def pkgbase_comment_edit(request: Request, name: str, id: int, @router.post("/pkgbase/{name}/comments/{id}/delete") -@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/delete") +@auth_required(True) async def pkgbase_comment_delete(request: Request, name: str, id: int, next: str = Form(default=None)): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -433,7 +433,7 @@ async def pkgbase_comment_delete(request: Request, name: str, id: int, @router.post("/pkgbase/{name}/comments/{id}/undelete") -@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/undelete") +@auth_required(True) async def pkgbase_comment_undelete(request: Request, name: str, id: int, next: str = Form(default=None)): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -458,7 +458,7 @@ async def pkgbase_comment_undelete(request: Request, name: str, id: int, @router.post("/pkgbase/{name}/comments/{id}/pin") -@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/pin") +@auth_required(True) async def pkgbase_comment_pin(request: Request, name: str, id: int, next: str = Form(default=None)): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -483,7 +483,7 @@ async def pkgbase_comment_pin(request: Request, name: str, id: int, @router.post("/pkgbase/{name}/comments/{id}/unpin") -@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/unpin") +@auth_required(True) async def pkgbase_comment_unpin(request: Request, name: str, id: int, next: str = Form(default=None)): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -507,7 +507,7 @@ async def pkgbase_comment_unpin(request: Request, name: str, id: int, @router.get("/pkgbase/{name}/comaintainers") -@auth_required(True, redirect="/pkgbase/{name}/comaintainers") +@auth_required(True) async def package_base_comaintainers(request: Request, name: str) -> Response: # Get the PackageBase. pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -532,7 +532,7 @@ async def package_base_comaintainers(request: Request, name: str) -> Response: @router.post("/pkgbase/{name}/comaintainers") -@auth_required(True, redirect="/pkgbase/{name}/comaintainers") +@auth_required(True) async def package_base_comaintainers_post( request: Request, name: str, users: str = Form(default=str())) -> Response: @@ -584,7 +584,7 @@ async def package_base_comaintainers_post( @router.get("/requests") -@auth_required(True, redirect="/requests") +@auth_required(True) async def requests(request: Request, O: int = Query(default=defaults.O), PP: int = Query(default=defaults.PP)): @@ -618,7 +618,7 @@ async def requests(request: Request, @router.get("/pkgbase/{name}/request") -@auth_required(True, redirect="/pkgbase/{name}/request") +@auth_required(True) async def package_request(request: Request, name: str): pkgbase = get_pkg_or_base(name, models.PackageBase) context = await make_variable_context(request, "Submit Request") @@ -627,7 +627,7 @@ async def package_request(request: Request, name: str): @router.post("/pkgbase/{name}/request") -@auth_required(True, redirect="/pkgbase/{name}/request") +@auth_required(True) async def pkgbase_request_post(request: Request, name: str, type: str = Form(...), merge_into: str = Form(default=None), @@ -699,7 +699,7 @@ async def pkgbase_request_post(request: Request, name: str, @router.get("/requests/{id}/close") -@auth_required(True, redirect="/requests/{id}/close") +@auth_required(True) async def requests_close(request: Request, id: int): pkgreq = get_pkgreq_by_id(id) if not request.user.is_elevated() and request.user != pkgreq.User: @@ -712,7 +712,7 @@ async def requests_close(request: Request, id: int): @router.post("/requests/{id}/close") -@auth_required(True, redirect="/requests/{id}/close") +@auth_required(True) async def requests_close_post(request: Request, id: int, reason: int = Form(default=0), comments: str = Form(default=str())): @@ -775,7 +775,7 @@ async def pkgbase_keywords(request: Request, name: str, @router.get("/pkgbase/{name}/flag") -@auth_required(True, redirect="/pkgbase/{name}/flag") +@auth_required(True) async def pkgbase_flag_get(request: Request, name: str): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -790,7 +790,7 @@ async def pkgbase_flag_get(request: Request, name: str): @router.post("/pkgbase/{name}/flag") -@auth_required(True, redirect="/pkgbase/{name}/flag") +@auth_required(True) async def pkgbase_flag_post(request: Request, name: str, comments: str = Form(default=str())): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -839,7 +839,7 @@ def pkgbase_unflag_instance(request: Request, pkgbase: models.PackageBase): @router.post("/pkgbase/{name}/unflag") -@auth_required(True, redirect="/pkgbase/{name}") +@auth_required(True) async def pkgbase_unflag(request: Request, name: str): pkgbase = get_pkg_or_base(name, models.PackageBase) pkgbase_unflag_instance(request, pkgbase) @@ -860,7 +860,7 @@ def pkgbase_notify_instance(request: Request, pkgbase: models.PackageBase): @router.post("/pkgbase/{name}/notify") -@auth_required(True, redirect="/pkgbase/{name}") +@auth_required(True) async def pkgbase_notify(request: Request, name: str): pkgbase = get_pkg_or_base(name, models.PackageBase) pkgbase_notify_instance(request, pkgbase) @@ -879,7 +879,7 @@ def pkgbase_unnotify_instance(request: Request, pkgbase: models.PackageBase): @router.post("/pkgbase/{name}/unnotify") -@auth_required(True, redirect="/pkgbase/{name}") +@auth_required(True) async def pkgbase_unnotify(request: Request, name: str): pkgbase = get_pkg_or_base(name, models.PackageBase) pkgbase_unnotify_instance(request, pkgbase) @@ -888,7 +888,7 @@ async def pkgbase_unnotify(request: Request, name: str): @router.post("/pkgbase/{name}/vote") -@auth_required(True, redirect="/pkgbase/{name}") +@auth_required(True) async def pkgbase_vote(request: Request, name: str): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -912,7 +912,7 @@ async def pkgbase_vote(request: Request, name: str): @router.post("/pkgbase/{name}/unvote") -@auth_required(True, redirect="/pkgbase/{name}") +@auth_required(True) async def pkgbase_unvote(request: Request, name: str): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -954,7 +954,7 @@ def pkgbase_disown_instance(request: Request, pkgbase: models.PackageBase): @router.get("/pkgbase/{name}/disown") -@auth_required(True, redirect="/pkgbase/{name}/disown") +@auth_required(True) async def pkgbase_disown_get(request: Request, name: str): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -970,7 +970,7 @@ async def pkgbase_disown_get(request: Request, name: str): @router.post("/pkgbase/{name}/disown") -@auth_required(True, redirect="/pkgbase/{name}/disown") +@auth_required(True) async def pkgbase_disown_post(request: Request, name: str, confirm: bool = Form(default=False)): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -1003,7 +1003,7 @@ def pkgbase_adopt_instance(request: Request, pkgbase: models.PackageBase): @router.post("/pkgbase/{name}/adopt") -@auth_required(True, redirect="/pkgbase/{name}") +@auth_required(True) async def pkgbase_adopt_post(request: Request, name: str): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -1019,7 +1019,7 @@ async def pkgbase_adopt_post(request: Request, name: str): @router.get("/pkgbase/{name}/delete") -@auth_required(True, redirect="/pkgbase/{name}/delete") +@auth_required(True) async def pkgbase_delete_get(request: Request, name: str): if not request.user.has_credential(creds.PKGBASE_DELETE): return RedirectResponse(f"/pkgbase/{name}", @@ -1031,7 +1031,7 @@ async def pkgbase_delete_get(request: Request, name: str): @router.post("/pkgbase/{name}/delete") -@auth_required(True, redirect="/pkgbase/{name}/delete") +@auth_required(True) async def pkgbase_delete_post(request: Request, name: str, confirm: bool = Form(default=False)): pkgbase = get_pkg_or_base(name, models.PackageBase) @@ -1279,7 +1279,7 @@ PACKAGE_ACTIONS = { @router.post("/packages") -@auth_required(redirect="/packages") +@auth_required() async def packages_post(request: Request, IDs: List[int] = Form(default=[]), action: str = Form(default=str()), @@ -1311,7 +1311,7 @@ async def packages_post(request: Request, @router.get("/pkgbase/{name}/merge") -@auth_required(redirect="/pkgbase/{name}/merge") +@auth_required() async def pkgbase_merge_get(request: Request, name: str, into: str = Query(default=str()), next: str = Query(default=str())): @@ -1423,7 +1423,7 @@ def pkgbase_merge_instance(request: Request, pkgbase: models.PackageBase, @router.post("/pkgbase/{name}/merge") -@auth_required(redirect="/pkgbase/{name}/merge") +@auth_required() async def pkgbase_merge_post(request: Request, name: str, into: str = Form(default=str()), confirm: bool = Form(default=False), diff --git a/aurweb/routers/trusted_user.py b/aurweb/routers/trusted_user.py index f0cea61e..09de58fe 100644 --- a/aurweb/routers/trusted_user.py +++ b/aurweb/routers/trusted_user.py @@ -41,7 +41,7 @@ ADDVOTE_SPECIFICS = { @router.get("/tu") -@auth_required(True, redirect="/tu") +@auth_required(True) @account_type_required(REQUIRED_TYPES) async def trusted_user(request: Request, coff: int = 0, # current offset @@ -147,7 +147,7 @@ def render_proposal(request: Request, @router.get("/tu/{proposal}") -@auth_required(True, redirect="/tu/{proposal}") +@auth_required(True) @account_type_required(REQUIRED_TYPES) async def trusted_user_proposal(request: Request, proposal: int): context = await make_variable_context(request, "Trusted User") @@ -176,7 +176,7 @@ async def trusted_user_proposal(request: Request, proposal: int): @router.post("/tu/{proposal}") -@auth_required(True, redirect="/tu/{proposal}") +@auth_required(True) @account_type_required(REQUIRED_TYPES) async def trusted_user_proposal_post(request: Request, proposal: int, @@ -227,7 +227,7 @@ async def trusted_user_proposal_post(request: Request, @router.get("/addvote") -@auth_required(True, redirect="/addvote") +@auth_required(True) @account_type_required({TRUSTED_USER, TRUSTED_USER_AND_DEV}) async def trusted_user_addvote(request: Request, user: str = str(), @@ -247,7 +247,7 @@ async def trusted_user_addvote(request: Request, @router.post("/addvote") -@auth_required(True, redirect="/addvote") +@auth_required(True) @account_type_required({TRUSTED_USER, TRUSTED_USER_AND_DEV}) async def trusted_user_addvote_post(request: Request, user: str = Form(default=str()), diff --git a/test/test_trusted_user_routes.py b/test/test_trusted_user_routes.py index 43a3443b..ac7f82d5 100644 --- a/test/test_trusted_user_routes.py +++ b/test/test_trusted_user_routes.py @@ -9,7 +9,7 @@ import pytest from fastapi.testclient import TestClient -from aurweb import db, util +from aurweb import config, db, util from aurweb.models.account_type import AccountType from aurweb.models.tu_vote import TUVote from aurweb.models.tu_voteinfo import TUVoteInfo @@ -124,8 +124,9 @@ def proposal(user, tu_user): def test_tu_index_guest(client): + headers = {"referer": config.get("options", "aur_location") + "/tu"} with client as request: - response = request.get("/tu", allow_redirects=False) + response = request.get("/tu", allow_redirects=False, headers=headers) assert response.status_code == int(HTTPStatus.SEE_OTHER) params = util.urlencode({"next": "/tu"})