mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
TestClient changes were reverted with 0.37.2: https://github.com/encode/starlette/pull/2525 https://github.com/encode/starlette/releases/tag/0.37.2 Signed-off-by: moson <moson@archlinux.org>
218 lines
6.4 KiB
Python
218 lines
6.4 KiB
Python
import math
|
|
import re
|
|
import secrets
|
|
import shlex
|
|
import string
|
|
from datetime import datetime
|
|
from hashlib import sha1
|
|
from http import HTTPStatus
|
|
from subprocess import PIPE, Popen
|
|
from typing import Callable, Iterable, Tuple, Union
|
|
from urllib.parse import urlparse
|
|
|
|
import fastapi
|
|
import pygit2
|
|
from email_validator import EmailSyntaxError, validate_email
|
|
from fastapi.responses import JSONResponse
|
|
from sqlalchemy.orm import Query
|
|
|
|
import aurweb.config
|
|
from aurweb import aur_logging, defaults
|
|
|
|
logger = aur_logging.get_logger(__name__)
|
|
|
|
|
|
def make_random_string(length: int) -> str:
|
|
alphanumerics = string.ascii_lowercase + string.digits
|
|
return "".join([secrets.choice(alphanumerics) for i in range(length)])
|
|
|
|
|
|
def make_nonce(length: int = 8):
|
|
"""Generate a single random nonce. Here, token_hex generates a hex
|
|
string of 2 hex characters per byte, where the length give is
|
|
nbytes. This means that to get our proper string length, we need to
|
|
cut it in half and truncate off any remaining (in the case that
|
|
length was uneven)."""
|
|
return secrets.token_hex(math.ceil(length / 2))[:length]
|
|
|
|
|
|
def valid_username(username):
|
|
min_len = aurweb.config.getint("options", "username_min_len")
|
|
max_len = aurweb.config.getint("options", "username_max_len")
|
|
if not (min_len <= len(username) <= max_len):
|
|
return False
|
|
|
|
# Check that username contains: one or more alphanumeric
|
|
# characters, an optional separator of '.', '-' or '_', followed
|
|
# by alphanumeric characters.
|
|
return re.match(r"^[a-zA-Z0-9]+[.\-_]?[a-zA-Z0-9]+$", username)
|
|
|
|
|
|
def valid_email(email):
|
|
try:
|
|
validate_email(email, check_deliverability=False)
|
|
except EmailSyntaxError:
|
|
return False
|
|
return True
|
|
|
|
|
|
def valid_homepage(homepage):
|
|
try:
|
|
parts = urlparse(homepage)
|
|
except ValueError:
|
|
return False
|
|
return parts.scheme in ("http", "https") and bool(parts.netloc)
|
|
|
|
|
|
def valid_password(password):
|
|
min_len = aurweb.config.getint("options", "passwd_min_len")
|
|
return len(password) >= min_len
|
|
|
|
|
|
def valid_pgp_fingerprint(fp):
|
|
try:
|
|
# Attempt to convert the fingerprint to an int via base16.
|
|
# If it can't, it's not a hex string.
|
|
int(fp, 16)
|
|
except ValueError:
|
|
return False
|
|
|
|
# Check the length; must be 40 hexadecimal digits.
|
|
return len(fp) == 40
|
|
|
|
|
|
def jsonify(obj):
|
|
"""Perform a conversion on obj if it's needed."""
|
|
if isinstance(obj, datetime):
|
|
obj = int(obj.timestamp())
|
|
return obj
|
|
|
|
|
|
def get_ssh_fingerprints():
|
|
return aurweb.config.get_section("fingerprints") or {}
|
|
|
|
|
|
def apply_all(iterable: Iterable, fn: Callable):
|
|
for item in iterable:
|
|
fn(item)
|
|
return iterable
|
|
|
|
|
|
def sanitize_params(offset_str: str, per_page_str: str) -> Tuple[int, int]:
|
|
try:
|
|
offset = defaults.O if int(offset_str) < 0 else int(offset_str)
|
|
except ValueError:
|
|
offset = defaults.O
|
|
|
|
try:
|
|
per_page = defaults.PP if int(per_page_str) <= 0 else int(per_page_str)
|
|
except ValueError:
|
|
per_page = defaults.PP
|
|
|
|
return offset, per_page
|
|
|
|
|
|
def strtobool(value: Union[str, bool]) -> bool:
|
|
if not value:
|
|
return False
|
|
return str(value).lower() in ("y", "yes", "t", "true", "on", "1")
|
|
|
|
|
|
def file_hash(filepath: str, hash_function: Callable) -> str:
|
|
"""
|
|
Return a hash of filepath contents using `hash_function`.
|
|
|
|
`hash_function` can be any one of the hashlib module's hash
|
|
functions which implement the `hexdigest()` method -- e.g.
|
|
hashlib.sha1, hashlib.md5, etc.
|
|
|
|
:param filepath: Path to file you want to hash
|
|
:param hash_function: hashlib hash function
|
|
:return: hash_function(filepath_content).hexdigest()
|
|
"""
|
|
with open(filepath, "rb") as f:
|
|
hash_ = hash_function(f.read())
|
|
return hash_.hexdigest()
|
|
|
|
|
|
def git_search(repo: pygit2.Repository, commit_hash: str) -> int:
|
|
"""
|
|
Return the shortest prefix length matching `commit_hash` found.
|
|
|
|
:param repo: pygit2.Repository instance
|
|
:param commit_hash: Full length commit hash
|
|
:return: Shortest unique prefix length found
|
|
"""
|
|
prefixlen = 12
|
|
while prefixlen < len(commit_hash):
|
|
if commit_hash[:prefixlen] in repo:
|
|
break
|
|
prefixlen += 1
|
|
return prefixlen
|
|
|
|
|
|
async def error_or_result(next: Callable, *args, **kwargs) -> fastapi.Response:
|
|
"""
|
|
Try to return a response from `next`.
|
|
|
|
If RuntimeError is raised during next(...) execution, return a
|
|
500 with the exception's error as a JSONResponse.
|
|
|
|
:param next: Callable of the next fastapi route callback
|
|
:param *args: Variable number of arguments passed to the endpoint
|
|
:param **kwargs: Optional kwargs to pass to the endpoint
|
|
:return: next(...) retval; if an exc is raised: a 500 response
|
|
"""
|
|
try:
|
|
response = await next(*args, **kwargs)
|
|
except RuntimeError as exc:
|
|
logger.error(f"RuntimeError: {exc}")
|
|
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
|
return JSONResponse({"error": str(exc)}, status_code=status_code)
|
|
return response
|
|
|
|
|
|
def parse_ssh_key(string: str) -> Tuple[str, str]:
|
|
"""Parse an SSH public key."""
|
|
invalid_exc = ValueError("The SSH public key is invalid.")
|
|
parts = re.sub(r"\s\s+", " ", string.strip()).split()
|
|
if len(parts) < 2:
|
|
raise invalid_exc
|
|
|
|
prefix, key = parts[:2]
|
|
prefixes = set(aurweb.config.get("auth", "valid-keytypes").split(" "))
|
|
if prefix not in prefixes:
|
|
raise invalid_exc
|
|
|
|
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE, stderr=PIPE)
|
|
out, _ = proc.communicate(f"{prefix} {key}".encode())
|
|
if proc.returncode:
|
|
raise invalid_exc
|
|
|
|
return prefix, key
|
|
|
|
|
|
def parse_ssh_keys(string: str) -> set[Tuple[str, str]]:
|
|
"""Parse a list of SSH public keys."""
|
|
return set([parse_ssh_key(e) for e in string.strip().splitlines(True) if e.strip()])
|
|
|
|
|
|
def shell_exec(cmdline: str, cwd: str) -> Tuple[int, str, str]:
|
|
args = shlex.split(cmdline)
|
|
proc = Popen(args, cwd=cwd, stdout=PIPE, stderr=PIPE)
|
|
out, err = proc.communicate()
|
|
return proc.returncode, out.decode().strip(), err.decode().strip()
|
|
|
|
|
|
def hash_query(query: Query):
|
|
return sha1(
|
|
str(query.statement.compile(compile_kwargs={"literal_binds": True})).encode()
|
|
).hexdigest()
|
|
|
|
|
|
def get_client_ip(request: fastapi.Request) -> str:
|
|
"""
|
|
Returns the client's IP address for a Request.
|
|
Falls back to 'testclient' if request.client is None
|
|
"""
|
|
return request.client.host if request.client else "testclient"
|