diff --git a/aurweb/exceptions.py b/aurweb/exceptions.py
index 82628b0a..31212676 100644
--- a/aurweb/exceptions.py
+++ b/aurweb/exceptions.py
@@ -1,3 +1,6 @@
+from typing import Any
+
+
class AurwebException(Exception):
pass
@@ -77,3 +80,9 @@ class InvalidArgumentsException(AurwebException):
class RPCError(AurwebException):
pass
+
+
+class ValidationError(AurwebException):
+ def __init__(self, data: Any, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.data = data
diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py
index aca322b5..47483acc 100644
--- a/aurweb/routers/accounts.py
+++ b/aurweb/routers/accounts.py
@@ -6,20 +6,20 @@ from http import HTTPStatus
from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse, RedirectResponse
-from sqlalchemy import and_, func, or_
+from sqlalchemy import and_, or_
import aurweb.config
-from aurweb import cookies, db, l10n, logging, models, time, util
+from aurweb import cookies, db, l10n, logging, models, util
from aurweb.auth import account_type_required, auth_required
-from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token
+from aurweb.captcha import get_captcha_salts
+from aurweb.exceptions import ValidationError
from aurweb.l10n import get_translator_for_request
-from aurweb.models import account_type
-from aurweb.models.account_type import (DEVELOPER, DEVELOPER_ID, TRUSTED_USER, TRUSTED_USER_AND_DEV, TRUSTED_USER_AND_DEV_ID,
- TRUSTED_USER_ID, USER_ID)
+from aurweb.models import account_type as at
from aurweb.models.ssh_pub_key import get_fingerprint
from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification
from aurweb.templates import make_context, make_variable_context, render_template
+from aurweb.users import validate
from aurweb.users.util import get_user_by_name
router = APIRouter()
@@ -126,146 +126,31 @@ def process_account_form(request: Request, user: models.User, args: dict):
# Get a local translator.
_ = get_translator_for_request(request)
- host = request.client.host
- ban = db.query(models.Ban, models.Ban.IPAddress == host).first()
- if ban:
- return (False, [
- "Account registration has been disabled for your "
- "IP address, probably due to sustained spam attacks. "
- "Sorry for the inconvenience."
- ])
+ checks = [
+ validate.is_banned,
+ validate.invalid_user_password,
+ validate.invalid_fields,
+ validate.invalid_suspend_permission,
+ validate.invalid_username,
+ validate.invalid_password,
+ validate.invalid_email,
+ validate.invalid_backup_email,
+ validate.invalid_homepage,
+ validate.invalid_pgp_key,
+ validate.invalid_ssh_pubkey,
+ validate.invalid_language,
+ validate.invalid_timezone,
+ validate.username_in_use,
+ validate.email_in_use,
+ validate.invalid_account_type,
+ validate.invalid_captcha
+ ]
- if request.user.is_authenticated():
- if not request.user.valid_password(args.get("passwd", None)):
- return (False, ["Invalid password."])
-
- email = args.get("E", None)
- username = args.get("U", None)
-
- if not email or not username:
- return (False, ["Missing a required field."])
-
- inactive = args.get("J", False)
- if not request.user.is_elevated() and inactive != bool(user.InactivityTS):
- return (False, ["You do not have permission to suspend accounts."])
-
- username_min_len = aurweb.config.getint("options", "username_min_len")
- username_max_len = aurweb.config.getint("options", "username_max_len")
- if not util.valid_username(args.get("U")):
- return (False, [
- "The username is invalid.",
- [
- _("It must be between %s and %s characters long") % (
- username_min_len, username_max_len),
- "Start and end with a letter or number",
- "Can contain only one period, underscore or hyphen.",
- ]
- ])
-
- password = args.get("P", None)
- if password:
- confirmation = args.get("C", None)
- if not util.valid_password(password):
- return (False, [
- _("Your password must be at least %s characters.") % (
- username_min_len)
- ])
- elif not confirmation:
- return (False, ["Please confirm your new password."])
- elif password != confirmation:
- return (False, ["Password fields do not match."])
-
- backup_email = args.get("BE", None)
- homepage = args.get("HP", None)
- pgp_key = args.get("K", None)
- ssh_pubkey = args.get("PK", None)
- language = args.get("L", None)
- timezone = args.get("TZ", None)
-
- def username_exists(username):
- return and_(models.User.ID != user.ID,
- func.lower(models.User.Username) == username.lower())
-
- def email_exists(email):
- return and_(models.User.ID != user.ID,
- func.lower(models.User.Email) == email.lower())
-
- if not util.valid_email(email):
- return (False, ["The email address is invalid."])
- elif backup_email and not util.valid_email(backup_email):
- return (False, ["The backup email address is invalid."])
- elif homepage and not util.valid_homepage(homepage):
- return (False, [
- "The home page is invalid, please specify the full HTTP(s) URL."])
- elif pgp_key and not util.valid_pgp_fingerprint(pgp_key):
- return (False, ["The PGP key fingerprint is invalid."])
- elif ssh_pubkey and not util.valid_ssh_pubkey(ssh_pubkey):
- return (False, ["The SSH public key is invalid."])
- elif language and language not in l10n.SUPPORTED_LANGUAGES:
- return (False, ["Language is not currently supported."])
- elif timezone and timezone not in time.SUPPORTED_TIMEZONES:
- return (False, ["Timezone is not currently supported."])
- elif db.query(models.User, username_exists(username)).first():
- # If the username already exists...
- return (False, [
- _("The username, %s%s%s, is already in use.") % (
- "", username, "")
- ])
- elif db.query(models.User, email_exists(email)).first():
- # If the email already exists...
- return (False, [
- _("The address, %s%s%s, is already in use.") % (
- "", email, "")
- ])
-
- def ssh_fingerprint_exists(fingerprint):
- return and_(models.SSHPubKey.UserID != user.ID,
- models.SSHPubKey.Fingerprint == fingerprint)
-
- if ssh_pubkey:
- fingerprint = get_fingerprint(ssh_pubkey.strip().rstrip())
- if fingerprint is None:
- return (False, ["The SSH public key is invalid."])
-
- if db.query(models.SSHPubKey,
- ssh_fingerprint_exists(fingerprint)).first():
- return (False, [
- _("The SSH public key, %s%s%s, is already in use.") % (
- "", fingerprint, "")
- ])
-
- T = int(args.get("T", user.AccountTypeID))
- if T != user.AccountTypeID:
- if T not in account_type.ACCOUNT_TYPE_NAME:
- return (False,
- ["Invalid account type provided."])
- elif not request.user.is_elevated():
- return (False,
- ["You do not have permission to change account types."])
-
- credential_checks = {
- DEVELOPER_ID: request.user.is_developer,
- TRUSTED_USER_AND_DEV_ID: request.user.is_developer,
- TRUSTED_USER_ID: request.user.is_elevated,
- USER_ID: request.user.is_elevated
- }
- credential_check = credential_checks.get(T)
-
- if not credential_check():
- name = account_type.ACCOUNT_TYPE_NAME.get(T)
- error = _("You do not have permission to change "
- "this user's account type to %s.") % name
- return (False, [error])
-
- captcha_salt = args.get("captcha_salt", None)
- if captcha_salt and captcha_salt not in get_captcha_salts():
- return (False, ["This CAPTCHA has expired. Please try again."])
-
- captcha = args.get("captcha", None)
- if captcha:
- answer = get_captcha_answer(get_captcha_token(captcha_salt))
- if captcha != answer:
- return (False, ["The entered CAPTCHA answer is invalid."])
+ try:
+ for check in checks:
+ check(**args, request=request, user=user, _=_)
+ except ValidationError as exc:
+ return (False, exc.data)
return (True, [])
@@ -286,16 +171,16 @@ def make_account_form_context(context: dict,
context = copy.copy(context)
context["account_types"] = [
- (USER_ID, "Normal User"),
- (TRUSTED_USER_ID, TRUSTED_USER)
+ (at.USER_ID, "Normal User"),
+ (at.TRUSTED_USER_ID, at.TRUSTED_USER)
]
user_account_type_id = context.get("account_types")[0][0]
if request.user.has_credential("CRED_ACCOUNT_EDIT_DEV"):
- context["account_types"].append((DEVELOPER_ID, DEVELOPER))
- context["account_types"].append((TRUSTED_USER_AND_DEV_ID,
- TRUSTED_USER_AND_DEV))
+ context["account_types"].append((at.DEVELOPER_ID, at.DEVELOPER))
+ context["account_types"].append((at.TRUSTED_USER_AND_DEV_ID,
+ at.TRUSTED_USER_AND_DEV))
if request.user.is_authenticated():
context["username"] = args.get("U", user.Username)
@@ -389,12 +274,10 @@ async def account_register_post(request: Request,
captcha: str = Form(default=None),
captcha_salt: str = Form(...)):
context = await make_variable_context(request, "Register")
-
args = dict(await request.form())
+
context = make_account_form_context(context, request, None, args)
-
ok, errors = process_account_form(request, request.user, args)
-
if not ok:
# If the field values given do not meet the requirements,
# return HTTP 400 with an error.
@@ -636,9 +519,9 @@ async def account_comments(request: Request, username: str):
@router.get("/accounts")
@auth_required(True, redirect="/accounts")
-@account_type_required({account_type.TRUSTED_USER,
- account_type.DEVELOPER,
- account_type.TRUSTED_USER_AND_DEV})
+@account_type_required({at.TRUSTED_USER,
+ at.DEVELOPER,
+ at.TRUSTED_USER_AND_DEV})
async def accounts(request: Request):
context = make_context(request, "Accounts")
return render_template(request, "account/search.html", context)
@@ -646,9 +529,9 @@ async def accounts(request: Request):
@router.post("/accounts")
@auth_required(True, redirect="/accounts")
-@account_type_required({account_type.TRUSTED_USER,
- account_type.DEVELOPER,
- account_type.TRUSTED_USER_AND_DEV})
+@account_type_required({at.TRUSTED_USER,
+ at.DEVELOPER,
+ at.TRUSTED_USER_AND_DEV})
async def accounts_post(request: Request,
O: int = Form(default=0), # Offset
SB: str = Form(default=str()), # Sort By
@@ -680,10 +563,10 @@ async def accounts_post(request: Request,
# Convert parameter T to an AccountType ID.
account_types = {
- "u": account_type.USER_ID,
- "t": account_type.TRUSTED_USER_ID,
- "d": account_type.DEVELOPER_ID,
- "td": account_type.TRUSTED_USER_AND_DEV_ID
+ "u": at.USER_ID,
+ "t": at.TRUSTED_USER_ID,
+ "d": at.DEVELOPER_ID,
+ "td": at.TRUSTED_USER_AND_DEV_ID
}
account_type_id = account_types.get(T, None)
diff --git a/aurweb/users/validate.py b/aurweb/users/validate.py
new file mode 100644
index 00000000..4959e316
--- /dev/null
+++ b/aurweb/users/validate.py
@@ -0,0 +1,204 @@
+"""
+Validation functions for account registration and edit fields.
+Each of these functions extracts a subset of keyword arguments
+out of form data from /account/register or /account/{username}/edit.
+
+All functions in this module raise aurweb.exceptions.ValidationError
+when encountering invalid criteria and return silently otherwise.
+"""
+from typing import List, Optional, Tuple
+
+from fastapi import Request
+from sqlalchemy import and_
+
+from aurweb import config, db, l10n, models, time, util
+from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token
+from aurweb.exceptions import ValidationError
+from aurweb.models import account_type as at
+from aurweb.models.account_type import ACCOUNT_TYPE_NAME
+from aurweb.models.ssh_pub_key import get_fingerprint
+
+
+def invalid_fields(E: str = str(), U: str = str(), **kwargs) \
+ -> Optional[Tuple[bool, List[str]]]:
+ if not E or not U:
+ raise ValidationError(["Missing a required field."])
+
+
+def invalid_suspend_permission(request: Request = None,
+ user: models.User = None,
+ J: bool = False,
+ **kwargs) \
+ -> Optional[Tuple[bool, List[str]]]:
+ if not request.user.is_elevated() and J != bool(user.InactivityTS):
+ raise ValidationError([
+ "You do not have permission to suspend accounts."])
+
+
+def invalid_username(request: Request = None, U: str = str(), _=None,
+ **kwargs):
+ if not util.valid_username(U):
+ username_min_len = config.getint("options", "username_min_len")
+ username_max_len = config.getint("options", "username_max_len")
+ raise ValidationError([
+ "The username is invalid.",
+ [
+ _("It must be between %s and %s characters long") % (
+ username_min_len, username_max_len),
+ "Start and end with a letter or number",
+ "Can contain only one period, underscore or hyphen.",
+ ]
+ ])
+
+
+def invalid_password(P: str = str(), C: str = str(),
+ _: l10n.Translator = None, **kwargs) -> None:
+ if P:
+ if not util.valid_password(P):
+ username_min_len = config.getint(
+ "options", "username_min_len")
+ raise ValidationError([
+ _("Your password must be at least %s characters.") % (
+ username_min_len)
+ ])
+ elif not C:
+ raise ValidationError(["Please confirm your new password."])
+ elif P != C:
+ raise ValidationError(["Password fields do not match."])
+
+
+def is_banned(request: Request = None, **kwargs) -> None:
+ host = request.client.host
+ exists = db.query(models.Ban, models.Ban.IPAddress == host).exists()
+ if db.query(exists).scalar():
+ raise ValidationError([
+ "Account registration has been disabled for your "
+ "IP address, probably due to sustained spam attacks. "
+ "Sorry for the inconvenience."
+ ])
+
+
+def invalid_user_password(request: Request = None, passwd: str = str(),
+ **kwargs) -> None:
+ if request.user.is_authenticated():
+ if not request.user.valid_password(passwd):
+ raise ValidationError(["Invalid password."])
+
+
+def invalid_email(E: str = str(), **kwargs) -> None:
+ if not util.valid_email(E):
+ raise ValidationError(["The email address is invalid."])
+
+
+def invalid_backup_email(BE: str = str(), **kwargs) -> None:
+ if BE and not util.valid_email(BE):
+ raise ValidationError(["The backup email address is invalid."])
+
+
+def invalid_homepage(HP: str = str(), **kwargs) -> None:
+ if HP and not util.valid_homepage(HP):
+ raise ValidationError([
+ "The home page is invalid, please specify the full HTTP(s) URL."])
+
+
+def invalid_pgp_key(K: str = str(), **kwargs) -> None:
+ if K and not util.valid_pgp_fingerprint(K):
+ raise ValidationError(["The PGP key fingerprint is invalid."])
+
+
+def invalid_ssh_pubkey(PK: str = str(), user: models.User = None,
+ _: l10n.Translator = None, **kwargs) -> None:
+ if PK:
+ invalid_exc = ValidationError(["The SSH public key is invalid."])
+ if not util.valid_ssh_pubkey(PK):
+ raise invalid_exc
+
+ fingerprint = get_fingerprint(PK.strip().rstrip())
+ if not fingerprint:
+ raise invalid_exc
+
+ exists = db.query(models.SSHPubKey).filter(
+ and_(models.SSHPubKey.UserID != user.ID,
+ models.SSHPubKey.Fingerprint == fingerprint)
+ ).exists()
+ if db.query(exists).scalar():
+ raise ValidationError([
+ _("The SSH public key, %s%s%s, is already in use.") % (
+ "", fingerprint, "")
+ ])
+
+
+def invalid_language(L: str = str(), **kwargs) -> None:
+ if L and L not in l10n.SUPPORTED_LANGUAGES:
+ raise ValidationError(["Language is not currently supported."])
+
+
+def invalid_timezone(TZ: str = str(), **kwargs) -> None:
+ if TZ and TZ not in time.SUPPORTED_TIMEZONES:
+ raise ValidationError(["Timezone is not currently supported."])
+
+
+def username_in_use(U: str = str(), user: models.User = None,
+ _: l10n.Translator = None, **kwargs) -> None:
+ exists = db.query(models.User).filter(
+ and_(models.User.ID != user.ID,
+ models.User.Username == U)
+ ).exists()
+ if db.query(exists).scalar():
+ # If the username already exists...
+ raise ValidationError([
+ _("The username, %s%s%s, is already in use.") % (
+ "", U, "")
+ ])
+
+
+def email_in_use(E: str = str(), user: models.User = None,
+ _: l10n.Translator = None, **kwargs) -> None:
+ exists = db.query(models.User).filter(
+ and_(models.User.ID != user.ID,
+ models.User.Email == E)
+ ).exists()
+ if db.query(exists).scalar():
+ # If the email already exists...
+ raise ValidationError([
+ _("The address, %s%s%s, is already in use.") % (
+ "", E, "")
+ ])
+
+
+def invalid_account_type(T: int = None, request: Request = None,
+ user: models.User = None,
+ _: l10n.Translator = None,
+ **kwargs) -> None:
+ if T is not None and (T := int(T)) != user.AccountTypeID:
+ if T not in ACCOUNT_TYPE_NAME:
+ raise ValidationError(["Invalid account type provided."])
+ elif not request.user.is_elevated():
+ raise ValidationError([
+ "You do not have permission to change account types."])
+
+ credential_checks = {
+ at.USER_ID: request.user.is_trusted_user,
+ at.TRUSTED_USER_ID: request.user.is_trusted_user,
+ at.DEVELOPER_ID: lambda: request.user.is_developer(),
+ at.TRUSTED_USER_AND_DEV_ID: (lambda: request.user.is_trusted_user()
+ and request.user.is_developer())
+ }
+ credential_check = credential_checks.get(T)
+
+ if not credential_check():
+ name = ACCOUNT_TYPE_NAME.get(T)
+ error = _("You do not have permission to change "
+ "this user's account type to %s.") % name
+ raise ValidationError([error])
+
+
+def invalid_captcha(captcha_salt: str = None, captcha: str = None, **kwargs) \
+ -> None:
+ if captcha_salt and captcha_salt not in get_captcha_salts():
+ raise ValidationError(["This CAPTCHA has expired. Please try again."])
+
+ if captcha:
+ answer = get_captcha_answer(get_captcha_token(captcha_salt))
+ if captcha != answer:
+ raise ValidationError(["The entered CAPTCHA answer is invalid."])
diff --git a/test/test_accounts_routes.py b/test/test_accounts_routes.py
index e828f70f..be929e97 100644
--- a/test/test_accounts_routes.py
+++ b/test/test_accounts_routes.py
@@ -1035,12 +1035,25 @@ def test_post_account_edit_account_types():
# Make sure it got changed to USER_ID as we intended.
assert user.AccountTypeID == USER_ID
- # Change user to a Developer.
+ # Change user to a TU & Dev, which can change themselves to a Developer.
with db.begin():
- user.AccountTypeID = DEVELOPER_ID
+ user.AccountTypeID = TRUSTED_USER_AND_DEV_ID
- # As a developer, we can absolutely change all account types.
- # For example, from DEVELOPER_ID to TRUSTED_USER_AND_DEV_ID:
+ # As a TU & Dev, we can absolutely change all account types.
+ # For example, from TRUSTED_USER_AND_DEV_ID to DEVELOPER_ID:
+ post_data = {
+ "U": user.Username,
+ "E": user.Email,
+ "T": DEVELOPER_ID,
+ "passwd": "testPassword"
+ }
+ with client as request:
+ resp = request.post(endpoint, data=post_data, cookies=cookies)
+ assert resp.status_code == int(HTTPStatus.OK)
+ assert user.AccountTypeID == DEVELOPER_ID
+
+ # But we can't change a user to a Trusted User & Developer when
+ # we're just a Developer.
post_data = {
"U": user.Username,
"E": user.Email,
@@ -1049,8 +1062,8 @@ def test_post_account_edit_account_types():
}
with client as request:
resp = request.post(endpoint, data=post_data, cookies=cookies)
- assert resp.status_code == int(HTTPStatus.OK)
- assert user.AccountTypeID == TRUSTED_USER_AND_DEV_ID
+ assert resp.status_code == int(HTTPStatus.BAD_REQUEST)
+ assert user.AccountTypeID == DEVELOPER_ID
def test_get_account():