diff --git a/aurweb/auth/__init__.py b/aurweb/auth/__init__.py index 18356ac2..b6dd6e3f 100644 --- a/aurweb/auth/__init__.py +++ b/aurweb/auth/__init__.py @@ -120,35 +120,39 @@ class BasicAuthBackend(AuthenticationBackend): return (AuthCredentials(["authenticated"]), user) -def auth_required(is_required: bool = True): - """ Authentication route decorator. +def auth_required(auth_goal: bool = True): + """ Enforce a user's authentication status, bringing them to the login page + or homepage if their authentication status does not match the goal. - :param is_required: A boolean indicating whether the function requires auth - :param status_code: An optional status_code for template render. - Redirects are always SEE_OTHER. + :param auth_goal: Whether authentication is required or entirely disallowed + for a user to perform this request. + :return: Return the FastAPI function this decorator wraps. """ def decorator(func): @functools.wraps(func) async def wrapper(request, *args, **kwargs): - if request.user.is_authenticated() != is_required: - url = "/" + if request.user.is_authenticated() == auth_goal: + return await func(request, *args, **kwargs) - if is_required: - if request.method == "GET": - url = request.url.path - elif request.method == "POST" and (referer := request.headers.get("Referer")): - aur = aurweb.config.get("options", "aur_location") + "/" - if not referer.startswith(aur): - _ = l10n.get_translator_for_request(request) - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, - detail=_("Bad Referer header.")) - url = referer[len(aur) - 1:] + url = "/" + if auth_goal is False: + return RedirectResponse(url, status_code=int(HTTPStatus.SEE_OTHER)) - url = "/login?" + util.urlencode({"next": url}) - return RedirectResponse(url, - status_code=int(HTTPStatus.SEE_OTHER)) - return await func(request, *args, **kwargs) + # Use the request path when the user can visit a page directly but + # is not authenticated and use the Referer header if visiting the + # page itself is not directly possible (e.g. submitting a form). + if request.method in ("GET", "HEAD"): + url = request.url.path + elif (referer := request.headers.get("Referer")): + aur = aurweb.config.get("options", "aur_location") + "/" + if not referer.startswith(aur): + _ = l10n.get_translator_for_request(request) + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, + detail=_("Bad Referer header.")) + url = referer[len(aur) - 1:] + url = "/login?" + util.urlencode({"next": url}) + return RedirectResponse(url, status_code=int(HTTPStatus.SEE_OTHER)) return wrapper return decorator