feat(auth): add requires_{auth,guest} decorators

These new decorators are meant to be used without any arguments
and provide aliases to auth_required:
- `auth_required(True) -> requires_auth`
- `auth_required(False) -> requires_guest`

These decorators should be used without arguments, e.g.:

    @router.get("/")
    @requires_guest
    async def my_route(request: Request):
        return HTMLResponse()

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2022-01-02 16:14:15 -08:00
parent 3e048e9675
commit 51b60f4210
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
8 changed files with 82 additions and 58 deletions

View file

@ -2,6 +2,7 @@ import functools
from datetime import datetime from datetime import datetime
from http import HTTPStatus from http import HTTPStatus
from typing import Callable
import fastapi import fastapi
@ -129,10 +130,15 @@ class BasicAuthBackend(AuthenticationBackend):
return (AuthCredentials(["authenticated"]), user) return (AuthCredentials(["authenticated"]), user)
def auth_required(auth_goal: bool = True): def _auth_required(auth_goal: bool = True):
""" Enforce a user's authentication status, bringing them to the login page """
Enforce a user's authentication status, bringing them to the login page
or homepage if their authentication status does not match the goal. or homepage if their authentication status does not match the goal.
NOTE: This function should not need to be used in downstream code.
See `requires_auth` and `requires_guest` for decorators meant to be
used on routes (they're a bit more implicitly understandable).
:param auth_goal: Whether authentication is required or entirely disallowed :param auth_goal: Whether authentication is required or entirely disallowed
for a user to perform this request. for a user to perform this request.
:return: Return the FastAPI function this decorator wraps. :return: Return the FastAPI function this decorator wraps.
@ -167,6 +173,24 @@ def auth_required(auth_goal: bool = True):
return decorator return decorator
def requires_auth(func: Callable) -> Callable:
""" Require an authenticated session for a particular route. """
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await _auth_required(True)(func)(*args, **kwargs)
return wrapper
def requires_guest(func: Callable) -> Callable:
""" Require a guest (unauthenticated) session for a particular route. """
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await _auth_required(False)(func)(*args, **kwargs)
return wrapper
def account_type_required(one_of: set): def account_type_required(one_of: set):
""" A decorator that can be used on FastAPI routes to dictate """ A decorator that can be used on FastAPI routes to dictate
that a user belongs to one of the types defined in one_of. that a user belongs to one of the types defined in one_of.

View file

@ -10,7 +10,7 @@ from sqlalchemy import and_, or_
import aurweb.config import aurweb.config
from aurweb import cookies, db, l10n, logging, models, util from aurweb import cookies, db, l10n, logging, models, util
from aurweb.auth import account_type_required, auth_required from aurweb.auth import account_type_required, requires_auth, requires_guest
from aurweb.captcha import get_captcha_salts from aurweb.captcha import get_captcha_salts
from aurweb.exceptions import ValidationError from aurweb.exceptions import ValidationError
from aurweb.l10n import get_translator_for_request from aurweb.l10n import get_translator_for_request
@ -27,14 +27,14 @@ logger = logging.get_logger(__name__)
@router.get("/passreset", response_class=HTMLResponse) @router.get("/passreset", response_class=HTMLResponse)
@auth_required(False) @requires_guest
async def passreset(request: Request): async def passreset(request: Request):
context = await make_variable_context(request, "Password Reset") context = await make_variable_context(request, "Password Reset")
return render_template(request, "passreset.html", context) return render_template(request, "passreset.html", context)
@router.post("/passreset", response_class=HTMLResponse) @router.post("/passreset", response_class=HTMLResponse)
@auth_required(False) @requires_guest
async def passreset_post(request: Request, async def passreset_post(request: Request,
user: str = Form(...), user: str = Form(...),
resetkey: str = Form(default=None), resetkey: str = Form(default=None),
@ -224,7 +224,7 @@ def make_account_form_context(context: dict,
@router.get("/register", response_class=HTMLResponse) @router.get("/register", response_class=HTMLResponse)
@auth_required(False) @requires_guest
async def account_register(request: Request, async def account_register(request: Request,
U: str = Form(default=str()), # Username U: str = Form(default=str()), # Username
E: str = Form(default=str()), # Email E: str = Form(default=str()), # Email
@ -250,7 +250,7 @@ async def account_register(request: Request,
@router.post("/register", response_class=HTMLResponse) @router.post("/register", response_class=HTMLResponse)
@auth_required(False) @requires_guest
async def account_register_post(request: Request, async def account_register_post(request: Request,
U: str = Form(default=str()), # Username U: str = Form(default=str()), # Username
E: str = Form(default=str()), # Email E: str = Form(default=str()), # Email
@ -348,7 +348,7 @@ def cannot_edit(request: Request, user: models.User) \
@router.get("/account/{username}/edit", response_class=HTMLResponse) @router.get("/account/{username}/edit", response_class=HTMLResponse)
@auth_required() @requires_auth
async def account_edit(request: Request, username: str): async def account_edit(request: Request, username: str):
user = db.query(models.User, models.User.Username == username).first() user = db.query(models.User, models.User.Username == username).first()
@ -364,7 +364,7 @@ async def account_edit(request: Request, username: str):
@router.post("/account/{username}/edit", response_class=HTMLResponse) @router.post("/account/{username}/edit", response_class=HTMLResponse)
@auth_required() @requires_auth
async def account_edit_post(request: Request, async def account_edit_post(request: Request,
username: str, username: str,
U: str = Form(default=str()), # Username U: str = Form(default=str()), # Username
@ -461,7 +461,7 @@ async def account(request: Request, username: str):
@router.get("/account/{username}/comments") @router.get("/account/{username}/comments")
@auth_required() @requires_auth
async def account_comments(request: Request, username: str): async def account_comments(request: Request, username: str):
user = get_user_by_name(username) user = get_user_by_name(username)
context = make_context(request, "Accounts") context = make_context(request, "Accounts")
@ -472,7 +472,7 @@ async def account_comments(request: Request, username: str):
@router.get("/accounts") @router.get("/accounts")
@auth_required() @requires_auth
@account_type_required({at.TRUSTED_USER, @account_type_required({at.TRUSTED_USER,
at.DEVELOPER, at.DEVELOPER,
at.TRUSTED_USER_AND_DEV}) at.TRUSTED_USER_AND_DEV})
@ -482,7 +482,7 @@ async def accounts(request: Request):
@router.post("/accounts") @router.post("/accounts")
@auth_required() @requires_auth
@account_type_required({at.TRUSTED_USER, @account_type_required({at.TRUSTED_USER,
at.DEVELOPER, at.DEVELOPER,
at.TRUSTED_USER_AND_DEV}) at.TRUSTED_USER_AND_DEV})
@ -567,7 +567,7 @@ def render_terms_of_service(request: Request,
@router.get("/tos") @router.get("/tos")
@auth_required() @requires_auth
async def terms_of_service(request: Request): async def terms_of_service(request: Request):
# Query the database for terms that were previously accepted, # Query the database for terms that were previously accepted,
# but now have a bumped Revision that needs to be accepted. # but now have a bumped Revision that needs to be accepted.
@ -591,7 +591,7 @@ async def terms_of_service(request: Request):
@router.post("/tos") @router.post("/tos")
@auth_required() @requires_auth
async def terms_of_service_post(request: Request, async def terms_of_service_post(request: Request,
accept: bool = Form(default=False)): accept: bool = Form(default=False)):
# Query the database for terms that were previously accepted, # Query the database for terms that were previously accepted,

View file

@ -7,7 +7,7 @@ from fastapi.responses import HTMLResponse, RedirectResponse
import aurweb.config import aurweb.config
from aurweb import cookies, db from aurweb import cookies, db
from aurweb.auth import auth_required from aurweb.auth import requires_auth, requires_guest
from aurweb.l10n import get_translator_for_request from aurweb.l10n import get_translator_for_request
from aurweb.models import User from aurweb.models import User
from aurweb.templates import make_variable_context, render_template from aurweb.templates import make_variable_context, render_template
@ -29,7 +29,7 @@ async def login_get(request: Request, next: str = "/"):
@router.post("/login", response_class=HTMLResponse) @router.post("/login", response_class=HTMLResponse)
@auth_required(False) @requires_guest
async def login_post(request: Request, async def login_post(request: Request,
next: str = Form(...), next: str = Form(...),
user: str = Form(default=str()), user: str = Form(default=str()),
@ -81,7 +81,7 @@ async def login_post(request: Request,
@router.post("/logout") @router.post("/logout")
@auth_required() @requires_auth
async def logout(request: Request, next: str = Form(default="/")): async def logout(request: Request, next: str = Form(default="/")):
if request.user.is_authenticated(): if request.user.is_authenticated():
request.user.logout(request) request.user.logout(request)

View file

@ -7,7 +7,7 @@ from fastapi import APIRouter, Form, Request, Response
import aurweb.filters # noqa: F401 import aurweb.filters # noqa: F401
from aurweb import config, db, defaults, logging, models, util from aurweb import config, db, defaults, logging, models, util
from aurweb.auth import auth_required, creds from aurweb.auth import creds, requires_auth
from aurweb.exceptions import InvariantError from aurweb.exceptions import InvariantError
from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID
from aurweb.packages import util as pkgutil from aurweb.packages import util as pkgutil
@ -406,7 +406,7 @@ PACKAGE_ACTIONS = {
@router.post("/packages") @router.post("/packages")
@auth_required() @requires_auth
async def packages_post(request: Request, async def packages_post(request: Request,
IDs: List[int] = Form(default=[]), IDs: List[int] = Form(default=[]),
action: str = Form(default=str()), action: str = Form(default=str()),

View file

@ -6,7 +6,7 @@ from fastapi.responses import JSONResponse, RedirectResponse
from sqlalchemy import and_ from sqlalchemy import and_
from aurweb import config, db, l10n, logging, templates, util from aurweb import config, db, l10n, logging, templates, util
from aurweb.auth import auth_required, creds from aurweb.auth import creds, requires_auth
from aurweb.exceptions import InvariantError, ValidationError from aurweb.exceptions import InvariantError, ValidationError
from aurweb.models import PackageBase from aurweb.models import PackageBase
from aurweb.models.package_comment import PackageComment from aurweb.models.package_comment import PackageComment
@ -116,7 +116,7 @@ async def pkgbase_keywords(request: Request, name: str,
@router.get("/pkgbase/{name}/flag") @router.get("/pkgbase/{name}/flag")
@auth_required() @requires_auth
async def pkgbase_flag_get(request: Request, name: str): async def pkgbase_flag_get(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
@ -131,7 +131,7 @@ async def pkgbase_flag_get(request: Request, name: str):
@router.post("/pkgbase/{name}/flag") @router.post("/pkgbase/{name}/flag")
@auth_required() @requires_auth
async def pkgbase_flag_post(request: Request, name: str, async def pkgbase_flag_post(request: Request, name: str,
comments: str = Form(default=str())): comments: str = Form(default=str())):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
@ -157,7 +157,7 @@ async def pkgbase_flag_post(request: Request, name: str,
@router.post("/pkgbase/{name}/comments") @router.post("/pkgbase/{name}/comments")
@auth_required() @requires_auth
async def pkgbase_comments_post( async def pkgbase_comments_post(
request: Request, name: str, request: Request, name: str,
comment: str = Form(default=str()), comment: str = Form(default=str()),
@ -189,7 +189,7 @@ async def pkgbase_comments_post(
@router.get("/pkgbase/{name}/comments/{id}/form") @router.get("/pkgbase/{name}/comments/{id}/form")
@auth_required() @requires_auth
async def pkgbase_comment_form(request: Request, name: str, id: int, async def pkgbase_comment_form(request: Request, name: str, id: int,
next: str = Query(default=None)): next: str = Query(default=None)):
""" """
@ -229,7 +229,7 @@ async def pkgbase_comment_form(request: Request, name: str, id: int,
@router.get("/pkgbase/{name}/comments/{id}/edit") @router.get("/pkgbase/{name}/comments/{id}/edit")
@auth_required() @requires_auth
async def pkgbase_comment_edit(request: Request, name: str, id: int, async def pkgbase_comment_edit(request: Request, name: str, id: int,
next: str = Form(default=None)): next: str = Form(default=None)):
""" """
@ -253,7 +253,7 @@ async def pkgbase_comment_edit(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}") @router.post("/pkgbase/{name}/comments/{id}")
@auth_required() @requires_auth
async def pkgbase_comment_post( async def pkgbase_comment_post(
request: Request, name: str, id: int, request: Request, name: str, id: int,
comment: str = Form(default=str()), comment: str = Form(default=str()),
@ -293,7 +293,7 @@ async def pkgbase_comment_post(
@router.post("/pkgbase/{name}/comments/{id}/pin") @router.post("/pkgbase/{name}/comments/{id}/pin")
@auth_required() @requires_auth
async def pkgbase_comment_pin(request: Request, name: str, id: int, async def pkgbase_comment_pin(request: Request, name: str, id: int,
next: str = Form(default=None)): next: str = Form(default=None)):
""" """
@ -327,7 +327,7 @@ async def pkgbase_comment_pin(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/unpin") @router.post("/pkgbase/{name}/comments/{id}/unpin")
@auth_required() @requires_auth
async def pkgbase_comment_unpin(request: Request, name: str, id: int, async def pkgbase_comment_unpin(request: Request, name: str, id: int,
next: str = Form(default=None)): next: str = Form(default=None)):
""" """
@ -360,7 +360,7 @@ async def pkgbase_comment_unpin(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/delete") @router.post("/pkgbase/{name}/comments/{id}/delete")
@auth_required() @requires_auth
async def pkgbase_comment_delete(request: Request, name: str, id: int, async def pkgbase_comment_delete(request: Request, name: str, id: int,
next: str = Form(default=None)): next: str = Form(default=None)):
""" """
@ -399,7 +399,7 @@ async def pkgbase_comment_delete(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/undelete") @router.post("/pkgbase/{name}/comments/{id}/undelete")
@auth_required() @requires_auth
async def pkgbase_comment_undelete(request: Request, name: str, id: int, async def pkgbase_comment_undelete(request: Request, name: str, id: int,
next: str = Form(default=None)): next: str = Form(default=None)):
""" """
@ -437,7 +437,7 @@ async def pkgbase_comment_undelete(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/vote") @router.post("/pkgbase/{name}/vote")
@auth_required() @requires_auth
async def pkgbase_vote(request: Request, name: str): async def pkgbase_vote(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
@ -461,7 +461,7 @@ async def pkgbase_vote(request: Request, name: str):
@router.post("/pkgbase/{name}/unvote") @router.post("/pkgbase/{name}/unvote")
@auth_required() @requires_auth
async def pkgbase_unvote(request: Request, name: str): async def pkgbase_unvote(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
@ -481,7 +481,7 @@ async def pkgbase_unvote(request: Request, name: str):
@router.post("/pkgbase/{name}/notify") @router.post("/pkgbase/{name}/notify")
@auth_required() @requires_auth
async def pkgbase_notify(request: Request, name: str): async def pkgbase_notify(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_notify_instance(request, pkgbase) actions.pkgbase_notify_instance(request, pkgbase)
@ -490,7 +490,7 @@ async def pkgbase_notify(request: Request, name: str):
@router.post("/pkgbase/{name}/unnotify") @router.post("/pkgbase/{name}/unnotify")
@auth_required() @requires_auth
async def pkgbase_unnotify(request: Request, name: str): async def pkgbase_unnotify(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_unnotify_instance(request, pkgbase) actions.pkgbase_unnotify_instance(request, pkgbase)
@ -499,7 +499,7 @@ async def pkgbase_unnotify(request: Request, name: str):
@router.post("/pkgbase/{name}/unflag") @router.post("/pkgbase/{name}/unflag")
@auth_required() @requires_auth
async def pkgbase_unflag(request: Request, name: str): async def pkgbase_unflag(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_unflag_instance(request, pkgbase) actions.pkgbase_unflag_instance(request, pkgbase)
@ -508,7 +508,7 @@ async def pkgbase_unflag(request: Request, name: str):
@router.get("/pkgbase/{name}/disown") @router.get("/pkgbase/{name}/disown")
@auth_required() @requires_auth
async def pkgbase_disown_get(request: Request, name: str): async def pkgbase_disown_get(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
@ -524,7 +524,7 @@ async def pkgbase_disown_get(request: Request, name: str):
@router.post("/pkgbase/{name}/disown") @router.post("/pkgbase/{name}/disown")
@auth_required() @requires_auth
async def pkgbase_disown_post(request: Request, name: str, async def pkgbase_disown_post(request: Request, name: str,
comments: str = Form(default=str()), comments: str = Form(default=str()),
confirm: bool = Form(default=False)): confirm: bool = Form(default=False)):
@ -559,7 +559,7 @@ async def pkgbase_disown_post(request: Request, name: str,
@router.post("/pkgbase/{name}/adopt") @router.post("/pkgbase/{name}/adopt")
@auth_required() @requires_auth
async def pkgbase_adopt_post(request: Request, name: str): async def pkgbase_adopt_post(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
@ -575,7 +575,7 @@ async def pkgbase_adopt_post(request: Request, name: str):
@router.get("/pkgbase/{name}/comaintainers") @router.get("/pkgbase/{name}/comaintainers")
@auth_required() @requires_auth
async def pkgbase_comaintainers(request: Request, name: str) -> Response: async def pkgbase_comaintainers(request: Request, name: str) -> Response:
# Get the PackageBase. # Get the PackageBase.
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
@ -601,7 +601,7 @@ async def pkgbase_comaintainers(request: Request, name: str) -> Response:
@router.post("/pkgbase/{name}/comaintainers") @router.post("/pkgbase/{name}/comaintainers")
@auth_required() @requires_auth
async def pkgbase_comaintainers_post(request: Request, name: str, async def pkgbase_comaintainers_post(request: Request, name: str,
users: str = Form(default=str())) \ users: str = Form(default=str())) \
-> Response: -> Response:
@ -643,7 +643,7 @@ async def pkgbase_comaintainers_post(request: Request, name: str,
@router.get("/pkgbase/{name}/request") @router.get("/pkgbase/{name}/request")
@auth_required() @requires_auth
async def pkgbase_request(request: Request, name: str): async def pkgbase_request(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
context = await make_variable_context(request, "Submit Request") context = await make_variable_context(request, "Submit Request")
@ -652,7 +652,7 @@ async def pkgbase_request(request: Request, name: str):
@router.post("/pkgbase/{name}/request") @router.post("/pkgbase/{name}/request")
@auth_required() @requires_auth
async def pkgbase_request_post(request: Request, name: str, async def pkgbase_request_post(request: Request, name: str,
type: str = Form(...), type: str = Form(...),
merge_into: str = Form(default=None), merge_into: str = Form(default=None),
@ -732,7 +732,7 @@ async def pkgbase_request_post(request: Request, name: str,
@router.get("/pkgbase/{name}/delete") @router.get("/pkgbase/{name}/delete")
@auth_required() @requires_auth
async def pkgbase_delete_get(request: Request, name: str): async def pkgbase_delete_get(request: Request, name: str):
if not request.user.has_credential(creds.PKGBASE_DELETE): if not request.user.has_credential(creds.PKGBASE_DELETE):
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}",
@ -744,7 +744,7 @@ async def pkgbase_delete_get(request: Request, name: str):
@router.post("/pkgbase/{name}/delete") @router.post("/pkgbase/{name}/delete")
@auth_required() @requires_auth
async def pkgbase_delete_post(request: Request, name: str, async def pkgbase_delete_post(request: Request, name: str,
confirm: bool = Form(default=False), confirm: bool = Form(default=False),
comments: str = Form(default=str())): comments: str = Form(default=str())):
@ -779,7 +779,7 @@ async def pkgbase_delete_post(request: Request, name: str,
@router.get("/pkgbase/{name}/merge") @router.get("/pkgbase/{name}/merge")
@auth_required() @requires_auth
async def pkgbase_merge_get(request: Request, name: str, async def pkgbase_merge_get(request: Request, name: str,
into: str = Query(default=str()), into: str = Query(default=str()),
next: str = Query(default=str())): next: str = Query(default=str())):
@ -810,7 +810,7 @@ async def pkgbase_merge_get(request: Request, name: str,
@router.post("/pkgbase/{name}/merge") @router.post("/pkgbase/{name}/merge")
@auth_required() @requires_auth
async def pkgbase_merge_post(request: Request, name: str, async def pkgbase_merge_post(request: Request, name: str,
into: str = Form(default=str()), into: str = Form(default=str()),
comments: str = Form(default=str()), comments: str = Form(default=str()),

View file

@ -6,7 +6,7 @@ from fastapi.responses import RedirectResponse
from sqlalchemy import case from sqlalchemy import case
from aurweb import db, defaults, util from aurweb import db, defaults, util
from aurweb.auth import auth_required, creds from aurweb.auth import creds, requires_auth
from aurweb.models import PackageRequest, User from aurweb.models import PackageRequest, User
from aurweb.models.package_request import PENDING_ID, REJECTED_ID from aurweb.models.package_request import PENDING_ID, REJECTED_ID
from aurweb.requests.util import get_pkgreq_by_id from aurweb.requests.util import get_pkgreq_by_id
@ -17,7 +17,7 @@ router = APIRouter()
@router.get("/requests") @router.get("/requests")
@auth_required() @requires_auth
async def requests(request: Request, async def requests(request: Request,
O: int = Query(default=defaults.O), O: int = Query(default=defaults.O),
PP: int = Query(default=defaults.PP)): PP: int = Query(default=defaults.PP)):
@ -50,7 +50,7 @@ async def requests(request: Request,
@router.get("/requests/{id}/close") @router.get("/requests/{id}/close")
@auth_required() @requires_auth
async def request_close(request: Request, id: int): async def request_close(request: Request, id: int):
pkgreq = get_pkgreq_by_id(id) pkgreq = get_pkgreq_by_id(id)
@ -64,7 +64,7 @@ async def request_close(request: Request, id: int):
@router.post("/requests/{id}/close") @router.post("/requests/{id}/close")
@auth_required() @requires_auth
async def request_close_post(request: Request, id: int, async def request_close_post(request: Request, id: int,
comments: str = Form(default=str())): comments: str = Form(default=str())):
pkgreq = get_pkgreq_by_id(id) pkgreq = get_pkgreq_by_id(id)

View file

@ -10,7 +10,7 @@ from fastapi.responses import RedirectResponse, Response
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
from aurweb import db, l10n, logging, models from aurweb import db, l10n, logging, models
from aurweb.auth import account_type_required, auth_required from aurweb.auth import account_type_required, requires_auth
from aurweb.models.account_type import DEVELOPER, TRUSTED_USER, TRUSTED_USER_AND_DEV from aurweb.models.account_type import DEVELOPER, TRUSTED_USER, TRUSTED_USER_AND_DEV
from aurweb.templates import make_context, make_variable_context, render_template from aurweb.templates import make_context, make_variable_context, render_template
@ -41,7 +41,7 @@ ADDVOTE_SPECIFICS = {
@router.get("/tu") @router.get("/tu")
@auth_required() @requires_auth
@account_type_required(REQUIRED_TYPES) @account_type_required(REQUIRED_TYPES)
async def trusted_user(request: Request, async def trusted_user(request: Request,
coff: int = 0, # current offset coff: int = 0, # current offset
@ -147,7 +147,7 @@ def render_proposal(request: Request,
@router.get("/tu/{proposal}") @router.get("/tu/{proposal}")
@auth_required() @requires_auth
@account_type_required(REQUIRED_TYPES) @account_type_required(REQUIRED_TYPES)
async def trusted_user_proposal(request: Request, proposal: int): async def trusted_user_proposal(request: Request, proposal: int):
context = await make_variable_context(request, "Trusted User") context = await make_variable_context(request, "Trusted User")
@ -176,7 +176,7 @@ async def trusted_user_proposal(request: Request, proposal: int):
@router.post("/tu/{proposal}") @router.post("/tu/{proposal}")
@auth_required() @requires_auth
@account_type_required(REQUIRED_TYPES) @account_type_required(REQUIRED_TYPES)
async def trusted_user_proposal_post(request: Request, async def trusted_user_proposal_post(request: Request,
proposal: int, proposal: int,
@ -227,7 +227,7 @@ async def trusted_user_proposal_post(request: Request,
@router.get("/addvote") @router.get("/addvote")
@auth_required() @requires_auth
@account_type_required({TRUSTED_USER, TRUSTED_USER_AND_DEV}) @account_type_required({TRUSTED_USER, TRUSTED_USER_AND_DEV})
async def trusted_user_addvote(request: Request, async def trusted_user_addvote(request: Request,
user: str = str(), user: str = str(),
@ -247,7 +247,7 @@ async def trusted_user_addvote(request: Request,
@router.post("/addvote") @router.post("/addvote")
@auth_required() @requires_auth
@account_type_required({TRUSTED_USER, TRUSTED_USER_AND_DEV}) @account_type_required({TRUSTED_USER, TRUSTED_USER_AND_DEV})
async def trusted_user_addvote_post(request: Request, async def trusted_user_addvote_post(request: Request,
user: str = Form(default=str()), user: str = Form(default=str()),

View file

@ -7,7 +7,7 @@ from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb import config, db from aurweb import config, db
from aurweb.auth import AnonymousUser, BasicAuthBackend, account_type_required, auth_required from aurweb.auth import AnonymousUser, BasicAuthBackend, _auth_required, account_type_required
from aurweb.models.account_type import USER, USER_ID from aurweb.models.account_type import USER, USER_ID
from aurweb.models.session import Session from aurweb.models.session import Session
from aurweb.models.user import User from aurweb.models.user import User
@ -105,7 +105,7 @@ async def test_auth_required_redirection_bad_referrer():
pass pass
# Get down to the nitty gritty internal wrapper. # Get down to the nitty gritty internal wrapper.
bad_referrer_route = auth_required()(bad_referrer_route) bad_referrer_route = _auth_required()(bad_referrer_route)
# Execute the route with a "./blahblahblah" Referer, which does not # Execute the route with a "./blahblahblah" Referer, which does not
# match aur_location; `./` has been used as a prefix to attempt to # match aur_location; `./` has been used as a prefix to attempt to