diff --git a/aurweb/auth.py b/aurweb/auth.py index 9f56f90f..52a4260c 100644 --- a/aurweb/auth.py +++ b/aurweb/auth.py @@ -1,4 +1,5 @@ import functools +import re from datetime import datetime from http import HTTPStatus @@ -121,6 +122,7 @@ class BasicAuthBackend(AuthenticationBackend): def auth_required(is_required: bool = True, + login: bool = True, redirect: str = "/", template: tuple = None, status_code: HTTPStatus = HTTPStatus.UNAUTHORIZED): @@ -150,6 +152,7 @@ def auth_required(is_required: bool = True, applying any format operations. :param is_required: A boolean indicating whether the function requires auth + :param login: Redirect to `/login`, passing `next=` :param redirect: Path to redirect to if is_required isn't True :param template: A three-element template tuple: (path, title_iterable, variable_iterable) @@ -162,8 +165,16 @@ def auth_required(is_required: bool = True, async def wrapper(request, *args, **kwargs): if request.user.is_authenticated() != is_required: url = "/" + if redirect: - url = redirect + path_params_expr = re.compile(r'\{(\w+)\}') + match = re.findall(path_params_expr, redirect) + args = {k: request.path_params.get(k) for k in match} + url = redirect.format(**args) + + if login: + url = "/login?" + util.urlencode({"next": url}) + if template: # template=("template.html", # ["Some Title", "someFormatted {}"], diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py index 3c799938..fc1c5242 100644 --- a/aurweb/routers/accounts.py +++ b/aurweb/routers/accounts.py @@ -30,14 +30,14 @@ logger = logging.getLogger(__name__) @router.get("/passreset", response_class=HTMLResponse) -@auth_required(False) +@auth_required(False, login=False) async def passreset(request: Request): context = await make_variable_context(request, "Password Reset") return render_template(request, "passreset.html", context) @router.post("/passreset", response_class=HTMLResponse) -@auth_required(False) +@auth_required(False, login=False) async def passreset_post(request: Request, user: str = Form(...), resetkey: str = Form(default=None), @@ -315,7 +315,7 @@ def make_account_form_context(context: dict, @router.get("/register", response_class=HTMLResponse) -@auth_required(False) +@auth_required(False, login=False) async def account_register(request: Request, U: str = Form(default=str()), # Username E: str = Form(default=str()), # Email @@ -341,7 +341,7 @@ async def account_register(request: Request, @router.post("/register", response_class=HTMLResponse) -@auth_required(False) +@auth_required(False, login=False) async def account_register_post(request: Request, U: str = Form(default=str()), # Username E: str = Form(default=str()), # Email @@ -432,7 +432,7 @@ def cannot_edit(request, user): @router.get("/account/{username}/edit", response_class=HTMLResponse) -@auth_required(True) +@auth_required(True, redirect="/account/{username}") async def account_edit(request: Request, username: str): user = db.query(User, User.Username == username).first() @@ -448,7 +448,7 @@ async def account_edit(request: Request, @router.post("/account/{username}/edit", response_class=HTMLResponse) -@auth_required(True) +@auth_required(True, redirect="/account/{username}") async def account_edit_post(request: Request, username: str, U: str = Form(default=str()), # Username @@ -594,7 +594,7 @@ async def account(request: Request, username: str): @router.get("/accounts/") -@auth_required(True) +@auth_required(True, redirect="/accounts/") @account_type_required({TRUSTED_USER, DEVELOPER, TRUSTED_USER_AND_DEV}) async def accounts(request: Request): context = make_context(request, "Accounts") @@ -602,7 +602,7 @@ async def accounts(request: Request): @router.post("/accounts/") -@auth_required(True) +@auth_required(True, redirect="/accounts/") @account_type_required({TRUSTED_USER, DEVELOPER, TRUSTED_USER_AND_DEV}) async def accounts_post(request: Request, O: int = Form(default=0), # Offset @@ -688,7 +688,7 @@ def render_terms_of_service(request: Request, @router.get("/tos") -@auth_required(True, redirect="/") +@auth_required(True, redirect="/tos") async def terms_of_service(request: Request): # Query the database for terms that were previously accepted, # but now have a bumped Revision that needs to be accepted. @@ -709,7 +709,7 @@ async def terms_of_service(request: Request): @router.post("/tos") -@auth_required(True, redirect="/") +@auth_required(True, redirect="/tos") async def terms_of_service_post(request: Request, accept: bool = Form(default=False)): # Query the database for terms that were previously accepted, diff --git a/aurweb/routers/packages.py b/aurweb/routers/packages.py index a3effb36..ee6d71ba 100644 --- a/aurweb/routers/packages.py +++ b/aurweb/routers/packages.py @@ -222,7 +222,7 @@ async def package_base_voters(request: Request, name: str) -> Response: @router.post("/pkgbase/{name}/comments") -@auth_required(True) +@auth_required(True, redirect="/pkgbase/{name}/comments") async def pkgbase_comments_post( request: Request, name: str, comment: str = Form(default=str()), @@ -254,7 +254,7 @@ async def pkgbase_comments_post( @router.get("/pkgbase/{name}/comments/{id}/form") -@auth_required(True) +@auth_required(True, login=False) async def pkgbase_comment_form(request: Request, name: str, id: int): """ Produce a comment form for comment {id}. """ pkgbase = get_pkg_or_base(name, PackageBase) @@ -274,7 +274,7 @@ async def pkgbase_comment_form(request: Request, name: str, id: int): @router.post("/pkgbase/{name}/comments/{id}") -@auth_required(True) +@auth_required(True, redirect="/pkgbase/{name}/comments/{id}") async def pkgbase_comment_post( request: Request, name: str, id: int, comment: str = Form(default=str()), @@ -309,7 +309,7 @@ async def pkgbase_comment_post( @router.post("/pkgbase/{name}/comments/{id}/delete") -@auth_required(True) +@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/delete") async def pkgbase_comment_delete(request: Request, name: str, id: int): pkgbase = get_pkg_or_base(name, PackageBase) comment = get_pkgbase_comment(pkgbase, id) @@ -332,7 +332,7 @@ async def pkgbase_comment_delete(request: Request, name: str, id: int): @router.post("/pkgbase/{name}/comments/{id}/undelete") -@auth_required(True) +@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/undelete") async def pkgbase_comment_undelete(request: Request, name: str, id: int): pkgbase = get_pkg_or_base(name, PackageBase) comment = get_pkgbase_comment(pkgbase, id) @@ -354,7 +354,7 @@ async def pkgbase_comment_undelete(request: Request, name: str, id: int): @router.post("/pkgbase/{name}/comments/{id}/pin") -@auth_required(True) +@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/pin") async def pkgbase_comment_pin(request: Request, name: str, id: int): pkgbase = get_pkg_or_base(name, PackageBase) comment = get_pkgbase_comment(pkgbase, id) @@ -376,7 +376,7 @@ async def pkgbase_comment_pin(request: Request, name: str, id: int): @router.post("/pkgbase/{name}/comments/{id}/unpin") -@auth_required(True) +@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/unpin") async def pkgbase_comment_unpin(request: Request, name: str, id: int): pkgbase = get_pkg_or_base(name, PackageBase) comment = get_pkgbase_comment(pkgbase, id) @@ -397,7 +397,7 @@ async def pkgbase_comment_unpin(request: Request, name: str, id: int): @router.get("/pkgbase/{name}/comaintainers") -@auth_required(True) +@auth_required(True, redirect="/pkgbase/{name}/comaintainers") async def package_base_comaintainers(request: Request, name: str) -> Response: # Get the PackageBase. pkgbase = get_pkg_or_base(name, PackageBase) @@ -444,7 +444,7 @@ def remove_users(pkgbase, usernames): @router.post("/pkgbase/{name}/comaintainers") -@auth_required(True) +@auth_required(True, redirect="/pkgbase/{name}/comaintainers") async def package_base_comaintainers_post( request: Request, name: str, users: str = Form(default=str())) -> Response: @@ -539,7 +539,7 @@ async def package_base_comaintainers_post( @router.get("/requests") -@auth_required(True, redirect="/") +@auth_required(True, redirect="/requests") async def requests(request: Request, O: int = Query(default=defaults.O), PP: int = Query(default=defaults.PP)): @@ -571,7 +571,7 @@ async def requests(request: Request, @router.get("/pkgbase/{name}/request") -@auth_required(True) +@auth_required(True, redirect="/pkgbase/{name}") async def package_request(request: Request, name: str): context = make_context(request, "Submit Request") @@ -585,7 +585,7 @@ async def package_request(request: Request, name: str): @router.post("/pkgbase/{name}/request") -@auth_required(True) +@auth_required(True, redirect="/pkgbase/{name}/request") async def pkgbase_request_post(request: Request, name: str, type: str = Form(...), merge_into: str = Form(default=None), @@ -654,7 +654,7 @@ async def pkgbase_request_post(request: Request, name: str, @router.get("/requests/{id}/close") -@auth_required(True) +@auth_required(True, redirect="/requests/{id}/close") async def requests_close(request: Request, id: int): pkgreq = db.query(PackageRequest).filter(PackageRequest.ID == id).first() if not request.user.is_elevated() and request.user != pkgreq.User: @@ -667,7 +667,7 @@ async def requests_close(request: Request, id: int): @router.post("/requests/{id}/close") -@auth_required(True) +@auth_required(True, redirect="/requests/{id}/close") async def requests_close_post(request: Request, id: int, reason: int = Form(default=0), comments: str = Form(default=str())): diff --git a/aurweb/routers/trusted_user.py b/aurweb/routers/trusted_user.py index a977b31a..b897a635 100644 --- a/aurweb/routers/trusted_user.py +++ b/aurweb/routers/trusted_user.py @@ -45,7 +45,7 @@ ADDVOTE_SPECIFICS = { @router.get("/tu") -@auth_required(True, redirect="/") +@auth_required(True, redirect="/tu") @account_type_required(REQUIRED_TYPES) async def trusted_user(request: Request, coff: int = 0, # current offset @@ -149,7 +149,7 @@ def render_proposal(request: Request, @router.get("/tu/{proposal}") -@auth_required(True, redirect="/") +@auth_required(True, redirect="/tu/{proposal}") @account_type_required(REQUIRED_TYPES) async def trusted_user_proposal(request: Request, proposal: int): context = await make_variable_context(request, "Trusted User") @@ -175,7 +175,7 @@ async def trusted_user_proposal(request: Request, proposal: int): @router.post("/tu/{proposal}") -@auth_required(True, redirect="/") +@auth_required(True, redirect="/tu/{proposal}") @account_type_required(REQUIRED_TYPES) async def trusted_user_proposal_post(request: Request, proposal: int, @@ -223,7 +223,7 @@ async def trusted_user_proposal_post(request: Request, @router.get("/addvote") -@auth_required(True) +@auth_required(True, redirect="/addvote") @account_type_required({"Trusted User", "Trusted User & Developer"}) async def trusted_user_addvote(request: Request, user: str = str(), @@ -243,7 +243,7 @@ async def trusted_user_addvote(request: Request, @router.post("/addvote") -@auth_required(True) +@auth_required(True, redirect="/addvote") @account_type_required({TRUSTED_USER, TRUSTED_USER_AND_DEV}) async def trusted_user_addvote_post(request: Request, user: str = Form(default=str()), diff --git a/templates/partials/packages/actions.html b/templates/partials/packages/actions.html index a54d4c90..f1863663 100644 --- a/templates/partials/packages/actions.html +++ b/templates/partials/packages/actions.html @@ -23,99 +23,68 @@ {{ "Search wiki" | tr }} - {% if not request.user.is_authenticated() %} - {% if not out_of_date %} + {% if not out_of_date %}
  • {{ "Flag package out-of-date" | tr }}
  • - {% else %} -
  • - - {% set ood_ts = result.OutOfDateTS | dt | as_timezone(timezone) %} - {{ - "Flagged out-of-date (%s)" - | tr | format(ood_ts.strftime("%Y-%m-%d")) - }} - -
  • - {% endif %} -
  • - - {{ "Vote for this package" | tr }} - -
  • -
  • - - {{ "Enable notifications" | tr }} - -
  • {% else %} - {% if not out_of_date %} -
  • - - {{ "Flag package out-of-date" | tr }} - -
  • - {% else %} -
  • - - {% set ood_ts = result.OutOfDateTS | dt | as_timezone(timezone) %} - {{ - "Flagged out-of-date (%s)" - | tr | format(ood_ts.strftime("%Y-%m-%d")) - }} - -
  • -
  • -
    - + + {% set ood_ts = result.OutOfDateTS | dt | as_timezone(timezone) %} + {{ + "Flagged out-of-date (%s)" + | tr | format(ood_ts.strftime("%Y-%m-%d")) + }} + +
  • +
  • + + + +
  • + {% endif %} +
  • + {% if not voted %} +
    + +
    + {% else %} +
    + +
    + {% endif %} +
  • +
  • + {% if notified %} +
    +
    -
  • - {% endif %} -
  • - {% if not voted %} -
    + {% else %} + + name="do_Notify" + value="{{ 'Enable notifications' | tr }}" + />
    - {% else %} -
    - -
    - {% endif %} -
  • -
  • - {% if notified %} -
    - -
    - {% else %} -
    - -
    - {% endif %} -
  • - - {% endif %} + {% endif %} + {% if request.user.has_credential('CRED_PKGBASE_EDIT_COMAINTAINERS', approved=[pkgbase.Maintainer]) %}
  • @@ -132,15 +101,9 @@
  • {% endif %}
  • - {% if not request.user.is_authenticated() %} - - {{ "Submit Request" | tr }} - - {% else %} {{ "Submit Request" | tr }} - {% endif %}
  • {% if request.user.has_credential("CRED_PKGBASE_DELETE") %}
  • diff --git a/test/test_trusted_user_routes.py b/test/test_trusted_user_routes.py index 67181db3..0579247e 100644 --- a/test/test_trusted_user_routes.py +++ b/test/test_trusted_user_routes.py @@ -9,7 +9,7 @@ import pytest from fastapi.testclient import TestClient -from aurweb import db +from aurweb import db, util from aurweb.models.account_type import AccountType from aurweb.models.tu_vote import TUVote from aurweb.models.tu_voteinfo import TUVoteInfo @@ -128,7 +128,9 @@ def test_tu_index_guest(client): with client as request: response = request.get("/tu", allow_redirects=False) assert response.status_code == int(HTTPStatus.SEE_OTHER) - assert response.headers.get("location") == "/" + + params = util.urlencode({"next": "/tu"}) + assert response.headers.get("location") == f"/login?{params}" def test_tu_index_unauthorized(client, user):