mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
auth_required: allow formattable template tuples
See docstring for updates. template= has been modified. status_code= has been added as an optional template status_code. Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
parent
d7941e6bed
commit
af76e660d0
3 changed files with 70 additions and 14 deletions
|
@ -10,9 +10,10 @@ from starlette.requests import HTTPConnection
|
||||||
|
|
||||||
import aurweb.config
|
import aurweb.config
|
||||||
|
|
||||||
|
from aurweb import l10n
|
||||||
from aurweb.models.session import Session
|
from aurweb.models.session import Session
|
||||||
from aurweb.models.user import User
|
from aurweb.models.user import User
|
||||||
from aurweb.templates import make_context, render_template
|
from aurweb.templates import make_variable_context, render_template
|
||||||
|
|
||||||
|
|
||||||
class AnonymousUser:
|
class AnonymousUser:
|
||||||
|
@ -60,7 +61,8 @@ class BasicAuthBackend(AuthenticationBackend):
|
||||||
|
|
||||||
def auth_required(is_required: bool = True,
|
def auth_required(is_required: bool = True,
|
||||||
redirect: str = "/",
|
redirect: str = "/",
|
||||||
template: tuple = None):
|
template: tuple = None,
|
||||||
|
status_code: HTTPStatus = HTTPStatus.UNAUTHORIZED):
|
||||||
""" Authentication route decorator.
|
""" Authentication route decorator.
|
||||||
|
|
||||||
If redirect is given, the user will be redirected if the auth state
|
If redirect is given, the user will be redirected if the auth state
|
||||||
|
@ -69,26 +71,73 @@ def auth_required(is_required: bool = True,
|
||||||
If template is given, it will be rendered with Unauthorized if
|
If template is given, it will be rendered with Unauthorized if
|
||||||
is_required does not match and take priority over redirect.
|
is_required does not match and take priority over redirect.
|
||||||
|
|
||||||
|
A precondition of this function is that, if template is provided,
|
||||||
|
it **must** match the following format:
|
||||||
|
|
||||||
|
template=("template.html", ["Some Template For", "{}"], ["username"])
|
||||||
|
|
||||||
|
Where `username` is a FastAPI request path parameter, fitting
|
||||||
|
a route like: `/some_route/{username}`.
|
||||||
|
|
||||||
|
If you wish to supply a non-formatted template, just omit any Python
|
||||||
|
format strings (with the '{}' substring). The third tuple element
|
||||||
|
will not be used, and so anything can be supplied.
|
||||||
|
|
||||||
|
template=("template.html", ["Some Page"], None)
|
||||||
|
|
||||||
|
All title shards and format parameters will be translated before
|
||||||
|
applying any format operations.
|
||||||
|
|
||||||
:param is_required: A boolean indicating whether the function requires auth
|
:param is_required: A boolean indicating whether the function requires auth
|
||||||
:param redirect: Path to redirect to if is_required isn't True
|
:param redirect: Path to redirect to if is_required isn't True
|
||||||
:param template: A template tuple: ("template.html", "Template Page")
|
:param template: A three-element template tuple:
|
||||||
|
(path, title_iterable, variable_iterable)
|
||||||
|
:param status_code: An optional status_code for template render.
|
||||||
|
Redirects are always SEE_OTHER.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(request, *args, **kwargs):
|
async def wrapper(request, *args, **kwargs):
|
||||||
if request.user.is_authenticated() != is_required:
|
if request.user.is_authenticated() != is_required:
|
||||||
status_code = int(HTTPStatus.UNAUTHORIZED)
|
|
||||||
url = "/"
|
url = "/"
|
||||||
if redirect:
|
if redirect:
|
||||||
status_code = int(HTTPStatus.SEE_OTHER)
|
|
||||||
url = redirect
|
url = redirect
|
||||||
if template:
|
if template:
|
||||||
path, title = template
|
# template=("template.html",
|
||||||
context = make_context(request, title)
|
# ["Some Title", "someFormatted {}"],
|
||||||
|
# ["variable"])
|
||||||
|
# => render template.html with title:
|
||||||
|
# "Some Title someFormatted variables"
|
||||||
|
path, title_parts, variables = template
|
||||||
|
_ = l10n.get_translator_for_request(request)
|
||||||
|
|
||||||
|
# Step through title_parts; for each part which contains
|
||||||
|
# a '{}' in it, apply .format(var) where var = the current
|
||||||
|
# iteration of variables.
|
||||||
|
#
|
||||||
|
# This implies that len(variables) is equal to
|
||||||
|
# len([part for part in title_parts if '{}' in part])
|
||||||
|
# and this must always be true.
|
||||||
|
#
|
||||||
|
sanitized = []
|
||||||
|
_variables = iter(variables)
|
||||||
|
for part in title_parts:
|
||||||
|
if "{}" in part: # If this part is formattable.
|
||||||
|
key = next(_variables)
|
||||||
|
var = request.path_params.get(key)
|
||||||
|
sanitized.append(_(part.format(var)))
|
||||||
|
else: # Otherwise, just add the translated part.
|
||||||
|
sanitized.append(_(part))
|
||||||
|
|
||||||
|
# Glue all title parts together, separated by spaces.
|
||||||
|
title = " ".join(sanitized)
|
||||||
|
|
||||||
|
context = await make_variable_context(request, title)
|
||||||
return render_template(request, path, context,
|
return render_template(request, path, context,
|
||||||
status_code=int(HTTPStatus.UNAUTHORIZED))
|
status_code=status_code)
|
||||||
return RedirectResponse(url=url, status_code=status_code)
|
return RedirectResponse(url,
|
||||||
|
status_code=int(HTTPStatus.SEE_OTHER))
|
||||||
return await func(request, *args, **kwargs)
|
return await func(request, *args, **kwargs)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
|
@ -555,13 +555,21 @@ async def account_edit_post(request: Request,
|
||||||
return util.migrate_cookies(request, response)
|
return util.migrate_cookies(request, response)
|
||||||
|
|
||||||
|
|
||||||
|
account_template = (
|
||||||
|
"account/show.html",
|
||||||
|
["Account", "{}"],
|
||||||
|
["username"] # Query parameters to replace in the title string.
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/account/{username}")
|
@router.get("/account/{username}")
|
||||||
@auth_required(True, template=("account/show.html", "Accounts"))
|
@auth_required(True, template=account_template,
|
||||||
|
status_code=HTTPStatus.UNAUTHORIZED)
|
||||||
async def account(request: Request, username: str):
|
async def account(request: Request, username: str):
|
||||||
|
_ = l10n.get_translator_for_request(request)
|
||||||
|
context = await make_variable_context(request, _("Account") + username)
|
||||||
|
|
||||||
user = db.query(User, User.Username == username).first()
|
user = db.query(User, User.Username == username).first()
|
||||||
|
|
||||||
context = await make_variable_context(request, "Accounts")
|
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=int(HTTPStatus.NOT_FOUND))
|
raise HTTPException(status_code=int(HTTPStatus.NOT_FOUND))
|
||||||
|
|
||||||
|
|
|
@ -915,7 +915,6 @@ def test_get_account_not_found():
|
||||||
def test_get_account_unauthenticated():
|
def test_get_account_unauthenticated():
|
||||||
with client as request:
|
with client as request:
|
||||||
response = request.get("/account/test", allow_redirects=False)
|
response = request.get("/account/test", allow_redirects=False)
|
||||||
|
|
||||||
assert response.status_code == int(HTTPStatus.UNAUTHORIZED)
|
assert response.status_code == int(HTTPStatus.UNAUTHORIZED)
|
||||||
|
|
||||||
content = response.content.decode()
|
content = response.content.decode()
|
||||||
|
|
Loading…
Add table
Reference in a new issue