diff --git a/aurweb/exceptions.py b/aurweb/exceptions.py index 82628b0a..31212676 100644 --- a/aurweb/exceptions.py +++ b/aurweb/exceptions.py @@ -1,3 +1,6 @@ +from typing import Any + + class AurwebException(Exception): pass @@ -77,3 +80,9 @@ class InvalidArgumentsException(AurwebException): class RPCError(AurwebException): pass + + +class ValidationError(AurwebException): + def __init__(self, data: Any, *args, **kwargs): + super().__init__(*args, **kwargs) + self.data = data diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py index aca322b5..47483acc 100644 --- a/aurweb/routers/accounts.py +++ b/aurweb/routers/accounts.py @@ -6,20 +6,20 @@ from http import HTTPStatus from fastapi import APIRouter, Form, Request from fastapi.responses import HTMLResponse, RedirectResponse -from sqlalchemy import and_, func, or_ +from sqlalchemy import and_, or_ import aurweb.config -from aurweb import cookies, db, l10n, logging, models, time, util +from aurweb import cookies, db, l10n, logging, models, util from aurweb.auth import account_type_required, auth_required -from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token +from aurweb.captcha import get_captcha_salts +from aurweb.exceptions import ValidationError from aurweb.l10n import get_translator_for_request -from aurweb.models import account_type -from aurweb.models.account_type import (DEVELOPER, DEVELOPER_ID, TRUSTED_USER, TRUSTED_USER_AND_DEV, TRUSTED_USER_AND_DEV_ID, - TRUSTED_USER_ID, USER_ID) +from aurweb.models import account_type as at from aurweb.models.ssh_pub_key import get_fingerprint from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification from aurweb.templates import make_context, make_variable_context, render_template +from aurweb.users import validate from aurweb.users.util import get_user_by_name router = APIRouter() @@ -126,146 +126,31 @@ def process_account_form(request: Request, user: models.User, args: dict): # Get a local translator. _ = get_translator_for_request(request) - host = request.client.host - ban = db.query(models.Ban, models.Ban.IPAddress == host).first() - if ban: - return (False, [ - "Account registration has been disabled for your " - "IP address, probably due to sustained spam attacks. " - "Sorry for the inconvenience." - ]) + checks = [ + validate.is_banned, + validate.invalid_user_password, + validate.invalid_fields, + validate.invalid_suspend_permission, + validate.invalid_username, + validate.invalid_password, + validate.invalid_email, + validate.invalid_backup_email, + validate.invalid_homepage, + validate.invalid_pgp_key, + validate.invalid_ssh_pubkey, + validate.invalid_language, + validate.invalid_timezone, + validate.username_in_use, + validate.email_in_use, + validate.invalid_account_type, + validate.invalid_captcha + ] - if request.user.is_authenticated(): - if not request.user.valid_password(args.get("passwd", None)): - return (False, ["Invalid password."]) - - email = args.get("E", None) - username = args.get("U", None) - - if not email or not username: - return (False, ["Missing a required field."]) - - inactive = args.get("J", False) - if not request.user.is_elevated() and inactive != bool(user.InactivityTS): - return (False, ["You do not have permission to suspend accounts."]) - - username_min_len = aurweb.config.getint("options", "username_min_len") - username_max_len = aurweb.config.getint("options", "username_max_len") - if not util.valid_username(args.get("U")): - return (False, [ - "The username is invalid.", - [ - _("It must be between %s and %s characters long") % ( - username_min_len, username_max_len), - "Start and end with a letter or number", - "Can contain only one period, underscore or hyphen.", - ] - ]) - - password = args.get("P", None) - if password: - confirmation = args.get("C", None) - if not util.valid_password(password): - return (False, [ - _("Your password must be at least %s characters.") % ( - username_min_len) - ]) - elif not confirmation: - return (False, ["Please confirm your new password."]) - elif password != confirmation: - return (False, ["Password fields do not match."]) - - backup_email = args.get("BE", None) - homepage = args.get("HP", None) - pgp_key = args.get("K", None) - ssh_pubkey = args.get("PK", None) - language = args.get("L", None) - timezone = args.get("TZ", None) - - def username_exists(username): - return and_(models.User.ID != user.ID, - func.lower(models.User.Username) == username.lower()) - - def email_exists(email): - return and_(models.User.ID != user.ID, - func.lower(models.User.Email) == email.lower()) - - if not util.valid_email(email): - return (False, ["The email address is invalid."]) - elif backup_email and not util.valid_email(backup_email): - return (False, ["The backup email address is invalid."]) - elif homepage and not util.valid_homepage(homepage): - return (False, [ - "The home page is invalid, please specify the full HTTP(s) URL."]) - elif pgp_key and not util.valid_pgp_fingerprint(pgp_key): - return (False, ["The PGP key fingerprint is invalid."]) - elif ssh_pubkey and not util.valid_ssh_pubkey(ssh_pubkey): - return (False, ["The SSH public key is invalid."]) - elif language and language not in l10n.SUPPORTED_LANGUAGES: - return (False, ["Language is not currently supported."]) - elif timezone and timezone not in time.SUPPORTED_TIMEZONES: - return (False, ["Timezone is not currently supported."]) - elif db.query(models.User, username_exists(username)).first(): - # If the username already exists... - return (False, [ - _("The username, %s%s%s, is already in use.") % ( - "", username, "") - ]) - elif db.query(models.User, email_exists(email)).first(): - # If the email already exists... - return (False, [ - _("The address, %s%s%s, is already in use.") % ( - "", email, "") - ]) - - def ssh_fingerprint_exists(fingerprint): - return and_(models.SSHPubKey.UserID != user.ID, - models.SSHPubKey.Fingerprint == fingerprint) - - if ssh_pubkey: - fingerprint = get_fingerprint(ssh_pubkey.strip().rstrip()) - if fingerprint is None: - return (False, ["The SSH public key is invalid."]) - - if db.query(models.SSHPubKey, - ssh_fingerprint_exists(fingerprint)).first(): - return (False, [ - _("The SSH public key, %s%s%s, is already in use.") % ( - "", fingerprint, "") - ]) - - T = int(args.get("T", user.AccountTypeID)) - if T != user.AccountTypeID: - if T not in account_type.ACCOUNT_TYPE_NAME: - return (False, - ["Invalid account type provided."]) - elif not request.user.is_elevated(): - return (False, - ["You do not have permission to change account types."]) - - credential_checks = { - DEVELOPER_ID: request.user.is_developer, - TRUSTED_USER_AND_DEV_ID: request.user.is_developer, - TRUSTED_USER_ID: request.user.is_elevated, - USER_ID: request.user.is_elevated - } - credential_check = credential_checks.get(T) - - if not credential_check(): - name = account_type.ACCOUNT_TYPE_NAME.get(T) - error = _("You do not have permission to change " - "this user's account type to %s.") % name - return (False, [error]) - - captcha_salt = args.get("captcha_salt", None) - if captcha_salt and captcha_salt not in get_captcha_salts(): - return (False, ["This CAPTCHA has expired. Please try again."]) - - captcha = args.get("captcha", None) - if captcha: - answer = get_captcha_answer(get_captcha_token(captcha_salt)) - if captcha != answer: - return (False, ["The entered CAPTCHA answer is invalid."]) + try: + for check in checks: + check(**args, request=request, user=user, _=_) + except ValidationError as exc: + return (False, exc.data) return (True, []) @@ -286,16 +171,16 @@ def make_account_form_context(context: dict, context = copy.copy(context) context["account_types"] = [ - (USER_ID, "Normal User"), - (TRUSTED_USER_ID, TRUSTED_USER) + (at.USER_ID, "Normal User"), + (at.TRUSTED_USER_ID, at.TRUSTED_USER) ] user_account_type_id = context.get("account_types")[0][0] if request.user.has_credential("CRED_ACCOUNT_EDIT_DEV"): - context["account_types"].append((DEVELOPER_ID, DEVELOPER)) - context["account_types"].append((TRUSTED_USER_AND_DEV_ID, - TRUSTED_USER_AND_DEV)) + context["account_types"].append((at.DEVELOPER_ID, at.DEVELOPER)) + context["account_types"].append((at.TRUSTED_USER_AND_DEV_ID, + at.TRUSTED_USER_AND_DEV)) if request.user.is_authenticated(): context["username"] = args.get("U", user.Username) @@ -389,12 +274,10 @@ async def account_register_post(request: Request, captcha: str = Form(default=None), captcha_salt: str = Form(...)): context = await make_variable_context(request, "Register") - args = dict(await request.form()) + context = make_account_form_context(context, request, None, args) - ok, errors = process_account_form(request, request.user, args) - if not ok: # If the field values given do not meet the requirements, # return HTTP 400 with an error. @@ -636,9 +519,9 @@ async def account_comments(request: Request, username: str): @router.get("/accounts") @auth_required(True, redirect="/accounts") -@account_type_required({account_type.TRUSTED_USER, - account_type.DEVELOPER, - account_type.TRUSTED_USER_AND_DEV}) +@account_type_required({at.TRUSTED_USER, + at.DEVELOPER, + at.TRUSTED_USER_AND_DEV}) async def accounts(request: Request): context = make_context(request, "Accounts") return render_template(request, "account/search.html", context) @@ -646,9 +529,9 @@ async def accounts(request: Request): @router.post("/accounts") @auth_required(True, redirect="/accounts") -@account_type_required({account_type.TRUSTED_USER, - account_type.DEVELOPER, - account_type.TRUSTED_USER_AND_DEV}) +@account_type_required({at.TRUSTED_USER, + at.DEVELOPER, + at.TRUSTED_USER_AND_DEV}) async def accounts_post(request: Request, O: int = Form(default=0), # Offset SB: str = Form(default=str()), # Sort By @@ -680,10 +563,10 @@ async def accounts_post(request: Request, # Convert parameter T to an AccountType ID. account_types = { - "u": account_type.USER_ID, - "t": account_type.TRUSTED_USER_ID, - "d": account_type.DEVELOPER_ID, - "td": account_type.TRUSTED_USER_AND_DEV_ID + "u": at.USER_ID, + "t": at.TRUSTED_USER_ID, + "d": at.DEVELOPER_ID, + "td": at.TRUSTED_USER_AND_DEV_ID } account_type_id = account_types.get(T, None) diff --git a/aurweb/users/validate.py b/aurweb/users/validate.py new file mode 100644 index 00000000..4959e316 --- /dev/null +++ b/aurweb/users/validate.py @@ -0,0 +1,204 @@ +""" +Validation functions for account registration and edit fields. +Each of these functions extracts a subset of keyword arguments +out of form data from /account/register or /account/{username}/edit. + +All functions in this module raise aurweb.exceptions.ValidationError +when encountering invalid criteria and return silently otherwise. +""" +from typing import List, Optional, Tuple + +from fastapi import Request +from sqlalchemy import and_ + +from aurweb import config, db, l10n, models, time, util +from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token +from aurweb.exceptions import ValidationError +from aurweb.models import account_type as at +from aurweb.models.account_type import ACCOUNT_TYPE_NAME +from aurweb.models.ssh_pub_key import get_fingerprint + + +def invalid_fields(E: str = str(), U: str = str(), **kwargs) \ + -> Optional[Tuple[bool, List[str]]]: + if not E or not U: + raise ValidationError(["Missing a required field."]) + + +def invalid_suspend_permission(request: Request = None, + user: models.User = None, + J: bool = False, + **kwargs) \ + -> Optional[Tuple[bool, List[str]]]: + if not request.user.is_elevated() and J != bool(user.InactivityTS): + raise ValidationError([ + "You do not have permission to suspend accounts."]) + + +def invalid_username(request: Request = None, U: str = str(), _=None, + **kwargs): + if not util.valid_username(U): + username_min_len = config.getint("options", "username_min_len") + username_max_len = config.getint("options", "username_max_len") + raise ValidationError([ + "The username is invalid.", + [ + _("It must be between %s and %s characters long") % ( + username_min_len, username_max_len), + "Start and end with a letter or number", + "Can contain only one period, underscore or hyphen.", + ] + ]) + + +def invalid_password(P: str = str(), C: str = str(), + _: l10n.Translator = None, **kwargs) -> None: + if P: + if not util.valid_password(P): + username_min_len = config.getint( + "options", "username_min_len") + raise ValidationError([ + _("Your password must be at least %s characters.") % ( + username_min_len) + ]) + elif not C: + raise ValidationError(["Please confirm your new password."]) + elif P != C: + raise ValidationError(["Password fields do not match."]) + + +def is_banned(request: Request = None, **kwargs) -> None: + host = request.client.host + exists = db.query(models.Ban, models.Ban.IPAddress == host).exists() + if db.query(exists).scalar(): + raise ValidationError([ + "Account registration has been disabled for your " + "IP address, probably due to sustained spam attacks. " + "Sorry for the inconvenience." + ]) + + +def invalid_user_password(request: Request = None, passwd: str = str(), + **kwargs) -> None: + if request.user.is_authenticated(): + if not request.user.valid_password(passwd): + raise ValidationError(["Invalid password."]) + + +def invalid_email(E: str = str(), **kwargs) -> None: + if not util.valid_email(E): + raise ValidationError(["The email address is invalid."]) + + +def invalid_backup_email(BE: str = str(), **kwargs) -> None: + if BE and not util.valid_email(BE): + raise ValidationError(["The backup email address is invalid."]) + + +def invalid_homepage(HP: str = str(), **kwargs) -> None: + if HP and not util.valid_homepage(HP): + raise ValidationError([ + "The home page is invalid, please specify the full HTTP(s) URL."]) + + +def invalid_pgp_key(K: str = str(), **kwargs) -> None: + if K and not util.valid_pgp_fingerprint(K): + raise ValidationError(["The PGP key fingerprint is invalid."]) + + +def invalid_ssh_pubkey(PK: str = str(), user: models.User = None, + _: l10n.Translator = None, **kwargs) -> None: + if PK: + invalid_exc = ValidationError(["The SSH public key is invalid."]) + if not util.valid_ssh_pubkey(PK): + raise invalid_exc + + fingerprint = get_fingerprint(PK.strip().rstrip()) + if not fingerprint: + raise invalid_exc + + exists = db.query(models.SSHPubKey).filter( + and_(models.SSHPubKey.UserID != user.ID, + models.SSHPubKey.Fingerprint == fingerprint) + ).exists() + if db.query(exists).scalar(): + raise ValidationError([ + _("The SSH public key, %s%s%s, is already in use.") % ( + "", fingerprint, "") + ]) + + +def invalid_language(L: str = str(), **kwargs) -> None: + if L and L not in l10n.SUPPORTED_LANGUAGES: + raise ValidationError(["Language is not currently supported."]) + + +def invalid_timezone(TZ: str = str(), **kwargs) -> None: + if TZ and TZ not in time.SUPPORTED_TIMEZONES: + raise ValidationError(["Timezone is not currently supported."]) + + +def username_in_use(U: str = str(), user: models.User = None, + _: l10n.Translator = None, **kwargs) -> None: + exists = db.query(models.User).filter( + and_(models.User.ID != user.ID, + models.User.Username == U) + ).exists() + if db.query(exists).scalar(): + # If the username already exists... + raise ValidationError([ + _("The username, %s%s%s, is already in use.") % ( + "", U, "") + ]) + + +def email_in_use(E: str = str(), user: models.User = None, + _: l10n.Translator = None, **kwargs) -> None: + exists = db.query(models.User).filter( + and_(models.User.ID != user.ID, + models.User.Email == E) + ).exists() + if db.query(exists).scalar(): + # If the email already exists... + raise ValidationError([ + _("The address, %s%s%s, is already in use.") % ( + "", E, "") + ]) + + +def invalid_account_type(T: int = None, request: Request = None, + user: models.User = None, + _: l10n.Translator = None, + **kwargs) -> None: + if T is not None and (T := int(T)) != user.AccountTypeID: + if T not in ACCOUNT_TYPE_NAME: + raise ValidationError(["Invalid account type provided."]) + elif not request.user.is_elevated(): + raise ValidationError([ + "You do not have permission to change account types."]) + + credential_checks = { + at.USER_ID: request.user.is_trusted_user, + at.TRUSTED_USER_ID: request.user.is_trusted_user, + at.DEVELOPER_ID: lambda: request.user.is_developer(), + at.TRUSTED_USER_AND_DEV_ID: (lambda: request.user.is_trusted_user() + and request.user.is_developer()) + } + credential_check = credential_checks.get(T) + + if not credential_check(): + name = ACCOUNT_TYPE_NAME.get(T) + error = _("You do not have permission to change " + "this user's account type to %s.") % name + raise ValidationError([error]) + + +def invalid_captcha(captcha_salt: str = None, captcha: str = None, **kwargs) \ + -> None: + if captcha_salt and captcha_salt not in get_captcha_salts(): + raise ValidationError(["This CAPTCHA has expired. Please try again."]) + + if captcha: + answer = get_captcha_answer(get_captcha_token(captcha_salt)) + if captcha != answer: + raise ValidationError(["The entered CAPTCHA answer is invalid."]) diff --git a/test/test_accounts_routes.py b/test/test_accounts_routes.py index e828f70f..be929e97 100644 --- a/test/test_accounts_routes.py +++ b/test/test_accounts_routes.py @@ -1035,12 +1035,25 @@ def test_post_account_edit_account_types(): # Make sure it got changed to USER_ID as we intended. assert user.AccountTypeID == USER_ID - # Change user to a Developer. + # Change user to a TU & Dev, which can change themselves to a Developer. with db.begin(): - user.AccountTypeID = DEVELOPER_ID + user.AccountTypeID = TRUSTED_USER_AND_DEV_ID - # As a developer, we can absolutely change all account types. - # For example, from DEVELOPER_ID to TRUSTED_USER_AND_DEV_ID: + # As a TU & Dev, we can absolutely change all account types. + # For example, from TRUSTED_USER_AND_DEV_ID to DEVELOPER_ID: + post_data = { + "U": user.Username, + "E": user.Email, + "T": DEVELOPER_ID, + "passwd": "testPassword" + } + with client as request: + resp = request.post(endpoint, data=post_data, cookies=cookies) + assert resp.status_code == int(HTTPStatus.OK) + assert user.AccountTypeID == DEVELOPER_ID + + # But we can't change a user to a Trusted User & Developer when + # we're just a Developer. post_data = { "U": user.Username, "E": user.Email, @@ -1049,8 +1062,8 @@ def test_post_account_edit_account_types(): } with client as request: resp = request.post(endpoint, data=post_data, cookies=cookies) - assert resp.status_code == int(HTTPStatus.OK) - assert user.AccountTypeID == TRUSTED_USER_AND_DEV_ID + assert resp.status_code == int(HTTPStatus.BAD_REQUEST) + assert user.AccountTypeID == DEVELOPER_ID def test_get_account():