From 51b60f4210d3f4c78b5b914e785f7d6fc9da0f38 Mon Sep 17 00:00:00 2001 From: Kevin Morris Date: Sun, 2 Jan 2022 16:14:15 -0800 Subject: [PATCH] feat(auth): add requires_{auth,guest} decorators These new decorators are meant to be used without any arguments and provide aliases to auth_required: - `auth_required(True) -> requires_auth` - `auth_required(False) -> requires_guest` These decorators should be used without arguments, e.g.: @router.get("/") @requires_guest async def my_route(request: Request): return HTMLResponse() Signed-off-by: Kevin Morris --- aurweb/auth/__init__.py | 28 ++++++++++++++++-- aurweb/routers/accounts.py | 24 +++++++-------- aurweb/routers/auth.py | 6 ++-- aurweb/routers/packages.py | 4 +-- aurweb/routers/pkgbase.py | 54 +++++++++++++++++----------------- aurweb/routers/requests.py | 8 ++--- aurweb/routers/trusted_user.py | 12 ++++---- test/test_auth.py | 4 +-- 8 files changed, 82 insertions(+), 58 deletions(-) diff --git a/aurweb/auth/__init__.py b/aurweb/auth/__init__.py index b683b1df..3befd6ee 100644 --- a/aurweb/auth/__init__.py +++ b/aurweb/auth/__init__.py @@ -2,6 +2,7 @@ import functools from datetime import datetime from http import HTTPStatus +from typing import Callable import fastapi @@ -129,10 +130,15 @@ class BasicAuthBackend(AuthenticationBackend): return (AuthCredentials(["authenticated"]), user) -def auth_required(auth_goal: bool = True): - """ Enforce a user's authentication status, bringing them to the login page +def _auth_required(auth_goal: bool = True): + """ + Enforce a user's authentication status, bringing them to the login page or homepage if their authentication status does not match the goal. + NOTE: This function should not need to be used in downstream code. + See `requires_auth` and `requires_guest` for decorators meant to be + used on routes (they're a bit more implicitly understandable). + :param auth_goal: Whether authentication is required or entirely disallowed for a user to perform this request. :return: Return the FastAPI function this decorator wraps. @@ -167,6 +173,24 @@ def auth_required(auth_goal: bool = True): return decorator +def requires_auth(func: Callable) -> Callable: + """ Require an authenticated session for a particular route. """ + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await _auth_required(True)(func)(*args, **kwargs) + return wrapper + + +def requires_guest(func: Callable) -> Callable: + """ Require a guest (unauthenticated) session for a particular route. """ + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + return await _auth_required(False)(func)(*args, **kwargs) + return wrapper + + def account_type_required(one_of: set): """ A decorator that can be used on FastAPI routes to dictate that a user belongs to one of the types defined in one_of. diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py index 6fffd79c..33ca9ed7 100644 --- a/aurweb/routers/accounts.py +++ b/aurweb/routers/accounts.py @@ -10,7 +10,7 @@ from sqlalchemy import and_, or_ import aurweb.config from aurweb import cookies, db, l10n, logging, models, util -from aurweb.auth import account_type_required, auth_required +from aurweb.auth import account_type_required, requires_auth, requires_guest from aurweb.captcha import get_captcha_salts from aurweb.exceptions import ValidationError from aurweb.l10n import get_translator_for_request @@ -27,14 +27,14 @@ logger = logging.get_logger(__name__) @router.get("/passreset", response_class=HTMLResponse) -@auth_required(False) +@requires_guest async def passreset(request: Request): context = await make_variable_context(request, "Password Reset") return render_template(request, "passreset.html", context) @router.post("/passreset", response_class=HTMLResponse) -@auth_required(False) +@requires_guest async def passreset_post(request: Request, user: str = Form(...), resetkey: str = Form(default=None), @@ -224,7 +224,7 @@ def make_account_form_context(context: dict, @router.get("/register", response_class=HTMLResponse) -@auth_required(False) +@requires_guest async def account_register(request: Request, U: str = Form(default=str()), # Username E: str = Form(default=str()), # Email @@ -250,7 +250,7 @@ async def account_register(request: Request, @router.post("/register", response_class=HTMLResponse) -@auth_required(False) +@requires_guest async def account_register_post(request: Request, U: str = Form(default=str()), # Username E: str = Form(default=str()), # Email @@ -348,7 +348,7 @@ def cannot_edit(request: Request, user: models.User) \ @router.get("/account/{username}/edit", response_class=HTMLResponse) -@auth_required() +@requires_auth async def account_edit(request: Request, username: str): user = db.query(models.User, models.User.Username == username).first() @@ -364,7 +364,7 @@ async def account_edit(request: Request, username: str): @router.post("/account/{username}/edit", response_class=HTMLResponse) -@auth_required() +@requires_auth async def account_edit_post(request: Request, username: str, U: str = Form(default=str()), # Username @@ -461,7 +461,7 @@ async def account(request: Request, username: str): @router.get("/account/{username}/comments") -@auth_required() +@requires_auth async def account_comments(request: Request, username: str): user = get_user_by_name(username) context = make_context(request, "Accounts") @@ -472,7 +472,7 @@ async def account_comments(request: Request, username: str): @router.get("/accounts") -@auth_required() +@requires_auth @account_type_required({at.TRUSTED_USER, at.DEVELOPER, at.TRUSTED_USER_AND_DEV}) @@ -482,7 +482,7 @@ async def accounts(request: Request): @router.post("/accounts") -@auth_required() +@requires_auth @account_type_required({at.TRUSTED_USER, at.DEVELOPER, at.TRUSTED_USER_AND_DEV}) @@ -567,7 +567,7 @@ def render_terms_of_service(request: Request, @router.get("/tos") -@auth_required() +@requires_auth async def terms_of_service(request: Request): # Query the database for terms that were previously accepted, # but now have a bumped Revision that needs to be accepted. @@ -591,7 +591,7 @@ async def terms_of_service(request: Request): @router.post("/tos") -@auth_required() +@requires_auth async def terms_of_service_post(request: Request, accept: bool = Form(default=False)): # Query the database for terms that were previously accepted, diff --git a/aurweb/routers/auth.py b/aurweb/routers/auth.py index 8815c896..0b68dac3 100644 --- a/aurweb/routers/auth.py +++ b/aurweb/routers/auth.py @@ -7,7 +7,7 @@ from fastapi.responses import HTMLResponse, RedirectResponse import aurweb.config from aurweb import cookies, db -from aurweb.auth import auth_required +from aurweb.auth import requires_auth, requires_guest from aurweb.l10n import get_translator_for_request from aurweb.models import User from aurweb.templates import make_variable_context, render_template @@ -29,7 +29,7 @@ async def login_get(request: Request, next: str = "/"): @router.post("/login", response_class=HTMLResponse) -@auth_required(False) +@requires_guest async def login_post(request: Request, next: str = Form(...), user: str = Form(default=str()), @@ -81,7 +81,7 @@ async def login_post(request: Request, @router.post("/logout") -@auth_required() +@requires_auth async def logout(request: Request, next: str = Form(default="/")): if request.user.is_authenticated(): request.user.logout(request) diff --git a/aurweb/routers/packages.py b/aurweb/routers/packages.py index 1930a7a2..ca37ff50 100644 --- a/aurweb/routers/packages.py +++ b/aurweb/routers/packages.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, Form, Request, Response import aurweb.filters # noqa: F401 from aurweb import config, db, defaults, logging, models, util -from aurweb.auth import auth_required, creds +from aurweb.auth import creds, requires_auth from aurweb.exceptions import InvariantError from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID from aurweb.packages import util as pkgutil @@ -406,7 +406,7 @@ PACKAGE_ACTIONS = { @router.post("/packages") -@auth_required() +@requires_auth async def packages_post(request: Request, IDs: List[int] = Form(default=[]), action: str = Form(default=str()), diff --git a/aurweb/routers/pkgbase.py b/aurweb/routers/pkgbase.py index e9fdd337..ab7fde88 100644 --- a/aurweb/routers/pkgbase.py +++ b/aurweb/routers/pkgbase.py @@ -6,7 +6,7 @@ from fastapi.responses import JSONResponse, RedirectResponse from sqlalchemy import and_ from aurweb import config, db, l10n, logging, templates, util -from aurweb.auth import auth_required, creds +from aurweb.auth import creds, requires_auth from aurweb.exceptions import InvariantError, ValidationError from aurweb.models import PackageBase from aurweb.models.package_comment import PackageComment @@ -116,7 +116,7 @@ async def pkgbase_keywords(request: Request, name: str, @router.get("/pkgbase/{name}/flag") -@auth_required() +@requires_auth async def pkgbase_flag_get(request: Request, name: str): pkgbase = get_pkg_or_base(name, PackageBase) @@ -131,7 +131,7 @@ async def pkgbase_flag_get(request: Request, name: str): @router.post("/pkgbase/{name}/flag") -@auth_required() +@requires_auth async def pkgbase_flag_post(request: Request, name: str, comments: str = Form(default=str())): pkgbase = get_pkg_or_base(name, PackageBase) @@ -157,7 +157,7 @@ async def pkgbase_flag_post(request: Request, name: str, @router.post("/pkgbase/{name}/comments") -@auth_required() +@requires_auth async def pkgbase_comments_post( request: Request, name: str, comment: str = Form(default=str()), @@ -189,7 +189,7 @@ async def pkgbase_comments_post( @router.get("/pkgbase/{name}/comments/{id}/form") -@auth_required() +@requires_auth async def pkgbase_comment_form(request: Request, name: str, id: int, next: str = Query(default=None)): """ @@ -229,7 +229,7 @@ async def pkgbase_comment_form(request: Request, name: str, id: int, @router.get("/pkgbase/{name}/comments/{id}/edit") -@auth_required() +@requires_auth async def pkgbase_comment_edit(request: Request, name: str, id: int, next: str = Form(default=None)): """ @@ -253,7 +253,7 @@ async def pkgbase_comment_edit(request: Request, name: str, id: int, @router.post("/pkgbase/{name}/comments/{id}") -@auth_required() +@requires_auth async def pkgbase_comment_post( request: Request, name: str, id: int, comment: str = Form(default=str()), @@ -293,7 +293,7 @@ async def pkgbase_comment_post( @router.post("/pkgbase/{name}/comments/{id}/pin") -@auth_required() +@requires_auth async def pkgbase_comment_pin(request: Request, name: str, id: int, next: str = Form(default=None)): """ @@ -327,7 +327,7 @@ async def pkgbase_comment_pin(request: Request, name: str, id: int, @router.post("/pkgbase/{name}/comments/{id}/unpin") -@auth_required() +@requires_auth async def pkgbase_comment_unpin(request: Request, name: str, id: int, next: str = Form(default=None)): """ @@ -360,7 +360,7 @@ async def pkgbase_comment_unpin(request: Request, name: str, id: int, @router.post("/pkgbase/{name}/comments/{id}/delete") -@auth_required() +@requires_auth async def pkgbase_comment_delete(request: Request, name: str, id: int, next: str = Form(default=None)): """ @@ -399,7 +399,7 @@ async def pkgbase_comment_delete(request: Request, name: str, id: int, @router.post("/pkgbase/{name}/comments/{id}/undelete") -@auth_required() +@requires_auth async def pkgbase_comment_undelete(request: Request, name: str, id: int, next: str = Form(default=None)): """ @@ -437,7 +437,7 @@ async def pkgbase_comment_undelete(request: Request, name: str, id: int, @router.post("/pkgbase/{name}/vote") -@auth_required() +@requires_auth async def pkgbase_vote(request: Request, name: str): pkgbase = get_pkg_or_base(name, PackageBase) @@ -461,7 +461,7 @@ async def pkgbase_vote(request: Request, name: str): @router.post("/pkgbase/{name}/unvote") -@auth_required() +@requires_auth async def pkgbase_unvote(request: Request, name: str): pkgbase = get_pkg_or_base(name, PackageBase) @@ -481,7 +481,7 @@ async def pkgbase_unvote(request: Request, name: str): @router.post("/pkgbase/{name}/notify") -@auth_required() +@requires_auth async def pkgbase_notify(request: Request, name: str): pkgbase = get_pkg_or_base(name, PackageBase) actions.pkgbase_notify_instance(request, pkgbase) @@ -490,7 +490,7 @@ async def pkgbase_notify(request: Request, name: str): @router.post("/pkgbase/{name}/unnotify") -@auth_required() +@requires_auth async def pkgbase_unnotify(request: Request, name: str): pkgbase = get_pkg_or_base(name, PackageBase) actions.pkgbase_unnotify_instance(request, pkgbase) @@ -499,7 +499,7 @@ async def pkgbase_unnotify(request: Request, name: str): @router.post("/pkgbase/{name}/unflag") -@auth_required() +@requires_auth async def pkgbase_unflag(request: Request, name: str): pkgbase = get_pkg_or_base(name, PackageBase) actions.pkgbase_unflag_instance(request, pkgbase) @@ -508,7 +508,7 @@ async def pkgbase_unflag(request: Request, name: str): @router.get("/pkgbase/{name}/disown") -@auth_required() +@requires_auth async def pkgbase_disown_get(request: Request, name: str): pkgbase = get_pkg_or_base(name, PackageBase) @@ -524,7 +524,7 @@ async def pkgbase_disown_get(request: Request, name: str): @router.post("/pkgbase/{name}/disown") -@auth_required() +@requires_auth async def pkgbase_disown_post(request: Request, name: str, comments: str = Form(default=str()), confirm: bool = Form(default=False)): @@ -559,7 +559,7 @@ async def pkgbase_disown_post(request: Request, name: str, @router.post("/pkgbase/{name}/adopt") -@auth_required() +@requires_auth async def pkgbase_adopt_post(request: Request, name: str): pkgbase = get_pkg_or_base(name, PackageBase) @@ -575,7 +575,7 @@ async def pkgbase_adopt_post(request: Request, name: str): @router.get("/pkgbase/{name}/comaintainers") -@auth_required() +@requires_auth async def pkgbase_comaintainers(request: Request, name: str) -> Response: # Get the PackageBase. pkgbase = get_pkg_or_base(name, PackageBase) @@ -601,7 +601,7 @@ async def pkgbase_comaintainers(request: Request, name: str) -> Response: @router.post("/pkgbase/{name}/comaintainers") -@auth_required() +@requires_auth async def pkgbase_comaintainers_post(request: Request, name: str, users: str = Form(default=str())) \ -> Response: @@ -643,7 +643,7 @@ async def pkgbase_comaintainers_post(request: Request, name: str, @router.get("/pkgbase/{name}/request") -@auth_required() +@requires_auth async def pkgbase_request(request: Request, name: str): pkgbase = get_pkg_or_base(name, PackageBase) context = await make_variable_context(request, "Submit Request") @@ -652,7 +652,7 @@ async def pkgbase_request(request: Request, name: str): @router.post("/pkgbase/{name}/request") -@auth_required() +@requires_auth async def pkgbase_request_post(request: Request, name: str, type: str = Form(...), merge_into: str = Form(default=None), @@ -732,7 +732,7 @@ async def pkgbase_request_post(request: Request, name: str, @router.get("/pkgbase/{name}/delete") -@auth_required() +@requires_auth async def pkgbase_delete_get(request: Request, name: str): if not request.user.has_credential(creds.PKGBASE_DELETE): return RedirectResponse(f"/pkgbase/{name}", @@ -744,7 +744,7 @@ async def pkgbase_delete_get(request: Request, name: str): @router.post("/pkgbase/{name}/delete") -@auth_required() +@requires_auth async def pkgbase_delete_post(request: Request, name: str, confirm: bool = Form(default=False), comments: str = Form(default=str())): @@ -779,7 +779,7 @@ async def pkgbase_delete_post(request: Request, name: str, @router.get("/pkgbase/{name}/merge") -@auth_required() +@requires_auth async def pkgbase_merge_get(request: Request, name: str, into: str = Query(default=str()), next: str = Query(default=str())): @@ -810,7 +810,7 @@ async def pkgbase_merge_get(request: Request, name: str, @router.post("/pkgbase/{name}/merge") -@auth_required() +@requires_auth async def pkgbase_merge_post(request: Request, name: str, into: str = Form(default=str()), comments: str = Form(default=str()), diff --git a/aurweb/routers/requests.py b/aurweb/routers/requests.py index 2c18c66a..4c976655 100644 --- a/aurweb/routers/requests.py +++ b/aurweb/routers/requests.py @@ -6,7 +6,7 @@ from fastapi.responses import RedirectResponse from sqlalchemy import case from aurweb import db, defaults, util -from aurweb.auth import auth_required, creds +from aurweb.auth import creds, requires_auth from aurweb.models import PackageRequest, User from aurweb.models.package_request import PENDING_ID, REJECTED_ID from aurweb.requests.util import get_pkgreq_by_id @@ -17,7 +17,7 @@ router = APIRouter() @router.get("/requests") -@auth_required() +@requires_auth async def requests(request: Request, O: int = Query(default=defaults.O), PP: int = Query(default=defaults.PP)): @@ -50,7 +50,7 @@ async def requests(request: Request, @router.get("/requests/{id}/close") -@auth_required() +@requires_auth async def request_close(request: Request, id: int): pkgreq = get_pkgreq_by_id(id) @@ -64,7 +64,7 @@ async def request_close(request: Request, id: int): @router.post("/requests/{id}/close") -@auth_required() +@requires_auth async def request_close_post(request: Request, id: int, comments: str = Form(default=str())): pkgreq = get_pkgreq_by_id(id) diff --git a/aurweb/routers/trusted_user.py b/aurweb/routers/trusted_user.py index fac68f04..bfc38bf6 100644 --- a/aurweb/routers/trusted_user.py +++ b/aurweb/routers/trusted_user.py @@ -10,7 +10,7 @@ from fastapi.responses import RedirectResponse, Response from sqlalchemy import and_, or_ from aurweb import db, l10n, logging, models -from aurweb.auth import account_type_required, auth_required +from aurweb.auth import account_type_required, requires_auth from aurweb.models.account_type import DEVELOPER, TRUSTED_USER, TRUSTED_USER_AND_DEV from aurweb.templates import make_context, make_variable_context, render_template @@ -41,7 +41,7 @@ ADDVOTE_SPECIFICS = { @router.get("/tu") -@auth_required() +@requires_auth @account_type_required(REQUIRED_TYPES) async def trusted_user(request: Request, coff: int = 0, # current offset @@ -147,7 +147,7 @@ def render_proposal(request: Request, @router.get("/tu/{proposal}") -@auth_required() +@requires_auth @account_type_required(REQUIRED_TYPES) async def trusted_user_proposal(request: Request, proposal: int): context = await make_variable_context(request, "Trusted User") @@ -176,7 +176,7 @@ async def trusted_user_proposal(request: Request, proposal: int): @router.post("/tu/{proposal}") -@auth_required() +@requires_auth @account_type_required(REQUIRED_TYPES) async def trusted_user_proposal_post(request: Request, proposal: int, @@ -227,7 +227,7 @@ async def trusted_user_proposal_post(request: Request, @router.get("/addvote") -@auth_required() +@requires_auth @account_type_required({TRUSTED_USER, TRUSTED_USER_AND_DEV}) async def trusted_user_addvote(request: Request, user: str = str(), @@ -247,7 +247,7 @@ async def trusted_user_addvote(request: Request, @router.post("/addvote") -@auth_required() +@requires_auth @account_type_required({TRUSTED_USER, TRUSTED_USER_AND_DEV}) async def trusted_user_addvote_post(request: Request, user: str = Form(default=str()), diff --git a/test/test_auth.py b/test/test_auth.py index 0094aa25..8e8b5859 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -7,7 +7,7 @@ from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from aurweb import config, db -from aurweb.auth import AnonymousUser, BasicAuthBackend, account_type_required, auth_required +from aurweb.auth import AnonymousUser, BasicAuthBackend, _auth_required, account_type_required from aurweb.models.account_type import USER, USER_ID from aurweb.models.session import Session from aurweb.models.user import User @@ -105,7 +105,7 @@ async def test_auth_required_redirection_bad_referrer(): pass # Get down to the nitty gritty internal wrapper. - bad_referrer_route = auth_required()(bad_referrer_route) + bad_referrer_route = _auth_required()(bad_referrer_route) # Execute the route with a "./blahblahblah" Referer, which does not # match aur_location; `./` has been used as a prefix to attempt to