auth_required: allow formattable template tuples

See docstring for updates.

template= has been modified.
status_code= has been added as an optional template status_code.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2021-06-21 21:35:05 -07:00
parent d7941e6bed
commit af76e660d0
3 changed files with 70 additions and 14 deletions

View file

@ -10,9 +10,10 @@ from starlette.requests import HTTPConnection
import aurweb.config import aurweb.config
from aurweb import l10n
from aurweb.models.session import Session from aurweb.models.session import Session
from aurweb.models.user import User from aurweb.models.user import User
from aurweb.templates import make_context, render_template from aurweb.templates import make_variable_context, render_template
class AnonymousUser: class AnonymousUser:
@ -60,7 +61,8 @@ class BasicAuthBackend(AuthenticationBackend):
def auth_required(is_required: bool = True, def auth_required(is_required: bool = True,
redirect: str = "/", redirect: str = "/",
template: tuple = None): template: tuple = None,
status_code: HTTPStatus = HTTPStatus.UNAUTHORIZED):
""" Authentication route decorator. """ Authentication route decorator.
If redirect is given, the user will be redirected if the auth state If redirect is given, the user will be redirected if the auth state
@ -69,26 +71,73 @@ def auth_required(is_required: bool = True,
If template is given, it will be rendered with Unauthorized if If template is given, it will be rendered with Unauthorized if
is_required does not match and take priority over redirect. is_required does not match and take priority over redirect.
A precondition of this function is that, if template is provided,
it **must** match the following format:
template=("template.html", ["Some Template For", "{}"], ["username"])
Where `username` is a FastAPI request path parameter, fitting
a route like: `/some_route/{username}`.
If you wish to supply a non-formatted template, just omit any Python
format strings (with the '{}' substring). The third tuple element
will not be used, and so anything can be supplied.
template=("template.html", ["Some Page"], None)
All title shards and format parameters will be translated before
applying any format operations.
:param is_required: A boolean indicating whether the function requires auth :param is_required: A boolean indicating whether the function requires auth
:param redirect: Path to redirect to if is_required isn't True :param redirect: Path to redirect to if is_required isn't True
:param template: A template tuple: ("template.html", "Template Page") :param template: A three-element template tuple:
(path, title_iterable, variable_iterable)
:param status_code: An optional status_code for template render.
Redirects are always SEE_OTHER.
""" """
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
async def wrapper(request, *args, **kwargs): async def wrapper(request, *args, **kwargs):
if request.user.is_authenticated() != is_required: if request.user.is_authenticated() != is_required:
status_code = int(HTTPStatus.UNAUTHORIZED)
url = "/" url = "/"
if redirect: if redirect:
status_code = int(HTTPStatus.SEE_OTHER)
url = redirect url = redirect
if template: if template:
path, title = template # template=("template.html",
context = make_context(request, title) # ["Some Title", "someFormatted {}"],
# ["variable"])
# => render template.html with title:
# "Some Title someFormatted variables"
path, title_parts, variables = template
_ = l10n.get_translator_for_request(request)
# Step through title_parts; for each part which contains
# a '{}' in it, apply .format(var) where var = the current
# iteration of variables.
#
# This implies that len(variables) is equal to
# len([part for part in title_parts if '{}' in part])
# and this must always be true.
#
sanitized = []
_variables = iter(variables)
for part in title_parts:
if "{}" in part: # If this part is formattable.
key = next(_variables)
var = request.path_params.get(key)
sanitized.append(_(part.format(var)))
else: # Otherwise, just add the translated part.
sanitized.append(_(part))
# Glue all title parts together, separated by spaces.
title = " ".join(sanitized)
context = await make_variable_context(request, title)
return render_template(request, path, context, return render_template(request, path, context,
status_code=int(HTTPStatus.UNAUTHORIZED)) status_code=status_code)
return RedirectResponse(url=url, status_code=status_code) return RedirectResponse(url,
status_code=int(HTTPStatus.SEE_OTHER))
return await func(request, *args, **kwargs) return await func(request, *args, **kwargs)
return wrapper return wrapper

View file

@ -555,13 +555,21 @@ async def account_edit_post(request: Request,
return util.migrate_cookies(request, response) return util.migrate_cookies(request, response)
account_template = (
"account/show.html",
["Account", "{}"],
["username"] # Query parameters to replace in the title string.
)
@router.get("/account/{username}") @router.get("/account/{username}")
@auth_required(True, template=("account/show.html", "Accounts")) @auth_required(True, template=account_template,
status_code=HTTPStatus.UNAUTHORIZED)
async def account(request: Request, username: str): async def account(request: Request, username: str):
_ = l10n.get_translator_for_request(request)
context = await make_variable_context(request, _("Account") + username)
user = db.query(User, User.Username == username).first() user = db.query(User, User.Username == username).first()
context = await make_variable_context(request, "Accounts")
if not user: if not user:
raise HTTPException(status_code=int(HTTPStatus.NOT_FOUND)) raise HTTPException(status_code=int(HTTPStatus.NOT_FOUND))

View file

@ -915,7 +915,6 @@ def test_get_account_not_found():
def test_get_account_unauthenticated(): def test_get_account_unauthenticated():
with client as request: with client as request:
response = request.get("/account/test", allow_redirects=False) response = request.get("/account/test", allow_redirects=False)
assert response.status_code == int(HTTPStatus.UNAUTHORIZED) assert response.status_code == int(HTTPStatus.UNAUTHORIZED)
content = response.content.decode() content = response.content.decode()