mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
change(fastapi): decouple update logic from account edit
Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
parent
94972841d6
commit
303585cdbf
3 changed files with 128 additions and 74 deletions
|
@ -1,7 +1,6 @@
|
|||
import copy
|
||||
import typing
|
||||
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Form, Request
|
||||
|
@ -19,7 +18,7 @@ from aurweb.models import account_type as at
|
|||
from aurweb.models.ssh_pub_key import get_fingerprint
|
||||
from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification
|
||||
from aurweb.templates import make_context, make_variable_context, render_template
|
||||
from aurweb.users import validate
|
||||
from aurweb.users import update, validate
|
||||
from aurweb.users.util import get_user_by_name
|
||||
|
||||
router = APIRouter()
|
||||
|
@ -405,79 +404,17 @@ async def account_edit_post(request: Request,
|
|||
return render_template(request, "account/edit.html", context,
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Set all updated fields as needed.
|
||||
with db.begin():
|
||||
user.Username = U or user.Username
|
||||
user.Email = E or user.Email
|
||||
user.HideEmail = bool(H)
|
||||
user.BackupEmail = BE or user.BackupEmail
|
||||
user.RealName = R or user.RealName
|
||||
user.Homepage = HP or user.Homepage
|
||||
user.IRCNick = I or user.IRCNick
|
||||
user.PGPKey = K or user.PGPKey
|
||||
user.Suspended = J
|
||||
user.InactivityTS = int(datetime.utcnow().timestamp()) * int(J)
|
||||
updates = [
|
||||
update.simple,
|
||||
update.language,
|
||||
update.timezone,
|
||||
update.ssh_pubkey,
|
||||
update.account_type,
|
||||
update.password
|
||||
]
|
||||
|
||||
# If we update the language, update the cookie as well.
|
||||
if L and L != user.LangPreference:
|
||||
request.cookies["AURLANG"] = L
|
||||
with db.begin():
|
||||
user.LangPreference = L
|
||||
context["language"] = L
|
||||
|
||||
# If we update the timezone, also update the cookie.
|
||||
if TZ and TZ != user.Timezone:
|
||||
with db.begin():
|
||||
user.Timezone = TZ
|
||||
request.cookies["AURTZ"] = TZ
|
||||
context["timezone"] = TZ
|
||||
|
||||
with db.begin():
|
||||
user.CommentNotify = bool(CN)
|
||||
user.UpdateNotify = bool(UN)
|
||||
user.OwnershipNotify = bool(ON)
|
||||
|
||||
# If a PK is given, compare it against the target user's PK.
|
||||
with db.begin():
|
||||
if PK:
|
||||
# Get the second token in the public key, which is the actual key.
|
||||
pubkey = PK.strip().rstrip()
|
||||
parts = pubkey.split(" ")
|
||||
if len(parts) == 3:
|
||||
# Remove the host part.
|
||||
pubkey = parts[0] + " " + parts[1]
|
||||
fingerprint = get_fingerprint(pubkey)
|
||||
if not user.ssh_pub_key:
|
||||
# No public key exists, create one.
|
||||
user.ssh_pub_key = models.SSHPubKey(UserID=user.ID,
|
||||
PubKey=pubkey,
|
||||
Fingerprint=fingerprint)
|
||||
elif user.ssh_pub_key.PubKey != pubkey:
|
||||
# A public key already exists, update it.
|
||||
user.ssh_pub_key.PubKey = pubkey
|
||||
user.ssh_pub_key.Fingerprint = fingerprint
|
||||
elif user.ssh_pub_key:
|
||||
# Else, if the user has a public key already, delete it.
|
||||
db.delete(user.ssh_pub_key)
|
||||
|
||||
if T and T != user.AccountTypeID:
|
||||
with db.begin():
|
||||
user.AccountTypeID = T
|
||||
|
||||
if P and not user.valid_password(P):
|
||||
# Remove the fields we consumed for passwords.
|
||||
context["P"] = context["C"] = str()
|
||||
|
||||
# If a password was given and it doesn't match the user's, update it.
|
||||
with db.begin():
|
||||
user.update_password(P)
|
||||
|
||||
if user == request.user:
|
||||
remember_me = request.cookies.get("AURREMEMBER", False)
|
||||
|
||||
# If the target user is the request user, login with
|
||||
# the updated password to update the Session record.
|
||||
user.login(request, P, cookies.timeout(remember_me))
|
||||
for f in updates:
|
||||
f(**args, request=request, user=user, context=context)
|
||||
|
||||
if not errors:
|
||||
context["complete"] = True
|
||||
|
|
110
aurweb/users/update.py
Normal file
110
aurweb/users/update.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from aurweb import cookies, db, models
|
||||
from aurweb.models.ssh_pub_key import get_fingerprint
|
||||
from aurweb.util import strtobool
|
||||
|
||||
|
||||
def simple(U: str = str(), E: str = str(), H: bool = False,
|
||||
BE: str = str(), R: str = str(), HP: str = str(),
|
||||
I: str = str(), K: str = str(), J: bool = False,
|
||||
CN: bool = False, UN: bool = False, ON: bool = False,
|
||||
user: models.User = None,
|
||||
**kwargs) -> None:
|
||||
now = int(datetime.utcnow().timestamp())
|
||||
with db.begin():
|
||||
user.Username = U or user.Username
|
||||
user.Email = E or user.Email
|
||||
user.HideEmail = strtobool(H)
|
||||
user.BackupEmail = BE or user.BackupEmail
|
||||
user.RealName = R or user.RealName
|
||||
user.Homepage = HP or user.Homepage
|
||||
user.IRCNick = I or user.IRCNick
|
||||
user.PGPKey = K or user.PGPKey
|
||||
user.Suspended = strtobool(J)
|
||||
user.InactivityTS = now * int(strtobool(J))
|
||||
user.CommentNotify = strtobool(CN)
|
||||
user.UpdateNotify = strtobool(UN)
|
||||
user.OwnershipNotify = strtobool(ON)
|
||||
|
||||
|
||||
def language(L: str = str(),
|
||||
request: Request = None,
|
||||
user: models.User = None,
|
||||
context: Dict[str, Any] = {},
|
||||
**kwargs) -> None:
|
||||
if L and L != user.LangPreference:
|
||||
with db.begin():
|
||||
user.LangPreference = L
|
||||
context["language"] = L
|
||||
|
||||
|
||||
def timezone(TZ: str = str(),
|
||||
request: Request = None,
|
||||
user: models.User = None,
|
||||
context: Dict[str, Any] = {},
|
||||
**kwargs) -> None:
|
||||
if TZ and TZ != user.Timezone:
|
||||
with db.begin():
|
||||
user.Timezone = TZ
|
||||
context["language"] = TZ
|
||||
|
||||
|
||||
def ssh_pubkey(PK: str = str(),
|
||||
user: models.User = None,
|
||||
**kwargs) -> None:
|
||||
# If a PK is given, compare it against the target user's PK.
|
||||
if PK:
|
||||
# Get the second token in the public key, which is the actual key.
|
||||
pubkey = PK.strip().rstrip()
|
||||
parts = pubkey.split(" ")
|
||||
if len(parts) == 3:
|
||||
# Remove the host part.
|
||||
pubkey = parts[0] + " " + parts[1]
|
||||
fingerprint = get_fingerprint(pubkey)
|
||||
if not user.ssh_pub_key:
|
||||
# No public key exists, create one.
|
||||
with db.begin():
|
||||
db.create(models.SSHPubKey, UserID=user.ID,
|
||||
PubKey=pubkey, Fingerprint=fingerprint)
|
||||
elif user.ssh_pub_key.PubKey != pubkey:
|
||||
# A public key already exists, update it.
|
||||
with db.begin():
|
||||
user.ssh_pub_key.PubKey = pubkey
|
||||
user.ssh_pub_key.Fingerprint = fingerprint
|
||||
elif user.ssh_pub_key:
|
||||
# Else, if the user has a public key already, delete it.
|
||||
with db.begin():
|
||||
db.delete(user.ssh_pub_key)
|
||||
|
||||
|
||||
def account_type(T: int = None,
|
||||
user: models.User = None,
|
||||
**kwargs) -> None:
|
||||
if T is not None and (T := int(T)) != user.AccountTypeID:
|
||||
with db.begin():
|
||||
user.AccountTypeID = T
|
||||
|
||||
|
||||
def password(P: str = str(),
|
||||
request: Request = None,
|
||||
user: models.User = None,
|
||||
context: Dict[str, Any] = {},
|
||||
**kwargs) -> None:
|
||||
if P and not user.valid_password(P):
|
||||
# Remove the fields we consumed for passwords.
|
||||
context["P"] = context["C"] = str()
|
||||
|
||||
# If a password was given and it doesn't match the user's, update it.
|
||||
with db.begin():
|
||||
user.update_password(P)
|
||||
|
||||
if user == request.user:
|
||||
remember_me = request.cookies.get("AURREMEMBER", False)
|
||||
|
||||
# If the target user is the request user, login with
|
||||
# the updated password to update the Session record.
|
||||
user.login(request, P, cookies.timeout(remember_me))
|
|
@ -7,6 +7,7 @@ import secrets
|
|||
import string
|
||||
|
||||
from datetime import datetime
|
||||
from distutils.util import strtobool as _strtobool
|
||||
from typing import Any, Callable, Dict, Iterable, Tuple
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from zoneinfo import ZoneInfo
|
||||
|
@ -170,3 +171,9 @@ def sanitize_params(offset: str, per_page: str) -> Tuple[int, int]:
|
|||
per_page = defaults.PP
|
||||
|
||||
return (offset, per_page)
|
||||
|
||||
|
||||
def strtobool(value: str) -> bool:
|
||||
if isinstance(value, str):
|
||||
return _strtobool(value)
|
||||
return value
|
||||
|
|
Loading…
Add table
Reference in a new issue