style: Run pre-commit

This commit is contained in:
Joakim Saario 2022-08-21 22:08:29 +02:00
parent b47882b114
commit 9c6c13b78a
No known key found for this signature in database
GPG key ID: D8B76D271B7BD453
235 changed files with 7180 additions and 5628 deletions

1
.gitignore vendored
View file

@ -1,3 +1,4 @@
/data/
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
.vim/ .vim/

View file

@ -5,4 +5,3 @@ host = https://www.transifex.com
file_filter = po/<lang>.po file_filter = po/<lang>.po
source_file = po/aurweb.pot source_file = po/aurweb.pot
source_lang = en source_lang = en

View file

@ -6,11 +6,9 @@ import re
import sys import sys
import traceback import traceback
import typing import typing
from urllib.parse import quote_plus from urllib.parse import quote_plus
import requests import requests
from fastapi import FastAPI, HTTPException, Request, Response from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
@ -26,7 +24,6 @@ import aurweb.config
import aurweb.filters # noqa: F401 import aurweb.filters # noqa: F401
import aurweb.logging import aurweb.logging
import aurweb.pkgbase.util as pkgbaseutil import aurweb.pkgbase.util as pkgbaseutil
from aurweb import logging, prometheus, util from aurweb import logging, prometheus, util
from aurweb.auth import BasicAuthBackend from aurweb.auth import BasicAuthBackend
from aurweb.db import get_engine, query from aurweb.db import get_engine, query
@ -60,33 +57,33 @@ async def app_startup():
# provided by the user. Docker uses .env's TEST_RECURSION_LIMIT # provided by the user. Docker uses .env's TEST_RECURSION_LIMIT
# when running test suites. # when running test suites.
# TODO: Find a proper fix to this issue. # TODO: Find a proper fix to this issue.
recursion_limit = int(os.environ.get( recursion_limit = int(
"TEST_RECURSION_LIMIT", sys.getrecursionlimit() + 1000)) os.environ.get("TEST_RECURSION_LIMIT", sys.getrecursionlimit() + 1000)
)
sys.setrecursionlimit(recursion_limit) sys.setrecursionlimit(recursion_limit)
backend = aurweb.config.get("database", "backend") backend = aurweb.config.get("database", "backend")
if backend not in aurweb.db.DRIVERS: if backend not in aurweb.db.DRIVERS:
raise ValueError( raise ValueError(
f"The configured database backend ({backend}) is unsupported. " f"The configured database backend ({backend}) is unsupported. "
f"Supported backends: {str(aurweb.db.DRIVERS.keys())}") f"Supported backends: {str(aurweb.db.DRIVERS.keys())}"
)
session_secret = aurweb.config.get("fastapi", "session_secret") session_secret = aurweb.config.get("fastapi", "session_secret")
if not session_secret: if not session_secret:
raise Exception("[fastapi] session_secret must not be empty") raise Exception("[fastapi] session_secret must not be empty")
if not os.environ.get("PROMETHEUS_MULTIPROC_DIR", None): if not os.environ.get("PROMETHEUS_MULTIPROC_DIR", None):
logger.warning("$PROMETHEUS_MULTIPROC_DIR is not set, the /metrics " logger.warning(
"endpoint is disabled.") "$PROMETHEUS_MULTIPROC_DIR is not set, the /metrics "
"endpoint is disabled."
)
app.mount("/static/css", app.mount("/static/css", StaticFiles(directory="web/html/css"), name="static_css")
StaticFiles(directory="web/html/css"), app.mount("/static/js", StaticFiles(directory="web/html/js"), name="static_js")
name="static_css") app.mount(
app.mount("/static/js", "/static/images", StaticFiles(directory="web/html/images"), name="static_images"
StaticFiles(directory="web/html/js"), )
name="static_js")
app.mount("/static/images",
StaticFiles(directory="web/html/images"),
name="static_images")
# Add application middlewares. # Add application middlewares.
app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend()) app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend())
@ -95,6 +92,7 @@ async def app_startup():
# Add application routes. # Add application routes.
def add_router(module): def add_router(module):
app.include_router(module.router) app.include_router(module.router)
util.apply_all(APP_ROUTES, add_router) util.apply_all(APP_ROUTES, add_router)
# Initialize the database engine and ORM. # Initialize the database engine and ORM.
@ -102,8 +100,8 @@ async def app_startup():
def child_exit(server, worker): # pragma: no cover def child_exit(server, worker): # pragma: no cover
""" This function is required for gunicorn customization """This function is required for gunicorn customization
of prometheus multiprocessing. """ of prometheus multiprocessing."""
multiprocess.mark_process_dead(worker.pid) multiprocess.mark_process_dead(worker.pid)
@ -177,9 +175,7 @@ async def internal_server_error(request: Request, exc: Exception) -> Response:
else: else:
# post # post
form_data = str(dict(request.state.form_data)) form_data = str(dict(request.state.form_data))
desc = desc + [ desc = desc + [f"- Data: `{form_data}`"] + ["", f"```{tb}```"]
f"- Data: `{form_data}`"
] + ["", f"```{tb}```"]
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
data = { data = {
@ -191,11 +187,12 @@ async def internal_server_error(request: Request, exc: Exception) -> Response:
logger.info(endp) logger.info(endp)
resp = requests.post(endp, json=data, headers=headers) resp = requests.post(endp, json=data, headers=headers)
if resp.status_code != http.HTTPStatus.CREATED: if resp.status_code != http.HTTPStatus.CREATED:
logger.error( logger.error(f"Unable to report exception to {repo}: {resp.text}")
f"Unable to report exception to {repo}: {resp.text}")
else: else:
logger.warning("Unable to report an exception found due to " logger.warning(
"unset notifications.error-{{project,token}}") "Unable to report an exception found due to "
"unset notifications.error-{{project,token}}"
)
# Log details about the exception traceback. # Log details about the exception traceback.
logger.error(f"FATAL[{tb_id}]: An unexpected exception has occurred.") logger.error(f"FATAL[{tb_id}]: An unexpected exception has occurred.")
@ -203,14 +200,17 @@ async def internal_server_error(request: Request, exc: Exception) -> Response:
else: else:
retval = retval.decode() retval = retval.decode()
return render_template(request, "errors/500.html", context, return render_template(
status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR) request,
"errors/500.html",
context,
status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR,
)
@app.exception_handler(StarletteHTTPException) @app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: HTTPException) \ async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
-> Response: """Handle an HTTPException thrown in a route."""
""" Handle an HTTPException thrown in a route. """
phrase = http.HTTPStatus(exc.status_code).phrase phrase = http.HTTPStatus(exc.status_code).phrase
context = make_context(request, phrase) context = make_context(request, phrase)
context["exc"] = exc context["exc"] = exc
@ -228,16 +228,16 @@ async def http_exception_handler(request: Request, exc: HTTPException) \
pass pass
try: try:
return render_template(request, f"errors/{exc.status_code}.html", return render_template(
context, exc.status_code) request, f"errors/{exc.status_code}.html", context, exc.status_code
)
except TemplateNotFound: except TemplateNotFound:
return render_template(request, "errors/detail.html", return render_template(request, "errors/detail.html", context, exc.status_code)
context, exc.status_code)
@app.middleware("http") @app.middleware("http")
async def add_security_headers(request: Request, call_next: typing.Callable): async def add_security_headers(request: Request, call_next: typing.Callable):
""" This middleware adds the CSP, XCTO, XFO and RP security """This middleware adds the CSP, XCTO, XFO and RP security
headers to the HTTP response associated with request. headers to the HTTP response associated with request.
CSP: Content-Security-Policy CSP: Content-Security-Policy
@ -254,7 +254,7 @@ async def add_security_headers(request: Request, call_next: typing.Callable):
nonce = request.user.nonce nonce = request.user.nonce
csp = "default-src 'self'; " csp = "default-src 'self'; "
script_hosts = [] script_hosts = []
csp += f"script-src 'self' 'nonce-{nonce}' " + ' '.join(script_hosts) csp += f"script-src 'self' 'nonce-{nonce}' " + " ".join(script_hosts)
# It's fine if css is inlined. # It's fine if css is inlined.
csp += "; style-src 'self' 'unsafe-inline'" csp += "; style-src 'self' 'unsafe-inline'"
response.headers["Content-Security-Policy"] = csp response.headers["Content-Security-Policy"] = csp
@ -276,17 +276,25 @@ async def add_security_headers(request: Request, call_next: typing.Callable):
@app.middleware("http") @app.middleware("http")
async def check_terms_of_service(request: Request, call_next: typing.Callable): async def check_terms_of_service(request: Request, call_next: typing.Callable):
""" This middleware function redirects authenticated users if they """This middleware function redirects authenticated users if they
have any outstanding Terms to agree to. """ have any outstanding Terms to agree to."""
if request.user.is_authenticated() and request.url.path != "/tos": if request.user.is_authenticated() and request.url.path != "/tos":
unaccepted = query(Term).join(AcceptedTerm).filter( unaccepted = (
or_(AcceptedTerm.UsersID != request.user.ID, query(Term)
and_(AcceptedTerm.UsersID == request.user.ID, .join(AcceptedTerm)
AcceptedTerm.TermsID == Term.ID, .filter(
AcceptedTerm.Revision < Term.Revision))) or_(
AcceptedTerm.UsersID != request.user.ID,
and_(
AcceptedTerm.UsersID == request.user.ID,
AcceptedTerm.TermsID == Term.ID,
AcceptedTerm.Revision < Term.Revision,
),
)
)
)
if query(Term).count() > unaccepted.count(): if query(Term).count() > unaccepted.count():
return RedirectResponse( return RedirectResponse("/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
"/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
return await util.error_or_result(call_next, request) return await util.error_or_result(call_next, request)
@ -301,9 +309,9 @@ async def id_redirect_middleware(request: Request, call_next: typing.Callable):
for k, v in request.query_params.items(): for k, v in request.query_params.items():
if k != "id": if k != "id":
qs.append(f"{k}={quote_plus(str(v))}") qs.append(f"{k}={quote_plus(str(v))}")
qs = str() if not qs else '?' + '&'.join(qs) qs = str() if not qs else "?" + "&".join(qs)
path = request.url.path.rstrip('/') path = request.url.path.rstrip("/")
return RedirectResponse(f"{path}/{id}{qs}") return RedirectResponse(f"{path}/{id}{qs}")
return await util.error_or_result(call_next, request) return await util.error_or_result(call_next, request)

View file

@ -1,25 +1,22 @@
import functools import functools
from http import HTTPStatus from http import HTTPStatus
from typing import Callable from typing import Callable
import fastapi import fastapi
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from starlette.authentication import AuthCredentials, AuthenticationBackend from starlette.authentication import AuthCredentials, AuthenticationBackend
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
import aurweb.config import aurweb.config
from aurweb import db, filters, l10n, time, util from aurweb import db, filters, l10n, time, util
from aurweb.models import Session, User from aurweb.models import Session, User
from aurweb.models.account_type import ACCOUNT_TYPE_ID from aurweb.models.account_type import ACCOUNT_TYPE_ID
class StubQuery: class StubQuery:
""" Acts as a stubbed version of an orm.Query. Typically used """Acts as a stubbed version of an orm.Query. Typically used
to masquerade fake records for an AnonymousUser. """ to masquerade fake records for an AnonymousUser."""
def filter(self, *args): def filter(self, *args):
return StubQuery() return StubQuery()
@ -29,19 +26,21 @@ class StubQuery:
class AnonymousUser: class AnonymousUser:
""" A stubbed User class used when an unauthenticated User """A stubbed User class used when an unauthenticated User
makes a request against FastAPI. """ makes a request against FastAPI."""
# Stub attributes used to mimic a real user. # Stub attributes used to mimic a real user.
ID = 0 ID = 0
Username = "N/A" Username = "N/A"
Email = "N/A" Email = "N/A"
class AccountType: class AccountType:
""" A stubbed AccountType static class. In here, we use an ID """A stubbed AccountType static class. In here, we use an ID
and AccountType which do not exist in our constant records. and AccountType which do not exist in our constant records.
All records primary keys (AccountType.ID) should be non-zero, All records primary keys (AccountType.ID) should be non-zero,
so using a zero here means that we'll never match against a so using a zero here means that we'll never match against a
real AccountType. """ real AccountType."""
ID = 0 ID = 0
AccountType = "Anonymous" AccountType = "Anonymous"
@ -104,11 +103,11 @@ class BasicAuthBackend(AuthenticationBackend):
return unauthenticated return unauthenticated
timeout = aurweb.config.getint("options", "login_timeout") timeout = aurweb.config.getint("options", "login_timeout")
remembered = ("AURREMEMBER" in conn.cookies remembered = "AURREMEMBER" in conn.cookies and bool(
and bool(conn.cookies.get("AURREMEMBER"))) conn.cookies.get("AURREMEMBER")
)
if remembered: if remembered:
timeout = aurweb.config.getint("options", timeout = aurweb.config.getint("options", "persistent_cookie_timeout")
"persistent_cookie_timeout")
# If no session with sid and a LastUpdateTS now or later exists. # If no session with sid and a LastUpdateTS now or later exists.
now_ts = time.utcnow() now_ts = time.utcnow()
@ -160,40 +159,45 @@ def _auth_required(auth_goal: bool = True):
# page itself is not directly possible (e.g. submitting a form). # page itself is not directly possible (e.g. submitting a form).
if request.method in ("GET", "HEAD"): if request.method in ("GET", "HEAD"):
url = request.url.path url = request.url.path
elif (referer := request.headers.get("Referer")): elif referer := request.headers.get("Referer"):
aur = aurweb.config.get("options", "aur_location") + "/" aur = aurweb.config.get("options", "aur_location") + "/"
if not referer.startswith(aur): if not referer.startswith(aur):
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, raise HTTPException(
detail=_("Bad Referer header.")) status_code=HTTPStatus.BAD_REQUEST,
url = referer[len(aur) - 1:] detail=_("Bad Referer header."),
)
url = referer[len(aur) - 1 :]
url = "/login?" + filters.urlencode({"next": url}) url = "/login?" + filters.urlencode({"next": url})
return RedirectResponse(url, status_code=int(HTTPStatus.SEE_OTHER)) return RedirectResponse(url, status_code=int(HTTPStatus.SEE_OTHER))
return wrapper return wrapper
return decorator return decorator
def requires_auth(func: Callable) -> Callable: def requires_auth(func: Callable) -> Callable:
""" Require an authenticated session for a particular route. """ """Require an authenticated session for a particular route."""
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
return await _auth_required(True)(func)(*args, **kwargs) return await _auth_required(True)(func)(*args, **kwargs)
return wrapper return wrapper
def requires_guest(func: Callable) -> Callable: def requires_guest(func: Callable) -> Callable:
""" Require a guest (unauthenticated) session for a particular route. """ """Require a guest (unauthenticated) session for a particular route."""
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
return await _auth_required(False)(func)(*args, **kwargs) return await _auth_required(False)(func)(*args, **kwargs)
return wrapper return wrapper
def account_type_required(one_of: set): def account_type_required(one_of: set):
""" A decorator that can be used on FastAPI routes to dictate """A decorator that can be used on FastAPI routes to dictate
that a user belongs to one of the types defined in one_of. that a user belongs to one of the types defined in one_of.
This decorator should be run after an @auth_required(True) is This decorator should be run after an @auth_required(True) is
@ -211,18 +215,15 @@ def account_type_required(one_of: set):
:return: Return the FastAPI function this decorator wraps. :return: Return the FastAPI function this decorator wraps.
""" """
# Convert any account type string constants to their integer IDs. # Convert any account type string constants to their integer IDs.
one_of = { one_of = {ACCOUNT_TYPE_ID[atype] for atype in one_of if isinstance(atype, str)}
ACCOUNT_TYPE_ID[atype]
for atype in one_of
if isinstance(atype, str)
}
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
async def wrapper(request: fastapi.Request, *args, **kwargs): async def wrapper(request: fastapi.Request, *args, **kwargs):
if request.user.AccountTypeID not in one_of: if request.user.AccountTypeID not in one_of:
return RedirectResponse("/", return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))
status_code=int(HTTPStatus.SEE_OTHER))
return await func(request, *args, **kwargs) return await func(request, *args, **kwargs)
return wrapper return wrapper
return decorator return decorator

View file

@ -1,4 +1,9 @@
from aurweb.models.account_type import DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, USER_ID from aurweb.models.account_type import (
DEVELOPER_ID,
TRUSTED_USER_AND_DEV_ID,
TRUSTED_USER_ID,
USER_ID,
)
from aurweb.models.user import User from aurweb.models.user import User
ACCOUNT_CHANGE_TYPE = 1 ACCOUNT_CHANGE_TYPE = 1
@ -30,7 +35,9 @@ TU_LIST_VOTES = 20
TU_VOTE = 21 TU_VOTE = 21
PKGBASE_MERGE = 29 PKGBASE_MERGE = 29
user_developer_or_trusted_user = set([USER_ID, TRUSTED_USER_ID, DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID]) user_developer_or_trusted_user = set(
[USER_ID, TRUSTED_USER_ID, DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID]
)
trusted_user_or_dev = set([TRUSTED_USER_ID, DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID]) trusted_user_or_dev = set([TRUSTED_USER_ID, DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID])
developer = set([DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID]) developer = set([DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID])
trusted_user = set([TRUSTED_USER_ID, TRUSTED_USER_AND_DEV_ID]) trusted_user = set([TRUSTED_USER_ID, TRUSTED_USER_AND_DEV_ID])
@ -67,9 +74,7 @@ cred_filters = {
} }
def has_credential(user: User, def has_credential(user: User, credential: int, approved: list = tuple()):
credential: int,
approved: list = tuple()):
if user in approved: if user in approved:
return True return True

View file

@ -6,16 +6,16 @@ class Benchmark:
self.start() self.start()
def _timestamp(self) -> float: def _timestamp(self) -> float:
""" Generate a timestamp. """ """Generate a timestamp."""
return float(datetime.utcnow().timestamp()) return float(datetime.utcnow().timestamp())
def start(self) -> int: def start(self) -> int:
""" Start a benchmark. """ """Start a benchmark."""
self.current = self._timestamp() self.current = self._timestamp()
return self.current return self.current
def end(self): def end(self):
""" Return the diff between now - start(). """ """Return the diff between now - start()."""
n = self._timestamp() - self.current n = self._timestamp() - self.current
self.current = float(0) self.current = float(0)
return n return n

View file

@ -2,9 +2,10 @@ from redis import Redis
from sqlalchemy import orm from sqlalchemy import orm
async def db_count_cache(redis: Redis, key: str, query: orm.Query, async def db_count_cache(
expire: int = None) -> int: redis: Redis, key: str, query: orm.Query, expire: int = None
""" Store and retrieve a query.count() via redis cache. ) -> int:
"""Store and retrieve a query.count() via redis cache.
:param redis: Redis handle :param redis: Redis handle
:param key: Redis key :param key: Redis key

View file

@ -9,7 +9,7 @@ from aurweb.templates import register_filter
def get_captcha_salts(): def get_captcha_salts():
""" Produce salts based on the current user count. """ """Produce salts based on the current user count."""
count = query(User).count() count = query(User).count()
salts = [] salts = []
for i in range(0, 6): for i in range(0, 6):
@ -18,19 +18,19 @@ def get_captcha_salts():
def get_captcha_token(salt): def get_captcha_token(salt):
""" Produce a token for the CAPTCHA salt. """ """Produce a token for the CAPTCHA salt."""
return hashlib.md5(salt.encode()).hexdigest()[:3] return hashlib.md5(salt.encode()).hexdigest()[:3]
def get_captcha_challenge(salt): def get_captcha_challenge(salt):
""" Get a CAPTCHA challenge string (shell command) for a salt. """ """Get a CAPTCHA challenge string (shell command) for a salt."""
token = get_captcha_token(salt) token = get_captcha_token(salt)
return f"LC_ALL=C pacman -V|sed -r 's#[0-9]+#{token}#g'|md5sum|cut -c1-6" return f"LC_ALL=C pacman -V|sed -r 's#[0-9]+#{token}#g'|md5sum|cut -c1-6"
def get_captcha_answer(token): def get_captcha_answer(token):
""" Compute the answer via md5 of the real template text, return the """Compute the answer via md5 of the real template text, return the
first six digits of the hexadecimal hash. """ first six digits of the hexadecimal hash."""
text = r""" text = r"""
.--. Pacman v%s.%s.%s - libalpm v%s.%s.%s .--. Pacman v%s.%s.%s - libalpm v%s.%s.%s
/ _.-' .-. .-. .-. Copyright (C) %s-%s Pacman Development Team / _.-' .-. .-. .-. Copyright (C) %s-%s Pacman Development Team
@ -38,14 +38,16 @@ def get_captcha_answer(token):
'--' '--'
This program may be freely redistributed under This program may be freely redistributed under
the terms of the GNU General Public License. the terms of the GNU General Public License.
""" % tuple([token] * 10) """ % tuple(
[token] * 10
)
return hashlib.md5((text + "\n").encode()).hexdigest()[:6] return hashlib.md5((text + "\n").encode()).hexdigest()[:6]
@register_filter("captcha_salt") @register_filter("captcha_salt")
@pass_context @pass_context
def captcha_salt_filter(context): def captcha_salt_filter(context):
""" Returns the most recent CAPTCHA salt in the list of salts. """ """Returns the most recent CAPTCHA salt in the list of salts."""
salts = get_captcha_salts() salts = get_captcha_salts()
return salts[0] return salts[0]
@ -53,5 +55,5 @@ def captcha_salt_filter(context):
@register_filter("captcha_cmdline") @register_filter("captcha_cmdline")
@pass_context @pass_context
def captcha_cmdline_filter(context, salt): def captcha_cmdline_filter(context, salt):
""" Returns a CAPTCHA challenge for a given salt. """ """Returns a CAPTCHA challenge for a given salt."""
return get_captcha_challenge(salt) return get_captcha_challenge(salt)

View file

@ -1,6 +1,5 @@
import configparser import configparser
import os import os
from typing import Any from typing import Any
# Publicly visible version of aurweb. This is used to display # Publicly visible version of aurweb. This is used to display
@ -15,8 +14,8 @@ def _get_parser():
global _parser global _parser
if not _parser: if not _parser:
path = os.environ.get('AUR_CONFIG', '/etc/aurweb/config') path = os.environ.get("AUR_CONFIG", "/etc/aurweb/config")
defaults = os.environ.get('AUR_CONFIG_DEFAULTS', path + '.defaults') defaults = os.environ.get("AUR_CONFIG_DEFAULTS", path + ".defaults")
_parser = configparser.RawConfigParser() _parser = configparser.RawConfigParser()
_parser.optionxform = lambda option: option _parser.optionxform = lambda option: option
@ -29,7 +28,7 @@ def _get_parser():
def rehash(): def rehash():
""" Globally rehash the configuration parser. """ """Globally rehash the configuration parser."""
global _parser global _parser
_parser = None _parser = None
_get_parser() _get_parser()

View file

@ -5,7 +5,7 @@ from aurweb import config
def samesite() -> str: def samesite() -> str:
""" Produce cookie SameSite value. """Produce cookie SameSite value.
Currently this is hard-coded to return "lax" Currently this is hard-coded to return "lax"
@ -15,7 +15,7 @@ def samesite() -> str:
def timeout(extended: bool) -> int: def timeout(extended: bool) -> int:
""" Produce a session timeout based on `remember_me`. """Produce a session timeout based on `remember_me`.
This method returns one of AUR_CONFIG's options.persistent_cookie_timeout This method returns one of AUR_CONFIG's options.persistent_cookie_timeout
and options.login_timeout based on the `extended` argument. and options.login_timeout based on the `extended` argument.
@ -35,10 +35,14 @@ def timeout(extended: bool) -> int:
return timeout return timeout
def update_response_cookies(request: Request, response: Response, def update_response_cookies(
aurtz: str = None, aurlang: str = None, request: Request,
aursid: str = None) -> Response: response: Response,
""" Update session cookies. This method is particularly useful aurtz: str = None,
aurlang: str = None,
aursid: str = None,
) -> Response:
"""Update session cookies. This method is particularly useful
when updating a cookie which was already set. when updating a cookie which was already set.
The AURSID cookie's expiration is based on the AURREMEMBER cookie, The AURSID cookie's expiration is based on the AURREMEMBER cookie,
@ -53,14 +57,21 @@ def update_response_cookies(request: Request, response: Response,
""" """
secure = config.getboolean("options", "disable_http_login") secure = config.getboolean("options", "disable_http_login")
if aurtz: if aurtz:
response.set_cookie("AURTZ", aurtz, secure=secure, httponly=secure, response.set_cookie(
samesite=samesite()) "AURTZ", aurtz, secure=secure, httponly=secure, samesite=samesite()
)
if aurlang: if aurlang:
response.set_cookie("AURLANG", aurlang, secure=secure, httponly=secure, response.set_cookie(
samesite=samesite()) "AURLANG", aurlang, secure=secure, httponly=secure, samesite=samesite()
)
if aursid: if aursid:
remember_me = bool(request.cookies.get("AURREMEMBER", False)) remember_me = bool(request.cookies.get("AURREMEMBER", False))
response.set_cookie("AURSID", aursid, secure=secure, httponly=secure, response.set_cookie(
max_age=timeout(remember_me), "AURSID",
samesite=samesite()) aursid,
secure=secure,
httponly=secure,
max_age=timeout(remember_me),
samesite=samesite(),
)
return response return response

View file

@ -1,15 +1,14 @@
# Supported database drivers. # Supported database drivers.
DRIVERS = { DRIVERS = {"mysql": "mysql+mysqldb"}
"mysql": "mysql+mysqldb"
}
def make_random_value(table: str, column: str, length: int): def make_random_value(table: str, column: str, length: int):
""" Generate a unique, random value for a string column in a table. """Generate a unique, random value for a string column in a table.
:return: A unique string that is not in the database :return: A unique string that is not in the database
""" """
import aurweb.util import aurweb.util
string = aurweb.util.make_random_string(length) string = aurweb.util.make_random_string(length)
while query(table).filter(column == string).first(): while query(table).filter(column == string).first():
string = aurweb.util.make_random_string(length) string = aurweb.util.make_random_string(length)
@ -37,8 +36,7 @@ def test_name() -> str:
import aurweb.config import aurweb.config
db = os.environ.get("PYTEST_CURRENT_TEST", db = os.environ.get("PYTEST_CURRENT_TEST", aurweb.config.get("database", "name"))
aurweb.config.get("database", "name"))
return db.split(":")[0] return db.split(":")[0]
@ -57,6 +55,7 @@ def name() -> str:
return dbname return dbname
import hashlib import hashlib
sha1 = hashlib.sha1(dbname.encode()).hexdigest() sha1 = hashlib.sha1(dbname.encode()).hexdigest()
return "db" + sha1 return "db" + sha1
@ -67,7 +66,7 @@ _sessions = dict()
def get_session(engine=None): def get_session(engine=None):
""" Return aurweb.db's global session. """ """Return aurweb.db's global session."""
dbname = name() dbname = name()
global _sessions global _sessions
@ -78,7 +77,8 @@ def get_session(engine=None):
engine = get_engine() engine = get_engine()
Session = scoped_session( Session = scoped_session(
sessionmaker(autocommit=True, autoflush=False, bind=engine)) sessionmaker(autocommit=True, autoflush=False, bind=engine)
)
_sessions[dbname] = Session() _sessions[dbname] = Session()
return _sessions.get(dbname) return _sessions.get(dbname)
@ -138,25 +138,26 @@ def delete(model) -> None:
def delete_all(iterable) -> None: def delete_all(iterable) -> None:
""" Delete each instance found in `iterable`. """ """Delete each instance found in `iterable`."""
import aurweb.util import aurweb.util
session_ = get_session() session_ = get_session()
aurweb.util.apply_all(iterable, session_.delete) aurweb.util.apply_all(iterable, session_.delete)
def rollback() -> None: def rollback() -> None:
""" Rollback the database session. """ """Rollback the database session."""
get_session().rollback() get_session().rollback()
def add(model): def add(model):
""" Add `model` to the database session. """ """Add `model` to the database session."""
get_session().add(model) get_session().add(model)
return model return model
def begin(): def begin():
""" Begin an SQLAlchemy SessionTransaction. """ """Begin an SQLAlchemy SessionTransaction."""
return get_session().begin() return get_session().begin()
@ -167,62 +168,62 @@ def get_sqlalchemy_url():
:return: sqlalchemy.engine.url.URL :return: sqlalchemy.engine.url.URL
""" """
import sqlalchemy import sqlalchemy
from sqlalchemy.engine.url import URL from sqlalchemy.engine.url import URL
import aurweb.config import aurweb.config
constructor = URL constructor = URL
parts = sqlalchemy.__version__.split('.') parts = sqlalchemy.__version__.split(".")
major = int(parts[0]) major = int(parts[0])
minor = int(parts[1]) minor = int(parts[1])
if major == 1 and minor >= 4: # pragma: no cover if major == 1 and minor >= 4: # pragma: no cover
constructor = URL.create constructor = URL.create
aur_db_backend = aurweb.config.get('database', 'backend') aur_db_backend = aurweb.config.get("database", "backend")
if aur_db_backend == 'mysql': if aur_db_backend == "mysql":
param_query = {} param_query = {}
port = aurweb.config.get_with_fallback("database", "port", None) port = aurweb.config.get_with_fallback("database", "port", None)
if not port: if not port:
param_query["unix_socket"] = aurweb.config.get( param_query["unix_socket"] = aurweb.config.get("database", "socket")
"database", "socket")
return constructor( return constructor(
DRIVERS.get(aur_db_backend), DRIVERS.get(aur_db_backend),
username=aurweb.config.get('database', 'user'), username=aurweb.config.get("database", "user"),
password=aurweb.config.get_with_fallback('database', 'password', password=aurweb.config.get_with_fallback(
fallback=None), "database", "password", fallback=None
host=aurweb.config.get('database', 'host'), ),
host=aurweb.config.get("database", "host"),
database=name(), database=name(),
port=port, port=port,
query=param_query query=param_query,
) )
elif aur_db_backend == 'sqlite': elif aur_db_backend == "sqlite":
return constructor( return constructor(
'sqlite', "sqlite",
database=aurweb.config.get('database', 'name'), database=aurweb.config.get("database", "name"),
) )
else: else:
raise ValueError('unsupported database backend') raise ValueError("unsupported database backend")
def sqlite_regexp(regex, item) -> bool: # pragma: no cover def sqlite_regexp(regex, item) -> bool: # pragma: no cover
""" Method which mimics SQL's REGEXP for SQLite. """ """Method which mimics SQL's REGEXP for SQLite."""
import re import re
return bool(re.search(regex, str(item))) return bool(re.search(regex, str(item)))
def setup_sqlite(engine) -> None: # pragma: no cover def setup_sqlite(engine) -> None: # pragma: no cover
""" Perform setup for an SQLite engine. """ """Perform setup for an SQLite engine."""
from sqlalchemy import event from sqlalchemy import event
@event.listens_for(engine, "connect") @event.listens_for(engine, "connect")
def do_begin(conn, record): def do_begin(conn, record):
import functools import functools
create_deterministic_function = functools.partial( create_deterministic_function = functools.partial(
conn.create_function, conn.create_function, deterministic=True
deterministic=True
) )
create_deterministic_function("REGEXP", 2, sqlite_regexp) create_deterministic_function("REGEXP", 2, sqlite_regexp)
@ -256,11 +257,9 @@ def get_engine(dbname: str = None, echo: bool = False):
if is_sqlite: # pragma: no cover if is_sqlite: # pragma: no cover
connect_args["check_same_thread"] = False connect_args["check_same_thread"] = False
kwargs = { kwargs = {"echo": echo, "connect_args": connect_args}
"echo": echo,
"connect_args": connect_args
}
from sqlalchemy import create_engine from sqlalchemy import create_engine
_engines[dbname] = create_engine(get_sqlalchemy_url(), **kwargs) _engines[dbname] = create_engine(get_sqlalchemy_url(), **kwargs)
if is_sqlite: # pragma: no cover if is_sqlite: # pragma: no cover
@ -281,7 +280,7 @@ def pop_engine(dbname: str) -> None:
def kill_engine() -> None: def kill_engine() -> None:
""" Close the current session and dispose of the engine. """ """Close the current session and dispose of the engine."""
dbname = name() dbname = name()
session = get_session() session = get_session()
@ -317,6 +316,7 @@ class ConnectionExecutor:
self._paramstyle = "format" self._paramstyle = "format"
elif backend == "sqlite": elif backend == "sqlite":
import sqlite3 import sqlite3
self._paramstyle = sqlite3.paramstyle self._paramstyle = sqlite3.paramstyle
def paramstyle(self): def paramstyle(self):
@ -325,12 +325,12 @@ class ConnectionExecutor:
def execute(self, query, params=()): # pragma: no cover def execute(self, query, params=()): # pragma: no cover
# TODO: SQLite support has been removed in FastAPI. It remains # TODO: SQLite support has been removed in FastAPI. It remains
# here to fund its support for PHP until it is removed. # here to fund its support for PHP until it is removed.
if self._paramstyle in ('format', 'pyformat'): if self._paramstyle in ("format", "pyformat"):
query = query.replace('%', '%%').replace('?', '%s') query = query.replace("%", "%%").replace("?", "%s")
elif self._paramstyle == 'qmark': elif self._paramstyle == "qmark":
pass pass
else: else:
raise ValueError('unsupported paramstyle') raise ValueError("unsupported paramstyle")
cur = self._conn.cursor() cur = self._conn.cursor()
cur.execute(query, params) cur.execute(query, params)
@ -350,32 +350,35 @@ class Connection:
def __init__(self): def __init__(self):
import aurweb.config import aurweb.config
aur_db_backend = aurweb.config.get('database', 'backend')
if aur_db_backend == 'mysql': aur_db_backend = aurweb.config.get("database", "backend")
if aur_db_backend == "mysql":
import MySQLdb import MySQLdb
aur_db_host = aurweb.config.get('database', 'host')
aur_db_host = aurweb.config.get("database", "host")
aur_db_name = name() aur_db_name = name()
aur_db_user = aurweb.config.get('database', 'user') aur_db_user = aurweb.config.get("database", "user")
aur_db_pass = aurweb.config.get_with_fallback( aur_db_pass = aurweb.config.get_with_fallback("database", "password", str())
'database', 'password', str()) aur_db_socket = aurweb.config.get("database", "socket")
aur_db_socket = aurweb.config.get('database', 'socket') self._conn = MySQLdb.connect(
self._conn = MySQLdb.connect(host=aur_db_host, host=aur_db_host,
user=aur_db_user, user=aur_db_user,
passwd=aur_db_pass, passwd=aur_db_pass,
db=aur_db_name, db=aur_db_name,
unix_socket=aur_db_socket) unix_socket=aur_db_socket,
elif aur_db_backend == 'sqlite': # pragma: no cover )
elif aur_db_backend == "sqlite": # pragma: no cover
# TODO: SQLite support has been removed in FastAPI. It remains # TODO: SQLite support has been removed in FastAPI. It remains
# here to fund its support for PHP until it is removed. # here to fund its support for PHP until it is removed.
import math import math
import sqlite3 import sqlite3
aur_db_name = aurweb.config.get('database', 'name') aur_db_name = aurweb.config.get("database", "name")
self._conn = sqlite3.connect(aur_db_name) self._conn = sqlite3.connect(aur_db_name)
self._conn.create_function("POWER", 2, math.pow) self._conn.create_function("POWER", 2, math.pow)
else: else:
raise ValueError('unsupported database backend') raise ValueError("unsupported database backend")
self._conn = ConnectionExecutor(self._conn, aur_db_backend) self._conn = ConnectionExecutor(self._conn, aur_db_backend)

View file

@ -17,8 +17,8 @@ RPC_SEARCH_BY = "name-desc"
def fallback_pp(per_page: int) -> int: def fallback_pp(per_page: int) -> int:
""" If `per_page` is a valid value in PP_WHITELIST, return it. """If `per_page` is a valid value in PP_WHITELIST, return it.
Otherwise, return defaults.PP. """ Otherwise, return defaults.PP."""
if per_page not in PP_WHITELIST: if per_page not in PP_WHITELIST:
return PP return PP
return per_page return per_page

View file

@ -1,5 +1,4 @@
import functools import functools
from typing import Any, Callable from typing import Any, Callable
import fastapi import fastapi
@ -19,61 +18,61 @@ class BannedException(AurwebException):
class PermissionDeniedException(AurwebException): class PermissionDeniedException(AurwebException):
def __init__(self, user): def __init__(self, user):
msg = 'permission denied: {:s}'.format(user) msg = "permission denied: {:s}".format(user)
super(PermissionDeniedException, self).__init__(msg) super(PermissionDeniedException, self).__init__(msg)
class BrokenUpdateHookException(AurwebException): class BrokenUpdateHookException(AurwebException):
def __init__(self, cmd): def __init__(self, cmd):
msg = 'broken update hook: {:s}'.format(cmd) msg = "broken update hook: {:s}".format(cmd)
super(BrokenUpdateHookException, self).__init__(msg) super(BrokenUpdateHookException, self).__init__(msg)
class InvalidUserException(AurwebException): class InvalidUserException(AurwebException):
def __init__(self, user): def __init__(self, user):
msg = 'unknown user: {:s}'.format(user) msg = "unknown user: {:s}".format(user)
super(InvalidUserException, self).__init__(msg) super(InvalidUserException, self).__init__(msg)
class InvalidPackageBaseException(AurwebException): class InvalidPackageBaseException(AurwebException):
def __init__(self, pkgbase): def __init__(self, pkgbase):
msg = 'package base not found: {:s}'.format(pkgbase) msg = "package base not found: {:s}".format(pkgbase)
super(InvalidPackageBaseException, self).__init__(msg) super(InvalidPackageBaseException, self).__init__(msg)
class InvalidRepositoryNameException(AurwebException): class InvalidRepositoryNameException(AurwebException):
def __init__(self, pkgbase): def __init__(self, pkgbase):
msg = 'invalid repository name: {:s}'.format(pkgbase) msg = "invalid repository name: {:s}".format(pkgbase)
super(InvalidRepositoryNameException, self).__init__(msg) super(InvalidRepositoryNameException, self).__init__(msg)
class PackageBaseExistsException(AurwebException): class PackageBaseExistsException(AurwebException):
def __init__(self, pkgbase): def __init__(self, pkgbase):
msg = 'package base already exists: {:s}'.format(pkgbase) msg = "package base already exists: {:s}".format(pkgbase)
super(PackageBaseExistsException, self).__init__(msg) super(PackageBaseExistsException, self).__init__(msg)
class InvalidReasonException(AurwebException): class InvalidReasonException(AurwebException):
def __init__(self, reason): def __init__(self, reason):
msg = 'invalid reason: {:s}'.format(reason) msg = "invalid reason: {:s}".format(reason)
super(InvalidReasonException, self).__init__(msg) super(InvalidReasonException, self).__init__(msg)
class InvalidCommentException(AurwebException): class InvalidCommentException(AurwebException):
def __init__(self, comment): def __init__(self, comment):
msg = 'comment is too short: {:s}'.format(comment) msg = "comment is too short: {:s}".format(comment)
super(InvalidCommentException, self).__init__(msg) super(InvalidCommentException, self).__init__(msg)
class AlreadyVotedException(AurwebException): class AlreadyVotedException(AurwebException):
def __init__(self, comment): def __init__(self, comment):
msg = 'already voted for package base: {:s}'.format(comment) msg = "already voted for package base: {:s}".format(comment)
super(AlreadyVotedException, self).__init__(msg) super(AlreadyVotedException, self).__init__(msg)
class NotVotedException(AurwebException): class NotVotedException(AurwebException):
def __init__(self, comment): def __init__(self, comment):
msg = 'missing vote for package base: {:s}'.format(comment) msg = "missing vote for package base: {:s}".format(comment)
super(NotVotedException, self).__init__(msg) super(NotVotedException, self).__init__(msg)
@ -109,4 +108,5 @@ def handle_form_exceptions(route: Callable) -> fastapi.Response:
async def wrapper(request: fastapi.Request, *args, **kwargs): async def wrapper(request: fastapi.Request, *args, **kwargs):
request.state.form_data = await request.form() request.state.form_data = await request.form()
return await route(request, *args, **kwargs) return await route(request, *args, **kwargs)
return wrapper return wrapper

View file

@ -1,6 +1,5 @@
import copy import copy
import math import math
from datetime import datetime from datetime import datetime
from typing import Any, Union from typing import Any, Union
from urllib.parse import quote_plus, urlencode from urllib.parse import quote_plus, urlencode
@ -8,19 +7,16 @@ from zoneinfo import ZoneInfo
import fastapi import fastapi
import paginate import paginate
from jinja2 import pass_context from jinja2 import pass_context
import aurweb.models import aurweb.models
from aurweb import config, l10n from aurweb import config, l10n
from aurweb.templates import register_filter, register_function from aurweb.templates import register_filter, register_function
@register_filter("pager_nav") @register_filter("pager_nav")
@pass_context @pass_context
def pager_nav(context: dict[str, Any], def pager_nav(context: dict[str, Any], page: int, total: int, prefix: str) -> str:
page: int, total: int, prefix: str) -> str:
page = int(page) # Make sure this is an int. page = int(page) # Make sure this is an int.
pp = context.get("PP", 50) pp = context.get("PP", 50)
@ -43,10 +39,9 @@ def pager_nav(context: dict[str, Any],
return f"{prefix}?{qs}" return f"{prefix}?{qs}"
# Use the paginate module to produce our linkage. # Use the paginate module to produce our linkage.
pager = paginate.Page([], page=page + 1, pager = paginate.Page(
items_per_page=pp, [], page=page + 1, items_per_page=pp, item_count=total, url_maker=create_url
item_count=total, )
url_maker=create_url)
return pager.pager( return pager.pager(
link_attr={"class": "page"}, link_attr={"class": "page"},
@ -56,7 +51,8 @@ def pager_nav(context: dict[str, Any],
symbol_first="« First", symbol_first="« First",
symbol_previous=" Previous", symbol_previous=" Previous",
symbol_next="Next ", symbol_next="Next ",
symbol_last="Last »") symbol_last="Last »",
)
@register_function("config_getint") @register_function("config_getint")
@ -72,16 +68,15 @@ def do_round(f: float) -> int:
@register_filter("tr") @register_filter("tr")
@pass_context @pass_context
def tr(context: dict[str, Any], value: str): def tr(context: dict[str, Any], value: str):
""" A translation filter; example: {{ "Hello" | tr("de") }}. """ """A translation filter; example: {{ "Hello" | tr("de") }}."""
_ = l10n.get_translator_for_request(context.get("request")) _ = l10n.get_translator_for_request(context.get("request"))
return _(value) return _(value)
@register_filter("tn") @register_filter("tn")
@pass_context @pass_context
def tn(context: dict[str, Any], count: int, def tn(context: dict[str, Any], count: int, singular: str, plural: str) -> str:
singular: str, plural: str) -> str: """A singular and plural translation filter.
""" A singular and plural translation filter.
Example: Example:
{{ some_integer | tn("singular %d", "plural %d") }} {{ some_integer | tn("singular %d", "plural %d") }}
@ -108,7 +103,7 @@ def as_timezone(dt: datetime, timezone: str):
@register_filter("extend_query") @register_filter("extend_query")
def extend_query(query: dict[str, Any], *additions) -> dict[str, Any]: def extend_query(query: dict[str, Any], *additions) -> dict[str, Any]:
""" Add additional key value pairs to query. """ """Add additional key value pairs to query."""
q = copy.copy(query) q = copy.copy(query)
for k, v in list(additions): for k, v in list(additions):
q[k] = v q[k] = v
@ -123,19 +118,19 @@ def to_qs(query: dict[str, Any]) -> str:
@register_filter("get_vote") @register_filter("get_vote")
def get_vote(voteinfo, request: fastapi.Request): def get_vote(voteinfo, request: fastapi.Request):
from aurweb.models import TUVote from aurweb.models import TUVote
return voteinfo.tu_votes.filter(TUVote.User == request.user).first() return voteinfo.tu_votes.filter(TUVote.User == request.user).first()
@register_filter("number_format") @register_filter("number_format")
def number_format(value: float, places: int): def number_format(value: float, places: int):
""" A converter function similar to PHP's number_format. """ """A converter function similar to PHP's number_format."""
return f"{value:.{places}f}" return f"{value:.{places}f}"
@register_filter("account_url") @register_filter("account_url")
@pass_context @pass_context
def account_url(context: dict[str, Any], def account_url(context: dict[str, Any], user: "aurweb.models.user.User") -> str:
user: "aurweb.models.user.User") -> str:
base = aurweb.config.get("options", "aur_location") base = aurweb.config.get("options", "aur_location")
return f"{base}/account/{user.Username}" return f"{base}/account/{user.Username}"
@ -152,8 +147,7 @@ def ceil(*args, **kwargs) -> int:
@register_function("date_strftime") @register_function("date_strftime")
@pass_context @pass_context
def date_strftime(context: dict[str, Any], dt: Union[int, datetime], fmt: str) \ def date_strftime(context: dict[str, Any], dt: Union[int, datetime], fmt: str) -> str:
-> str:
if isinstance(dt, int): if isinstance(dt, int):
dt = timestamp_to_datetime(dt) dt = timestamp_to_datetime(dt)
tz = context.get("timezone") tz = context.get("timezone")

View file

@ -9,12 +9,12 @@ import aurweb.db
def format_command(env_vars, command, ssh_opts, ssh_key): def format_command(env_vars, command, ssh_opts, ssh_key):
environment = '' environment = ""
for key, var in env_vars.items(): for key, var in env_vars.items():
environment += '{}={} '.format(key, shlex.quote(var)) environment += "{}={} ".format(key, shlex.quote(var))
command = shlex.quote(command) command = shlex.quote(command)
command = '{}{}'.format(environment, command) command = "{}{}".format(environment, command)
# The command is being substituted into an authorized_keys line below, # The command is being substituted into an authorized_keys line below,
# so we need to escape the double quotes. # so we need to escape the double quotes.
@ -24,10 +24,10 @@ def format_command(env_vars, command, ssh_opts, ssh_key):
def main(): def main():
valid_keytypes = aurweb.config.get('auth', 'valid-keytypes').split() valid_keytypes = aurweb.config.get("auth", "valid-keytypes").split()
username_regex = aurweb.config.get('auth', 'username-regex') username_regex = aurweb.config.get("auth", "username-regex")
git_serve_cmd = aurweb.config.get('auth', 'git-serve-cmd') git_serve_cmd = aurweb.config.get("auth", "git-serve-cmd")
ssh_opts = aurweb.config.get('auth', 'ssh-options') ssh_opts = aurweb.config.get("auth", "ssh-options")
keytype = sys.argv[1] keytype = sys.argv[1]
keytext = sys.argv[2] keytext = sys.argv[2]
@ -36,11 +36,13 @@ def main():
conn = aurweb.db.Connection() conn = aurweb.db.Connection()
cur = conn.execute("SELECT Users.Username, Users.AccountTypeID FROM Users " cur = conn.execute(
"INNER JOIN SSHPubKeys ON SSHPubKeys.UserID = Users.ID " "SELECT Users.Username, Users.AccountTypeID FROM Users "
"WHERE SSHPubKeys.PubKey = ? AND Users.Suspended = 0 " "INNER JOIN SSHPubKeys ON SSHPubKeys.UserID = Users.ID "
"AND NOT Users.Passwd = ''", "WHERE SSHPubKeys.PubKey = ? AND Users.Suspended = 0 "
(keytype + " " + keytext,)) "AND NOT Users.Passwd = ''",
(keytype + " " + keytext,),
)
row = cur.fetchone() row = cur.fetchone()
if not row or cur.fetchone(): if not row or cur.fetchone():
@ -51,13 +53,13 @@ def main():
exit(1) exit(1)
env_vars = { env_vars = {
'AUR_USER': user, "AUR_USER": user,
'AUR_PRIVILEGED': '1' if account_type > 1 else '0', "AUR_PRIVILEGED": "1" if account_type > 1 else "0",
} }
key = keytype + ' ' + keytext key = keytype + " " + keytext
print(format_command(env_vars, git_serve_cmd, ssh_opts, key)) print(format_command(env_vars, git_serve_cmd, ssh_opts, key))
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -11,16 +11,16 @@ import aurweb.config
import aurweb.db import aurweb.db
import aurweb.exceptions import aurweb.exceptions
notify_cmd = aurweb.config.get('notifications', 'notify-cmd') notify_cmd = aurweb.config.get("notifications", "notify-cmd")
repo_path = aurweb.config.get('serve', 'repo-path') repo_path = aurweb.config.get("serve", "repo-path")
repo_regex = aurweb.config.get('serve', 'repo-regex') repo_regex = aurweb.config.get("serve", "repo-regex")
git_shell_cmd = aurweb.config.get('serve', 'git-shell-cmd') git_shell_cmd = aurweb.config.get("serve", "git-shell-cmd")
git_update_cmd = aurweb.config.get('serve', 'git-update-cmd') git_update_cmd = aurweb.config.get("serve", "git-update-cmd")
ssh_cmdline = aurweb.config.get('serve', 'ssh-cmdline') ssh_cmdline = aurweb.config.get("serve", "ssh-cmdline")
enable_maintenance = aurweb.config.getboolean('options', 'enable-maintenance') enable_maintenance = aurweb.config.getboolean("options", "enable-maintenance")
maintenance_exc = aurweb.config.get('options', 'maintenance-exceptions').split() maintenance_exc = aurweb.config.get("options", "maintenance-exceptions").split()
def pkgbase_from_name(pkgbase): def pkgbase_from_name(pkgbase):
@ -43,10 +43,12 @@ def list_repos(user):
if userid == 0: if userid == 0:
raise aurweb.exceptions.InvalidUserException(user) raise aurweb.exceptions.InvalidUserException(user)
cur = conn.execute("SELECT Name, PackagerUID FROM PackageBases " + cur = conn.execute(
"WHERE MaintainerUID = ?", [userid]) "SELECT Name, PackagerUID FROM PackageBases " + "WHERE MaintainerUID = ?",
[userid],
)
for row in cur: for row in cur:
print((' ' if row[1] else '*') + row[0]) print((" " if row[1] else "*") + row[0])
conn.close() conn.close()
@ -64,15 +66,18 @@ def create_pkgbase(pkgbase, user):
raise aurweb.exceptions.InvalidUserException(user) raise aurweb.exceptions.InvalidUserException(user)
now = int(time.time()) now = int(time.time())
cur = conn.execute("INSERT INTO PackageBases (Name, SubmittedTS, " + cur = conn.execute(
"ModifiedTS, SubmitterUID, MaintainerUID, " + "INSERT INTO PackageBases (Name, SubmittedTS, "
"FlaggerComment) VALUES (?, ?, ?, ?, ?, '')", + "ModifiedTS, SubmitterUID, MaintainerUID, "
[pkgbase, now, now, userid, userid]) + "FlaggerComment) VALUES (?, ?, ?, ?, ?, '')",
[pkgbase, now, now, userid, userid],
)
pkgbase_id = cur.lastrowid pkgbase_id = cur.lastrowid
cur = conn.execute("INSERT INTO PackageNotifications " + cur = conn.execute(
"(PackageBaseID, UserID) VALUES (?, ?)", "INSERT INTO PackageNotifications " + "(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, userid]) [pkgbase_id, userid],
)
conn.commit() conn.commit()
conn.close() conn.close()
@ -85,8 +90,10 @@ def pkgbase_adopt(pkgbase, user, privileged):
conn = aurweb.db.Connection() conn = aurweb.db.Connection()
cur = conn.execute("SELECT ID FROM PackageBases WHERE ID = ? AND " + cur = conn.execute(
"MaintainerUID IS NULL", [pkgbase_id]) "SELECT ID FROM PackageBases WHERE ID = ? AND " + "MaintainerUID IS NULL",
[pkgbase_id],
)
if not privileged and not cur.fetchone(): if not privileged and not cur.fetchone():
raise aurweb.exceptions.PermissionDeniedException(user) raise aurweb.exceptions.PermissionDeniedException(user)
@ -95,19 +102,25 @@ def pkgbase_adopt(pkgbase, user, privileged):
if userid == 0: if userid == 0:
raise aurweb.exceptions.InvalidUserException(user) raise aurweb.exceptions.InvalidUserException(user)
cur = conn.execute("UPDATE PackageBases SET MaintainerUID = ? " + cur = conn.execute(
"WHERE ID = ?", [userid, pkgbase_id]) "UPDATE PackageBases SET MaintainerUID = ? " + "WHERE ID = ?",
[userid, pkgbase_id],
)
cur = conn.execute("SELECT COUNT(*) FROM PackageNotifications WHERE " + cur = conn.execute(
"PackageBaseID = ? AND UserID = ?", "SELECT COUNT(*) FROM PackageNotifications WHERE "
[pkgbase_id, userid]) + "PackageBaseID = ? AND UserID = ?",
[pkgbase_id, userid],
)
if cur.fetchone()[0] == 0: if cur.fetchone()[0] == 0:
cur = conn.execute("INSERT INTO PackageNotifications " + cur = conn.execute(
"(PackageBaseID, UserID) VALUES (?, ?)", "INSERT INTO PackageNotifications "
[pkgbase_id, userid]) + "(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, userid],
)
conn.commit() conn.commit()
subprocess.Popen((notify_cmd, 'adopt', str(userid), str(pkgbase_id))) subprocess.Popen((notify_cmd, "adopt", str(userid), str(pkgbase_id)))
conn.close() conn.close()
@ -115,13 +128,16 @@ def pkgbase_adopt(pkgbase, user, privileged):
def pkgbase_get_comaintainers(pkgbase): def pkgbase_get_comaintainers(pkgbase):
conn = aurweb.db.Connection() conn = aurweb.db.Connection()
cur = conn.execute("SELECT UserName FROM PackageComaintainers " + cur = conn.execute(
"INNER JOIN Users " + "SELECT UserName FROM PackageComaintainers "
"ON Users.ID = PackageComaintainers.UsersID " + + "INNER JOIN Users "
"INNER JOIN PackageBases " + + "ON Users.ID = PackageComaintainers.UsersID "
"ON PackageBases.ID = PackageComaintainers.PackageBaseID " + + "INNER JOIN PackageBases "
"WHERE PackageBases.Name = ? " + + "ON PackageBases.ID = PackageComaintainers.PackageBaseID "
"ORDER BY Priority ASC", [pkgbase]) + "WHERE PackageBases.Name = ? "
+ "ORDER BY Priority ASC",
[pkgbase],
)
return [row[0] for row in cur.fetchall()] return [row[0] for row in cur.fetchall()]
@ -140,8 +156,7 @@ def pkgbase_set_comaintainers(pkgbase, userlist, user, privileged):
uids_old = set() uids_old = set()
for olduser in userlist_old: for olduser in userlist_old:
cur = conn.execute("SELECT ID FROM Users WHERE Username = ?", cur = conn.execute("SELECT ID FROM Users WHERE Username = ?", [olduser])
[olduser])
userid = cur.fetchone()[0] userid = cur.fetchone()[0]
if userid == 0: if userid == 0:
raise aurweb.exceptions.InvalidUserException(user) raise aurweb.exceptions.InvalidUserException(user)
@ -149,8 +164,7 @@ def pkgbase_set_comaintainers(pkgbase, userlist, user, privileged):
uids_new = set() uids_new = set()
for newuser in userlist: for newuser in userlist:
cur = conn.execute("SELECT ID FROM Users WHERE Username = ?", cur = conn.execute("SELECT ID FROM Users WHERE Username = ?", [newuser])
[newuser])
userid = cur.fetchone()[0] userid = cur.fetchone()[0]
if userid == 0: if userid == 0:
raise aurweb.exceptions.InvalidUserException(user) raise aurweb.exceptions.InvalidUserException(user)
@ -162,24 +176,33 @@ def pkgbase_set_comaintainers(pkgbase, userlist, user, privileged):
i = 1 i = 1
for userid in uids_new: for userid in uids_new:
if userid in uids_add: if userid in uids_add:
cur = conn.execute("INSERT INTO PackageComaintainers " + cur = conn.execute(
"(PackageBaseID, UsersID, Priority) " + "INSERT INTO PackageComaintainers "
"VALUES (?, ?, ?)", [pkgbase_id, userid, i]) + "(PackageBaseID, UsersID, Priority) "
subprocess.Popen((notify_cmd, 'comaintainer-add', str(userid), + "VALUES (?, ?, ?)",
str(pkgbase_id))) [pkgbase_id, userid, i],
)
subprocess.Popen(
(notify_cmd, "comaintainer-add", str(userid), str(pkgbase_id))
)
else: else:
cur = conn.execute("UPDATE PackageComaintainers " + cur = conn.execute(
"SET Priority = ? " + "UPDATE PackageComaintainers "
"WHERE PackageBaseID = ? AND UsersID = ?", + "SET Priority = ? "
[i, pkgbase_id, userid]) + "WHERE PackageBaseID = ? AND UsersID = ?",
[i, pkgbase_id, userid],
)
i += 1 i += 1
for userid in uids_rem: for userid in uids_rem:
cur = conn.execute("DELETE FROM PackageComaintainers " + cur = conn.execute(
"WHERE PackageBaseID = ? AND UsersID = ?", "DELETE FROM PackageComaintainers "
[pkgbase_id, userid]) + "WHERE PackageBaseID = ? AND UsersID = ?",
subprocess.Popen((notify_cmd, 'comaintainer-remove', [pkgbase_id, userid],
str(userid), str(pkgbase_id))) )
subprocess.Popen(
(notify_cmd, "comaintainer-remove", str(userid), str(pkgbase_id))
)
conn.commit() conn.commit()
conn.close() conn.close()
@ -188,18 +211,21 @@ def pkgbase_set_comaintainers(pkgbase, userlist, user, privileged):
def pkgreq_by_pkgbase(pkgbase_id, reqtype): def pkgreq_by_pkgbase(pkgbase_id, reqtype):
conn = aurweb.db.Connection() conn = aurweb.db.Connection()
cur = conn.execute("SELECT PackageRequests.ID FROM PackageRequests " + cur = conn.execute(
"INNER JOIN RequestTypes ON " + "SELECT PackageRequests.ID FROM PackageRequests "
"RequestTypes.ID = PackageRequests.ReqTypeID " + + "INNER JOIN RequestTypes ON "
"WHERE PackageRequests.Status = 0 " + + "RequestTypes.ID = PackageRequests.ReqTypeID "
"AND PackageRequests.PackageBaseID = ? " + + "WHERE PackageRequests.Status = 0 "
"AND RequestTypes.Name = ?", [pkgbase_id, reqtype]) + "AND PackageRequests.PackageBaseID = ? "
+ "AND RequestTypes.Name = ?",
[pkgbase_id, reqtype],
)
return [row[0] for row in cur.fetchall()] return [row[0] for row in cur.fetchall()]
def pkgreq_close(reqid, user, reason, comments, autoclose=False): def pkgreq_close(reqid, user, reason, comments, autoclose=False):
statusmap = {'accepted': 2, 'rejected': 3} statusmap = {"accepted": 2, "rejected": 3}
if reason not in statusmap: if reason not in statusmap:
raise aurweb.exceptions.InvalidReasonException(reason) raise aurweb.exceptions.InvalidReasonException(reason)
status = statusmap[reason] status = statusmap[reason]
@ -215,16 +241,20 @@ def pkgreq_close(reqid, user, reason, comments, autoclose=False):
raise aurweb.exceptions.InvalidUserException(user) raise aurweb.exceptions.InvalidUserException(user)
now = int(time.time()) now = int(time.time())
conn.execute("UPDATE PackageRequests SET Status = ?, ClosedTS = ?, " + conn.execute(
"ClosedUID = ?, ClosureComment = ? " + "UPDATE PackageRequests SET Status = ?, ClosedTS = ?, "
"WHERE ID = ?", [status, now, userid, comments, reqid]) + "ClosedUID = ?, ClosureComment = ? "
+ "WHERE ID = ?",
[status, now, userid, comments, reqid],
)
conn.commit() conn.commit()
conn.close() conn.close()
if not userid: if not userid:
userid = 0 userid = 0
subprocess.Popen((notify_cmd, 'request-close', str(userid), str(reqid), subprocess.Popen(
reason)).wait() (notify_cmd, "request-close", str(userid), str(reqid), reason)
).wait()
def pkgbase_disown(pkgbase, user, privileged): def pkgbase_disown(pkgbase, user, privileged):
@ -239,9 +269,9 @@ def pkgbase_disown(pkgbase, user, privileged):
# TODO: Support disowning package bases via package request. # TODO: Support disowning package bases via package request.
# Scan through pending orphan requests and close them. # Scan through pending orphan requests and close them.
comment = 'The user {:s} disowned the package.'.format(user) comment = "The user {:s} disowned the package.".format(user)
for reqid in pkgreq_by_pkgbase(pkgbase_id, 'orphan'): for reqid in pkgreq_by_pkgbase(pkgbase_id, "orphan"):
pkgreq_close(reqid, user, 'accepted', comment, True) pkgreq_close(reqid, user, "accepted", comment, True)
comaintainers = [] comaintainers = []
new_maintainer_userid = None new_maintainer_userid = None
@ -254,14 +284,17 @@ def pkgbase_disown(pkgbase, user, privileged):
comaintainers = pkgbase_get_comaintainers(pkgbase) comaintainers = pkgbase_get_comaintainers(pkgbase)
if len(comaintainers) > 0: if len(comaintainers) > 0:
new_maintainer = comaintainers[0] new_maintainer = comaintainers[0]
cur = conn.execute("SELECT ID FROM Users WHERE Username = ?", cur = conn.execute(
[new_maintainer]) "SELECT ID FROM Users WHERE Username = ?", [new_maintainer]
)
new_maintainer_userid = cur.fetchone()[0] new_maintainer_userid = cur.fetchone()[0]
comaintainers.remove(new_maintainer) comaintainers.remove(new_maintainer)
pkgbase_set_comaintainers(pkgbase, comaintainers, user, privileged) pkgbase_set_comaintainers(pkgbase, comaintainers, user, privileged)
cur = conn.execute("UPDATE PackageBases SET MaintainerUID = ? " + cur = conn.execute(
"WHERE ID = ?", [new_maintainer_userid, pkgbase_id]) "UPDATE PackageBases SET MaintainerUID = ? " + "WHERE ID = ?",
[new_maintainer_userid, pkgbase_id],
)
conn.commit() conn.commit()
@ -270,7 +303,7 @@ def pkgbase_disown(pkgbase, user, privileged):
if userid == 0: if userid == 0:
raise aurweb.exceptions.InvalidUserException(user) raise aurweb.exceptions.InvalidUserException(user)
subprocess.Popen((notify_cmd, 'disown', str(userid), str(pkgbase_id))) subprocess.Popen((notify_cmd, "disown", str(userid), str(pkgbase_id)))
conn.close() conn.close()
@ -290,14 +323,16 @@ def pkgbase_flag(pkgbase, user, comment):
raise aurweb.exceptions.InvalidUserException(user) raise aurweb.exceptions.InvalidUserException(user)
now = int(time.time()) now = int(time.time())
conn.execute("UPDATE PackageBases SET " + conn.execute(
"OutOfDateTS = ?, FlaggerUID = ?, FlaggerComment = ? " + "UPDATE PackageBases SET "
"WHERE ID = ? AND OutOfDateTS IS NULL", + "OutOfDateTS = ?, FlaggerUID = ?, FlaggerComment = ? "
[now, userid, comment, pkgbase_id]) + "WHERE ID = ? AND OutOfDateTS IS NULL",
[now, userid, comment, pkgbase_id],
)
conn.commit() conn.commit()
subprocess.Popen((notify_cmd, 'flag', str(userid), str(pkgbase_id))) subprocess.Popen((notify_cmd, "flag", str(userid), str(pkgbase_id)))
def pkgbase_unflag(pkgbase, user): def pkgbase_unflag(pkgbase, user):
@ -313,12 +348,15 @@ def pkgbase_unflag(pkgbase, user):
raise aurweb.exceptions.InvalidUserException(user) raise aurweb.exceptions.InvalidUserException(user)
if user in pkgbase_get_comaintainers(pkgbase): if user in pkgbase_get_comaintainers(pkgbase):
conn.execute("UPDATE PackageBases SET OutOfDateTS = NULL " + conn.execute(
"WHERE ID = ?", [pkgbase_id]) "UPDATE PackageBases SET OutOfDateTS = NULL " + "WHERE ID = ?", [pkgbase_id]
)
else: else:
conn.execute("UPDATE PackageBases SET OutOfDateTS = NULL " + conn.execute(
"WHERE ID = ? AND (MaintainerUID = ? OR FlaggerUID = ?)", "UPDATE PackageBases SET OutOfDateTS = NULL "
[pkgbase_id, userid, userid]) + "WHERE ID = ? AND (MaintainerUID = ? OR FlaggerUID = ?)",
[pkgbase_id, userid, userid],
)
conn.commit() conn.commit()
@ -335,17 +373,24 @@ def pkgbase_vote(pkgbase, user):
if userid == 0: if userid == 0:
raise aurweb.exceptions.InvalidUserException(user) raise aurweb.exceptions.InvalidUserException(user)
cur = conn.execute("SELECT COUNT(*) FROM PackageVotes " + cur = conn.execute(
"WHERE UsersID = ? AND PackageBaseID = ?", "SELECT COUNT(*) FROM PackageVotes "
[userid, pkgbase_id]) + "WHERE UsersID = ? AND PackageBaseID = ?",
[userid, pkgbase_id],
)
if cur.fetchone()[0] > 0: if cur.fetchone()[0] > 0:
raise aurweb.exceptions.AlreadyVotedException(pkgbase) raise aurweb.exceptions.AlreadyVotedException(pkgbase)
now = int(time.time()) now = int(time.time())
conn.execute("INSERT INTO PackageVotes (UsersID, PackageBaseID, VoteTS) " + conn.execute(
"VALUES (?, ?, ?)", [userid, pkgbase_id, now]) "INSERT INTO PackageVotes (UsersID, PackageBaseID, VoteTS) "
conn.execute("UPDATE PackageBases SET NumVotes = NumVotes + 1 " + + "VALUES (?, ?, ?)",
"WHERE ID = ?", [pkgbase_id]) [userid, pkgbase_id, now],
)
conn.execute(
"UPDATE PackageBases SET NumVotes = NumVotes + 1 " + "WHERE ID = ?",
[pkgbase_id],
)
conn.commit() conn.commit()
@ -361,16 +406,22 @@ def pkgbase_unvote(pkgbase, user):
if userid == 0: if userid == 0:
raise aurweb.exceptions.InvalidUserException(user) raise aurweb.exceptions.InvalidUserException(user)
cur = conn.execute("SELECT COUNT(*) FROM PackageVotes " + cur = conn.execute(
"WHERE UsersID = ? AND PackageBaseID = ?", "SELECT COUNT(*) FROM PackageVotes "
[userid, pkgbase_id]) + "WHERE UsersID = ? AND PackageBaseID = ?",
[userid, pkgbase_id],
)
if cur.fetchone()[0] == 0: if cur.fetchone()[0] == 0:
raise aurweb.exceptions.NotVotedException(pkgbase) raise aurweb.exceptions.NotVotedException(pkgbase)
conn.execute("DELETE FROM PackageVotes WHERE UsersID = ? AND " + conn.execute(
"PackageBaseID = ?", [userid, pkgbase_id]) "DELETE FROM PackageVotes WHERE UsersID = ? AND " + "PackageBaseID = ?",
conn.execute("UPDATE PackageBases SET NumVotes = NumVotes - 1 " + [userid, pkgbase_id],
"WHERE ID = ?", [pkgbase_id]) )
conn.execute(
"UPDATE PackageBases SET NumVotes = NumVotes - 1 " + "WHERE ID = ?",
[pkgbase_id],
)
conn.commit() conn.commit()
@ -381,11 +432,12 @@ def pkgbase_set_keywords(pkgbase, keywords):
conn = aurweb.db.Connection() conn = aurweb.db.Connection()
conn.execute("DELETE FROM PackageKeywords WHERE PackageBaseID = ?", conn.execute("DELETE FROM PackageKeywords WHERE PackageBaseID = ?", [pkgbase_id])
[pkgbase_id])
for keyword in keywords: for keyword in keywords:
conn.execute("INSERT INTO PackageKeywords (PackageBaseID, Keyword) " + conn.execute(
"VALUES (?, ?)", [pkgbase_id, keyword]) "INSERT INTO PackageKeywords (PackageBaseID, Keyword) " + "VALUES (?, ?)",
[pkgbase_id, keyword],
)
conn.commit() conn.commit()
conn.close() conn.close()
@ -394,24 +446,30 @@ def pkgbase_set_keywords(pkgbase, keywords):
def pkgbase_has_write_access(pkgbase, user): def pkgbase_has_write_access(pkgbase, user):
conn = aurweb.db.Connection() conn = aurweb.db.Connection()
cur = conn.execute("SELECT COUNT(*) FROM PackageBases " + cur = conn.execute(
"LEFT JOIN PackageComaintainers " + "SELECT COUNT(*) FROM PackageBases "
"ON PackageComaintainers.PackageBaseID = PackageBases.ID " + + "LEFT JOIN PackageComaintainers "
"INNER JOIN Users " + + "ON PackageComaintainers.PackageBaseID = PackageBases.ID "
"ON Users.ID = PackageBases.MaintainerUID " + + "INNER JOIN Users "
"OR PackageBases.MaintainerUID IS NULL " + + "ON Users.ID = PackageBases.MaintainerUID "
"OR Users.ID = PackageComaintainers.UsersID " + + "OR PackageBases.MaintainerUID IS NULL "
"WHERE Name = ? AND Username = ?", [pkgbase, user]) + "OR Users.ID = PackageComaintainers.UsersID "
+ "WHERE Name = ? AND Username = ?",
[pkgbase, user],
)
return cur.fetchone()[0] > 0 return cur.fetchone()[0] > 0
def pkgbase_has_full_access(pkgbase, user): def pkgbase_has_full_access(pkgbase, user):
conn = aurweb.db.Connection() conn = aurweb.db.Connection()
cur = conn.execute("SELECT COUNT(*) FROM PackageBases " + cur = conn.execute(
"INNER JOIN Users " + "SELECT COUNT(*) FROM PackageBases "
"ON Users.ID = PackageBases.MaintainerUID " + + "INNER JOIN Users "
"WHERE Name = ? AND Username = ?", [pkgbase, user]) + "ON Users.ID = PackageBases.MaintainerUID "
+ "WHERE Name = ? AND Username = ?",
[pkgbase, user],
)
return cur.fetchone()[0] > 0 return cur.fetchone()[0] > 0
@ -419,9 +477,11 @@ def log_ssh_login(user, remote_addr):
conn = aurweb.db.Connection() conn = aurweb.db.Connection()
now = int(time.time()) now = int(time.time())
conn.execute("UPDATE Users SET LastSSHLogin = ?, " + conn.execute(
"LastSSHLoginIPAddress = ? WHERE Username = ?", "UPDATE Users SET LastSSHLogin = ?, "
[now, remote_addr, user]) + "LastSSHLoginIPAddress = ? WHERE Username = ?",
[now, remote_addr, user],
)
conn.commit() conn.commit()
conn.close() conn.close()
@ -430,8 +490,7 @@ def log_ssh_login(user, remote_addr):
def bans_match(remote_addr): def bans_match(remote_addr):
conn = aurweb.db.Connection() conn = aurweb.db.Connection()
cur = conn.execute("SELECT COUNT(*) FROM Bans WHERE IPAddress = ?", cur = conn.execute("SELECT COUNT(*) FROM Bans WHERE IPAddress = ?", [remote_addr])
[remote_addr])
return cur.fetchone()[0] > 0 return cur.fetchone()[0] > 0
@ -458,13 +517,13 @@ def usage(cmds):
def checkarg_atleast(cmdargv, *argdesc): def checkarg_atleast(cmdargv, *argdesc):
if len(cmdargv) - 1 < len(argdesc): if len(cmdargv) - 1 < len(argdesc):
msg = 'missing {:s}'.format(argdesc[len(cmdargv) - 1]) msg = "missing {:s}".format(argdesc[len(cmdargv) - 1])
raise aurweb.exceptions.InvalidArgumentsException(msg) raise aurweb.exceptions.InvalidArgumentsException(msg)
def checkarg_atmost(cmdargv, *argdesc): def checkarg_atmost(cmdargv, *argdesc):
if len(cmdargv) - 1 > len(argdesc): if len(cmdargv) - 1 > len(argdesc):
raise aurweb.exceptions.InvalidArgumentsException('too many arguments') raise aurweb.exceptions.InvalidArgumentsException("too many arguments")
def checkarg(cmdargv, *argdesc): def checkarg(cmdargv, *argdesc):
@ -480,23 +539,23 @@ def serve(action, cmdargv, user, privileged, remote_addr): # noqa: C901
raise aurweb.exceptions.BannedException raise aurweb.exceptions.BannedException
log_ssh_login(user, remote_addr) log_ssh_login(user, remote_addr)
if action == 'git' and cmdargv[1] in ('upload-pack', 'receive-pack'): if action == "git" and cmdargv[1] in ("upload-pack", "receive-pack"):
action = action + '-' + cmdargv[1] action = action + "-" + cmdargv[1]
del cmdargv[1] del cmdargv[1]
if action == 'git-upload-pack' or action == 'git-receive-pack': if action == "git-upload-pack" or action == "git-receive-pack":
checkarg(cmdargv, 'path') checkarg(cmdargv, "path")
path = cmdargv[1].rstrip('/') path = cmdargv[1].rstrip("/")
if not path.startswith('/'): if not path.startswith("/"):
path = '/' + path path = "/" + path
if not path.endswith('.git'): if not path.endswith(".git"):
path = path + '.git' path = path + ".git"
pkgbase = path[1:-4] pkgbase = path[1:-4]
if not re.match(repo_regex, pkgbase): if not re.match(repo_regex, pkgbase):
raise aurweb.exceptions.InvalidRepositoryNameException(pkgbase) raise aurweb.exceptions.InvalidRepositoryNameException(pkgbase)
if action == 'git-receive-pack' and pkgbase_exists(pkgbase): if action == "git-receive-pack" and pkgbase_exists(pkgbase):
if not privileged and not pkgbase_has_write_access(pkgbase, user): if not privileged and not pkgbase_has_write_access(pkgbase, user):
raise aurweb.exceptions.PermissionDeniedException(user) raise aurweb.exceptions.PermissionDeniedException(user)
@ -507,65 +566,67 @@ def serve(action, cmdargv, user, privileged, remote_addr): # noqa: C901
os.environ["AUR_PKGBASE"] = pkgbase os.environ["AUR_PKGBASE"] = pkgbase
os.environ["GIT_NAMESPACE"] = pkgbase os.environ["GIT_NAMESPACE"] = pkgbase
cmd = action + " '" + repo_path + "'" cmd = action + " '" + repo_path + "'"
os.execl(git_shell_cmd, git_shell_cmd, '-c', cmd) os.execl(git_shell_cmd, git_shell_cmd, "-c", cmd)
elif action == 'set-keywords': elif action == "set-keywords":
checkarg_atleast(cmdargv, 'repository name') checkarg_atleast(cmdargv, "repository name")
pkgbase_set_keywords(cmdargv[1], cmdargv[2:]) pkgbase_set_keywords(cmdargv[1], cmdargv[2:])
elif action == 'list-repos': elif action == "list-repos":
checkarg(cmdargv) checkarg(cmdargv)
list_repos(user) list_repos(user)
elif action == 'setup-repo': elif action == "setup-repo":
checkarg(cmdargv, 'repository name') checkarg(cmdargv, "repository name")
warn('{:s} is deprecated. ' warn(
'Use `git push` to create new repositories.'.format(action)) "{:s} is deprecated. "
"Use `git push` to create new repositories.".format(action)
)
create_pkgbase(cmdargv[1], user) create_pkgbase(cmdargv[1], user)
elif action == 'restore': elif action == "restore":
checkarg(cmdargv, 'repository name') checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1] pkgbase = cmdargv[1]
create_pkgbase(pkgbase, user) create_pkgbase(pkgbase, user)
os.environ["AUR_USER"] = user os.environ["AUR_USER"] = user
os.environ["AUR_PKGBASE"] = pkgbase os.environ["AUR_PKGBASE"] = pkgbase
os.execl(git_update_cmd, git_update_cmd, 'restore') os.execl(git_update_cmd, git_update_cmd, "restore")
elif action == 'adopt': elif action == "adopt":
checkarg(cmdargv, 'repository name') checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1] pkgbase = cmdargv[1]
pkgbase_adopt(pkgbase, user, privileged) pkgbase_adopt(pkgbase, user, privileged)
elif action == 'disown': elif action == "disown":
checkarg(cmdargv, 'repository name') checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1] pkgbase = cmdargv[1]
pkgbase_disown(pkgbase, user, privileged) pkgbase_disown(pkgbase, user, privileged)
elif action == 'flag': elif action == "flag":
checkarg(cmdargv, 'repository name', 'comment') checkarg(cmdargv, "repository name", "comment")
pkgbase = cmdargv[1] pkgbase = cmdargv[1]
comment = cmdargv[2] comment = cmdargv[2]
pkgbase_flag(pkgbase, user, comment) pkgbase_flag(pkgbase, user, comment)
elif action == 'unflag': elif action == "unflag":
checkarg(cmdargv, 'repository name') checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1] pkgbase = cmdargv[1]
pkgbase_unflag(pkgbase, user) pkgbase_unflag(pkgbase, user)
elif action == 'vote': elif action == "vote":
checkarg(cmdargv, 'repository name') checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1] pkgbase = cmdargv[1]
pkgbase_vote(pkgbase, user) pkgbase_vote(pkgbase, user)
elif action == 'unvote': elif action == "unvote":
checkarg(cmdargv, 'repository name') checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1] pkgbase = cmdargv[1]
pkgbase_unvote(pkgbase, user) pkgbase_unvote(pkgbase, user)
elif action == 'set-comaintainers': elif action == "set-comaintainers":
checkarg_atleast(cmdargv, 'repository name') checkarg_atleast(cmdargv, "repository name")
pkgbase = cmdargv[1] pkgbase = cmdargv[1]
userlist = cmdargv[2:] userlist = cmdargv[2:]
pkgbase_set_comaintainers(pkgbase, userlist, user, privileged) pkgbase_set_comaintainers(pkgbase, userlist, user, privileged)
elif action == 'help': elif action == "help":
cmds = { cmds = {
"adopt <name>": "Adopt a package base.", "adopt <name>": "Adopt a package base.",
"disown <name>": "Disown a package base.", "disown <name>": "Disown a package base.",
@ -584,21 +645,21 @@ def serve(action, cmdargv, user, privileged, remote_addr): # noqa: C901
} }
usage(cmds) usage(cmds)
else: else:
msg = 'invalid command: {:s}'.format(action) msg = "invalid command: {:s}".format(action)
raise aurweb.exceptions.InvalidArgumentsException(msg) raise aurweb.exceptions.InvalidArgumentsException(msg)
def main(): def main():
user = os.environ.get('AUR_USER') user = os.environ.get("AUR_USER")
privileged = (os.environ.get('AUR_PRIVILEGED', '0') == '1') privileged = os.environ.get("AUR_PRIVILEGED", "0") == "1"
ssh_cmd = os.environ.get('SSH_ORIGINAL_COMMAND') ssh_cmd = os.environ.get("SSH_ORIGINAL_COMMAND")
ssh_client = os.environ.get('SSH_CLIENT') ssh_client = os.environ.get("SSH_CLIENT")
if not ssh_cmd: if not ssh_cmd:
die_with_help("Interactive shell is disabled.") die_with_help("Interactive shell is disabled.")
cmdargv = shlex.split(ssh_cmd) cmdargv = shlex.split(ssh_cmd)
action = cmdargv[0] action = cmdargv[0]
remote_addr = ssh_client.split(' ')[0] if ssh_client else None remote_addr = ssh_client.split(" ")[0] if ssh_client else None
try: try:
serve(action, cmdargv, user, privileged, remote_addr) serve(action, cmdargv, user, privileged, remote_addr)
@ -607,10 +668,10 @@ def main():
except aurweb.exceptions.BannedException: except aurweb.exceptions.BannedException:
die("The SSH interface is disabled for your IP address.") die("The SSH interface is disabled for your IP address.")
except aurweb.exceptions.InvalidArgumentsException as e: except aurweb.exceptions.InvalidArgumentsException as e:
die_with_help('{:s}: {}'.format(action, e)) die_with_help("{:s}: {}".format(action, e))
except aurweb.exceptions.AurwebException as e: except aurweb.exceptions.AurwebException as e:
die('{:s}: {}'.format(action, e)) die("{:s}: {}".format(action, e))
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -13,23 +13,23 @@ import srcinfo.utils
import aurweb.config import aurweb.config
import aurweb.db import aurweb.db
notify_cmd = aurweb.config.get('notifications', 'notify-cmd') notify_cmd = aurweb.config.get("notifications", "notify-cmd")
repo_path = aurweb.config.get('serve', 'repo-path') repo_path = aurweb.config.get("serve", "repo-path")
repo_regex = aurweb.config.get('serve', 'repo-regex') repo_regex = aurweb.config.get("serve", "repo-regex")
max_blob_size = aurweb.config.getint('update', 'max-blob-size') max_blob_size = aurweb.config.getint("update", "max-blob-size")
def size_humanize(num): def size_humanize(num):
for unit in ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB']: for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB"]:
if abs(num) < 2048.0: if abs(num) < 2048.0:
if isinstance(num, int): if isinstance(num, int):
return "{}{}".format(num, unit) return "{}{}".format(num, unit)
else: else:
return "{:.2f}{}".format(num, unit) return "{:.2f}{}".format(num, unit)
num /= 1024.0 num /= 1024.0
return "{:.2f}{}".format(num, 'YiB') return "{:.2f}{}".format(num, "YiB")
def extract_arch_fields(pkginfo, field): def extract_arch_fields(pkginfo, field):
@ -39,18 +39,18 @@ def extract_arch_fields(pkginfo, field):
for val in pkginfo[field]: for val in pkginfo[field]:
values.append({"value": val, "arch": None}) values.append({"value": val, "arch": None})
for arch in pkginfo['arch']: for arch in pkginfo["arch"]:
if field + '_' + arch in pkginfo: if field + "_" + arch in pkginfo:
for val in pkginfo[field + '_' + arch]: for val in pkginfo[field + "_" + arch]:
values.append({"value": val, "arch": arch}) values.append({"value": val, "arch": arch})
return values return values
def parse_dep(depstring): def parse_dep(depstring):
dep, _, desc = depstring.partition(': ') dep, _, desc = depstring.partition(": ")
depname = re.sub(r'(<|=|>).*', '', dep) depname = re.sub(r"(<|=|>).*", "", dep)
depcond = dep[len(depname):] depcond = dep[len(depname) :]
return (depname, desc, depcond) return (depname, desc, depcond)
@ -60,15 +60,18 @@ def create_pkgbase(conn, pkgbase, user):
userid = cur.fetchone()[0] userid = cur.fetchone()[0]
now = int(time.time()) now = int(time.time())
cur = conn.execute("INSERT INTO PackageBases (Name, SubmittedTS, " + cur = conn.execute(
"ModifiedTS, SubmitterUID, MaintainerUID, " + "INSERT INTO PackageBases (Name, SubmittedTS, "
"FlaggerComment) VALUES (?, ?, ?, ?, ?, '')", + "ModifiedTS, SubmitterUID, MaintainerUID, "
[pkgbase, now, now, userid, userid]) + "FlaggerComment) VALUES (?, ?, ?, ?, ?, '')",
[pkgbase, now, now, userid, userid],
)
pkgbase_id = cur.lastrowid pkgbase_id = cur.lastrowid
cur = conn.execute("INSERT INTO PackageNotifications " + cur = conn.execute(
"(PackageBaseID, UserID) VALUES (?, ?)", "INSERT INTO PackageNotifications " + "(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, userid]) [pkgbase_id, userid],
)
conn.commit() conn.commit()
@ -77,9 +80,10 @@ def create_pkgbase(conn, pkgbase, user):
def save_metadata(metadata, conn, user): # noqa: C901 def save_metadata(metadata, conn, user): # noqa: C901
# Obtain package base ID and previous maintainer. # Obtain package base ID and previous maintainer.
pkgbase = metadata['pkgbase'] pkgbase = metadata["pkgbase"]
cur = conn.execute("SELECT ID, MaintainerUID FROM PackageBases " cur = conn.execute(
"WHERE Name = ?", [pkgbase]) "SELECT ID, MaintainerUID FROM PackageBases " "WHERE Name = ?", [pkgbase]
)
(pkgbase_id, maintainer_uid) = cur.fetchone() (pkgbase_id, maintainer_uid) = cur.fetchone()
was_orphan = not maintainer_uid was_orphan = not maintainer_uid
@ -89,119 +93,142 @@ def save_metadata(metadata, conn, user): # noqa: C901
# Update package base details and delete current packages. # Update package base details and delete current packages.
now = int(time.time()) now = int(time.time())
conn.execute("UPDATE PackageBases SET ModifiedTS = ?, " + conn.execute(
"PackagerUID = ?, OutOfDateTS = NULL WHERE ID = ?", "UPDATE PackageBases SET ModifiedTS = ?, "
[now, user_id, pkgbase_id]) + "PackagerUID = ?, OutOfDateTS = NULL WHERE ID = ?",
conn.execute("UPDATE PackageBases SET MaintainerUID = ? " + [now, user_id, pkgbase_id],
"WHERE ID = ? AND MaintainerUID IS NULL", )
[user_id, pkgbase_id]) conn.execute(
for table in ('Sources', 'Depends', 'Relations', 'Licenses', 'Groups'): "UPDATE PackageBases SET MaintainerUID = ? "
conn.execute("DELETE FROM Package" + table + " WHERE EXISTS (" + + "WHERE ID = ? AND MaintainerUID IS NULL",
"SELECT * FROM Packages " + [user_id, pkgbase_id],
"WHERE Packages.PackageBaseID = ? AND " + )
"Package" + table + ".PackageID = Packages.ID)", for table in ("Sources", "Depends", "Relations", "Licenses", "Groups"):
[pkgbase_id]) conn.execute(
"DELETE FROM Package"
+ table
+ " WHERE EXISTS ("
+ "SELECT * FROM Packages "
+ "WHERE Packages.PackageBaseID = ? AND "
+ "Package"
+ table
+ ".PackageID = Packages.ID)",
[pkgbase_id],
)
conn.execute("DELETE FROM Packages WHERE PackageBaseID = ?", [pkgbase_id]) conn.execute("DELETE FROM Packages WHERE PackageBaseID = ?", [pkgbase_id])
for pkgname in srcinfo.utils.get_package_names(metadata): for pkgname in srcinfo.utils.get_package_names(metadata):
pkginfo = srcinfo.utils.get_merged_package(pkgname, metadata) pkginfo = srcinfo.utils.get_merged_package(pkgname, metadata)
if 'epoch' in pkginfo and int(pkginfo['epoch']) > 0: if "epoch" in pkginfo and int(pkginfo["epoch"]) > 0:
ver = '{:d}:{:s}-{:s}'.format(int(pkginfo['epoch']), ver = "{:d}:{:s}-{:s}".format(
pkginfo['pkgver'], int(pkginfo["epoch"]), pkginfo["pkgver"], pkginfo["pkgrel"]
pkginfo['pkgrel']) )
else: else:
ver = '{:s}-{:s}'.format(pkginfo['pkgver'], pkginfo['pkgrel']) ver = "{:s}-{:s}".format(pkginfo["pkgver"], pkginfo["pkgrel"])
for field in ('pkgdesc', 'url'): for field in ("pkgdesc", "url"):
if field not in pkginfo: if field not in pkginfo:
pkginfo[field] = None pkginfo[field] = None
# Create a new package. # Create a new package.
cur = conn.execute("INSERT INTO Packages (PackageBaseID, Name, " + cur = conn.execute(
"Version, Description, URL) " + "INSERT INTO Packages (PackageBaseID, Name, "
"VALUES (?, ?, ?, ?, ?)", + "Version, Description, URL) "
[pkgbase_id, pkginfo['pkgname'], ver, + "VALUES (?, ?, ?, ?, ?)",
pkginfo['pkgdesc'], pkginfo['url']]) [pkgbase_id, pkginfo["pkgname"], ver, pkginfo["pkgdesc"], pkginfo["url"]],
)
conn.commit() conn.commit()
pkgid = cur.lastrowid pkgid = cur.lastrowid
# Add package sources. # Add package sources.
for source_info in extract_arch_fields(pkginfo, 'source'): for source_info in extract_arch_fields(pkginfo, "source"):
conn.execute("INSERT INTO PackageSources (PackageID, Source, " + conn.execute(
"SourceArch) VALUES (?, ?, ?)", "INSERT INTO PackageSources (PackageID, Source, "
[pkgid, source_info['value'], source_info['arch']]) + "SourceArch) VALUES (?, ?, ?)",
[pkgid, source_info["value"], source_info["arch"]],
)
# Add package dependencies. # Add package dependencies.
for deptype in ('depends', 'makedepends', for deptype in ("depends", "makedepends", "checkdepends", "optdepends"):
'checkdepends', 'optdepends'): cur = conn.execute(
cur = conn.execute("SELECT ID FROM DependencyTypes WHERE Name = ?", "SELECT ID FROM DependencyTypes WHERE Name = ?", [deptype]
[deptype]) )
deptypeid = cur.fetchone()[0] deptypeid = cur.fetchone()[0]
for dep_info in extract_arch_fields(pkginfo, deptype): for dep_info in extract_arch_fields(pkginfo, deptype):
depname, depdesc, depcond = parse_dep(dep_info['value']) depname, depdesc, depcond = parse_dep(dep_info["value"])
deparch = dep_info['arch'] deparch = dep_info["arch"]
conn.execute("INSERT INTO PackageDepends (PackageID, " + conn.execute(
"DepTypeID, DepName, DepDesc, DepCondition, " + "INSERT INTO PackageDepends (PackageID, "
"DepArch) VALUES (?, ?, ?, ?, ?, ?)", + "DepTypeID, DepName, DepDesc, DepCondition, "
[pkgid, deptypeid, depname, depdesc, depcond, + "DepArch) VALUES (?, ?, ?, ?, ?, ?)",
deparch]) [pkgid, deptypeid, depname, depdesc, depcond, deparch],
)
# Add package relations (conflicts, provides, replaces). # Add package relations (conflicts, provides, replaces).
for reltype in ('conflicts', 'provides', 'replaces'): for reltype in ("conflicts", "provides", "replaces"):
cur = conn.execute("SELECT ID FROM RelationTypes WHERE Name = ?", cur = conn.execute("SELECT ID FROM RelationTypes WHERE Name = ?", [reltype])
[reltype])
reltypeid = cur.fetchone()[0] reltypeid = cur.fetchone()[0]
for rel_info in extract_arch_fields(pkginfo, reltype): for rel_info in extract_arch_fields(pkginfo, reltype):
relname, _, relcond = parse_dep(rel_info['value']) relname, _, relcond = parse_dep(rel_info["value"])
relarch = rel_info['arch'] relarch = rel_info["arch"]
conn.execute("INSERT INTO PackageRelations (PackageID, " + conn.execute(
"RelTypeID, RelName, RelCondition, RelArch) " + "INSERT INTO PackageRelations (PackageID, "
"VALUES (?, ?, ?, ?, ?)", + "RelTypeID, RelName, RelCondition, RelArch) "
[pkgid, reltypeid, relname, relcond, relarch]) + "VALUES (?, ?, ?, ?, ?)",
[pkgid, reltypeid, relname, relcond, relarch],
)
# Add package licenses. # Add package licenses.
if 'license' in pkginfo: if "license" in pkginfo:
for license in pkginfo['license']: for license in pkginfo["license"]:
cur = conn.execute("SELECT ID FROM Licenses WHERE Name = ?", cur = conn.execute("SELECT ID FROM Licenses WHERE Name = ?", [license])
[license])
row = cur.fetchone() row = cur.fetchone()
if row: if row:
licenseid = row[0] licenseid = row[0]
else: else:
cur = conn.execute("INSERT INTO Licenses (Name) " + cur = conn.execute(
"VALUES (?)", [license]) "INSERT INTO Licenses (Name) " + "VALUES (?)", [license]
)
conn.commit() conn.commit()
licenseid = cur.lastrowid licenseid = cur.lastrowid
conn.execute("INSERT INTO PackageLicenses (PackageID, " + conn.execute(
"LicenseID) VALUES (?, ?)", "INSERT INTO PackageLicenses (PackageID, "
[pkgid, licenseid]) + "LicenseID) VALUES (?, ?)",
[pkgid, licenseid],
)
# Add package groups. # Add package groups.
if 'groups' in pkginfo: if "groups" in pkginfo:
for group in pkginfo['groups']: for group in pkginfo["groups"]:
cur = conn.execute("SELECT ID FROM `Groups` WHERE Name = ?", cur = conn.execute("SELECT ID FROM `Groups` WHERE Name = ?", [group])
[group])
row = cur.fetchone() row = cur.fetchone()
if row: if row:
groupid = row[0] groupid = row[0]
else: else:
cur = conn.execute("INSERT INTO `Groups` (Name) VALUES (?)", cur = conn.execute(
[group]) "INSERT INTO `Groups` (Name) VALUES (?)", [group]
)
conn.commit() conn.commit()
groupid = cur.lastrowid groupid = cur.lastrowid
conn.execute("INSERT INTO PackageGroups (PackageID, " conn.execute(
"GroupID) VALUES (?, ?)", [pkgid, groupid]) "INSERT INTO PackageGroups (PackageID, " "GroupID) VALUES (?, ?)",
[pkgid, groupid],
)
# Add user to notification list on adoption. # Add user to notification list on adoption.
if was_orphan: if was_orphan:
cur = conn.execute("SELECT COUNT(*) FROM PackageNotifications WHERE " + cur = conn.execute(
"PackageBaseID = ? AND UserID = ?", "SELECT COUNT(*) FROM PackageNotifications WHERE "
[pkgbase_id, user_id]) + "PackageBaseID = ? AND UserID = ?",
[pkgbase_id, user_id],
)
if cur.fetchone()[0] == 0: if cur.fetchone()[0] == 0:
conn.execute("INSERT INTO PackageNotifications " + conn.execute(
"(PackageBaseID, UserID) VALUES (?, ?)", "INSERT INTO PackageNotifications "
[pkgbase_id, user_id]) + "(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, user_id],
)
conn.commit() conn.commit()
@ -212,7 +239,7 @@ def update_notify(conn, user, pkgbase_id):
user_id = int(cur.fetchone()[0]) user_id = int(cur.fetchone()[0])
# Execute the notification script. # Execute the notification script.
subprocess.Popen((notify_cmd, 'update', str(user_id), str(pkgbase_id))) subprocess.Popen((notify_cmd, "update", str(user_id), str(pkgbase_id)))
def die(msg): def die(msg):
@ -225,8 +252,7 @@ def warn(msg):
def die_commit(msg, commit): def die_commit(msg, commit):
sys.stderr.write("error: The following error " + sys.stderr.write("error: The following error " + "occurred when parsing commit\n")
"occurred when parsing commit\n")
sys.stderr.write("error: {:s}:\n".format(commit)) sys.stderr.write("error: {:s}:\n".format(commit))
sys.stderr.write("error: {:s}\n".format(msg)) sys.stderr.write("error: {:s}\n".format(msg))
exit(1) exit(1)
@ -237,16 +263,15 @@ def main(): # noqa: C901
user = os.environ.get("AUR_USER") user = os.environ.get("AUR_USER")
pkgbase = os.environ.get("AUR_PKGBASE") pkgbase = os.environ.get("AUR_PKGBASE")
privileged = (os.environ.get("AUR_PRIVILEGED", '0') == '1') privileged = os.environ.get("AUR_PRIVILEGED", "0") == "1"
allow_overwrite = (os.environ.get("AUR_OVERWRITE", '0') == '1') and privileged allow_overwrite = (os.environ.get("AUR_OVERWRITE", "0") == "1") and privileged
warn_or_die = warn if privileged else die warn_or_die = warn if privileged else die
if len(sys.argv) == 2 and sys.argv[1] == "restore": if len(sys.argv) == 2 and sys.argv[1] == "restore":
if 'refs/heads/' + pkgbase not in repo.listall_references(): if "refs/heads/" + pkgbase not in repo.listall_references():
die('{:s}: repository not found: {:s}'.format(sys.argv[1], die("{:s}: repository not found: {:s}".format(sys.argv[1], pkgbase))
pkgbase))
refname = "refs/heads/master" refname = "refs/heads/master"
branchref = 'refs/heads/' + pkgbase branchref = "refs/heads/" + pkgbase
sha1_old = sha1_new = repo.lookup_reference(branchref).target sha1_old = sha1_new = repo.lookup_reference(branchref).target
elif len(sys.argv) == 4: elif len(sys.argv) == 4:
refname, sha1_old, sha1_new = sys.argv[1:4] refname, sha1_old, sha1_new = sys.argv[1:4]
@ -272,7 +297,7 @@ def main(): # noqa: C901
# Validate all new commits. # Validate all new commits.
for commit in walker: for commit in walker:
for fname in ('.SRCINFO', 'PKGBUILD'): for fname in (".SRCINFO", "PKGBUILD"):
if fname not in commit.tree: if fname not in commit.tree:
die_commit("missing {:s}".format(fname), str(commit.id)) die_commit("missing {:s}".format(fname), str(commit.id))
@ -280,99 +305,115 @@ def main(): # noqa: C901
blob = repo[treeobj.id] blob = repo[treeobj.id]
if isinstance(blob, pygit2.Tree): if isinstance(blob, pygit2.Tree):
die_commit("the repository must not contain subdirectories", die_commit(
str(commit.id)) "the repository must not contain subdirectories", str(commit.id)
)
if not isinstance(blob, pygit2.Blob): if not isinstance(blob, pygit2.Blob):
die_commit("not a blob object: {:s}".format(treeobj), die_commit("not a blob object: {:s}".format(treeobj), str(commit.id))
str(commit.id))
if blob.size > max_blob_size: if blob.size > max_blob_size:
die_commit("maximum blob size ({:s}) exceeded".format( die_commit(
size_humanize(max_blob_size)), str(commit.id)) "maximum blob size ({:s}) exceeded".format(
size_humanize(max_blob_size)
),
str(commit.id),
)
metadata_raw = repo[commit.tree['.SRCINFO'].id].data.decode() metadata_raw = repo[commit.tree[".SRCINFO"].id].data.decode()
(metadata, errors) = srcinfo.parse.parse_srcinfo(metadata_raw) (metadata, errors) = srcinfo.parse.parse_srcinfo(metadata_raw)
if errors: if errors:
sys.stderr.write("error: The following errors occurred " sys.stderr.write(
"when parsing .SRCINFO in commit\n") "error: The following errors occurred "
"when parsing .SRCINFO in commit\n"
)
sys.stderr.write("error: {:s}:\n".format(str(commit.id))) sys.stderr.write("error: {:s}:\n".format(str(commit.id)))
for error in errors: for error in errors:
for err in error['error']: for err in error["error"]:
sys.stderr.write("error: line {:d}: {:s}\n".format( sys.stderr.write(
error['line'], err)) "error: line {:d}: {:s}\n".format(error["line"], err)
)
exit(1) exit(1)
try: try:
metadata_pkgbase = metadata['pkgbase'] metadata_pkgbase = metadata["pkgbase"]
except KeyError: except KeyError:
die_commit('invalid .SRCINFO, does not contain a pkgbase (is the file empty?)', die_commit(
str(commit.id)) "invalid .SRCINFO, does not contain a pkgbase (is the file empty?)",
str(commit.id),
)
if not re.match(repo_regex, metadata_pkgbase): if not re.match(repo_regex, metadata_pkgbase):
die_commit('invalid pkgbase: {:s}'.format(metadata_pkgbase), die_commit("invalid pkgbase: {:s}".format(metadata_pkgbase), str(commit.id))
str(commit.id))
if not metadata['packages']: if not metadata["packages"]:
die_commit('missing pkgname entry', str(commit.id)) die_commit("missing pkgname entry", str(commit.id))
for pkgname in set(metadata['packages'].keys()): for pkgname in set(metadata["packages"].keys()):
pkginfo = srcinfo.utils.get_merged_package(pkgname, metadata) pkginfo = srcinfo.utils.get_merged_package(pkgname, metadata)
for field in ('pkgver', 'pkgrel', 'pkgname'): for field in ("pkgver", "pkgrel", "pkgname"):
if field not in pkginfo: if field not in pkginfo:
die_commit('missing mandatory field: {:s}'.format(field), die_commit(
str(commit.id)) "missing mandatory field: {:s}".format(field), str(commit.id)
)
if 'epoch' in pkginfo and not pkginfo['epoch'].isdigit(): if "epoch" in pkginfo and not pkginfo["epoch"].isdigit():
die_commit('invalid epoch: {:s}'.format(pkginfo['epoch']), die_commit(
str(commit.id)) "invalid epoch: {:s}".format(pkginfo["epoch"]), str(commit.id)
)
if not re.match(r'[a-z0-9][a-z0-9\.+_-]*$', pkginfo['pkgname']): if not re.match(r"[a-z0-9][a-z0-9\.+_-]*$", pkginfo["pkgname"]):
die_commit('invalid package name: {:s}'.format( die_commit(
pkginfo['pkgname']), str(commit.id)) "invalid package name: {:s}".format(pkginfo["pkgname"]),
str(commit.id),
)
max_len = {'pkgname': 255, 'pkgdesc': 255, 'url': 8000} max_len = {"pkgname": 255, "pkgdesc": 255, "url": 8000}
for field in max_len.keys(): for field in max_len.keys():
if field in pkginfo and len(pkginfo[field]) > max_len[field]: if field in pkginfo and len(pkginfo[field]) > max_len[field]:
die_commit('{:s} field too long: {:s}'.format(field, die_commit(
pkginfo[field]), str(commit.id)) "{:s} field too long: {:s}".format(field, pkginfo[field]),
str(commit.id),
)
for field in ('install', 'changelog'): for field in ("install", "changelog"):
if field in pkginfo and not pkginfo[field] in commit.tree: if field in pkginfo and not pkginfo[field] in commit.tree:
die_commit('missing {:s} file: {:s}'.format(field, die_commit(
pkginfo[field]), str(commit.id)) "missing {:s} file: {:s}".format(field, pkginfo[field]),
str(commit.id),
)
for field in extract_arch_fields(pkginfo, 'source'): for field in extract_arch_fields(pkginfo, "source"):
fname = field['value'] fname = field["value"]
if len(fname) > 8000: if len(fname) > 8000:
die_commit('source entry too long: {:s}'.format(fname), die_commit(
str(commit.id)) "source entry too long: {:s}".format(fname), str(commit.id)
)
if "://" in fname or "lp:" in fname: if "://" in fname or "lp:" in fname:
continue continue
if fname not in commit.tree: if fname not in commit.tree:
die_commit('missing source file: {:s}'.format(fname), die_commit(
str(commit.id)) "missing source file: {:s}".format(fname), str(commit.id)
)
# Display a warning if .SRCINFO is unchanged. # Display a warning if .SRCINFO is unchanged.
if sha1_old not in ("0000000000000000000000000000000000000000", sha1_new): if sha1_old not in ("0000000000000000000000000000000000000000", sha1_new):
srcinfo_id_old = repo[sha1_old].tree['.SRCINFO'].id srcinfo_id_old = repo[sha1_old].tree[".SRCINFO"].id
srcinfo_id_new = repo[sha1_new].tree['.SRCINFO'].id srcinfo_id_new = repo[sha1_new].tree[".SRCINFO"].id
if srcinfo_id_old == srcinfo_id_new: if srcinfo_id_old == srcinfo_id_new:
warn(".SRCINFO unchanged. " warn(".SRCINFO unchanged. " "The package database will not be updated!")
"The package database will not be updated!")
# Read .SRCINFO from the HEAD commit. # Read .SRCINFO from the HEAD commit.
metadata_raw = repo[repo[sha1_new].tree['.SRCINFO'].id].data.decode() metadata_raw = repo[repo[sha1_new].tree[".SRCINFO"].id].data.decode()
(metadata, errors) = srcinfo.parse.parse_srcinfo(metadata_raw) (metadata, errors) = srcinfo.parse.parse_srcinfo(metadata_raw)
# Ensure that the package base name matches the repository name. # Ensure that the package base name matches the repository name.
metadata_pkgbase = metadata['pkgbase'] metadata_pkgbase = metadata["pkgbase"]
if metadata_pkgbase != pkgbase: if metadata_pkgbase != pkgbase:
die('invalid pkgbase: {:s}, expected {:s}'.format(metadata_pkgbase, die("invalid pkgbase: {:s}, expected {:s}".format(metadata_pkgbase, pkgbase))
pkgbase))
# Ensure that packages are neither blacklisted nor overwritten. # Ensure that packages are neither blacklisted nor overwritten.
pkgbase = metadata['pkgbase'] pkgbase = metadata["pkgbase"]
cur = conn.execute("SELECT ID FROM PackageBases WHERE Name = ?", [pkgbase]) cur = conn.execute("SELECT ID FROM PackageBases WHERE Name = ?", [pkgbase])
row = cur.fetchone() row = cur.fetchone()
pkgbase_id = row[0] if row else 0 pkgbase_id = row[0] if row else 0
@ -385,18 +426,23 @@ def main(): # noqa: C901
for pkgname in srcinfo.utils.get_package_names(metadata): for pkgname in srcinfo.utils.get_package_names(metadata):
pkginfo = srcinfo.utils.get_merged_package(pkgname, metadata) pkginfo = srcinfo.utils.get_merged_package(pkgname, metadata)
pkgname = pkginfo['pkgname'] pkgname = pkginfo["pkgname"]
if pkgname in blacklist: if pkgname in blacklist:
warn_or_die('package is blacklisted: {:s}'.format(pkgname)) warn_or_die("package is blacklisted: {:s}".format(pkgname))
if pkgname in providers: if pkgname in providers:
warn_or_die('package already provided by [{:s}]: {:s}'.format( warn_or_die(
providers[pkgname], pkgname)) "package already provided by [{:s}]: {:s}".format(
providers[pkgname], pkgname
)
)
cur = conn.execute("SELECT COUNT(*) FROM Packages WHERE Name = ? " + cur = conn.execute(
"AND PackageBaseID <> ?", [pkgname, pkgbase_id]) "SELECT COUNT(*) FROM Packages WHERE Name = ? " + "AND PackageBaseID <> ?",
[pkgname, pkgbase_id],
)
if cur.fetchone()[0] > 0: if cur.fetchone()[0] > 0:
die('cannot overwrite package: {:s}'.format(pkgname)) die("cannot overwrite package: {:s}".format(pkgname))
# Create a new package base if it does not exist yet. # Create a new package base if it does not exist yet.
if pkgbase_id == 0: if pkgbase_id == 0:
@ -407,7 +453,7 @@ def main(): # noqa: C901
# Create (or update) a branch with the name of the package base for better # Create (or update) a branch with the name of the package base for better
# accessibility. # accessibility.
branchref = 'refs/heads/' + pkgbase branchref = "refs/heads/" + pkgbase
repo.create_reference(branchref, sha1_new, True) repo.create_reference(branchref, sha1_new, True)
# Work around a Git bug: The HEAD ref is not updated when using # Work around a Git bug: The HEAD ref is not updated when using
@ -415,7 +461,7 @@ def main(): # noqa: C901
# mainline. See # mainline. See
# http://git.661346.n2.nabble.com/PATCH-receive-pack-Create-a-HEAD-ref-for-ref-namespace-td7632149.html # http://git.661346.n2.nabble.com/PATCH-receive-pack-Create-a-HEAD-ref-for-ref-namespace-td7632149.html
# for details. # for details.
headref = 'refs/namespaces/' + pkgbase + '/HEAD' headref = "refs/namespaces/" + pkgbase + "/HEAD"
repo.create_reference(headref, sha1_new, True) repo.create_reference(headref, sha1_new, True)
# Send package update notifications. # Send package update notifications.
@ -426,5 +472,5 @@ def main(): # noqa: C901
conn.close() conn.close()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -9,28 +9,40 @@ import aurweb.schema
def feed_initial_data(conn): def feed_initial_data(conn):
conn.execute(aurweb.schema.AccountTypes.insert(), [ conn.execute(
{'ID': 1, 'AccountType': 'User'}, aurweb.schema.AccountTypes.insert(),
{'ID': 2, 'AccountType': 'Trusted User'}, [
{'ID': 3, 'AccountType': 'Developer'}, {"ID": 1, "AccountType": "User"},
{'ID': 4, 'AccountType': 'Trusted User & Developer'}, {"ID": 2, "AccountType": "Trusted User"},
]) {"ID": 3, "AccountType": "Developer"},
conn.execute(aurweb.schema.DependencyTypes.insert(), [ {"ID": 4, "AccountType": "Trusted User & Developer"},
{'ID': 1, 'Name': 'depends'}, ],
{'ID': 2, 'Name': 'makedepends'}, )
{'ID': 3, 'Name': 'checkdepends'}, conn.execute(
{'ID': 4, 'Name': 'optdepends'}, aurweb.schema.DependencyTypes.insert(),
]) [
conn.execute(aurweb.schema.RelationTypes.insert(), [ {"ID": 1, "Name": "depends"},
{'ID': 1, 'Name': 'conflicts'}, {"ID": 2, "Name": "makedepends"},
{'ID': 2, 'Name': 'provides'}, {"ID": 3, "Name": "checkdepends"},
{'ID': 3, 'Name': 'replaces'}, {"ID": 4, "Name": "optdepends"},
]) ],
conn.execute(aurweb.schema.RequestTypes.insert(), [ )
{'ID': 1, 'Name': 'deletion'}, conn.execute(
{'ID': 2, 'Name': 'orphan'}, aurweb.schema.RelationTypes.insert(),
{'ID': 3, 'Name': 'merge'}, [
]) {"ID": 1, "Name": "conflicts"},
{"ID": 2, "Name": "provides"},
{"ID": 3, "Name": "replaces"},
],
)
conn.execute(
aurweb.schema.RequestTypes.insert(),
[
{"ID": 1, "Name": "deletion"},
{"ID": 2, "Name": "orphan"},
{"ID": 3, "Name": "merge"},
],
)
def run(args): def run(args):
@ -40,8 +52,8 @@ def run(args):
# the last step and leave the database in an inconsistent state. The # the last step and leave the database in an inconsistent state. The
# configuration is loaded lazily, so we query it to force its loading. # configuration is loaded lazily, so we query it to force its loading.
if args.use_alembic: if args.use_alembic:
alembic_config = alembic.config.Config('alembic.ini') alembic_config = alembic.config.Config("alembic.ini")
alembic_config.get_main_option('script_location') alembic_config.get_main_option("script_location")
alembic_config.attributes["configure_logger"] = False alembic_config.attributes["configure_logger"] = False
engine = aurweb.db.get_engine(echo=(args.verbose >= 1)) engine = aurweb.db.get_engine(echo=(args.verbose >= 1))
@ -51,17 +63,21 @@ def run(args):
conn.close() conn.close()
if args.use_alembic: if args.use_alembic:
alembic.command.stamp(alembic_config, 'head') alembic.command.stamp(alembic_config, "head")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='python -m aurweb.initdb', prog="python -m aurweb.initdb", description="Initialize the aurweb database."
description='Initialize the aurweb database.') )
parser.add_argument('-v', '--verbose', action='count', default=0, parser.add_argument(
help='increase verbosity') "-v", "--verbose", action="count", default=0, help="increase verbosity"
parser.add_argument('--no-alembic', )
help='disable Alembic migrations support', parser.add_argument(
dest='use_alembic', action='store_false') "--no-alembic",
help="disable Alembic migrations support",
dest="use_alembic",
action="store_false",
)
args = parser.parse_args() args = parser.parse_args()
run(args) run(args)

View file

@ -1,43 +1,44 @@
import gettext import gettext
from collections import OrderedDict from collections import OrderedDict
from fastapi import Request from fastapi import Request
import aurweb.config import aurweb.config
SUPPORTED_LANGUAGES = OrderedDict({ SUPPORTED_LANGUAGES = OrderedDict(
"ar": "العربية", {
"ast": "Asturianu", "ar": "العربية",
"ca": "Català", "ast": "Asturianu",
"cs": "Český", "ca": "Català",
"da": "Dansk", "cs": "Český",
"de": "Deutsch", "da": "Dansk",
"el": "Ελληνικά", "de": "Deutsch",
"en": "English", "el": "Ελληνικά",
"es": "Español", "en": "English",
"es_419": "Español (Latinoamérica)", "es": "Español",
"fi": "Suomi", "es_419": "Español (Latinoamérica)",
"fr": "Français", "fi": "Suomi",
"he": "עברית", "fr": "Français",
"hr": "Hrvatski", "he": "עברית",
"hu": "Magyar", "hr": "Hrvatski",
"it": "Italiano", "hu": "Magyar",
"ja": "日本語", "it": "Italiano",
"nb": "Norsk", "ja": "日本語",
"nl": "Nederlands", "nb": "Norsk",
"pl": "Polski", "nl": "Nederlands",
"pt_BR": "Português (Brasil)", "pl": "Polski",
"pt_PT": "Português (Portugal)", "pt_BR": "Português (Brasil)",
"ro": "Română", "pt_PT": "Português (Portugal)",
"ru": "Русский", "ro": "Română",
"sk": "Slovenčina", "ru": "Русский",
"sr": "Srpski", "sk": "Slovenčina",
"tr": "Türkçe", "sr": "Srpski",
"uk": "Українська", "tr": "Türkçe",
"zh_CN": "简体中文", "uk": "Українська",
"zh_TW": "正體中文" "zh_CN": "简体中文",
}) "zh_TW": "正體中文",
}
)
RIGHT_TO_LEFT_LANGUAGES = ("he", "ar") RIGHT_TO_LEFT_LANGUAGES = ("he", "ar")
@ -45,15 +46,14 @@ RIGHT_TO_LEFT_LANGUAGES = ("he", "ar")
class Translator: class Translator:
def __init__(self): def __init__(self):
self._localedir = aurweb.config.get('options', 'localedir') self._localedir = aurweb.config.get("options", "localedir")
self._translator = {} self._translator = {}
def get_translator(self, lang: str): def get_translator(self, lang: str):
if lang not in self._translator: if lang not in self._translator:
self._translator[lang] = gettext.translation("aurweb", self._translator[lang] = gettext.translation(
self._localedir, "aurweb", self._localedir, languages=[lang], fallback=True
languages=[lang], )
fallback=True)
return self._translator.get(lang) return self._translator.get(lang)
def translate(self, s: str, lang: str): def translate(self, s: str, lang: str):

View file

@ -15,7 +15,7 @@ logging.getLogger("root").addHandler(logging.NullHandler())
def get_logger(name: str) -> logging.Logger: def get_logger(name: str) -> logging.Logger:
""" A logging.getLogger wrapper. Importing this function and """A logging.getLogger wrapper. Importing this function and
using it to get a module-local logger ensures that logging.conf using it to get a module-local logger ensures that logging.conf
initialization is performed wherever loggers are used. initialization is performed wherever loggers are used.

View file

@ -13,12 +13,16 @@ class AcceptedTerm(Base):
__mapper_args__ = {"primary_key": [__table__.c.TermsID]} __mapper_args__ = {"primary_key": [__table__.c.TermsID]}
User = relationship( User = relationship(
_User, backref=backref("accepted_terms", lazy="dynamic"), _User,
foreign_keys=[__table__.c.UsersID]) backref=backref("accepted_terms", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID],
)
Term = relationship( Term = relationship(
_Term, backref=backref("accepted_terms", lazy="dynamic"), _Term,
foreign_keys=[__table__.c.TermsID]) backref=backref("accepted_terms", lazy="dynamic"),
foreign_keys=[__table__.c.TermsID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -27,10 +31,12 @@ class AcceptedTerm(Base):
raise IntegrityError( raise IntegrityError(
statement="Foreign key UsersID cannot be null.", statement="Foreign key UsersID cannot be null.",
orig="AcceptedTerms.UserID", orig="AcceptedTerms.UserID",
params=("NULL")) params=("NULL"),
)
if not self.Term and not self.TermsID: if not self.Term and not self.TermsID:
raise IntegrityError( raise IntegrityError(
statement="Foreign key TermID cannot be null.", statement="Foreign key TermID cannot be null.",
orig="AcceptedTerms.TermID", orig="AcceptedTerms.TermID",
params=("NULL")) params=("NULL"),
)

View file

@ -16,7 +16,7 @@ ACCOUNT_TYPE_ID = {
USER: USER_ID, USER: USER_ID,
TRUSTED_USER: TRUSTED_USER_ID, TRUSTED_USER: TRUSTED_USER_ID,
DEVELOPER: DEVELOPER_ID, DEVELOPER: DEVELOPER_ID,
TRUSTED_USER_AND_DEV: TRUSTED_USER_AND_DEV_ID TRUSTED_USER_AND_DEV: TRUSTED_USER_AND_DEV_ID,
} }
# Reversed ACCOUNT_TYPE_ID mapping. # Reversed ACCOUNT_TYPE_ID mapping.
@ -24,7 +24,8 @@ ACCOUNT_TYPE_NAME = {v: k for k, v in ACCOUNT_TYPE_ID.items()}
class AccountType(Base): class AccountType(Base):
""" An ORM model of a single AccountTypes record. """ """An ORM model of a single AccountTypes record."""
__table__ = schema.AccountTypes __table__ = schema.AccountTypes
__tablename__ = __table__.name __tablename__ = __table__.name
__mapper_args__ = {"primary_key": [__table__.c.ID]} __mapper_args__ = {"primary_key": [__table__.c.ID]}
@ -36,5 +37,4 @@ class AccountType(Base):
return str(self.AccountType) return str(self.AccountType)
def __repr__(self): def __repr__(self):
return "<AccountType(ID='%s', AccountType='%s')>" % ( return "<AccountType(ID='%s', AccountType='%s')>" % (self.ID, str(self))
self.ID, str(self))

View file

@ -16,10 +16,12 @@ class ApiRateLimit(Base):
raise IntegrityError( raise IntegrityError(
statement="Column Requests cannot be null.", statement="Column Requests cannot be null.",
orig="ApiRateLimit.Requests", orig="ApiRateLimit.Requests",
params=("NULL")) params=("NULL"),
)
if self.WindowStart is None: if self.WindowStart is None:
raise IntegrityError( raise IntegrityError(
statement="Column WindowStart cannot be null.", statement="Column WindowStart cannot be null.",
orig="ApiRateLimit.WindowStart", orig="ApiRateLimit.WindowStart",
params=("NULL")) params=("NULL"),
)

View file

@ -6,26 +6,19 @@ from aurweb import util
def to_dict(model): def to_dict(model):
return { return {c.name: getattr(model, c.name) for c in model.__table__.columns}
c.name: getattr(model, c.name)
for c in model.__table__.columns
}
def to_json(model, indent: int = None): def to_json(model, indent: int = None):
return json.dumps({ return json.dumps(
k: util.jsonify(v) {k: util.jsonify(v) for k, v in to_dict(model).items()}, indent=indent
for k, v in to_dict(model).items() )
}, indent=indent)
Base = declarative_base() Base = declarative_base()
# Setup __table_args__ applicable to every table. # Setup __table_args__ applicable to every table.
Base.__table_args__ = { Base.__table_args__ = {"autoload": False, "extend_existing": True}
"autoload": False,
"extend_existing": True
}
# Setup Base.as_dict and Base.json. # Setup Base.as_dict and Base.json.
# #

View file

@ -15,4 +15,5 @@ class Group(Base):
raise IntegrityError( raise IntegrityError(
statement="Column Name cannot be null.", statement="Column Name cannot be null.",
orig="Groups.Name", orig="Groups.Name",
params=("NULL")) params=("NULL"),
)

View file

@ -16,4 +16,5 @@ class License(Base):
raise IntegrityError( raise IntegrityError(
statement="Column Name cannot be null.", statement="Column Name cannot be null.",
orig="Licenses.Name", orig="Licenses.Name",
params=("NULL")) params=("NULL"),
)

View file

@ -21,16 +21,19 @@ class OfficialProvider(Base):
raise IntegrityError( raise IntegrityError(
statement="Column Name cannot be null.", statement="Column Name cannot be null.",
orig="OfficialProviders.Name", orig="OfficialProviders.Name",
params=("NULL")) params=("NULL"),
)
if not self.Repo: if not self.Repo:
raise IntegrityError( raise IntegrityError(
statement="Column Repo cannot be null.", statement="Column Repo cannot be null.",
orig="OfficialProviders.Repo", orig="OfficialProviders.Repo",
params=("NULL")) params=("NULL"),
)
if not self.Provides: if not self.Provides:
raise IntegrityError( raise IntegrityError(
statement="Column Provides cannot be null.", statement="Column Provides cannot be null.",
orig="OfficialProviders.Provides", orig="OfficialProviders.Provides",
params=("NULL")) params=("NULL"),
)

View file

@ -12,9 +12,10 @@ class Package(Base):
__mapper_args__ = {"primary_key": [__table__.c.ID]} __mapper_args__ = {"primary_key": [__table__.c.ID]}
PackageBase = relationship( PackageBase = relationship(
_PackageBase, backref=backref("packages", lazy="dynamic", _PackageBase,
cascade="all, delete"), backref=backref("packages", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID]) foreign_keys=[__table__.c.PackageBaseID],
)
# No Package instances are official packages. # No Package instances are official packages.
is_official = False is_official = False
@ -26,10 +27,12 @@ class Package(Base):
raise IntegrityError( raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.", statement="Foreign key PackageBaseID cannot be null.",
orig="Packages.PackageBaseID", orig="Packages.PackageBaseID",
params=("NULL")) params=("NULL"),
)
if self.Name is None: if self.Name is None:
raise IntegrityError( raise IntegrityError(
statement="Column Name cannot be null.", statement="Column Name cannot be null.",
orig="Packages.Name", orig="Packages.Name",
params=("NULL")) params=("NULL"),
)

View file

@ -12,20 +12,28 @@ class PackageBase(Base):
__mapper_args__ = {"primary_key": [__table__.c.ID]} __mapper_args__ = {"primary_key": [__table__.c.ID]}
Flagger = relationship( Flagger = relationship(
_User, backref=backref("flagged_bases", lazy="dynamic"), _User,
foreign_keys=[__table__.c.FlaggerUID]) backref=backref("flagged_bases", lazy="dynamic"),
foreign_keys=[__table__.c.FlaggerUID],
)
Submitter = relationship( Submitter = relationship(
_User, backref=backref("submitted_bases", lazy="dynamic"), _User,
foreign_keys=[__table__.c.SubmitterUID]) backref=backref("submitted_bases", lazy="dynamic"),
foreign_keys=[__table__.c.SubmitterUID],
)
Maintainer = relationship( Maintainer = relationship(
_User, backref=backref("maintained_bases", lazy="dynamic"), _User,
foreign_keys=[__table__.c.MaintainerUID]) backref=backref("maintained_bases", lazy="dynamic"),
foreign_keys=[__table__.c.MaintainerUID],
)
Packager = relationship( Packager = relationship(
_User, backref=backref("package_bases", lazy="dynamic"), _User,
foreign_keys=[__table__.c.PackagerUID]) backref=backref("package_bases", lazy="dynamic"),
foreign_keys=[__table__.c.PackagerUID],
)
# A set used to check for floatable values. # A set used to check for floatable values.
TO_FLOAT = {"Popularity"} TO_FLOAT = {"Popularity"}
@ -37,7 +45,8 @@ class PackageBase(Base):
raise IntegrityError( raise IntegrityError(
statement="Column Name cannot be null.", statement="Column Name cannot be null.",
orig="PackageBases.Name", orig="PackageBases.Name",
params=("NULL")) params=("NULL"),
)
# If no SubmittedTS/ModifiedTS is provided on creation, set them # If no SubmittedTS/ModifiedTS is provided on creation, set them
# here to the current utc timestamp. # here to the current utc timestamp.

View file

@ -16,4 +16,5 @@ class PackageBlacklist(Base):
raise IntegrityError( raise IntegrityError(
statement="Column Name cannot be null.", statement="Column Name cannot be null.",
orig="PackageBlacklist.Name", orig="PackageBlacklist.Name",
params=("NULL")) params=("NULL"),
)

View file

@ -10,19 +10,19 @@ from aurweb.models.user import User as _User
class PackageComaintainer(Base): class PackageComaintainer(Base):
__table__ = schema.PackageComaintainers __table__ = schema.PackageComaintainers
__tablename__ = __table__.name __tablename__ = __table__.name
__mapper_args__ = { __mapper_args__ = {"primary_key": [__table__.c.UsersID, __table__.c.PackageBaseID]}
"primary_key": [__table__.c.UsersID, __table__.c.PackageBaseID]
}
User = relationship( User = relationship(
_User, backref=backref("comaintained", lazy="dynamic", _User,
cascade="all, delete"), backref=backref("comaintained", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.UsersID]) foreign_keys=[__table__.c.UsersID],
)
PackageBase = relationship( PackageBase = relationship(
_PackageBase, backref=backref("comaintainers", lazy="dynamic", _PackageBase,
cascade="all, delete"), backref=backref("comaintainers", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID]) foreign_keys=[__table__.c.PackageBaseID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -31,16 +31,19 @@ class PackageComaintainer(Base):
raise IntegrityError( raise IntegrityError(
statement="Foreign key UsersID cannot be null.", statement="Foreign key UsersID cannot be null.",
orig="PackageComaintainers.UsersID", orig="PackageComaintainers.UsersID",
params=("NULL")) params=("NULL"),
)
if not self.PackageBase and not self.PackageBaseID: if not self.PackageBase and not self.PackageBaseID:
raise IntegrityError( raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.", statement="Foreign key PackageBaseID cannot be null.",
orig="PackageComaintainers.PackageBaseID", orig="PackageComaintainers.PackageBaseID",
params=("NULL")) params=("NULL"),
)
if not self.Priority: if not self.Priority:
raise IntegrityError( raise IntegrityError(
statement="Column Priority cannot be null.", statement="Column Priority cannot be null.",
orig="PackageComaintainers.Priority", orig="PackageComaintainers.Priority",
params=("NULL")) params=("NULL"),
)

View file

@ -13,21 +13,28 @@ class PackageComment(Base):
__mapper_args__ = {"primary_key": [__table__.c.ID]} __mapper_args__ = {"primary_key": [__table__.c.ID]}
PackageBase = relationship( PackageBase = relationship(
_PackageBase, backref=backref("comments", lazy="dynamic", _PackageBase,
cascade="all, delete"), backref=backref("comments", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID]) foreign_keys=[__table__.c.PackageBaseID],
)
User = relationship( User = relationship(
_User, backref=backref("package_comments", lazy="dynamic"), _User,
foreign_keys=[__table__.c.UsersID]) backref=backref("package_comments", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID],
)
Editor = relationship( Editor = relationship(
_User, backref=backref("edited_comments", lazy="dynamic"), _User,
foreign_keys=[__table__.c.EditedUsersID]) backref=backref("edited_comments", lazy="dynamic"),
foreign_keys=[__table__.c.EditedUsersID],
)
Deleter = relationship( Deleter = relationship(
_User, backref=backref("deleted_comments", lazy="dynamic"), _User,
foreign_keys=[__table__.c.DelUsersID]) backref=backref("deleted_comments", lazy="dynamic"),
foreign_keys=[__table__.c.DelUsersID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -36,27 +43,31 @@ class PackageComment(Base):
raise IntegrityError( raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.", statement="Foreign key PackageBaseID cannot be null.",
orig="PackageComments.PackageBaseID", orig="PackageComments.PackageBaseID",
params=("NULL")) params=("NULL"),
)
if not self.User and not self.UsersID: if not self.User and not self.UsersID:
raise IntegrityError( raise IntegrityError(
statement="Foreign key UsersID cannot be null.", statement="Foreign key UsersID cannot be null.",
orig="PackageComments.UsersID", orig="PackageComments.UsersID",
params=("NULL")) params=("NULL"),
)
if self.Comments is None: if self.Comments is None:
raise IntegrityError( raise IntegrityError(
statement="Column Comments cannot be null.", statement="Column Comments cannot be null.",
orig="PackageComments.Comments", orig="PackageComments.Comments",
params=("NULL")) params=("NULL"),
)
if self.RenderedComment is None: if self.RenderedComment is None:
self.RenderedComment = str() self.RenderedComment = str()
def maintainers(self): def maintainers(self):
return list(filter( return list(
lambda e: e is not None, filter(
[self.PackageBase.Maintainer] + [ lambda e: e is not None,
c.User for c in self.PackageBase.comaintainers [self.PackageBase.Maintainer]
] + [c.User for c in self.PackageBase.comaintainers],
)) )
)

View file

@ -22,14 +22,16 @@ class PackageDependency(Base):
} }
Package = relationship( Package = relationship(
_Package, backref=backref("package_dependencies", lazy="dynamic", _Package,
cascade="all, delete"), backref=backref("package_dependencies", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageID]) foreign_keys=[__table__.c.PackageID],
)
DependencyType = relationship( DependencyType = relationship(
_DependencyType, _DependencyType,
backref=backref("package_dependencies", lazy="dynamic"), backref=backref("package_dependencies", lazy="dynamic"),
foreign_keys=[__table__.c.DepTypeID]) foreign_keys=[__table__.c.DepTypeID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -38,43 +40,58 @@ class PackageDependency(Base):
raise IntegrityError( raise IntegrityError(
statement="Foreign key PackageID cannot be null.", statement="Foreign key PackageID cannot be null.",
orig="PackageDependencies.PackageID", orig="PackageDependencies.PackageID",
params=("NULL")) params=("NULL"),
)
if not self.DependencyType and not self.DepTypeID: if not self.DependencyType and not self.DepTypeID:
raise IntegrityError( raise IntegrityError(
statement="Foreign key DepTypeID cannot be null.", statement="Foreign key DepTypeID cannot be null.",
orig="PackageDependencies.DepTypeID", orig="PackageDependencies.DepTypeID",
params=("NULL")) params=("NULL"),
)
if self.DepName is None: if self.DepName is None:
raise IntegrityError( raise IntegrityError(
statement="Column DepName cannot be null.", statement="Column DepName cannot be null.",
orig="PackageDependencies.DepName", orig="PackageDependencies.DepName",
params=("NULL")) params=("NULL"),
)
def is_package(self) -> bool: def is_package(self) -> bool:
pkg = db.query(_Package).filter(_Package.Name == self.DepName).exists() pkg = db.query(_Package).filter(_Package.Name == self.DepName).exists()
official = db.query(_OfficialProvider).filter( official = (
_OfficialProvider.Name == self.DepName).exists() db.query(_OfficialProvider)
.filter(_OfficialProvider.Name == self.DepName)
.exists()
)
return db.query(pkg).scalar() or db.query(official).scalar() return db.query(pkg).scalar() or db.query(official).scalar()
def provides(self) -> list[PackageRelation]: def provides(self) -> list[PackageRelation]:
from aurweb.models.relation_type import PROVIDES_ID from aurweb.models.relation_type import PROVIDES_ID
rels = db.query(PackageRelation).join(_Package).filter( rels = (
and_(PackageRelation.RelTypeID == PROVIDES_ID, db.query(PackageRelation)
PackageRelation.RelName == self.DepName) .join(_Package)
).with_entities( .filter(
_Package.Name, and_(
literal(False).label("is_official") PackageRelation.RelTypeID == PROVIDES_ID,
).order_by(_Package.Name.asc()) PackageRelation.RelName == self.DepName,
)
)
.with_entities(_Package.Name, literal(False).label("is_official"))
.order_by(_Package.Name.asc())
)
official_rels = db.query(_OfficialProvider).filter( official_rels = (
and_(_OfficialProvider.Provides == self.DepName, db.query(_OfficialProvider)
_OfficialProvider.Name != self.DepName) .filter(
).with_entities( and_(
_OfficialProvider.Name, _OfficialProvider.Provides == self.DepName,
literal(True).label("is_official") _OfficialProvider.Name != self.DepName,
).order_by(_OfficialProvider.Name.asc()) )
)
.with_entities(_OfficialProvider.Name, literal(True).label("is_official"))
.order_by(_OfficialProvider.Name.asc())
)
return rels.union(official_rels).all() return rels.union(official_rels).all()

View file

@ -10,19 +10,19 @@ from aurweb.models.package import Package as _Package
class PackageGroup(Base): class PackageGroup(Base):
__table__ = schema.PackageGroups __table__ = schema.PackageGroups
__tablename__ = __table__.name __tablename__ = __table__.name
__mapper_args__ = { __mapper_args__ = {"primary_key": [__table__.c.PackageID, __table__.c.GroupID]}
"primary_key": [__table__.c.PackageID, __table__.c.GroupID]
}
Package = relationship( Package = relationship(
_Package, backref=backref("package_groups", lazy="dynamic", _Package,
cascade="all, delete"), backref=backref("package_groups", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageID]) foreign_keys=[__table__.c.PackageID],
)
Group = relationship( Group = relationship(
_Group, backref=backref("package_groups", lazy="dynamic", _Group,
cascade="all, delete"), backref=backref("package_groups", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.GroupID]) foreign_keys=[__table__.c.GroupID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -31,10 +31,12 @@ class PackageGroup(Base):
raise IntegrityError( raise IntegrityError(
statement="Primary key PackageID cannot be null.", statement="Primary key PackageID cannot be null.",
orig="PackageGroups.PackageID", orig="PackageGroups.PackageID",
params=("NULL")) params=("NULL"),
)
if not self.Group and not self.GroupID: if not self.Group and not self.GroupID:
raise IntegrityError( raise IntegrityError(
statement="Primary key GroupID cannot be null.", statement="Primary key GroupID cannot be null.",
orig="PackageGroups.GroupID", orig="PackageGroups.GroupID",
params=("NULL")) params=("NULL"),
)

View file

@ -9,14 +9,13 @@ from aurweb.models.package_base import PackageBase as _PackageBase
class PackageKeyword(Base): class PackageKeyword(Base):
__table__ = schema.PackageKeywords __table__ = schema.PackageKeywords
__tablename__ = __table__.name __tablename__ = __table__.name
__mapper_args__ = { __mapper_args__ = {"primary_key": [__table__.c.PackageBaseID, __table__.c.Keyword]}
"primary_key": [__table__.c.PackageBaseID, __table__.c.Keyword]
}
PackageBase = relationship( PackageBase = relationship(
_PackageBase, backref=backref("keywords", lazy="dynamic", _PackageBase,
cascade="all, delete"), backref=backref("keywords", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID]) foreign_keys=[__table__.c.PackageBaseID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -25,4 +24,5 @@ class PackageKeyword(Base):
raise IntegrityError( raise IntegrityError(
statement="Primary key PackageBaseID cannot be null.", statement="Primary key PackageBaseID cannot be null.",
orig="PackageKeywords.PackageBaseID", orig="PackageKeywords.PackageBaseID",
params=("NULL")) params=("NULL"),
)

View file

@ -10,19 +10,19 @@ from aurweb.models.package import Package as _Package
class PackageLicense(Base): class PackageLicense(Base):
__table__ = schema.PackageLicenses __table__ = schema.PackageLicenses
__tablename__ = __table__.name __tablename__ = __table__.name
__mapper_args__ = { __mapper_args__ = {"primary_key": [__table__.c.PackageID, __table__.c.LicenseID]}
"primary_key": [__table__.c.PackageID, __table__.c.LicenseID]
}
Package = relationship( Package = relationship(
_Package, backref=backref("package_licenses", lazy="dynamic", _Package,
cascade="all, delete"), backref=backref("package_licenses", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageID]) foreign_keys=[__table__.c.PackageID],
)
License = relationship( License = relationship(
_License, backref=backref("package_licenses", lazy="dynamic", _License,
cascade="all, delete"), backref=backref("package_licenses", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.LicenseID]) foreign_keys=[__table__.c.LicenseID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -31,10 +31,12 @@ class PackageLicense(Base):
raise IntegrityError( raise IntegrityError(
statement="Primary key PackageID cannot be null.", statement="Primary key PackageID cannot be null.",
orig="PackageLicenses.PackageID", orig="PackageLicenses.PackageID",
params=("NULL")) params=("NULL"),
)
if not self.License and not self.LicenseID: if not self.License and not self.LicenseID:
raise IntegrityError( raise IntegrityError(
statement="Primary key LicenseID cannot be null.", statement="Primary key LicenseID cannot be null.",
orig="PackageLicenses.LicenseID", orig="PackageLicenses.LicenseID",
params=("NULL")) params=("NULL"),
)

View file

@ -10,20 +10,19 @@ from aurweb.models.user import User as _User
class PackageNotification(Base): class PackageNotification(Base):
__table__ = schema.PackageNotifications __table__ = schema.PackageNotifications
__tablename__ = __table__.name __tablename__ = __table__.name
__mapper_args__ = { __mapper_args__ = {"primary_key": [__table__.c.UserID, __table__.c.PackageBaseID]}
"primary_key": [__table__.c.UserID, __table__.c.PackageBaseID]
}
User = relationship( User = relationship(
_User, backref=backref("notifications", lazy="dynamic", _User,
cascade="all, delete"), backref=backref("notifications", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.UserID]) foreign_keys=[__table__.c.UserID],
)
PackageBase = relationship( PackageBase = relationship(
_PackageBase, _PackageBase,
backref=backref("notifications", lazy="dynamic", backref=backref("notifications", lazy="dynamic", cascade="all, delete"),
cascade="all, delete"), foreign_keys=[__table__.c.PackageBaseID],
foreign_keys=[__table__.c.PackageBaseID]) )
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -32,10 +31,12 @@ class PackageNotification(Base):
raise IntegrityError( raise IntegrityError(
statement="Foreign key UserID cannot be null.", statement="Foreign key UserID cannot be null.",
orig="PackageNotifications.UserID", orig="PackageNotifications.UserID",
params=("NULL")) params=("NULL"),
)
if not self.PackageBase and not self.PackageBaseID: if not self.PackageBase and not self.PackageBaseID:
raise IntegrityError( raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.", statement="Foreign key PackageBaseID cannot be null.",
orig="PackageNotifications.PackageBaseID", orig="PackageNotifications.PackageBaseID",
params=("NULL")) params=("NULL"),
)

View file

@ -19,13 +19,16 @@ class PackageRelation(Base):
} }
Package = relationship( Package = relationship(
_Package, backref=backref("package_relations", lazy="dynamic", _Package,
cascade="all, delete"), backref=backref("package_relations", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageID]) foreign_keys=[__table__.c.PackageID],
)
RelationType = relationship( RelationType = relationship(
_RelationType, backref=backref("package_relations", lazy="dynamic"), _RelationType,
foreign_keys=[__table__.c.RelTypeID]) backref=backref("package_relations", lazy="dynamic"),
foreign_keys=[__table__.c.RelTypeID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -34,16 +37,19 @@ class PackageRelation(Base):
raise IntegrityError( raise IntegrityError(
statement="Foreign key PackageID cannot be null.", statement="Foreign key PackageID cannot be null.",
orig="PackageRelations.PackageID", orig="PackageRelations.PackageID",
params=("NULL")) params=("NULL"),
)
if not self.RelationType and not self.RelTypeID: if not self.RelationType and not self.RelTypeID:
raise IntegrityError( raise IntegrityError(
statement="Foreign key RelTypeID cannot be null.", statement="Foreign key RelTypeID cannot be null.",
orig="PackageRelations.RelTypeID", orig="PackageRelations.RelTypeID",
params=("NULL")) params=("NULL"),
)
if not self.RelName: if not self.RelName:
raise IntegrityError( raise IntegrityError(
statement="Column RelName cannot be null.", statement="Column RelName cannot be null.",
orig="PackageRelations.RelName", orig="PackageRelations.RelName",
params=("NULL")) params=("NULL"),
)

View file

@ -25,26 +25,34 @@ class PackageRequest(Base):
__mapper_args__ = {"primary_key": [__table__.c.ID]} __mapper_args__ = {"primary_key": [__table__.c.ID]}
RequestType = relationship( RequestType = relationship(
_RequestType, backref=backref("package_requests", lazy="dynamic"), _RequestType,
foreign_keys=[__table__.c.ReqTypeID]) backref=backref("package_requests", lazy="dynamic"),
foreign_keys=[__table__.c.ReqTypeID],
)
User = relationship( User = relationship(
_User, backref=backref("package_requests", lazy="dynamic"), _User,
foreign_keys=[__table__.c.UsersID]) backref=backref("package_requests", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID],
)
PackageBase = relationship( PackageBase = relationship(
_PackageBase, backref=backref("requests", lazy="dynamic"), _PackageBase,
foreign_keys=[__table__.c.PackageBaseID]) backref=backref("requests", lazy="dynamic"),
foreign_keys=[__table__.c.PackageBaseID],
)
Closer = relationship( Closer = relationship(
_User, backref=backref("closed_requests", lazy="dynamic"), _User,
foreign_keys=[__table__.c.ClosedUID]) backref=backref("closed_requests", lazy="dynamic"),
foreign_keys=[__table__.c.ClosedUID],
)
STATUS_DISPLAY = { STATUS_DISPLAY = {
PENDING_ID: PENDING, PENDING_ID: PENDING,
CLOSED_ID: CLOSED, CLOSED_ID: CLOSED,
ACCEPTED_ID: ACCEPTED, ACCEPTED_ID: ACCEPTED,
REJECTED_ID: REJECTED REJECTED_ID: REJECTED,
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -54,38 +62,44 @@ class PackageRequest(Base):
raise IntegrityError( raise IntegrityError(
statement="Foreign key ReqTypeID cannot be null.", statement="Foreign key ReqTypeID cannot be null.",
orig="PackageRequests.ReqTypeID", orig="PackageRequests.ReqTypeID",
params=("NULL")) params=("NULL"),
)
if not self.PackageBase and not self.PackageBaseID: if not self.PackageBase and not self.PackageBaseID:
raise IntegrityError( raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.", statement="Foreign key PackageBaseID cannot be null.",
orig="PackageRequests.PackageBaseID", orig="PackageRequests.PackageBaseID",
params=("NULL")) params=("NULL"),
)
if not self.PackageBaseName: if not self.PackageBaseName:
raise IntegrityError( raise IntegrityError(
statement="Column PackageBaseName cannot be null.", statement="Column PackageBaseName cannot be null.",
orig="PackageRequests.PackageBaseName", orig="PackageRequests.PackageBaseName",
params=("NULL")) params=("NULL"),
)
if not self.User and not self.UsersID: if not self.User and not self.UsersID:
raise IntegrityError( raise IntegrityError(
statement="Foreign key UsersID cannot be null.", statement="Foreign key UsersID cannot be null.",
orig="PackageRequests.UsersID", orig="PackageRequests.UsersID",
params=("NULL")) params=("NULL"),
)
if self.Comments is None: if self.Comments is None:
raise IntegrityError( raise IntegrityError(
statement="Column Comments cannot be null.", statement="Column Comments cannot be null.",
orig="PackageRequests.Comments", orig="PackageRequests.Comments",
params=("NULL")) params=("NULL"),
)
if self.ClosureComment is None: if self.ClosureComment is None:
raise IntegrityError( raise IntegrityError(
statement="Column ClosureComment cannot be null.", statement="Column ClosureComment cannot be null.",
orig="PackageRequests.ClosureComment", orig="PackageRequests.ClosureComment",
params=("NULL")) params=("NULL"),
)
def status_display(self) -> str: def status_display(self) -> str:
""" Return a display string for the Status column. """ """Return a display string for the Status column."""
return self.STATUS_DISPLAY[self.Status] return self.STATUS_DISPLAY[self.Status]

View file

@ -9,17 +9,13 @@ from aurweb.models.package import Package as _Package
class PackageSource(Base): class PackageSource(Base):
__table__ = schema.PackageSources __table__ = schema.PackageSources
__tablename__ = __table__.name __tablename__ = __table__.name
__mapper_args__ = { __mapper_args__ = {"primary_key": [__table__.c.PackageID, __table__.c.Source]}
"primary_key": [
__table__.c.PackageID,
__table__.c.Source
]
}
Package = relationship( Package = relationship(
_Package, backref=backref("package_sources", lazy="dynamic", _Package,
cascade="all, delete"), backref=backref("package_sources", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageID]) foreign_keys=[__table__.c.PackageID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -28,7 +24,8 @@ class PackageSource(Base):
raise IntegrityError( raise IntegrityError(
statement="Foreign key PackageID cannot be null.", statement="Foreign key PackageID cannot be null.",
orig="PackageSources.PackageID", orig="PackageSources.PackageID",
params=("NULL")) params=("NULL"),
)
if not self.Source: if not self.Source:
self.Source = "/dev/null" self.Source = "/dev/null"

View file

@ -10,18 +10,19 @@ from aurweb.models.user import User as _User
class PackageVote(Base): class PackageVote(Base):
__table__ = schema.PackageVotes __table__ = schema.PackageVotes
__tablename__ = __table__.name __tablename__ = __table__.name
__mapper_args__ = { __mapper_args__ = {"primary_key": [__table__.c.UsersID, __table__.c.PackageBaseID]}
"primary_key": [__table__.c.UsersID, __table__.c.PackageBaseID]
}
User = relationship( User = relationship(
_User, backref=backref("package_votes", lazy="dynamic"), _User,
foreign_keys=[__table__.c.UsersID]) backref=backref("package_votes", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID],
)
PackageBase = relationship( PackageBase = relationship(
_PackageBase, backref=backref("package_votes", lazy="dynamic", _PackageBase,
cascade="all, delete"), backref=backref("package_votes", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID]) foreign_keys=[__table__.c.PackageBaseID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -30,16 +31,19 @@ class PackageVote(Base):
raise IntegrityError( raise IntegrityError(
statement="Foreign key UsersID cannot be null.", statement="Foreign key UsersID cannot be null.",
orig="PackageVotes.UsersID", orig="PackageVotes.UsersID",
params=("NULL")) params=("NULL"),
)
if not self.PackageBase and not self.PackageBaseID: if not self.PackageBase and not self.PackageBaseID:
raise IntegrityError( raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.", statement="Foreign key PackageBaseID cannot be null.",
orig="PackageVotes.PackageBaseID", orig="PackageVotes.PackageBaseID",
params=("NULL")) params=("NULL"),
)
if not self.VoteTS: if not self.VoteTS:
raise IntegrityError( raise IntegrityError(
statement="Column VoteTS cannot be null.", statement="Column VoteTS cannot be null.",
orig="PackageVotes.VoteTS", orig="PackageVotes.VoteTS",
params=("NULL")) params=("NULL"),
)

View file

@ -16,5 +16,5 @@ class RequestType(Base):
__mapper_args__ = {"primary_key": [__table__.c.ID]} __mapper_args__ = {"primary_key": [__table__.c.ID]}
def name_display(self) -> str: def name_display(self) -> str:
""" Return the Name column with its first char capitalized. """ """Return the Name column with its first char capitalized."""
return self.Name.title() return self.Name.title()

View file

@ -12,8 +12,10 @@ class Session(Base):
__mapper_args__ = {"primary_key": [__table__.c.UsersID]} __mapper_args__ = {"primary_key": [__table__.c.UsersID]}
User = relationship( User = relationship(
_User, backref=backref("session", uselist=False), _User,
foreign_keys=[__table__.c.UsersID]) backref=backref("session", uselist=False),
foreign_keys=[__table__.c.UsersID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -29,10 +31,13 @@ class Session(Base):
user_exists = db.query(_User).filter(_User.ID == uid).exists() user_exists = db.query(_User).filter(_User.ID == uid).exists()
if not db.query(user_exists).scalar(): if not db.query(user_exists).scalar():
raise IntegrityError( raise IntegrityError(
statement=("Foreign key UsersID cannot be null and " statement=(
"must be a valid user's ID."), "Foreign key UsersID cannot be null and "
"must be a valid user's ID."
),
orig="Sessions.UsersID", orig="Sessions.UsersID",
params=("NULL")) params=("NULL"),
)
def generate_unique_sid(): def generate_unique_sid():

View file

@ -12,16 +12,17 @@ class SSHPubKey(Base):
__mapper_args__ = {"primary_key": [__table__.c.Fingerprint]} __mapper_args__ = {"primary_key": [__table__.c.Fingerprint]}
User = relationship( User = relationship(
"User", backref=backref("ssh_pub_keys", lazy="dynamic"), "User",
foreign_keys=[__table__.c.UserID]) backref=backref("ssh_pub_keys", lazy="dynamic"),
foreign_keys=[__table__.c.UserID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
def get_fingerprint(pubkey: str) -> str: def get_fingerprint(pubkey: str) -> str:
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE, proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE, stderr=PIPE)
stderr=PIPE)
out, _ = proc.communicate(pubkey.encode()) out, _ = proc.communicate(pubkey.encode())
if proc.returncode: if proc.returncode:
raise ValueError("The SSH public key is invalid.") raise ValueError("The SSH public key is invalid.")

View file

@ -16,10 +16,12 @@ class Term(Base):
raise IntegrityError( raise IntegrityError(
statement="Column Description cannot be null.", statement="Column Description cannot be null.",
orig="Terms.Description", orig="Terms.Description",
params=("NULL")) params=("NULL"),
)
if not self.URL: if not self.URL:
raise IntegrityError( raise IntegrityError(
statement="Column URL cannot be null.", statement="Column URL cannot be null.",
orig="Terms.URL", orig="Terms.URL",
params=("NULL")) params=("NULL"),
)

View file

@ -10,17 +10,19 @@ from aurweb.models.user import User as _User
class TUVote(Base): class TUVote(Base):
__table__ = schema.TU_Votes __table__ = schema.TU_Votes
__tablename__ = __table__.name __tablename__ = __table__.name
__mapper_args__ = { __mapper_args__ = {"primary_key": [__table__.c.VoteID, __table__.c.UserID]}
"primary_key": [__table__.c.VoteID, __table__.c.UserID]
}
VoteInfo = relationship( VoteInfo = relationship(
_TUVoteInfo, backref=backref("tu_votes", lazy="dynamic"), _TUVoteInfo,
foreign_keys=[__table__.c.VoteID]) backref=backref("tu_votes", lazy="dynamic"),
foreign_keys=[__table__.c.VoteID],
)
User = relationship( User = relationship(
_User, backref=backref("tu_votes", lazy="dynamic"), _User,
foreign_keys=[__table__.c.UserID]) backref=backref("tu_votes", lazy="dynamic"),
foreign_keys=[__table__.c.UserID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -29,10 +31,12 @@ class TUVote(Base):
raise IntegrityError( raise IntegrityError(
statement="Foreign key VoteID cannot be null.", statement="Foreign key VoteID cannot be null.",
orig="TU_Votes.VoteID", orig="TU_Votes.VoteID",
params=("NULL")) params=("NULL"),
)
if not self.User and not self.UserID: if not self.User and not self.UserID:
raise IntegrityError( raise IntegrityError(
statement="Foreign key UserID cannot be null.", statement="Foreign key UserID cannot be null.",
orig="TU_Votes.UserID", orig="TU_Votes.UserID",
params=("NULL")) params=("NULL"),
)

View file

@ -14,8 +14,10 @@ class TUVoteInfo(Base):
__mapper_args__ = {"primary_key": [__table__.c.ID]} __mapper_args__ = {"primary_key": [__table__.c.ID]}
Submitter = relationship( Submitter = relationship(
_User, backref=backref("tu_voteinfo_set", lazy="dynamic"), _User,
foreign_keys=[__table__.c.SubmitterID]) backref=backref("tu_voteinfo_set", lazy="dynamic"),
foreign_keys=[__table__.c.SubmitterID],
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Default Quorum, Yes, No and Abstain columns to 0. # Default Quorum, Yes, No and Abstain columns to 0.
@ -29,40 +31,45 @@ class TUVoteInfo(Base):
raise IntegrityError( raise IntegrityError(
statement="Column Agenda cannot be null.", statement="Column Agenda cannot be null.",
orig="TU_VoteInfo.Agenda", orig="TU_VoteInfo.Agenda",
params=("NULL")) params=("NULL"),
)
if self.User is None: if self.User is None:
raise IntegrityError( raise IntegrityError(
statement="Column User cannot be null.", statement="Column User cannot be null.",
orig="TU_VoteInfo.User", orig="TU_VoteInfo.User",
params=("NULL")) params=("NULL"),
)
if self.Submitted is None: if self.Submitted is None:
raise IntegrityError( raise IntegrityError(
statement="Column Submitted cannot be null.", statement="Column Submitted cannot be null.",
orig="TU_VoteInfo.Submitted", orig="TU_VoteInfo.Submitted",
params=("NULL")) params=("NULL"),
)
if self.End is None: if self.End is None:
raise IntegrityError( raise IntegrityError(
statement="Column End cannot be null.", statement="Column End cannot be null.",
orig="TU_VoteInfo.End", orig="TU_VoteInfo.End",
params=("NULL")) params=("NULL"),
)
if not self.Submitter: if not self.Submitter:
raise IntegrityError( raise IntegrityError(
statement="Foreign key SubmitterID cannot be null.", statement="Foreign key SubmitterID cannot be null.",
orig="TU_VoteInfo.SubmitterID", orig="TU_VoteInfo.SubmitterID",
params=("NULL")) params=("NULL"),
)
def __setattr__(self, key: str, value: typing.Any): def __setattr__(self, key: str, value: typing.Any):
""" Customize setattr to stringify any Quorum keys given. """ """Customize setattr to stringify any Quorum keys given."""
if key == "Quorum": if key == "Quorum":
value = str(value) value = str(value)
return super().__setattr__(key, value) return super().__setattr__(key, value)
def __getattribute__(self, key: str): def __getattribute__(self, key: str):
""" Customize getattr to floatify any fetched Quorum values. """ """Customize getattr to floatify any fetched Quorum values."""
attr = super().__getattribute__(key) attr = super().__getattribute__(key)
if key == "Quorum": if key == "Quorum":
return float(attr) return float(attr)

View file

@ -1,9 +1,7 @@
import hashlib import hashlib
from typing import Set from typing import Set
import bcrypt import bcrypt
from fastapi import Request from fastapi import Request
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -12,7 +10,6 @@ from sqlalchemy.orm import backref, relationship
import aurweb.config import aurweb.config
import aurweb.models.account_type import aurweb.models.account_type
import aurweb.schema import aurweb.schema
from aurweb import db, logging, schema, time, util from aurweb import db, logging, schema, time, util
from aurweb.models.account_type import AccountType as _AccountType from aurweb.models.account_type import AccountType as _AccountType
from aurweb.models.ban import is_banned from aurweb.models.ban import is_banned
@ -24,7 +21,8 @@ SALT_ROUNDS_DEFAULT = 12
class User(Base): class User(Base):
""" An ORM model of a single Users record. """ """An ORM model of a single Users record."""
__table__ = schema.Users __table__ = schema.Users
__tablename__ = __table__.name __tablename__ = __table__.name
__mapper_args__ = {"primary_key": [__table__.c.ID]} __mapper_args__ = {"primary_key": [__table__.c.ID]}
@ -33,7 +31,8 @@ class User(Base):
_AccountType, _AccountType,
backref=backref("users", lazy="dynamic"), backref=backref("users", lazy="dynamic"),
foreign_keys=[__table__.c.AccountTypeID], foreign_keys=[__table__.c.AccountTypeID],
uselist=False) uselist=False,
)
# High-level variables used to track authentication (not in DB). # High-level variables used to track authentication (not in DB).
authenticated = False authenticated = False
@ -41,50 +40,50 @@ class User(Base):
# Make this static to the class just in case SQLAlchemy ever # Make this static to the class just in case SQLAlchemy ever
# does something to bypass our constructor. # does something to bypass our constructor.
salt_rounds = aurweb.config.getint("options", "salt_rounds", salt_rounds = aurweb.config.getint("options", "salt_rounds", SALT_ROUNDS_DEFAULT)
SALT_ROUNDS_DEFAULT)
def __init__(self, Passwd: str = str(), **kwargs): def __init__(self, Passwd: str = str(), **kwargs):
super().__init__(**kwargs, Passwd=str()) super().__init__(**kwargs, Passwd=str())
# Run this again in the constructor in case we rehashed config. # Run this again in the constructor in case we rehashed config.
self.salt_rounds = aurweb.config.getint("options", "salt_rounds", self.salt_rounds = aurweb.config.getint(
SALT_ROUNDS_DEFAULT) "options", "salt_rounds", SALT_ROUNDS_DEFAULT
)
if Passwd: if Passwd:
self.update_password(Passwd) self.update_password(Passwd)
def update_password(self, password): def update_password(self, password):
self.Passwd = bcrypt.hashpw( self.Passwd = bcrypt.hashpw(
password.encode(), password.encode(), bcrypt.gensalt(rounds=self.salt_rounds)
bcrypt.gensalt(rounds=self.salt_rounds)).decode() ).decode()
@staticmethod @staticmethod
def minimum_passwd_length(): def minimum_passwd_length():
return aurweb.config.getint("options", "passwd_min_len") return aurweb.config.getint("options", "passwd_min_len")
def is_authenticated(self): def is_authenticated(self):
""" Return internal authenticated state. """ """Return internal authenticated state."""
return self.authenticated return self.authenticated
def valid_password(self, password: str): def valid_password(self, password: str):
""" Check authentication against a given password. """ """Check authentication against a given password."""
if password is None: if password is None:
return False return False
password_is_valid = False password_is_valid = False
try: try:
password_is_valid = bcrypt.checkpw(password.encode(), password_is_valid = bcrypt.checkpw(password.encode(), self.Passwd.encode())
self.Passwd.encode())
except ValueError: except ValueError:
pass pass
# If our Salt column is not empty, we're using a legacy password. # If our Salt column is not empty, we're using a legacy password.
if not password_is_valid and self.Salt != str(): if not password_is_valid and self.Salt != str():
# Try to login with legacy method. # Try to login with legacy method.
password_is_valid = hashlib.md5( password_is_valid = (
f"{self.Salt}{password}".encode() hashlib.md5(f"{self.Salt}{password}".encode()).hexdigest()
).hexdigest() == self.Passwd == self.Passwd
)
# We got here, we passed the legacy authentication. # We got here, we passed the legacy authentication.
# Update the password to our modern hash style. # Update the password to our modern hash style.
@ -96,9 +95,8 @@ class User(Base):
def _login_approved(self, request: Request): def _login_approved(self, request: Request):
return not is_banned(request) and not self.Suspended return not is_banned(request) and not self.Suspended
def login(self, request: Request, password: str, def login(self, request: Request, password: str, session_time: int = 0) -> str:
session_time: int = 0) -> str: """Login and authenticate a request."""
""" Login and authenticate a request. """
from aurweb import db from aurweb import db
from aurweb.models.session import Session, generate_unique_sid from aurweb.models.session import Session, generate_unique_sid
@ -127,9 +125,9 @@ class User(Base):
self.LastLoginIPAddress = request.client.host self.LastLoginIPAddress = request.client.host
if not self.session: if not self.session:
sid = generate_unique_sid() sid = generate_unique_sid()
self.session = db.create(Session, User=self, self.session = db.create(
SessionID=sid, Session, User=self, SessionID=sid, LastUpdateTS=now_ts
LastUpdateTS=now_ts) )
else: else:
last_updated = self.session.LastUpdateTS last_updated = self.session.LastUpdateTS
if last_updated and last_updated < now_ts: if last_updated and last_updated < now_ts:
@ -148,9 +146,9 @@ class User(Base):
return self.session.SessionID return self.session.SessionID
def has_credential(self, credential: Set[int], def has_credential(self, credential: Set[int], approved: list["User"] = list()):
approved: list["User"] = list()):
from aurweb.auth.creds import has_credential from aurweb.auth.creds import has_credential
return has_credential(self, credential, approved) return has_credential(self, credential, approved)
def logout(self, request: Request): def logout(self, request: Request):
@ -162,18 +160,18 @@ class User(Base):
def is_trusted_user(self): def is_trusted_user(self):
return self.AccountType.ID in { return self.AccountType.ID in {
aurweb.models.account_type.TRUSTED_USER_ID, aurweb.models.account_type.TRUSTED_USER_ID,
aurweb.models.account_type.TRUSTED_USER_AND_DEV_ID aurweb.models.account_type.TRUSTED_USER_AND_DEV_ID,
} }
def is_developer(self): def is_developer(self):
return self.AccountType.ID in { return self.AccountType.ID in {
aurweb.models.account_type.DEVELOPER_ID, aurweb.models.account_type.DEVELOPER_ID,
aurweb.models.account_type.TRUSTED_USER_AND_DEV_ID aurweb.models.account_type.TRUSTED_USER_AND_DEV_ID,
} }
def is_elevated(self): def is_elevated(self):
""" A User is 'elevated' when they have either a """A User is 'elevated' when they have either a
Trusted User or Developer AccountType. """ Trusted User or Developer AccountType."""
return self.AccountType.ID in { return self.AccountType.ID in {
aurweb.models.account_type.TRUSTED_USER_ID, aurweb.models.account_type.TRUSTED_USER_ID,
aurweb.models.account_type.DEVELOPER_ID, aurweb.models.account_type.DEVELOPER_ID,
@ -196,18 +194,22 @@ class User(Base):
:return: Boolean indicating whether `self` can edit `target` :return: Boolean indicating whether `self` can edit `target`
""" """
from aurweb.auth import creds from aurweb.auth import creds
has_cred = self.has_credential(creds.ACCOUNT_EDIT, approved=[target]) has_cred = self.has_credential(creds.ACCOUNT_EDIT, approved=[target])
return has_cred and self.AccountTypeID >= target.AccountTypeID return has_cred and self.AccountTypeID >= target.AccountTypeID
def voted_for(self, package) -> bool: def voted_for(self, package) -> bool:
""" Has this User voted for package? """ """Has this User voted for package?"""
from aurweb.models.package_vote import PackageVote from aurweb.models.package_vote import PackageVote
return bool(package.PackageBase.package_votes.filter(
PackageVote.UsersID == self.ID return bool(
).scalar()) package.PackageBase.package_votes.filter(
PackageVote.UsersID == self.ID
).scalar()
)
def notified(self, package) -> bool: def notified(self, package) -> bool:
""" Is this User being notified about package (or package base)? """Is this User being notified about package (or package base)?
:param package: Package or PackageBase instance :param package: Package or PackageBase instance
:return: Boolean indicating state of package notification :return: Boolean indicating state of package notification
@ -225,12 +227,14 @@ class User(Base):
# Run an exists() query where a pkgbase-related # Run an exists() query where a pkgbase-related
# PackageNotification exists for self (a user). # PackageNotification exists for self (a user).
return bool(db.query( return bool(
query.filter(PackageNotification.UserID == self.ID).exists() db.query(
).scalar()) query.filter(PackageNotification.UserID == self.ID).exists()
).scalar()
)
def packages(self): def packages(self):
""" Returns an ORM query to Package objects owned by this user. """Returns an ORM query to Package objects owned by this user.
This should really be replaced with an internal ORM join This should really be replaced with an internal ORM join
configured for the User model. This has not been done yet configured for the User model. This has not been done yet
@ -241,16 +245,24 @@ class User(Base):
""" """
from aurweb.models.package import Package from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
return db.query(Package).join(PackageBase).filter(
or_( return (
PackageBase.PackagerUID == self.ID, db.query(Package)
PackageBase.MaintainerUID == self.ID .join(PackageBase)
.filter(
or_(
PackageBase.PackagerUID == self.ID,
PackageBase.MaintainerUID == self.ID,
)
) )
) )
def __repr__(self): def __repr__(self):
return "<User(ID='%s', AccountType='%s', Username='%s')>" % ( return "<User(ID='%s', AccountType='%s', Username='%s')>" % (
self.ID, str(self.AccountType), self.Username) self.ID,
str(self.AccountType),
self.Username,
)
def __str__(self) -> str: def __str__(self) -> str:
return self.Username return self.Username

View file

@ -7,46 +7,55 @@ from aurweb import config, db, l10n, time, util
from aurweb.exceptions import InvariantError from aurweb.exceptions import InvariantError
from aurweb.models import PackageBase, PackageRequest, User from aurweb.models import PackageBase, PackageRequest, User
from aurweb.models.package_request import ACCEPTED_ID, PENDING_ID, REJECTED_ID from aurweb.models.package_request import ACCEPTED_ID, PENDING_ID, REJECTED_ID
from aurweb.models.request_type import DELETION, DELETION_ID, MERGE, MERGE_ID, ORPHAN, ORPHAN_ID from aurweb.models.request_type import (
DELETION,
DELETION_ID,
MERGE,
MERGE_ID,
ORPHAN,
ORPHAN_ID,
)
from aurweb.scripts import notify from aurweb.scripts import notify
class ClosureFactory: class ClosureFactory:
""" A factory class used to autogenerate closure comments. """ """A factory class used to autogenerate closure comments."""
REQTYPE_NAMES = { REQTYPE_NAMES = {DELETION_ID: DELETION, MERGE_ID: MERGE, ORPHAN_ID: ORPHAN}
DELETION_ID: DELETION,
MERGE_ID: MERGE,
ORPHAN_ID: ORPHAN
}
def _deletion_closure(self, requester: User, def _deletion_closure(
pkgbase: PackageBase, self, requester: User, pkgbase: PackageBase, target: PackageBase = None
target: PackageBase = None): ):
return (f"[Autogenerated] Accepted deletion for {pkgbase.Name}.") return f"[Autogenerated] Accepted deletion for {pkgbase.Name}."
def _merge_closure(self, requester: User, def _merge_closure(
pkgbase: PackageBase, self, requester: User, pkgbase: PackageBase, target: PackageBase = None
target: PackageBase = None): ):
return (f"[Autogenerated] Accepted merge for {pkgbase.Name} " return (
f"into {target.Name}.") f"[Autogenerated] Accepted merge for {pkgbase.Name} " f"into {target.Name}."
)
def _orphan_closure(self, requester: User, def _orphan_closure(
pkgbase: PackageBase, self, requester: User, pkgbase: PackageBase, target: PackageBase = None
target: PackageBase = None): ):
return (f"[Autogenerated] Accepted orphan for {pkgbase.Name}.") return f"[Autogenerated] Accepted orphan for {pkgbase.Name}."
def _rejected_merge_closure(self, requester: User, def _rejected_merge_closure(
pkgbase: PackageBase, self, requester: User, pkgbase: PackageBase, target: PackageBase = None
target: PackageBase = None): ):
return (f"[Autogenerated] Another request to merge {pkgbase.Name} " return (
f"into {target.Name} has rendered this request invalid.") f"[Autogenerated] Another request to merge {pkgbase.Name} "
f"into {target.Name} has rendered this request invalid."
)
def get_closure(self, reqtype_id: int, def get_closure(
requester: User, self,
pkgbase: PackageBase, reqtype_id: int,
target: PackageBase = None, requester: User,
status: int = ACCEPTED_ID) -> str: pkgbase: PackageBase,
target: PackageBase = None,
status: int = ACCEPTED_ID,
) -> str:
""" """
Return a closure comment handled by this class. Return a closure comment handled by this class.
@ -69,8 +78,9 @@ class ClosureFactory:
return handler(requester, pkgbase, target) return handler(requester, pkgbase, target)
def update_closure_comment(pkgbase: PackageBase, reqtype_id: int, def update_closure_comment(
comments: str, target: PackageBase = None) -> None: pkgbase: PackageBase, reqtype_id: int, comments: str, target: PackageBase = None
) -> None:
""" """
Update all pending requests related to `pkgbase` with a closure comment. Update all pending requests related to `pkgbase` with a closure comment.
@ -90,8 +100,10 @@ def update_closure_comment(pkgbase: PackageBase, reqtype_id: int,
return return
query = pkgbase.requests.filter( query = pkgbase.requests.filter(
and_(PackageRequest.ReqTypeID == reqtype_id, and_(
PackageRequest.Status == PENDING_ID)) PackageRequest.ReqTypeID == reqtype_id, PackageRequest.Status == PENDING_ID
)
)
if reqtype_id == MERGE_ID: if reqtype_id == MERGE_ID:
query = query.filter(PackageRequest.MergeBaseName == target.Name) query = query.filter(PackageRequest.MergeBaseName == target.Name)
@ -100,9 +112,8 @@ def update_closure_comment(pkgbase: PackageBase, reqtype_id: int,
def verify_orphan_request(user: User, pkgbase: PackageBase): def verify_orphan_request(user: User, pkgbase: PackageBase):
""" Verify that an undue orphan request exists in `requests`. """ """Verify that an undue orphan request exists in `requests`."""
requests = pkgbase.requests.filter( requests = pkgbase.requests.filter(PackageRequest.ReqTypeID == ORPHAN_ID)
PackageRequest.ReqTypeID == ORPHAN_ID)
for pkgreq in requests: for pkgreq in requests:
idle_time = config.getint("options", "request_idle_time") idle_time = config.getint("options", "request_idle_time")
time_delta = time.utcnow() - pkgreq.RequestTS time_delta = time.utcnow() - pkgreq.RequestTS
@ -115,9 +126,13 @@ def verify_orphan_request(user: User, pkgbase: PackageBase):
return False return False
def close_pkgreq(pkgreq: PackageRequest, closer: User, def close_pkgreq(
pkgbase: PackageBase, target: Optional[PackageBase], pkgreq: PackageRequest,
status: int) -> None: closer: User,
pkgbase: PackageBase,
target: Optional[PackageBase],
status: int,
) -> None:
""" """
Close a package request with `pkgreq`.Status == `status`. Close a package request with `pkgreq`.Status == `status`.
@ -130,16 +145,15 @@ def close_pkgreq(pkgreq: PackageRequest, closer: User,
now = time.utcnow() now = time.utcnow()
pkgreq.Status = status pkgreq.Status = status
pkgreq.Closer = closer pkgreq.Closer = closer
pkgreq.ClosureComment = ( pkgreq.ClosureComment = pkgreq.ClosureComment or ClosureFactory().get_closure(
pkgreq.ClosureComment or ClosureFactory().get_closure( pkgreq.ReqTypeID, closer, pkgbase, target, status
pkgreq.ReqTypeID, closer, pkgbase, target, status)
) )
pkgreq.ClosedTS = now pkgreq.ClosedTS = now
def handle_request(request: Request, reqtype_id: int, def handle_request(
pkgbase: PackageBase, request: Request, reqtype_id: int, pkgbase: PackageBase, target: PackageBase = None
target: PackageBase = None) -> list[notify.Notification]: ) -> list[notify.Notification]:
""" """
Handle package requests before performing an action. Handle package requests before performing an action.
@ -165,17 +179,20 @@ def handle_request(request: Request, reqtype_id: int,
if reqtype_id == ORPHAN_ID: if reqtype_id == ORPHAN_ID:
if not verify_orphan_request(request.user, pkgbase): if not verify_orphan_request(request.user, pkgbase):
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
raise InvariantError(_( raise InvariantError(
"No due existing orphan requests to accept for %s." _("No due existing orphan requests to accept for %s.") % pkgbase.Name
) % pkgbase.Name) )
# Produce a base query for requests related to `pkgbase`, based # Produce a base query for requests related to `pkgbase`, based
# on ReqTypeID matching `reqtype_id`, pending status and a correct # on ReqTypeID matching `reqtype_id`, pending status and a correct
# PackagBaseName column. # PackagBaseName column.
query: orm.Query = pkgbase.requests.filter( query: orm.Query = pkgbase.requests.filter(
and_(PackageRequest.ReqTypeID == reqtype_id, and_(
PackageRequest.Status == PENDING_ID, PackageRequest.ReqTypeID == reqtype_id,
PackageRequest.PackageBaseName == pkgbase.Name)) PackageRequest.Status == PENDING_ID,
PackageRequest.PackageBaseName == pkgbase.Name,
)
)
# Build a query for records we should accept. For merge requests, # Build a query for records we should accept. For merge requests,
# this is specific to a matching MergeBaseName. For others, this # this is specific to a matching MergeBaseName. For others, this
@ -183,8 +200,7 @@ def handle_request(request: Request, reqtype_id: int,
accept_query: orm.Query = query accept_query: orm.Query = query
if target: if target:
# If a `target` was supplied, filter by MergeBaseName # If a `target` was supplied, filter by MergeBaseName
accept_query = query.filter( accept_query = query.filter(PackageRequest.MergeBaseName == target.Name)
PackageRequest.MergeBaseName == target.Name)
# Build an accept list out of `accept_query`. # Build an accept list out of `accept_query`.
to_accept: list[PackageRequest] = accept_query.all() to_accept: list[PackageRequest] = accept_query.all()
@ -203,14 +219,16 @@ def handle_request(request: Request, reqtype_id: int,
if not to_accept: if not to_accept:
utcnow = time.utcnow() utcnow = time.utcnow()
with db.begin(): with db.begin():
pkgreq = db.create(PackageRequest, pkgreq = db.create(
ReqTypeID=reqtype_id, PackageRequest,
RequestTS=utcnow, ReqTypeID=reqtype_id,
User=request.user, RequestTS=utcnow,
PackageBase=pkgbase, User=request.user,
PackageBaseName=pkgbase.Name, PackageBase=pkgbase,
Comments="Autogenerated by aurweb.", PackageBaseName=pkgbase.Name,
ClosureComment=str()) Comments="Autogenerated by aurweb.",
ClosureComment=str(),
)
# If it's a merge request, set MergeBaseName to `target`.Name. # If it's a merge request, set MergeBaseName to `target`.Name.
if pkgreq.ReqTypeID == MERGE_ID: if pkgreq.ReqTypeID == MERGE_ID:
@ -222,15 +240,20 @@ def handle_request(request: Request, reqtype_id: int,
# Update requests with their new status and closures. # Update requests with their new status and closures.
with db.begin(): with db.begin():
util.apply_all(to_accept, lambda p: close_pkgreq( util.apply_all(
p, request.user, pkgbase, target, ACCEPTED_ID)) to_accept,
util.apply_all(to_reject, lambda p: close_pkgreq( lambda p: close_pkgreq(p, request.user, pkgbase, target, ACCEPTED_ID),
p, request.user, pkgbase, target, REJECTED_ID)) )
util.apply_all(
to_reject,
lambda p: close_pkgreq(p, request.user, pkgbase, target, REJECTED_ID),
)
# Create RequestCloseNotifications for all requests involved. # Create RequestCloseNotifications for all requests involved.
for pkgreq in (to_accept + to_reject): for pkgreq in to_accept + to_reject:
notif = notify.RequestCloseNotification( notif = notify.RequestCloseNotification(
request.user.ID, pkgreq.ID, pkgreq.status_display()) request.user.ID, pkgreq.ID, pkgreq.status_display()
)
notifs.append(notif) notifs.append(notif)
# Return notifications to the caller for sending. # Return notifications to the caller for sending.

View file

@ -4,7 +4,12 @@ from sqlalchemy import and_, case, or_, orm
from aurweb import db, models from aurweb import db, models
from aurweb.models import Package, PackageBase, User from aurweb.models import Package, PackageBase, User
from aurweb.models.dependency_type import CHECKDEPENDS_ID, DEPENDS_ID, MAKEDEPENDS_ID, OPTDEPENDS_ID from aurweb.models.dependency_type import (
CHECKDEPENDS_ID,
DEPENDS_ID,
MAKEDEPENDS_ID,
OPTDEPENDS_ID,
)
from aurweb.models.package_comaintainer import PackageComaintainer from aurweb.models.package_comaintainer import PackageComaintainer
from aurweb.models.package_keyword import PackageKeyword from aurweb.models.package_keyword import PackageKeyword
from aurweb.models.package_notification import PackageNotification from aurweb.models.package_notification import PackageNotification

View file

@ -3,7 +3,6 @@ from http import HTTPStatus
from typing import Tuple, Union from typing import Tuple, Union
import orjson import orjson
from fastapi import HTTPException from fastapi import HTTPException
from sqlalchemy import orm from sqlalchemy import orm
@ -43,10 +42,10 @@ def dep_optdepends_extra(dep: models.PackageDependency) -> str:
@register_filter("dep_extra") @register_filter("dep_extra")
def dep_extra(dep: models.PackageDependency) -> str: def dep_extra(dep: models.PackageDependency) -> str:
""" Some dependency types have extra text added to their """Some dependency types have extra text added to their
display. This function provides that output. However, it display. This function provides that output. However, it
**assumes** that the dep passed is bound to a valid one **assumes** that the dep passed is bound to a valid one
of: depends, makedepends, checkdepends or optdepends. """ of: depends, makedepends, checkdepends or optdepends."""
f = globals().get(f"dep_{dep.DependencyType.Name}_extra") f = globals().get(f"dep_{dep.DependencyType.Name}_extra")
return f(dep) return f(dep)
@ -61,13 +60,13 @@ def dep_extra_desc(dep: models.PackageDependency) -> str:
@register_filter("pkgname_link") @register_filter("pkgname_link")
def pkgname_link(pkgname: str) -> str: def pkgname_link(pkgname: str) -> str:
record = db.query(Package).filter( record = db.query(Package).filter(Package.Name == pkgname).exists()
Package.Name == pkgname).exists()
if db.query(record).scalar(): if db.query(record).scalar():
return f"/packages/{pkgname}" return f"/packages/{pkgname}"
official = db.query(OfficialProvider).filter( official = (
OfficialProvider.Name == pkgname).exists() db.query(OfficialProvider).filter(OfficialProvider.Name == pkgname).exists()
)
if db.query(official).scalar(): if db.query(official).scalar():
base = "/".join([OFFICIAL_BASE, "packages"]) base = "/".join([OFFICIAL_BASE, "packages"])
return f"{base}/?q={pkgname}" return f"{base}/?q={pkgname}"
@ -83,17 +82,15 @@ def package_link(package: Union[Package, OfficialProvider]) -> str:
@register_filter("provides_markup") @register_filter("provides_markup")
def provides_markup(provides: Providers) -> str: def provides_markup(provides: Providers) -> str:
return ", ".join([ return ", ".join(
f'<a href="{package_link(pkg)}">{pkg.Name}</a>' [f'<a href="{package_link(pkg)}">{pkg.Name}</a>' for pkg in provides]
for pkg in provides )
])
def get_pkg_or_base( def get_pkg_or_base(
name: str, name: str, cls: Union[models.Package, models.PackageBase] = models.PackageBase
cls: Union[models.Package, models.PackageBase] = models.PackageBase) \ ) -> Union[models.Package, models.PackageBase]:
-> Union[models.Package, models.PackageBase]: """Get a PackageBase instance by its name or raise a 404 if
""" Get a PackageBase instance by its name or raise a 404 if
it can't be found in the database. it can't be found in the database.
:param name: {Package,PackageBase}.Name :param name: {Package,PackageBase}.Name
@ -109,8 +106,7 @@ def get_pkg_or_base(
return instance return instance
def get_pkgbase_comment(pkgbase: models.PackageBase, id: int) \ def get_pkgbase_comment(pkgbase: models.PackageBase, id: int) -> models.PackageComment:
-> models.PackageComment:
comment = pkgbase.comments.filter(models.PackageComment.ID == id).first() comment = pkgbase.comments.filter(models.PackageComment.ID == id).first()
if not comment: if not comment:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
@ -122,9 +118,8 @@ def out_of_date(packages: orm.Query) -> orm.Query:
return packages.filter(models.PackageBase.OutOfDateTS.isnot(None)) return packages.filter(models.PackageBase.OutOfDateTS.isnot(None))
def updated_packages(limit: int = 0, def updated_packages(limit: int = 0, cache_ttl: int = 600) -> list[models.Package]:
cache_ttl: int = 600) -> list[models.Package]: """Return a list of valid Package objects ordered by their
""" Return a list of valid Package objects ordered by their
ModifiedTS column in descending order from cache, after setting ModifiedTS column in descending order from cache, after setting
the cache when no key yet exists. the cache when no key yet exists.
@ -139,10 +134,11 @@ def updated_packages(limit: int = 0,
return orjson.loads(packages) return orjson.loads(packages)
with db.begin(): with db.begin():
query = db.query(models.Package).join(models.PackageBase).filter( query = (
models.PackageBase.PackagerUID.isnot(None) db.query(models.Package)
).order_by( .join(models.PackageBase)
models.PackageBase.ModifiedTS.desc() .filter(models.PackageBase.PackagerUID.isnot(None))
.order_by(models.PackageBase.ModifiedTS.desc())
) )
if limit: if limit:
@ -152,13 +148,13 @@ def updated_packages(limit: int = 0,
for pkg in query: for pkg in query:
# For each Package returned by the query, append a dict # For each Package returned by the query, append a dict
# containing Package columns we're interested in. # containing Package columns we're interested in.
packages.append({ packages.append(
"Name": pkg.Name, {
"Version": pkg.Version, "Name": pkg.Name,
"PackageBase": { "Version": pkg.Version,
"ModifiedTS": pkg.PackageBase.ModifiedTS "PackageBase": {"ModifiedTS": pkg.PackageBase.ModifiedTS},
} }
}) )
# Store the JSON serialization of the package_updates key into Redis. # Store the JSON serialization of the package_updates key into Redis.
redis.set("package_updates", orjson.dumps(packages)) redis.set("package_updates", orjson.dumps(packages))
@ -168,9 +164,8 @@ def updated_packages(limit: int = 0,
return packages return packages
def query_voted(query: list[models.Package], def query_voted(query: list[models.Package], user: models.User) -> dict[int, bool]:
user: models.User) -> dict[int, bool]: """Produce a dictionary of package base ID keys to boolean values,
""" Produce a dictionary of package base ID keys to boolean values,
which indicate whether or not the package base has a vote record which indicate whether or not the package base has a vote record
related to user. related to user.
@ -180,20 +175,18 @@ def query_voted(query: list[models.Package],
""" """
output = defaultdict(bool) output = defaultdict(bool)
query_set = {pkg.PackageBaseID for pkg in query} query_set = {pkg.PackageBaseID for pkg in query}
voted = db.query(models.PackageVote).join( voted = (
models.PackageBase, db.query(models.PackageVote)
models.PackageBase.ID.in_(query_set) .join(models.PackageBase, models.PackageBase.ID.in_(query_set))
).filter( .filter(models.PackageVote.UsersID == user.ID)
models.PackageVote.UsersID == user.ID
) )
for vote in voted: for vote in voted:
output[vote.PackageBase.ID] = True output[vote.PackageBase.ID] = True
return output return output
def query_notified(query: list[models.Package], def query_notified(query: list[models.Package], user: models.User) -> dict[int, bool]:
user: models.User) -> dict[int, bool]: """Produce a dictionary of package base ID keys to boolean values,
""" Produce a dictionary of package base ID keys to boolean values,
which indicate whether or not the package base has a notification which indicate whether or not the package base has a notification
record related to user. record related to user.
@ -203,19 +196,17 @@ def query_notified(query: list[models.Package],
""" """
output = defaultdict(bool) output = defaultdict(bool)
query_set = {pkg.PackageBaseID for pkg in query} query_set = {pkg.PackageBaseID for pkg in query}
notified = db.query(models.PackageNotification).join( notified = (
models.PackageBase, db.query(models.PackageNotification)
models.PackageBase.ID.in_(query_set) .join(models.PackageBase, models.PackageBase.ID.in_(query_set))
).filter( .filter(models.PackageNotification.UserID == user.ID)
models.PackageNotification.UserID == user.ID
) )
for notif in notified: for notif in notified:
output[notif.PackageBase.ID] = True output[notif.PackageBase.ID] = True
return output return output
def pkg_required(pkgname: str, provides: list[str]) \ def pkg_required(pkgname: str, provides: list[str]) -> list[PackageDependency]:
-> list[PackageDependency]:
""" """
Get dependencies that match a string in `[pkgname] + provides`. Get dependencies that match a string in `[pkgname] + provides`.
@ -225,9 +216,12 @@ def pkg_required(pkgname: str, provides: list[str]) \
:return: List of PackageDependency instances :return: List of PackageDependency instances
""" """
targets = set([pkgname] + provides) targets = set([pkgname] + provides)
query = db.query(PackageDependency).join(Package).filter( query = (
PackageDependency.DepName.in_(targets) db.query(PackageDependency)
).order_by(Package.Name.asc()) .join(Package)
.filter(PackageDependency.DepName.in_(targets))
.order_by(Package.Name.asc())
)
return query return query

View file

@ -14,15 +14,15 @@ logger = logging.get_logger(__name__)
def pkgbase_notify_instance(request: Request, pkgbase: PackageBase) -> None: def pkgbase_notify_instance(request: Request, pkgbase: PackageBase) -> None:
notif = db.query(pkgbase.notifications.filter( notif = db.query(
PackageNotification.UserID == request.user.ID pkgbase.notifications.filter(
).exists()).scalar() PackageNotification.UserID == request.user.ID
).exists()
).scalar()
has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY) has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY)
if has_cred and not notif: if has_cred and not notif:
with db.begin(): with db.begin():
db.create(PackageNotification, db.create(PackageNotification, PackageBase=pkgbase, User=request.user)
PackageBase=pkgbase,
User=request.user)
def pkgbase_unnotify_instance(request: Request, pkgbase: PackageBase) -> None: def pkgbase_unnotify_instance(request: Request, pkgbase: PackageBase) -> None:
@ -36,8 +36,11 @@ def pkgbase_unnotify_instance(request: Request, pkgbase: PackageBase) -> None:
def pkgbase_unflag_instance(request: Request, pkgbase: PackageBase) -> None: def pkgbase_unflag_instance(request: Request, pkgbase: PackageBase) -> None:
has_cred = request.user.has_credential(creds.PKGBASE_UNFLAG, approved=[ has_cred = request.user.has_credential(
pkgbase.Flagger, pkgbase.Maintainer] + [c.User for c in pkgbase.comaintainers]) creds.PKGBASE_UNFLAG,
approved=[pkgbase.Flagger, pkgbase.Maintainer]
+ [c.User for c in pkgbase.comaintainers],
)
if has_cred: if has_cred:
with db.begin(): with db.begin():
pkgbase.OutOfDateTS = None pkgbase.OutOfDateTS = None
@ -93,9 +96,9 @@ def pkgbase_adopt_instance(request: Request, pkgbase: PackageBase) -> None:
notif.send() notif.send()
def pkgbase_delete_instance(request: Request, pkgbase: PackageBase, def pkgbase_delete_instance(
comments: str = str()) \ request: Request, pkgbase: PackageBase, comments: str = str()
-> list[notify.Notification]: ) -> list[notify.Notification]:
notifs = handle_request(request, DELETION_ID, pkgbase) + [ notifs = handle_request(request, DELETION_ID, pkgbase) + [
notify.DeleteNotification(request.user.ID, pkgbase.ID) notify.DeleteNotification(request.user.ID, pkgbase.ID)
] ]
@ -107,8 +110,9 @@ def pkgbase_delete_instance(request: Request, pkgbase: PackageBase,
return notifs return notifs
def pkgbase_merge_instance(request: Request, pkgbase: PackageBase, def pkgbase_merge_instance(
target: PackageBase, comments: str = str()) -> None: request: Request, pkgbase: PackageBase, target: PackageBase, comments: str = str()
) -> None:
pkgbasename = str(pkgbase.Name) pkgbasename = str(pkgbase.Name)
# Create notifications. # Create notifications.
@ -144,8 +148,10 @@ def pkgbase_merge_instance(request: Request, pkgbase: PackageBase,
db.delete(pkgbase) db.delete(pkgbase)
# Log this out for accountability purposes. # Log this out for accountability purposes.
logger.info(f"Trusted User '{request.user.Username}' merged " logger.info(
f"'{pkgbasename}' into '{target.Name}'.") f"Trusted User '{request.user.Username}' merged "
f"'{pkgbasename}' into '{target.Name}'."
)
# Send notifications. # Send notifications.
util.apply_all(notifs, lambda n: n.send()) util.apply_all(notifs, lambda n: n.send())

View file

@ -10,19 +10,23 @@ from aurweb.models.package_comment import PackageComment
from aurweb.models.package_request import PENDING_ID, PackageRequest from aurweb.models.package_request import PENDING_ID, PackageRequest
from aurweb.models.package_vote import PackageVote from aurweb.models.package_vote import PackageVote
from aurweb.scripts import notify from aurweb.scripts import notify
from aurweb.templates import make_context as _make_context from aurweb.templates import (
from aurweb.templates import make_variable_context as _make_variable_context make_context as _make_context,
make_variable_context as _make_variable_context,
)
async def make_variable_context(request: Request, pkgbase: PackageBase) \ async def make_variable_context(
-> dict[str, Any]: request: Request, pkgbase: PackageBase
) -> dict[str, Any]:
ctx = await _make_variable_context(request, pkgbase.Name) ctx = await _make_variable_context(request, pkgbase.Name)
return make_context(request, pkgbase, ctx) return make_context(request, pkgbase, ctx)
def make_context(request: Request, pkgbase: PackageBase, def make_context(
context: dict[str, Any] = None) -> dict[str, Any]: request: Request, pkgbase: PackageBase, context: dict[str, Any] = None
""" Make a basic context for package or pkgbase. ) -> dict[str, Any]:
"""Make a basic context for package or pkgbase.
:param request: FastAPI request :param request: FastAPI request
:param pkgbase: PackageBase instance :param pkgbase: PackageBase instance
@ -34,14 +38,16 @@ def make_context(request: Request, pkgbase: PackageBase,
# Per page and offset. # Per page and offset.
offset, per_page = util.sanitize_params( offset, per_page = util.sanitize_params(
request.query_params.get("O", defaults.O), request.query_params.get("O", defaults.O),
request.query_params.get("PP", defaults.COMMENTS_PER_PAGE)) request.query_params.get("PP", defaults.COMMENTS_PER_PAGE),
)
context["O"] = offset context["O"] = offset
context["PP"] = per_page context["PP"] = per_page
context["git_clone_uri_anon"] = config.get("options", "git_clone_uri_anon") context["git_clone_uri_anon"] = config.get("options", "git_clone_uri_anon")
context["git_clone_uri_priv"] = config.get("options", "git_clone_uri_priv") context["git_clone_uri_priv"] = config.get("options", "git_clone_uri_priv")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
context["comaintainers"] = [ context["comaintainers"] = [
c.User for c in pkgbase.comaintainers.order_by( c.User
for c in pkgbase.comaintainers.order_by(
PackageComaintainer.Priority.asc() PackageComaintainer.Priority.asc()
).all() ).all()
] ]
@ -53,9 +59,11 @@ def make_context(request: Request, pkgbase: PackageBase,
context["comments_total"] = pkgbase.comments.order_by( context["comments_total"] = pkgbase.comments.order_by(
PackageComment.CommentTS.desc() PackageComment.CommentTS.desc()
).count() ).count()
context["comments"] = pkgbase.comments.order_by( context["comments"] = (
PackageComment.CommentTS.desc() pkgbase.comments.order_by(PackageComment.CommentTS.desc())
).limit(per_page).offset(offset) .limit(per_page)
.offset(offset)
)
context["pinned_comments"] = pkgbase.comments.filter( context["pinned_comments"] = pkgbase.comments.filter(
PackageComment.PinnedTS != 0 PackageComment.PinnedTS != 0
).order_by(PackageComment.CommentTS.desc()) ).order_by(PackageComment.CommentTS.desc())
@ -70,15 +78,15 @@ def make_context(request: Request, pkgbase: PackageBase,
).scalar() ).scalar()
context["requests"] = pkgbase.requests.filter( context["requests"] = pkgbase.requests.filter(
and_(PackageRequest.Status == PENDING_ID, and_(PackageRequest.Status == PENDING_ID, PackageRequest.ClosedTS.is_(None))
PackageRequest.ClosedTS.is_(None))
).count() ).count()
return context return context
def remove_comaintainer(comaint: PackageComaintainer) \ def remove_comaintainer(
-> notify.ComaintainerRemoveNotification: comaint: PackageComaintainer,
) -> notify.ComaintainerRemoveNotification:
""" """
Remove a PackageComaintainer. Remove a PackageComaintainer.
@ -107,9 +115,9 @@ def remove_comaintainers(pkgbase: PackageBase, usernames: list[str]) -> None:
""" """
notifications = [] notifications = []
with db.begin(): with db.begin():
comaintainers = pkgbase.comaintainers.join(User).filter( comaintainers = (
User.Username.in_(usernames) pkgbase.comaintainers.join(User).filter(User.Username.in_(usernames)).all()
).all() )
notifications = [ notifications = [
notify.ComaintainerRemoveNotification(co.User.ID, pkgbase.ID) notify.ComaintainerRemoveNotification(co.User.ID, pkgbase.ID)
for co in comaintainers for co in comaintainers
@ -133,23 +141,23 @@ def latest_priority(pkgbase: PackageBase) -> int:
""" """
# Order comaintainers related to pkgbase by Priority DESC. # Order comaintainers related to pkgbase by Priority DESC.
record = pkgbase.comaintainers.order_by( record = pkgbase.comaintainers.order_by(PackageComaintainer.Priority.desc()).first()
PackageComaintainer.Priority.desc()).first()
# Use Priority column if record exists, otherwise 0. # Use Priority column if record exists, otherwise 0.
return record.Priority if record else 0 return record.Priority if record else 0
class NoopComaintainerNotification: class NoopComaintainerNotification:
""" A noop notification stub used as an error-state return value. """ """A noop notification stub used as an error-state return value."""
def send(self) -> None: def send(self) -> None:
""" noop """ """noop"""
return return
def add_comaintainer(pkgbase: PackageBase, comaintainer: User) \ def add_comaintainer(
-> notify.ComaintainerAddNotification: pkgbase: PackageBase, comaintainer: User
) -> notify.ComaintainerAddNotification:
""" """
Add a new comaintainer to `pkgbase`. Add a new comaintainer to `pkgbase`.
@ -165,14 +173,19 @@ def add_comaintainer(pkgbase: PackageBase, comaintainer: User) \
new_prio = latest_priority(pkgbase) + 1 new_prio = latest_priority(pkgbase) + 1
with db.begin(): with db.begin():
db.create(PackageComaintainer, PackageBase=pkgbase, db.create(
User=comaintainer, Priority=new_prio) PackageComaintainer,
PackageBase=pkgbase,
User=comaintainer,
Priority=new_prio,
)
return notify.ComaintainerAddNotification(comaintainer.ID, pkgbase.ID) return notify.ComaintainerAddNotification(comaintainer.ID, pkgbase.ID)
def add_comaintainers(request: Request, pkgbase: PackageBase, def add_comaintainers(
usernames: list[str]) -> None: request: Request, pkgbase: PackageBase, usernames: list[str]
) -> None:
""" """
Add comaintainers to `pkgbase`. Add comaintainers to `pkgbase`.
@ -216,7 +229,6 @@ def rotate_comaintainers(pkgbase: PackageBase) -> None:
:param pkgbase: PackageBase instance :param pkgbase: PackageBase instance
""" """
comaintainers = pkgbase.comaintainers.order_by( comaintainers = pkgbase.comaintainers.order_by(PackageComaintainer.Priority.asc())
PackageComaintainer.Priority.asc())
for i, comaint in enumerate(comaintainers): for i, comaint in enumerate(comaintainers):
comaint.Priority = i + 1 comaint.Priority = i + 1

View file

@ -5,9 +5,13 @@ from aurweb.exceptions import ValidationError
from aurweb.models import PackageBase from aurweb.models import PackageBase
def request(pkgbase: PackageBase, def request(
type: str, comments: str, merge_into: str, pkgbase: PackageBase,
context: dict[str, Any]) -> None: type: str,
comments: str,
merge_into: str,
context: dict[str, Any],
) -> None:
if not comments: if not comments:
raise ValidationError(["The comment field must not be empty."]) raise ValidationError(["The comment field must not be empty."])
@ -15,21 +19,16 @@ def request(pkgbase: PackageBase,
# Perform merge-related checks. # Perform merge-related checks.
if not merge_into: if not merge_into:
# TODO: This error needs to be translated. # TODO: This error needs to be translated.
raise ValidationError( raise ValidationError(['The "Merge into" field must not be empty.'])
['The "Merge into" field must not be empty.'])
target = db.query(PackageBase).filter( target = db.query(PackageBase).filter(PackageBase.Name == merge_into).first()
PackageBase.Name == merge_into
).first()
if not target: if not target:
# TODO: This error needs to be translated. # TODO: This error needs to be translated.
raise ValidationError([ raise ValidationError(
"The package base you want to merge into does not exist." ["The package base you want to merge into does not exist."]
]) )
db.refresh(target) db.refresh(target)
if target.ID == pkgbase.ID: if target.ID == pkgbase.ID:
# TODO: This error needs to be translated. # TODO: This error needs to be translated.
raise ValidationError([ raise ValidationError(["You cannot merge a package base into itself."])
"You cannot merge a package base into itself."
])

View file

@ -19,8 +19,9 @@ def instrumentator():
# Their license is included in LICENSES/starlette_exporter. # Their license is included in LICENSES/starlette_exporter.
# The code has been modified to remove child route checks # The code has been modified to remove child route checks
# (since we don't have any) and to stay within an 80-width limit. # (since we don't have any) and to stay within an 80-width limit.
def get_matching_route_path(scope: dict[Any, Any], routes: list[Route], def get_matching_route_path(
route_name: Optional[str] = None) -> str: scope: dict[Any, Any], routes: list[Route], route_name: Optional[str] = None
) -> str:
""" """
Find a matching route and return its original path string Find a matching route and return its original path string
@ -34,7 +35,7 @@ def get_matching_route_path(scope: dict[Any, Any], routes: list[Route],
if match == Match.FULL: if match == Match.FULL:
route_name = route.path route_name = route.path
''' """
# This path exists in the original function's code, but we # This path exists in the original function's code, but we
# don't need it (currently), so it's been removed to avoid # don't need it (currently), so it's been removed to avoid
# useless test coverage. # useless test coverage.
@ -47,7 +48,7 @@ def get_matching_route_path(scope: dict[Any, Any], routes: list[Route],
route_name = None route_name = None
else: else:
route_name += child_route_name route_name += child_route_name
''' """
return route_name return route_name
elif match == Match.PARTIAL and route_name is None: elif match == Match.PARTIAL and route_name is None:
@ -55,9 +56,11 @@ def get_matching_route_path(scope: dict[Any, Any], routes: list[Route],
def http_requests_total() -> Callable[[Info], None]: def http_requests_total() -> Callable[[Info], None]:
metric = Counter("http_requests_total", metric = Counter(
"Number of HTTP requests.", "http_requests_total",
labelnames=("method", "path", "status")) "Number of HTTP requests.",
labelnames=("method", "path", "status"),
)
def instrumentation(info: Info) -> None: def instrumentation(info: Info) -> None:
if info.request.method.lower() in ("head", "options"): # pragma: no cover if info.request.method.lower() in ("head", "options"): # pragma: no cover
@ -79,13 +82,13 @@ def http_requests_total() -> Callable[[Info], None]:
if hasattr(app, "root_path"): if hasattr(app, "root_path"):
app_root_path = getattr(app, "root_path") app_root_path = getattr(app, "root_path")
if root_path.startswith(app_root_path): if root_path.startswith(app_root_path):
root_path = root_path[len(app_root_path):] root_path = root_path[len(app_root_path) :]
base_scope = { base_scope = {
"type": scope.get("type"), "type": scope.get("type"),
"path": root_path + scope.get("path"), "path": root_path + scope.get("path"),
"path_params": scope.get("path_params", {}), "path_params": scope.get("path_params", {}),
"method": scope.get("method") "method": scope.get("method"),
} }
method = scope.get("method") method = scope.get("method")
@ -102,7 +105,8 @@ def http_api_requests_total() -> Callable[[Info], None]:
metric = Counter( metric = Counter(
"http_api_requests", "http_api_requests",
"Number of times an RPC API type has been requested.", "Number of times an RPC API type has been requested.",
labelnames=("type", "status")) labelnames=("type", "status"),
)
def instrumentation(info: Info) -> None: def instrumentation(info: Info) -> None:
if info.request.method.lower() in ("head", "options"): # pragma: no cover if info.request.method.lower() in ("head", "options"): # pragma: no cover

View file

@ -38,8 +38,7 @@ def _update_ratelimit_db(request: Request):
now = time.utcnow() now = time.utcnow()
time_to_delete = now - window_length time_to_delete = now - window_length
records = db.query(ApiRateLimit).filter( records = db.query(ApiRateLimit).filter(ApiRateLimit.WindowStart < time_to_delete)
ApiRateLimit.WindowStart < time_to_delete)
with db.begin(): with db.begin():
db.delete_all(records) db.delete_all(records)
@ -47,9 +46,7 @@ def _update_ratelimit_db(request: Request):
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first() record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
with db.begin(): with db.begin():
if not record: if not record:
record = db.create(ApiRateLimit, record = db.create(ApiRateLimit, WindowStart=now, IP=host, Requests=1)
WindowStart=now,
IP=host, Requests=1)
else: else:
record.Requests += 1 record.Requests += 1
@ -58,7 +55,7 @@ def _update_ratelimit_db(request: Request):
def update_ratelimit(request: Request, pipeline: Pipeline): def update_ratelimit(request: Request, pipeline: Pipeline):
""" Update the ratelimit stored in Redis or the database depending """Update the ratelimit stored in Redis or the database depending
on AUR_CONFIG's [options] cache setting. on AUR_CONFIG's [options] cache setting.
This Redis-capable function is slightly different than most. If Redis This Redis-capable function is slightly different than most. If Redis
@ -75,7 +72,7 @@ def update_ratelimit(request: Request, pipeline: Pipeline):
def check_ratelimit(request: Request): def check_ratelimit(request: Request):
""" Increment and check to see if request has exceeded their rate limit. """Increment and check to see if request has exceeded their rate limit.
:param request: FastAPI request :param request: FastAPI request
:returns: True if the request host has exceeded the rate limit else False :returns: True if the request host has exceeded the rate limit else False

View file

@ -1,9 +1,7 @@
import fakeredis import fakeredis
from redis import ConnectionPool, Redis from redis import ConnectionPool, Redis
import aurweb.config import aurweb.config
from aurweb import logging from aurweb import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -11,7 +9,7 @@ pool = None
class FakeConnectionPool: class FakeConnectionPool:
""" A fake ConnectionPool class which holds an internal reference """A fake ConnectionPool class which holds an internal reference
to a fakeredis handle. to a fakeredis handle.
We normally deal with Redis by keeping its ConnectionPool globally We normally deal with Redis by keeping its ConnectionPool globally

View file

@ -3,7 +3,18 @@ API routers for FastAPI.
See https://fastapi.tiangolo.com/tutorial/bigger-applications/ See https://fastapi.tiangolo.com/tutorial/bigger-applications/
""" """
from . import accounts, auth, html, packages, pkgbase, requests, rpc, rss, sso, trusted_user from . import (
accounts,
auth,
html,
packages,
pkgbase,
requests,
rpc,
rss,
sso,
trusted_user,
)
""" """
aurweb application routes. This constant can be any iterable aurweb application routes. This constant can be any iterable

View file

@ -1,6 +1,5 @@
import copy import copy
import typing import typing
from http import HTTPStatus from http import HTTPStatus
from typing import Any from typing import Any
@ -9,7 +8,6 @@ from fastapi.responses import HTMLResponse, RedirectResponse
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
import aurweb.config import aurweb.config
from aurweb import cookies, db, l10n, logging, models, util from aurweb import cookies, db, l10n, logging, models, util
from aurweb.auth import account_type_required, requires_auth, requires_guest from aurweb.auth import account_type_required, requires_auth, requires_guest
from aurweb.captcha import get_captcha_salts from aurweb.captcha import get_captcha_salts
@ -37,21 +35,23 @@ async def passreset(request: Request):
@router.post("/passreset", response_class=HTMLResponse) @router.post("/passreset", response_class=HTMLResponse)
@handle_form_exceptions @handle_form_exceptions
@requires_guest @requires_guest
async def passreset_post(request: Request, async def passreset_post(
user: str = Form(...), request: Request,
resetkey: str = Form(default=None), user: str = Form(...),
password: str = Form(default=None), resetkey: str = Form(default=None),
confirm: str = Form(default=None)): password: str = Form(default=None),
confirm: str = Form(default=None),
):
context = await make_variable_context(request, "Password Reset") context = await make_variable_context(request, "Password Reset")
# The user parameter being required, we can match against # The user parameter being required, we can match against
criteria = or_(models.User.Username == user, models.User.Email == user) criteria = or_(models.User.Username == user, models.User.Email == user)
db_user = db.query(models.User, db_user = db.query(models.User, and_(criteria, models.User.Suspended == 0)).first()
and_(criteria, models.User.Suspended == 0)).first()
if db_user is None: if db_user is None:
context["errors"] = ["Invalid e-mail."] context["errors"] = ["Invalid e-mail."]
return render_template(request, "passreset.html", context, return render_template(
status_code=HTTPStatus.NOT_FOUND) request, "passreset.html", context, status_code=HTTPStatus.NOT_FOUND
)
db.refresh(db_user) db.refresh(db_user)
if resetkey: if resetkey:
@ -59,29 +59,34 @@ async def passreset_post(request: Request,
if not db_user.ResetKey or resetkey != db_user.ResetKey: if not db_user.ResetKey or resetkey != db_user.ResetKey:
context["errors"] = ["Invalid e-mail."] context["errors"] = ["Invalid e-mail."]
return render_template(request, "passreset.html", context, return render_template(
status_code=HTTPStatus.NOT_FOUND) request, "passreset.html", context, status_code=HTTPStatus.NOT_FOUND
)
if not user or not password: if not user or not password:
context["errors"] = ["Missing a required field."] context["errors"] = ["Missing a required field."]
return render_template(request, "passreset.html", context, return render_template(
status_code=HTTPStatus.BAD_REQUEST) request, "passreset.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if password != confirm: if password != confirm:
# If the provided password does not match the provided confirm. # If the provided password does not match the provided confirm.
context["errors"] = ["Password fields do not match."] context["errors"] = ["Password fields do not match."]
return render_template(request, "passreset.html", context, return render_template(
status_code=HTTPStatus.BAD_REQUEST) request, "passreset.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if len(password) < models.User.minimum_passwd_length(): if len(password) < models.User.minimum_passwd_length():
# Translate the error here, which simplifies error output # Translate the error here, which simplifies error output
# in the jinja2 template. # in the jinja2 template.
_ = get_translator_for_request(request) _ = get_translator_for_request(request)
context["errors"] = [_( context["errors"] = [
"Your password must be at least %s characters.") % ( _("Your password must be at least %s characters.")
str(models.User.minimum_passwd_length()))] % (str(models.User.minimum_passwd_length()))
return render_template(request, "passreset.html", context, ]
status_code=HTTPStatus.BAD_REQUEST) return render_template(
request, "passreset.html", context, status_code=HTTPStatus.BAD_REQUEST
)
# We got to this point; everything matched up. Update the password # We got to this point; everything matched up. Update the password
# and remove the ResetKey. # and remove the ResetKey.
@ -92,8 +97,9 @@ async def passreset_post(request: Request,
db_user.update_password(password) db_user.update_password(password)
# Render ?step=complete. # Render ?step=complete.
return RedirectResponse(url="/passreset?step=complete", return RedirectResponse(
status_code=HTTPStatus.SEE_OTHER) url="/passreset?step=complete", status_code=HTTPStatus.SEE_OTHER
)
# If we got here, we continue with issuing a resetkey for the user. # If we got here, we continue with issuing a resetkey for the user.
resetkey = generate_resetkey() resetkey = generate_resetkey()
@ -103,13 +109,13 @@ async def passreset_post(request: Request,
ResetKeyNotification(db_user.ID).send() ResetKeyNotification(db_user.ID).send()
# Render ?step=confirm. # Render ?step=confirm.
return RedirectResponse(url="/passreset?step=confirm", return RedirectResponse(
status_code=HTTPStatus.SEE_OTHER) url="/passreset?step=confirm", status_code=HTTPStatus.SEE_OTHER
)
def process_account_form(request: Request, user: models.User, def process_account_form(request: Request, user: models.User, args: dict[str, Any]):
args: dict[str, Any]): """Process an account form. All fields are optional and only checks
""" Process an account form. All fields are optional and only checks
requirements in the case they are present. requirements in the case they are present.
``` ```
@ -146,7 +152,7 @@ def process_account_form(request: Request, user: models.User,
validate.username_in_use, validate.username_in_use,
validate.email_in_use, validate.email_in_use,
validate.invalid_account_type, validate.invalid_account_type,
validate.invalid_captcha validate.invalid_captcha,
] ]
try: try:
@ -158,11 +164,10 @@ def process_account_form(request: Request, user: models.User,
return (True, []) return (True, [])
def make_account_form_context(context: dict, def make_account_form_context(
request: Request, context: dict, request: Request, user: models.User, args: dict
user: models.User, ):
args: dict): """Modify a FastAPI context and add attributes for the account form.
""" Modify a FastAPI context and add attributes for the account form.
:param context: FastAPI context :param context: FastAPI context
:param request: FastAPI request :param request: FastAPI request
@ -173,15 +178,17 @@ def make_account_form_context(context: dict,
# Do not modify the original context. # Do not modify the original context.
context = copy.copy(context) context = copy.copy(context)
context["account_types"] = list(filter( context["account_types"] = list(
lambda e: request.user.AccountTypeID >= e[0], filter(
[ lambda e: request.user.AccountTypeID >= e[0],
(at.USER_ID, f"Normal {at.USER}"), [
(at.TRUSTED_USER_ID, at.TRUSTED_USER), (at.USER_ID, f"Normal {at.USER}"),
(at.DEVELOPER_ID, at.DEVELOPER), (at.TRUSTED_USER_ID, at.TRUSTED_USER),
(at.TRUSTED_USER_AND_DEV_ID, at.TRUSTED_USER_AND_DEV) (at.DEVELOPER_ID, at.DEVELOPER),
] (at.TRUSTED_USER_AND_DEV_ID, at.TRUSTED_USER_AND_DEV),
)) ],
)
)
if request.user.is_authenticated(): if request.user.is_authenticated():
context["username"] = args.get("U", user.Username) context["username"] = args.get("U", user.Username)
@ -229,24 +236,24 @@ def make_account_form_context(context: dict,
@router.get("/register", response_class=HTMLResponse) @router.get("/register", response_class=HTMLResponse)
@requires_guest @requires_guest
async def account_register(request: Request, async def account_register(
U: str = Form(default=str()), # Username request: Request,
E: str = Form(default=str()), # Email U: str = Form(default=str()), # Username
H: str = Form(default=False), # Hide Email E: str = Form(default=str()), # Email
BE: str = Form(default=None), # Backup Email H: str = Form(default=False), # Hide Email
R: str = Form(default=None), # Real Name BE: str = Form(default=None), # Backup Email
HP: str = Form(default=None), # Homepage R: str = Form(default=None), # Real Name
I: str = Form(default=None), # IRC Nick HP: str = Form(default=None), # Homepage
K: str = Form(default=None), # PGP Key FP I: str = Form(default=None), # IRC Nick
L: str = Form(default=aurweb.config.get( K: str = Form(default=None), # PGP Key FP
"options", "default_lang")), L: str = Form(default=aurweb.config.get("options", "default_lang")),
TZ: str = Form(default=aurweb.config.get( TZ: str = Form(default=aurweb.config.get("options", "default_timezone")),
"options", "default_timezone")), PK: str = Form(default=None),
PK: str = Form(default=None), CN: bool = Form(default=False), # Comment Notify
CN: bool = Form(default=False), # Comment Notify CU: bool = Form(default=False), # Update Notify
CU: bool = Form(default=False), # Update Notify CO: bool = Form(default=False), # Owner Notify
CO: bool = Form(default=False), # Owner Notify captcha: str = Form(default=str()),
captcha: str = Form(default=str())): ):
context = await make_variable_context(request, "Register") context = await make_variable_context(request, "Register")
context["captcha_salt"] = get_captcha_salts()[0] context["captcha_salt"] = get_captcha_salts()[0]
context = make_account_form_context(context, request, None, dict()) context = make_account_form_context(context, request, None, dict())
@ -256,32 +263,32 @@ async def account_register(request: Request,
@router.post("/register", response_class=HTMLResponse) @router.post("/register", response_class=HTMLResponse)
@handle_form_exceptions @handle_form_exceptions
@requires_guest @requires_guest
async def account_register_post(request: Request, async def account_register_post(
U: str = Form(default=str()), # Username request: Request,
E: str = Form(default=str()), # Email U: str = Form(default=str()), # Username
H: str = Form(default=False), # Hide Email E: str = Form(default=str()), # Email
BE: str = Form(default=None), # Backup Email H: str = Form(default=False), # Hide Email
R: str = Form(default=''), # Real Name BE: str = Form(default=None), # Backup Email
HP: str = Form(default=None), # Homepage R: str = Form(default=""), # Real Name
I: str = Form(default=None), # IRC Nick HP: str = Form(default=None), # Homepage
K: str = Form(default=None), # PGP Key I: str = Form(default=None), # IRC Nick
L: str = Form(default=aurweb.config.get( K: str = Form(default=None), # PGP Key
"options", "default_lang")), L: str = Form(default=aurweb.config.get("options", "default_lang")),
TZ: str = Form(default=aurweb.config.get( TZ: str = Form(default=aurweb.config.get("options", "default_timezone")),
"options", "default_timezone")), PK: str = Form(default=str()), # SSH PubKey
PK: str = Form(default=str()), # SSH PubKey CN: bool = Form(default=False),
CN: bool = Form(default=False), UN: bool = Form(default=False),
UN: bool = Form(default=False), ON: bool = Form(default=False),
ON: bool = Form(default=False), captcha: str = Form(default=None),
captcha: str = Form(default=None), captcha_salt: str = Form(...),
captcha_salt: str = Form(...)): ):
context = await make_variable_context(request, "Register") context = await make_variable_context(request, "Register")
args = dict(await request.form()) args = dict(await request.form())
args["K"] = args.get("K", str()).replace(" ", "") args["K"] = args.get("K", str()).replace(" ", "")
K = args.get("K") K = args.get("K")
# Force "H" into a boolean. # Force "H" into a boolean.
args["H"] = H = (args.get("H", str()) == "on") args["H"] = H = args.get("H", str()) == "on"
context = make_account_form_context(context, request, None, args) context = make_account_form_context(context, request, None, args)
ok, errors = process_account_form(request, request.user, args) ok, errors = process_account_form(request, request.user, args)
@ -289,30 +296,45 @@ async def account_register_post(request: Request,
# If the field values given do not meet the requirements, # If the field values given do not meet the requirements,
# return HTTP 400 with an error. # return HTTP 400 with an error.
context["errors"] = errors context["errors"] = errors
return render_template(request, "register.html", context, return render_template(
status_code=HTTPStatus.BAD_REQUEST) request, "register.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if not captcha: if not captcha:
context["errors"] = ["The CAPTCHA is missing."] context["errors"] = ["The CAPTCHA is missing."]
return render_template(request, "register.html", context, return render_template(
status_code=HTTPStatus.BAD_REQUEST) request, "register.html", context, status_code=HTTPStatus.BAD_REQUEST
)
# Create a user with no password with a resetkey, then send # Create a user with no password with a resetkey, then send
# an email off about it. # an email off about it.
resetkey = generate_resetkey() resetkey = generate_resetkey()
# By default, we grab the User account type to associate with. # By default, we grab the User account type to associate with.
atype = db.query(models.AccountType, atype = db.query(
models.AccountType.AccountType == "User").first() models.AccountType, models.AccountType.AccountType == "User"
).first()
# Create a user given all parameters available. # Create a user given all parameters available.
with db.begin(): with db.begin():
user = db.create(models.User, Username=U, user = db.create(
Email=E, HideEmail=H, BackupEmail=BE, models.User,
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K, Username=U,
LangPreference=L, Timezone=TZ, CommentNotify=CN, Email=E,
UpdateNotify=UN, OwnershipNotify=ON, HideEmail=H,
ResetKey=resetkey, AccountType=atype) BackupEmail=BE,
RealName=R,
Homepage=HP,
IRCNick=I,
PGPKey=K,
LangPreference=L,
Timezone=TZ,
CommentNotify=CN,
UpdateNotify=UN,
OwnershipNotify=ON,
ResetKey=resetkey,
AccountType=atype,
)
# If a PK was given and either one does not exist or the given # If a PK was given and either one does not exist or the given
# PK mismatches the existing user's SSHPubKey.PubKey. # PK mismatches the existing user's SSHPubKey.PubKey.
@ -323,8 +345,9 @@ async def account_register_post(request: Request,
pk = " ".join(k) pk = " ".join(k)
fprint = get_fingerprint(pk) fprint = get_fingerprint(pk)
with db.begin(): with db.begin():
db.create(models.SSHPubKey, UserID=user.ID, db.create(
PubKey=pk, Fingerprint=fprint) models.SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fprint
)
# Send a reset key notification to the new user. # Send a reset key notification to the new user.
WelcomeNotification(user.ID).send() WelcomeNotification(user.ID).send()
@ -334,8 +357,9 @@ async def account_register_post(request: Request,
return render_template(request, "register.html", context) return render_template(request, "register.html", context)
def cannot_edit(request: Request, user: models.User) \ def cannot_edit(
-> typing.Optional[RedirectResponse]: request: Request, user: models.User
) -> typing.Optional[RedirectResponse]:
""" """
Decide if `request.user` cannot edit `user`. Decide if `request.user` cannot edit `user`.
@ -373,31 +397,30 @@ async def account_edit(request: Request, username: str):
@router.post("/account/{username}/edit", response_class=HTMLResponse) @router.post("/account/{username}/edit", response_class=HTMLResponse)
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def account_edit_post(request: Request, async def account_edit_post(
username: str, request: Request,
U: str = Form(default=str()), # Username username: str,
J: bool = Form(default=False), U: str = Form(default=str()), # Username
E: str = Form(default=str()), # Email J: bool = Form(default=False),
H: str = Form(default=False), # Hide Email E: str = Form(default=str()), # Email
BE: str = Form(default=None), # Backup Email H: str = Form(default=False), # Hide Email
R: str = Form(default=None), # Real Name BE: str = Form(default=None), # Backup Email
HP: str = Form(default=None), # Homepage R: str = Form(default=None), # Real Name
I: str = Form(default=None), # IRC Nick HP: str = Form(default=None), # Homepage
K: str = Form(default=None), # PGP Key I: str = Form(default=None), # IRC Nick
L: str = Form(aurweb.config.get( K: str = Form(default=None), # PGP Key
"options", "default_lang")), L: str = Form(aurweb.config.get("options", "default_lang")),
TZ: str = Form(aurweb.config.get( TZ: str = Form(aurweb.config.get("options", "default_timezone")),
"options", "default_timezone")), P: str = Form(default=str()), # New Password
P: str = Form(default=str()), # New Password C: str = Form(default=None), # Password Confirm
C: str = Form(default=None), # Password Confirm PK: str = Form(default=None), # PubKey
PK: str = Form(default=None), # PubKey CN: bool = Form(default=False), # Comment Notify
CN: bool = Form(default=False), # Comment Notify UN: bool = Form(default=False), # Update Notify
UN: bool = Form(default=False), # Update Notify ON: bool = Form(default=False), # Owner Notify
ON: bool = Form(default=False), # Owner Notify T: int = Form(default=None),
T: int = Form(default=None), passwd: str = Form(default=str()),
passwd: str = Form(default=str())): ):
user = db.query(models.User).filter( user = db.query(models.User).filter(models.User.Username == username).first()
models.User.Username == username).first()
response = cannot_edit(request, user) response = cannot_edit(request, user)
if response: if response:
return response return response
@ -416,13 +439,15 @@ async def account_edit_post(request: Request,
if not passwd: if not passwd:
context["errors"] = ["Invalid password."] context["errors"] = ["Invalid password."]
return render_template(request, "account/edit.html", context, return render_template(
status_code=HTTPStatus.BAD_REQUEST) request, "account/edit.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if not ok: if not ok:
context["errors"] = errors context["errors"] = errors
return render_template(request, "account/edit.html", context, return render_template(
status_code=HTTPStatus.BAD_REQUEST) request, "account/edit.html", context, status_code=HTTPStatus.BAD_REQUEST
)
updates = [ updates = [
update.simple, update.simple,
@ -430,7 +455,7 @@ async def account_edit_post(request: Request,
update.timezone, update.timezone,
update.ssh_pubkey, update.ssh_pubkey,
update.account_type, update.account_type,
update.password update.password,
] ]
for f in updates: for f in updates:
@ -441,18 +466,17 @@ async def account_edit_post(request: Request,
# Update cookies with requests, in case they were changed. # Update cookies with requests, in case they were changed.
response = render_template(request, "account/edit.html", context) response = render_template(request, "account/edit.html", context)
return cookies.update_response_cookies(request, response, return cookies.update_response_cookies(request, response, aurtz=TZ, aurlang=L)
aurtz=TZ, aurlang=L)
@router.get("/account/{username}") @router.get("/account/{username}")
async def account(request: Request, username: str): async def account(request: Request, username: str):
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
context = await make_variable_context( context = await make_variable_context(request, _("Account") + " " + username)
request, _("Account") + " " + username)
if not request.user.is_authenticated(): if not request.user.is_authenticated():
return render_template(request, "account/show.html", context, return render_template(
status_code=HTTPStatus.UNAUTHORIZED) request, "account/show.html", context, status_code=HTTPStatus.UNAUTHORIZED
)
# Get related User record, if possible. # Get related User record, if possible.
user = get_user_by_name(username) user = get_user_by_name(username)
@ -460,11 +484,10 @@ async def account(request: Request, username: str):
# Format PGPKey for display with a space between each 4 characters. # Format PGPKey for display with a space between each 4 characters.
k = user.PGPKey or str() k = user.PGPKey or str()
context["pgp_key"] = " ".join([k[i:i + 4] for i in range(0, len(k), 4)]) context["pgp_key"] = " ".join([k[i : i + 4] for i in range(0, len(k), 4)])
login_ts = None login_ts = None
session = db.query(models.Session).filter( session = db.query(models.Session).filter(models.Session.UsersID == user.ID).first()
models.Session.UsersID == user.ID).first()
if session: if session:
login_ts = user.session.LastUpdateTS login_ts = user.session.LastUpdateTS
context["login_ts"] = login_ts context["login_ts"] = login_ts
@ -480,15 +503,14 @@ async def account_comments(request: Request, username: str):
context = make_context(request, "Accounts") context = make_context(request, "Accounts")
context["username"] = username context["username"] = username
context["comments"] = user.package_comments.order_by( context["comments"] = user.package_comments.order_by(
models.PackageComment.CommentTS.desc()) models.PackageComment.CommentTS.desc()
)
return render_template(request, "account/comments.html", context) return render_template(request, "account/comments.html", context)
@router.get("/accounts") @router.get("/accounts")
@requires_auth @requires_auth
@account_type_required({at.TRUSTED_USER, @account_type_required({at.TRUSTED_USER, at.DEVELOPER, at.TRUSTED_USER_AND_DEV})
at.DEVELOPER,
at.TRUSTED_USER_AND_DEV})
async def accounts(request: Request): async def accounts(request: Request):
context = make_context(request, "Accounts") context = make_context(request, "Accounts")
return render_template(request, "account/search.html", context) return render_template(request, "account/search.html", context)
@ -497,19 +519,19 @@ async def accounts(request: Request):
@router.post("/accounts") @router.post("/accounts")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@account_type_required({at.TRUSTED_USER, @account_type_required({at.TRUSTED_USER, at.DEVELOPER, at.TRUSTED_USER_AND_DEV})
at.DEVELOPER, async def accounts_post(
at.TRUSTED_USER_AND_DEV}) request: Request,
async def accounts_post(request: Request, O: int = Form(default=0), # Offset
O: int = Form(default=0), # Offset SB: str = Form(default=str()), # Sort By
SB: str = Form(default=str()), # Sort By U: str = Form(default=str()), # Username
U: str = Form(default=str()), # Username T: str = Form(default=str()), # Account Type
T: str = Form(default=str()), # Account Type S: bool = Form(default=False), # Suspended
S: bool = Form(default=False), # Suspended E: str = Form(default=str()), # Email
E: str = Form(default=str()), # Email R: str = Form(default=str()), # Real Name
R: str = Form(default=str()), # Real Name I: str = Form(default=str()), # IRC Nick
I: str = Form(default=str()), # IRC Nick K: str = Form(default=str()),
K: str = Form(default=str())): # PGP Key ): # PGP Key
context = await make_variable_context(request, "Accounts") context = await make_variable_context(request, "Accounts")
context["pp"] = pp = 50 # Hits per page. context["pp"] = pp = 50 # Hits per page.
@ -534,7 +556,7 @@ async def accounts_post(request: Request,
"u": at.USER_ID, "u": at.USER_ID,
"t": at.TRUSTED_USER_ID, "t": at.TRUSTED_USER_ID,
"d": at.DEVELOPER_ID, "d": at.DEVELOPER_ID,
"td": at.TRUSTED_USER_AND_DEV_ID "td": at.TRUSTED_USER_AND_DEV_ID,
} }
account_type_id = account_types.get(T, None) account_type_id = account_types.get(T, None)
@ -545,7 +567,8 @@ async def accounts_post(request: Request,
# Populate this list with any additional statements to # Populate this list with any additional statements to
# be ANDed together. # be ANDed together.
statements = [ statements = [
v for k, v in [ v
for k, v in [
(account_type_id is not None, models.AccountType.ID == account_type_id), (account_type_id is not None, models.AccountType.ID == account_type_id),
(bool(U), models.User.Username.like(f"%{U}%")), (bool(U), models.User.Username.like(f"%{U}%")),
(bool(S), models.User.Suspended == S), (bool(S), models.User.Suspended == S),
@ -553,7 +576,8 @@ async def accounts_post(request: Request,
(bool(R), models.User.RealName.like(f"%{R}%")), (bool(R), models.User.RealName.like(f"%{R}%")),
(bool(I), models.User.IRCNick.like(f"%{I}%")), (bool(I), models.User.IRCNick.like(f"%{I}%")),
(bool(K), models.User.PGPKey.like(f"%{K}%")), (bool(K), models.User.PGPKey.like(f"%{K}%")),
] if k ]
if k
] ]
# Filter the query by coe-mbining all statements added above into # Filter the query by coe-mbining all statements added above into
@ -571,9 +595,7 @@ async def accounts_post(request: Request,
return render_template(request, "account/index.html", context) return render_template(request, "account/index.html", context)
def render_terms_of_service(request: Request, def render_terms_of_service(request: Request, context: dict, terms: typing.Iterable):
context: dict,
terms: typing.Iterable):
if not terms: if not terms:
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
context["unaccepted_terms"] = terms context["unaccepted_terms"] = terms
@ -585,14 +607,21 @@ def render_terms_of_service(request: Request,
async def terms_of_service(request: Request): async def terms_of_service(request: Request):
# Query the database for terms that were previously accepted, # Query the database for terms that were previously accepted,
# but now have a bumped Revision that needs to be accepted. # but now have a bumped Revision that needs to be accepted.
diffs = db.query(models.Term).join(models.AcceptedTerm).filter( diffs = (
models.AcceptedTerm.Revision < models.Term.Revision).all() db.query(models.Term)
.join(models.AcceptedTerm)
.filter(models.AcceptedTerm.Revision < models.Term.Revision)
.all()
)
# Query the database for any terms that have not yet been accepted. # Query the database for any terms that have not yet been accepted.
unaccepted = db.query(models.Term).filter( unaccepted = (
~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all() db.query(models.Term)
.filter(~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID)))
.all()
)
for record in (diffs + unaccepted): for record in diffs + unaccepted:
db.refresh(record) db.refresh(record)
# Translate the 'Terms of Service' part of our page title. # Translate the 'Terms of Service' part of our page title.
@ -607,16 +636,22 @@ async def terms_of_service(request: Request):
@router.post("/tos") @router.post("/tos")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def terms_of_service_post(request: Request, async def terms_of_service_post(request: Request, accept: bool = Form(default=False)):
accept: bool = Form(default=False)):
# Query the database for terms that were previously accepted, # Query the database for terms that were previously accepted,
# but now have a bumped Revision that needs to be accepted. # but now have a bumped Revision that needs to be accepted.
diffs = db.query(models.Term).join(models.AcceptedTerm).filter( diffs = (
models.AcceptedTerm.Revision < models.Term.Revision).all() db.query(models.Term)
.join(models.AcceptedTerm)
.filter(models.AcceptedTerm.Revision < models.Term.Revision)
.all()
)
# Query the database for any terms that have not yet been accepted. # Query the database for any terms that have not yet been accepted.
unaccepted = db.query(models.Term).filter( unaccepted = (
~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all() db.query(models.Term)
.filter(~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID)))
.all()
)
if not accept: if not accept:
# Translate the 'Terms of Service' part of our page title. # Translate the 'Terms of Service' part of our page title.
@ -628,7 +663,8 @@ async def terms_of_service_post(request: Request,
# them instead of reiterating the process in terms_of_service. # them instead of reiterating the process in terms_of_service.
accept_needed = sorted(unaccepted + diffs) accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service( return render_terms_of_service(
request, context, util.apply_all(accept_needed, db.refresh)) request, context, util.apply_all(accept_needed, db.refresh)
)
with db.begin(): with db.begin():
# For each term we found, query for the matching accepted term # For each term we found, query for the matching accepted term
@ -636,13 +672,18 @@ async def terms_of_service_post(request: Request,
for term in diffs: for term in diffs:
db.refresh(term) db.refresh(term)
accepted_term = request.user.accepted_terms.filter( accepted_term = request.user.accepted_terms.filter(
models.AcceptedTerm.TermsID == term.ID).first() models.AcceptedTerm.TermsID == term.ID
).first()
accepted_term.Revision = term.Revision accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it! # For each term that was never accepted, accept it!
for term in unaccepted: for term in unaccepted:
db.refresh(term) db.refresh(term)
db.create(models.AcceptedTerm, User=request.user, db.create(
Term=term, Revision=term.Revision) models.AcceptedTerm,
User=request.user,
Term=term,
Revision=term.Revision,
)
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)

View file

@ -5,7 +5,6 @@ from fastapi.responses import HTMLResponse, RedirectResponse
from sqlalchemy import or_ from sqlalchemy import or_
import aurweb.config import aurweb.config
from aurweb import cookies, db from aurweb import cookies, db
from aurweb.auth import requires_auth, requires_guest from aurweb.auth import requires_auth, requires_guest
from aurweb.exceptions import handle_form_exceptions from aurweb.exceptions import handle_form_exceptions
@ -17,7 +16,7 @@ router = APIRouter()
async def login_template(request: Request, next: str, errors: list = None): async def login_template(request: Request, next: str, errors: list = None):
""" Provide login-specific template context to render_template. """ """Provide login-specific template context to render_template."""
context = await make_variable_context(request, "Login", next) context = await make_variable_context(request, "Login", next)
context["errors"] = errors context["errors"] = errors
context["url_base"] = f"{request.url.scheme}://{request.url.netloc}" context["url_base"] = f"{request.url.scheme}://{request.url.netloc}"
@ -32,55 +31,73 @@ async def login_get(request: Request, next: str = "/"):
@router.post("/login", response_class=HTMLResponse) @router.post("/login", response_class=HTMLResponse)
@handle_form_exceptions @handle_form_exceptions
@requires_guest @requires_guest
async def login_post(request: Request, async def login_post(
next: str = Form(...), request: Request,
user: str = Form(default=str()), next: str = Form(...),
passwd: str = Form(default=str()), user: str = Form(default=str()),
remember_me: bool = Form(default=False)): passwd: str = Form(default=str()),
remember_me: bool = Form(default=False),
):
# TODO: Once the Origin header gets broader adoption, this code can be # TODO: Once the Origin header gets broader adoption, this code can be
# slightly simplified to use it. # slightly simplified to use it.
login_path = aurweb.config.get("options", "aur_location") + "/login" login_path = aurweb.config.get("options", "aur_location") + "/login"
referer = request.headers.get("Referer") referer = request.headers.get("Referer")
if not referer or not referer.startswith(login_path): if not referer or not referer.startswith(login_path):
_ = get_translator_for_request(request) _ = get_translator_for_request(request)
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, raise HTTPException(
detail=_("Bad Referer header.")) status_code=HTTPStatus.BAD_REQUEST, detail=_("Bad Referer header.")
)
with db.begin(): with db.begin():
user = db.query(User).filter( user = (
or_(User.Username == user, User.Email == user) db.query(User)
).first() .filter(or_(User.Username == user, User.Email == user))
.first()
)
if not user: if not user:
return await login_template(request, next, return await login_template(request, next, errors=["Bad username or password."])
errors=["Bad username or password."])
if user.Suspended: if user.Suspended:
return await login_template(request, next, return await login_template(request, next, errors=["Account Suspended"])
errors=["Account Suspended"])
cookie_timeout = cookies.timeout(remember_me) cookie_timeout = cookies.timeout(remember_me)
sid = user.login(request, passwd, cookie_timeout) sid = user.login(request, passwd, cookie_timeout)
if not sid: if not sid:
return await login_template(request, next, return await login_template(request, next, errors=["Bad username or password."])
errors=["Bad username or password."])
response = RedirectResponse(url=next, response = RedirectResponse(url=next, status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
secure = aurweb.config.getboolean("options", "disable_http_login") secure = aurweb.config.getboolean("options", "disable_http_login")
response.set_cookie("AURSID", sid, max_age=cookie_timeout, response.set_cookie(
secure=secure, httponly=secure, "AURSID",
samesite=cookies.samesite()) sid,
response.set_cookie("AURTZ", user.Timezone, max_age=cookie_timeout,
secure=secure, httponly=secure, secure=secure,
samesite=cookies.samesite()) httponly=secure,
response.set_cookie("AURLANG", user.LangPreference, samesite=cookies.samesite(),
secure=secure, httponly=secure, )
samesite=cookies.samesite()) response.set_cookie(
response.set_cookie("AURREMEMBER", remember_me, "AURTZ",
secure=secure, httponly=secure, user.Timezone,
samesite=cookies.samesite()) secure=secure,
httponly=secure,
samesite=cookies.samesite(),
)
response.set_cookie(
"AURLANG",
user.LangPreference,
secure=secure,
httponly=secure,
samesite=cookies.samesite(),
)
response.set_cookie(
"AURREMEMBER",
remember_me,
secure=secure,
httponly=secure,
samesite=cookies.samesite(),
)
return response return response
@ -93,8 +110,7 @@ async def logout(request: Request, next: str = Form(default="/")):
# Use 303 since we may be handling a post request, that'll get it # Use 303 since we may be handling a post request, that'll get it
# to redirect to a get request. # to redirect to a get request.
response = RedirectResponse(url=next, response = RedirectResponse(url=next, status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
response.delete_cookie("AURSID") response.delete_cookie("AURSID")
response.delete_cookie("AURTZ") response.delete_cookie("AURTZ")
return response return response

View file

@ -2,17 +2,20 @@
decorators in some way; more complex routes should be defined in their decorators in some way; more complex routes should be defined in their
own modules and imported here. """ own modules and imported here. """
import os import os
from http import HTTPStatus from http import HTTPStatus
from fastapi import APIRouter, Form, HTTPException, Request, Response from fastapi import APIRouter, Form, HTTPException, Request, Response
from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.responses import HTMLResponse, RedirectResponse
from prometheus_client import CONTENT_TYPE_LATEST, CollectorRegistry, generate_latest, multiprocess from prometheus_client import (
CONTENT_TYPE_LATEST,
CollectorRegistry,
generate_latest,
multiprocess,
)
from sqlalchemy import and_, case, or_ from sqlalchemy import and_, case, or_
import aurweb.config import aurweb.config
import aurweb.models.package_request import aurweb.models.package_request
from aurweb import cookies, db, logging, models, time, util from aurweb import cookies, db, logging, models, time, util
from aurweb.cache import db_count_cache from aurweb.cache import db_count_cache
from aurweb.exceptions import handle_form_exceptions from aurweb.exceptions import handle_form_exceptions
@ -27,17 +30,19 @@ router = APIRouter()
@router.get("/favicon.ico") @router.get("/favicon.ico")
async def favicon(request: Request): async def favicon(request: Request):
""" Some browsers attempt to find a website's favicon via root uri at """Some browsers attempt to find a website's favicon via root uri at
/favicon.ico, so provide a redirection here to our static icon. """ /favicon.ico, so provide a redirection here to our static icon."""
return RedirectResponse("/static/images/favicon.ico") return RedirectResponse("/static/images/favicon.ico")
@router.post("/language", response_class=RedirectResponse) @router.post("/language", response_class=RedirectResponse)
@handle_form_exceptions @handle_form_exceptions
async def language(request: Request, async def language(
set_lang: str = Form(...), request: Request,
next: str = Form(...), set_lang: str = Form(...),
q: str = Form(default=None)): next: str = Form(...),
q: str = Form(default=None),
):
""" """
A POST route used to set a session's language. A POST route used to set a session's language.
@ -45,7 +50,7 @@ async def language(request: Request,
setting the language on any page, we want to preserve query setting the language on any page, we want to preserve query
parameters across the redirect. parameters across the redirect.
""" """
if next[0] != '/': if next[0] != "/":
return HTMLResponse(b"Invalid 'next' parameter.", status_code=400) return HTMLResponse(b"Invalid 'next' parameter.", status_code=400)
query_string = "?" + q if q else str() query_string = "?" + q if q else str()
@ -56,20 +61,21 @@ async def language(request: Request,
request.user.LangPreference = set_lang request.user.LangPreference = set_lang
# In any case, set the response's AURLANG cookie that never expires. # In any case, set the response's AURLANG cookie that never expires.
response = RedirectResponse(url=f"{next}{query_string}", response = RedirectResponse(
status_code=HTTPStatus.SEE_OTHER) url=f"{next}{query_string}", status_code=HTTPStatus.SEE_OTHER
)
secure = aurweb.config.getboolean("options", "disable_http_login") secure = aurweb.config.getboolean("options", "disable_http_login")
response.set_cookie("AURLANG", set_lang, response.set_cookie(
secure=secure, httponly=secure, "AURLANG", set_lang, secure=secure, httponly=secure, samesite=cookies.samesite()
samesite=cookies.samesite()) )
return response return response
@router.get("/", response_class=HTMLResponse) @router.get("/", response_class=HTMLResponse)
async def index(request: Request): async def index(request: Request):
""" Homepage route. """ """Homepage route."""
context = make_context(request, "Home") context = make_context(request, "Home")
context['ssh_fingerprints'] = util.get_ssh_fingerprints() context["ssh_fingerprints"] = util.get_ssh_fingerprints()
bases = db.query(models.PackageBase) bases = db.query(models.PackageBase)
@ -79,24 +85,33 @@ async def index(request: Request):
# Package statistics. # Package statistics.
query = bases.filter(models.PackageBase.PackagerUID.isnot(None)) query = bases.filter(models.PackageBase.PackagerUID.isnot(None))
context["package_count"] = await db_count_cache( context["package_count"] = await db_count_cache(
redis, "package_count", query, expire=cache_expire) redis, "package_count", query, expire=cache_expire
)
query = bases.filter( query = bases.filter(
and_(models.PackageBase.MaintainerUID.is_(None), and_(
models.PackageBase.PackagerUID.isnot(None)) models.PackageBase.MaintainerUID.is_(None),
models.PackageBase.PackagerUID.isnot(None),
)
) )
context["orphan_count"] = await db_count_cache( context["orphan_count"] = await db_count_cache(
redis, "orphan_count", query, expire=cache_expire) redis, "orphan_count", query, expire=cache_expire
)
query = db.query(models.User) query = db.query(models.User)
context["user_count"] = await db_count_cache( context["user_count"] = await db_count_cache(
redis, "user_count", query, expire=cache_expire) redis, "user_count", query, expire=cache_expire
)
query = query.filter( query = query.filter(
or_(models.User.AccountTypeID == TRUSTED_USER_ID, or_(
models.User.AccountTypeID == TRUSTED_USER_AND_DEV_ID)) models.User.AccountTypeID == TRUSTED_USER_ID,
models.User.AccountTypeID == TRUSTED_USER_AND_DEV_ID,
)
)
context["trusted_user_count"] = await db_count_cache( context["trusted_user_count"] = await db_count_cache(
redis, "trusted_user_count", query, expire=cache_expire) redis, "trusted_user_count", query, expire=cache_expire
)
# Current timestamp. # Current timestamp.
now = time.utcnow() now = time.utcnow()
@ -106,31 +121,40 @@ async def index(request: Request):
one_hour = 3600 one_hour = 3600
updated = bases.filter( updated = bases.filter(
and_(models.PackageBase.ModifiedTS - models.PackageBase.SubmittedTS >= one_hour, and_(
models.PackageBase.PackagerUID.isnot(None)) models.PackageBase.ModifiedTS - models.PackageBase.SubmittedTS >= one_hour,
models.PackageBase.PackagerUID.isnot(None),
)
) )
query = bases.filter( query = bases.filter(
and_(models.PackageBase.SubmittedTS >= seven_days_ago, and_(
models.PackageBase.PackagerUID.isnot(None)) models.PackageBase.SubmittedTS >= seven_days_ago,
models.PackageBase.PackagerUID.isnot(None),
)
) )
context["seven_days_old_added"] = await db_count_cache( context["seven_days_old_added"] = await db_count_cache(
redis, "seven_days_old_added", query, expire=cache_expire) redis, "seven_days_old_added", query, expire=cache_expire
)
query = updated.filter(models.PackageBase.ModifiedTS >= seven_days_ago) query = updated.filter(models.PackageBase.ModifiedTS >= seven_days_ago)
context["seven_days_old_updated"] = await db_count_cache( context["seven_days_old_updated"] = await db_count_cache(
redis, "seven_days_old_updated", query, expire=cache_expire) redis, "seven_days_old_updated", query, expire=cache_expire
)
year = seven_days * 52 # Fifty two weeks worth: one year. year = seven_days * 52 # Fifty two weeks worth: one year.
year_ago = now - year year_ago = now - year
query = updated.filter(models.PackageBase.ModifiedTS >= year_ago) query = updated.filter(models.PackageBase.ModifiedTS >= year_ago)
context["year_old_updated"] = await db_count_cache( context["year_old_updated"] = await db_count_cache(
redis, "year_old_updated", query, expire=cache_expire) redis, "year_old_updated", query, expire=cache_expire
)
query = bases.filter( query = bases.filter(
models.PackageBase.ModifiedTS - models.PackageBase.SubmittedTS < 3600) models.PackageBase.ModifiedTS - models.PackageBase.SubmittedTS < 3600
)
context["never_updated"] = await db_count_cache( context["never_updated"] = await db_count_cache(
redis, "never_updated", query, expire=cache_expire) redis, "never_updated", query, expire=cache_expire
)
# Get the 15 most recently updated packages. # Get the 15 most recently updated packages.
context["package_updates"] = updated_packages(15, cache_expire) context["package_updates"] = updated_packages(15, cache_expire)
@ -140,78 +164,92 @@ async def index(request: Request):
# the dashboard display. # the dashboard display.
packages = db.query(models.Package).join(models.PackageBase) packages = db.query(models.Package).join(models.PackageBase)
maintained = packages.join( maintained = (
models.PackageComaintainer, packages.join(
models.PackageComaintainer.PackageBaseID == models.PackageBase.ID, models.PackageComaintainer,
isouter=True models.PackageComaintainer.PackageBaseID == models.PackageBase.ID,
).join( isouter=True,
models.User, )
or_(models.PackageBase.MaintainerUID == models.User.ID, .join(
models.PackageComaintainer.UsersID == models.User.ID) models.User,
).filter( or_(
models.User.ID == request.user.ID models.PackageBase.MaintainerUID == models.User.ID,
models.PackageComaintainer.UsersID == models.User.ID,
),
)
.filter(models.User.ID == request.user.ID)
) )
# Packages maintained by the user that have been flagged. # Packages maintained by the user that have been flagged.
context["flagged_packages"] = maintained.filter( context["flagged_packages"] = (
models.PackageBase.OutOfDateTS.isnot(None) maintained.filter(models.PackageBase.OutOfDateTS.isnot(None))
).order_by( .order_by(models.PackageBase.ModifiedTS.desc(), models.Package.Name.asc())
models.PackageBase.ModifiedTS.desc(), models.Package.Name.asc() .limit(50)
).limit(50).all() .all()
)
# Flagged packages that request.user has voted for. # Flagged packages that request.user has voted for.
context["flagged_packages_voted"] = query_voted( context["flagged_packages_voted"] = query_voted(
context.get("flagged_packages"), request.user) context.get("flagged_packages"), request.user
)
# Flagged packages that request.user is being notified about. # Flagged packages that request.user is being notified about.
context["flagged_packages_notified"] = query_notified( context["flagged_packages_notified"] = query_notified(
context.get("flagged_packages"), request.user) context.get("flagged_packages"), request.user
)
archive_time = aurweb.config.getint('options', 'request_archive_time') archive_time = aurweb.config.getint("options", "request_archive_time")
start = now - archive_time start = now - archive_time
# Package requests created by request.user. # Package requests created by request.user.
context["package_requests"] = request.user.package_requests.filter( context["package_requests"] = (
models.PackageRequest.RequestTS >= start request.user.package_requests.filter(
).order_by( models.PackageRequest.RequestTS >= start
# Order primarily by the Status column being PENDING_ID, )
# and secondarily by RequestTS; both in descending order. .order_by(
case([(models.PackageRequest.Status == PENDING_ID, 1)], # Order primarily by the Status column being PENDING_ID,
else_=0).desc(), # and secondarily by RequestTS; both in descending order.
models.PackageRequest.RequestTS.desc() case([(models.PackageRequest.Status == PENDING_ID, 1)], else_=0).desc(),
).limit(50).all() models.PackageRequest.RequestTS.desc(),
)
.limit(50)
.all()
)
# Packages that the request user maintains or comaintains. # Packages that the request user maintains or comaintains.
context["packages"] = maintained.filter( context["packages"] = (
models.User.ID == models.PackageBase.MaintainerUID maintained.filter(models.User.ID == models.PackageBase.MaintainerUID)
).order_by( .order_by(models.PackageBase.ModifiedTS.desc(), models.Package.Name.desc())
models.PackageBase.ModifiedTS.desc(), models.Package.Name.desc() .limit(50)
).limit(50).all() .all()
)
# Packages that request.user has voted for. # Packages that request.user has voted for.
context["packages_voted"] = query_voted( context["packages_voted"] = query_voted(context.get("packages"), request.user)
context.get("packages"), request.user)
# Packages that request.user is being notified about. # Packages that request.user is being notified about.
context["packages_notified"] = query_notified( context["packages_notified"] = query_notified(
context.get("packages"), request.user) context.get("packages"), request.user
)
# Any packages that the request user comaintains. # Any packages that the request user comaintains.
context["comaintained"] = packages.join( context["comaintained"] = (
models.PackageComaintainer packages.join(models.PackageComaintainer)
).filter( .filter(models.PackageComaintainer.UsersID == request.user.ID)
models.PackageComaintainer.UsersID == request.user.ID .order_by(models.PackageBase.ModifiedTS.desc(), models.Package.Name.desc())
).order_by( .limit(50)
models.PackageBase.ModifiedTS.desc(), models.Package.Name.desc() .all()
).limit(50).all() )
# Comaintained packages that request.user has voted for. # Comaintained packages that request.user has voted for.
context["comaintained_voted"] = query_voted( context["comaintained_voted"] = query_voted(
context.get("comaintained"), request.user) context.get("comaintained"), request.user
)
# Comaintained packages that request.user is being notified about. # Comaintained packages that request.user is being notified about.
context["comaintained_notified"] = query_notified( context["comaintained_notified"] = query_notified(
context.get("comaintained"), request.user) context.get("comaintained"), request.user
)
return render_template(request, "index.html", context) return render_template(request, "index.html", context)
@ -232,16 +270,15 @@ async def archive_sha256(request: Request, archive: str):
@router.get("/metrics") @router.get("/metrics")
async def metrics(request: Request): async def metrics(request: Request):
if not os.environ.get("PROMETHEUS_MULTIPROC_DIR", None): if not os.environ.get("PROMETHEUS_MULTIPROC_DIR", None):
return Response("Prometheus metrics are not enabled.", return Response(
status_code=HTTPStatus.SERVICE_UNAVAILABLE) "Prometheus metrics are not enabled.",
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
registry = CollectorRegistry() registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry) multiprocess.MultiProcessCollector(registry)
data = generate_latest(registry) data = generate_latest(registry)
headers = { headers = {"Content-Type": CONTENT_TYPE_LATEST, "Content-Length": str(len(data))}
"Content-Type": CONTENT_TYPE_LATEST,
"Content-Length": str(len(data))
}
return Response(data, headers=headers) return Response(data, headers=headers)

View file

@ -5,7 +5,6 @@ from typing import Any
from fastapi import APIRouter, Form, Query, Request, Response from fastapi import APIRouter, Form, Query, Request, Response
import aurweb.filters # noqa: F401 import aurweb.filters # noqa: F401
from aurweb import config, db, defaults, logging, models, util from aurweb import config, db, defaults, logging, models, util
from aurweb.auth import creds, requires_auth from aurweb.auth import creds, requires_auth
from aurweb.exceptions import InvariantError, handle_form_exceptions from aurweb.exceptions import InvariantError, handle_form_exceptions
@ -13,23 +12,24 @@ from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID
from aurweb.packages import util as pkgutil from aurweb.packages import util as pkgutil
from aurweb.packages.search import PackageSearch from aurweb.packages.search import PackageSearch
from aurweb.packages.util import get_pkg_or_base from aurweb.packages.util import get_pkg_or_base
from aurweb.pkgbase import actions as pkgbase_actions from aurweb.pkgbase import actions as pkgbase_actions, util as pkgbaseutil
from aurweb.pkgbase import util as pkgbaseutil
from aurweb.templates import make_context, make_variable_context, render_template from aurweb.templates import make_context, make_variable_context, render_template
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
router = APIRouter() router = APIRouter()
async def packages_get(request: Request, context: dict[str, Any], async def packages_get(
status_code: HTTPStatus = HTTPStatus.OK): request: Request, context: dict[str, Any], status_code: HTTPStatus = HTTPStatus.OK
):
# Query parameters used in this request. # Query parameters used in this request.
context["q"] = dict(request.query_params) context["q"] = dict(request.query_params)
# Per page and offset. # Per page and offset.
offset, per_page = util.sanitize_params( offset, per_page = util.sanitize_params(
request.query_params.get("O", defaults.O), request.query_params.get("O", defaults.O),
request.query_params.get("PP", defaults.PP)) request.query_params.get("PP", defaults.PP),
)
context["O"] = offset context["O"] = offset
# Limit PP to options.max_search_results # Limit PP to options.max_search_results
@ -82,8 +82,7 @@ async def packages_get(request: Request, context: dict[str, Any],
if submit == "Orphans": if submit == "Orphans":
# If the user clicked the "Orphans" button, we only want # If the user clicked the "Orphans" button, we only want
# orphaned packages. # orphaned packages.
search.query = search.query.filter( search.query = search.query.filter(models.PackageBase.MaintainerUID.is_(None))
models.PackageBase.MaintainerUID.is_(None))
# Collect search result count here; we've applied our keywords. # Collect search result count here; we've applied our keywords.
# Including more query operations below, like ordering, will # Including more query operations below, like ordering, will
@ -94,26 +93,31 @@ async def packages_get(request: Request, context: dict[str, Any],
search.sort_by(sort_by, sort_order) search.sort_by(sort_by, sort_order)
# Insert search results into the context. # Insert search results into the context.
results = search.results().with_entities( results = (
models.Package.ID, search.results()
models.Package.Name, .with_entities(
models.Package.PackageBaseID, models.Package.ID,
models.Package.Version, models.Package.Name,
models.Package.Description, models.Package.PackageBaseID,
models.PackageBase.Popularity, models.Package.Version,
models.PackageBase.NumVotes, models.Package.Description,
models.PackageBase.OutOfDateTS, models.PackageBase.Popularity,
models.User.Username.label("Maintainer"), models.PackageBase.NumVotes,
models.PackageVote.PackageBaseID.label("Voted"), models.PackageBase.OutOfDateTS,
models.PackageNotification.PackageBaseID.label("Notify") models.User.Username.label("Maintainer"),
).group_by(models.Package.Name) models.PackageVote.PackageBaseID.label("Voted"),
models.PackageNotification.PackageBaseID.label("Notify"),
)
.group_by(models.Package.Name)
)
packages = results.limit(per_page).offset(offset) packages = results.limit(per_page).offset(offset)
context["packages"] = packages context["packages"] = packages
context["packages_count"] = num_packages context["packages_count"] = num_packages
return render_template(request, "packages/index.html", context, return render_template(
status_code=status_code) request, "packages/index.html", context, status_code=status_code
)
@router.get("/packages") @router.get("/packages")
@ -123,9 +127,12 @@ async def packages(request: Request) -> Response:
@router.get("/packages/{name}") @router.get("/packages/{name}")
async def package(request: Request, name: str, async def package(
all_deps: bool = Query(default=False), request: Request,
all_reqs: bool = Query(default=False)) -> Response: name: str,
all_deps: bool = Query(default=False),
all_reqs: bool = Query(default=False),
) -> Response:
""" """
Get a package by name. Get a package by name.
@ -156,26 +163,21 @@ async def package(request: Request, name: str,
# Add our base information. # Add our base information.
context = await pkgbaseutil.make_variable_context(request, pkgbase) context = await pkgbaseutil.make_variable_context(request, pkgbase)
context.update( context.update({"all_deps": all_deps, "all_reqs": all_reqs})
{
"all_deps": all_deps,
"all_reqs": all_reqs
}
)
context["package"] = pkg context["package"] = pkg
# Package sources. # Package sources.
context["sources"] = pkg.package_sources.order_by( context["sources"] = pkg.package_sources.order_by(
models.PackageSource.Source.asc()).all() models.PackageSource.Source.asc()
).all()
# Listing metadata. # Listing metadata.
context["max_listing"] = max_listing = 20 context["max_listing"] = max_listing = 20
# Package dependencies. # Package dependencies.
deps = pkg.package_dependencies.order_by( deps = pkg.package_dependencies.order_by(
models.PackageDependency.DepTypeID.asc(), models.PackageDependency.DepTypeID.asc(), models.PackageDependency.DepName.asc()
models.PackageDependency.DepName.asc()
) )
context["depends_count"] = deps.count() context["depends_count"] = deps.count()
if not all_deps: if not all_deps:
@ -183,8 +185,7 @@ async def package(request: Request, name: str,
context["dependencies"] = deps.all() context["dependencies"] = deps.all()
# Package requirements (other packages depend on this one). # Package requirements (other packages depend on this one).
reqs = pkgutil.pkg_required( reqs = pkgutil.pkg_required(pkg.Name, [p.RelName for p in rels_data.get("p", [])])
pkg.Name, [p.RelName for p in rels_data.get("p", [])])
context["reqs_count"] = reqs.count() context["reqs_count"] = reqs.count()
if not all_reqs: if not all_reqs:
reqs = reqs.limit(max_listing) reqs = reqs.limit(max_listing)
@ -210,8 +211,7 @@ async def package(request: Request, name: str,
return render_template(request, "packages/show.html", context) return render_template(request, "packages/show.html", context)
async def packages_unflag(request: Request, package_ids: list[int] = [], async def packages_unflag(request: Request, package_ids: list[int] = [], **kwargs):
**kwargs):
if not package_ids: if not package_ids:
return (False, ["You did not select any packages to unflag."]) return (False, ["You did not select any packages to unflag."])
@ -220,11 +220,11 @@ async def packages_unflag(request: Request, package_ids: list[int] = [],
bases = set() bases = set()
package_ids = set(package_ids) # Convert this to a set for O(1). package_ids = set(package_ids) # Convert this to a set for O(1).
packages = db.query(models.Package).filter( packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
models.Package.ID.in_(package_ids)).all()
for pkg in packages: for pkg in packages:
has_cred = request.user.has_credential( has_cred = request.user.has_credential(
creds.PKGBASE_UNFLAG, approved=[pkg.PackageBase.Flagger]) creds.PKGBASE_UNFLAG, approved=[pkg.PackageBase.Flagger]
)
if not has_cred: if not has_cred:
return (False, ["You did not select any packages to unflag."]) return (False, ["You did not select any packages to unflag."])
@ -236,20 +236,17 @@ async def packages_unflag(request: Request, package_ids: list[int] = [],
return (True, ["The selected packages have been unflagged."]) return (True, ["The selected packages have been unflagged."])
async def packages_notify(request: Request, package_ids: list[int] = [], async def packages_notify(request: Request, package_ids: list[int] = [], **kwargs):
**kwargs):
# In cases where we encounter errors with the request, we'll # In cases where we encounter errors with the request, we'll
# use this error tuple as a return value. # use this error tuple as a return value.
# TODO: This error does not yet have a translation. # TODO: This error does not yet have a translation.
error_tuple = (False, error_tuple = (False, ["You did not select any packages to be notified about."])
["You did not select any packages to be notified about."])
if not package_ids: if not package_ids:
return error_tuple return error_tuple
bases = set() bases = set()
package_ids = set(package_ids) package_ids = set(package_ids)
packages = db.query(models.Package).filter( packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
models.Package.ID.in_(package_ids)).all()
for pkg in packages: for pkg in packages:
if pkg.PackageBase not in bases: if pkg.PackageBase not in bases:
@ -257,9 +254,11 @@ async def packages_notify(request: Request, package_ids: list[int] = [],
# Perform some checks on what the user selected for notify. # Perform some checks on what the user selected for notify.
for pkgbase in bases: for pkgbase in bases:
notif = db.query(pkgbase.notifications.filter( notif = db.query(
models.PackageNotification.UserID == request.user.ID pkgbase.notifications.filter(
).exists()).scalar() models.PackageNotification.UserID == request.user.ID
).exists()
).scalar()
has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY) has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY)
# If the request user either does not have credentials # If the request user either does not have credentials
@ -275,23 +274,20 @@ async def packages_notify(request: Request, package_ids: list[int] = [],
return (True, ["The selected packages' notifications have been enabled."]) return (True, ["The selected packages' notifications have been enabled."])
async def packages_unnotify(request: Request, package_ids: list[int] = [], async def packages_unnotify(request: Request, package_ids: list[int] = [], **kwargs):
**kwargs):
if not package_ids: if not package_ids:
# TODO: This error does not yet have a translation. # TODO: This error does not yet have a translation.
return (False, return (False, ["You did not select any packages for notification removal."])
["You did not select any packages for notification removal."])
# TODO: This error does not yet have a translation. # TODO: This error does not yet have a translation.
error_tuple = ( error_tuple = (
False, False,
["A package you selected does not have notifications enabled."] ["A package you selected does not have notifications enabled."],
) )
bases = set() bases = set()
package_ids = set(package_ids) package_ids = set(package_ids)
packages = db.query(models.Package).filter( packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
models.Package.ID.in_(package_ids)).all()
for pkg in packages: for pkg in packages:
if pkg.PackageBase not in bases: if pkg.PackageBase not in bases:
@ -299,9 +295,11 @@ async def packages_unnotify(request: Request, package_ids: list[int] = [],
# Perform some checks on what the user selected for notify. # Perform some checks on what the user selected for notify.
for pkgbase in bases: for pkgbase in bases:
notif = db.query(pkgbase.notifications.filter( notif = db.query(
models.PackageNotification.UserID == request.user.ID pkgbase.notifications.filter(
).exists()).scalar() models.PackageNotification.UserID == request.user.ID
).exists()
).scalar()
if not notif: if not notif:
return error_tuple return error_tuple
@ -312,19 +310,24 @@ async def packages_unnotify(request: Request, package_ids: list[int] = [],
return (True, ["The selected packages' notifications have been removed."]) return (True, ["The selected packages' notifications have been removed."])
async def packages_adopt(request: Request, package_ids: list[int] = [], async def packages_adopt(
confirm: bool = False, **kwargs): request: Request, package_ids: list[int] = [], confirm: bool = False, **kwargs
):
if not package_ids: if not package_ids:
return (False, ["You did not select any packages to adopt."]) return (False, ["You did not select any packages to adopt."])
if not confirm: if not confirm:
return (False, ["The selected packages have not been adopted, " return (
"check the confirmation checkbox."]) False,
[
"The selected packages have not been adopted, "
"check the confirmation checkbox."
],
)
bases = set() bases = set()
package_ids = set(package_ids) package_ids = set(package_ids)
packages = db.query(models.Package).filter( packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
models.Package.ID.in_(package_ids)).all()
for pkg in packages: for pkg in packages:
if pkg.PackageBase not in bases: if pkg.PackageBase not in bases:
@ -335,8 +338,10 @@ async def packages_adopt(request: Request, package_ids: list[int] = [],
has_cred = request.user.has_credential(creds.PKGBASE_ADOPT) has_cred = request.user.has_credential(creds.PKGBASE_ADOPT)
if not (has_cred or not pkgbase.Maintainer): if not (has_cred or not pkgbase.Maintainer):
# TODO: This error needs to be translated. # TODO: This error needs to be translated.
return (False, ["You are not allowed to adopt one of the " return (
"packages you selected."]) False,
["You are not allowed to adopt one of the " "packages you selected."],
)
# Now, really adopt the bases. # Now, really adopt the bases.
for pkgbase in bases: for pkgbase in bases:
@ -345,8 +350,7 @@ async def packages_adopt(request: Request, package_ids: list[int] = [],
return (True, ["The selected packages have been adopted."]) return (True, ["The selected packages have been adopted."])
def disown_all(request: Request, pkgbases: list[models.PackageBase]) \ def disown_all(request: Request, pkgbases: list[models.PackageBase]) -> list[str]:
-> list[str]:
errors = [] errors = []
for pkgbase in pkgbases: for pkgbase in pkgbases:
try: try:
@ -356,19 +360,24 @@ def disown_all(request: Request, pkgbases: list[models.PackageBase]) \
return errors return errors
async def packages_disown(request: Request, package_ids: list[int] = [], async def packages_disown(
confirm: bool = False, **kwargs): request: Request, package_ids: list[int] = [], confirm: bool = False, **kwargs
):
if not package_ids: if not package_ids:
return (False, ["You did not select any packages to disown."]) return (False, ["You did not select any packages to disown."])
if not confirm: if not confirm:
return (False, ["The selected packages have not been disowned, " return (
"check the confirmation checkbox."]) False,
[
"The selected packages have not been disowned, "
"check the confirmation checkbox."
],
)
bases = set() bases = set()
package_ids = set(package_ids) package_ids = set(package_ids)
packages = db.query(models.Package).filter( packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
models.Package.ID.in_(package_ids)).all()
for pkg in packages: for pkg in packages:
if pkg.PackageBase not in bases: if pkg.PackageBase not in bases:
@ -376,12 +385,15 @@ async def packages_disown(request: Request, package_ids: list[int] = [],
# Check that the user has credentials for every package they selected. # Check that the user has credentials for every package they selected.
for pkgbase in bases: for pkgbase in bases:
has_cred = request.user.has_credential(creds.PKGBASE_DISOWN, has_cred = request.user.has_credential(
approved=[pkgbase.Maintainer]) creds.PKGBASE_DISOWN, approved=[pkgbase.Maintainer]
)
if not has_cred: if not has_cred:
# TODO: This error needs to be translated. # TODO: This error needs to be translated.
return (False, ["You are not allowed to disown one " return (
"of the packages you selected."]) False,
["You are not allowed to disown one " "of the packages you selected."],
)
# Now, disown all the bases if we can. # Now, disown all the bases if we can.
if errors := disown_all(request, bases): if errors := disown_all(request, bases):
@ -390,23 +402,31 @@ async def packages_disown(request: Request, package_ids: list[int] = [],
return (True, ["The selected packages have been disowned."]) return (True, ["The selected packages have been disowned."])
async def packages_delete(request: Request, package_ids: list[int] = [], async def packages_delete(
confirm: bool = False, merge_into: str = str(), request: Request,
**kwargs): package_ids: list[int] = [],
confirm: bool = False,
merge_into: str = str(),
**kwargs,
):
if not package_ids: if not package_ids:
return (False, ["You did not select any packages to delete."]) return (False, ["You did not select any packages to delete."])
if not confirm: if not confirm:
return (False, ["The selected packages have not been deleted, " return (
"check the confirmation checkbox."]) False,
[
"The selected packages have not been deleted, "
"check the confirmation checkbox."
],
)
if not request.user.has_credential(creds.PKGBASE_DELETE): if not request.user.has_credential(creds.PKGBASE_DELETE):
return (False, ["You do not have permission to delete packages."]) return (False, ["You do not have permission to delete packages."])
# set-ify package_ids and query the database for related records. # set-ify package_ids and query the database for related records.
package_ids = set(package_ids) package_ids = set(package_ids)
packages = db.query(models.Package).filter( packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
models.Package.ID.in_(package_ids)).all()
if len(packages) != len(package_ids): if len(packages) != len(package_ids):
# Let the user know there was an issue with their input: they have # Let the user know there was an issue with their input: they have
@ -422,12 +442,15 @@ async def packages_delete(request: Request, package_ids: list[int] = [],
notifs += pkgbase_actions.pkgbase_delete_instance(request, pkgbase) notifs += pkgbase_actions.pkgbase_delete_instance(request, pkgbase)
# Log out the fact that this happened for accountability. # Log out the fact that this happened for accountability.
logger.info(f"Privileged user '{request.user.Username}' deleted the " logger.info(
f"following package bases: {str(deleted_bases)}.") f"Privileged user '{request.user.Username}' deleted the "
f"following package bases: {str(deleted_bases)}."
)
util.apply_all(notifs, lambda n: n.send()) util.apply_all(notifs, lambda n: n.send())
return (True, ["The selected packages have been deleted."]) return (True, ["The selected packages have been deleted."])
# A mapping of action string -> callback functions used within the # A mapping of action string -> callback functions used within the
# `packages_post` route below. We expect any action callback to # `packages_post` route below. We expect any action callback to
# return a tuple in the format: (succeeded: bool, message: list[str]). # return a tuple in the format: (succeeded: bool, message: list[str]).
@ -444,10 +467,12 @@ PACKAGE_ACTIONS = {
@router.post("/packages") @router.post("/packages")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def packages_post(request: Request, async def packages_post(
IDs: list[int] = Form(default=[]), request: Request,
action: str = Form(default=str()), IDs: list[int] = Form(default=[]),
confirm: bool = Form(default=False)): action: str = Form(default=str()),
confirm: bool = Form(default=False),
):
# If an invalid action is specified, just render GET /packages # If an invalid action is specified, just render GET /packages
# with an BAD_REQUEST status_code. # with an BAD_REQUEST status_code.

View file

@ -16,9 +16,7 @@ from aurweb.models.package_vote import PackageVote
from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID
from aurweb.packages.requests import update_closure_comment from aurweb.packages.requests import update_closure_comment
from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment
from aurweb.pkgbase import actions from aurweb.pkgbase import actions, util as pkgbaseutil, validate
from aurweb.pkgbase import util as pkgbaseutil
from aurweb.pkgbase import validate
from aurweb.scripts import notify, popupdate from aurweb.scripts import notify, popupdate
from aurweb.scripts.rendercomment import update_comment_render_fastapi from aurweb.scripts.rendercomment import update_comment_render_fastapi
from aurweb.templates import make_variable_context, render_template from aurweb.templates import make_variable_context, render_template
@ -44,8 +42,9 @@ async def pkgbase(request: Request, name: str) -> Response:
packages = pkgbase.packages.all() packages = pkgbase.packages.all()
pkg = packages[0] pkg = packages[0]
if len(packages) == 1 and pkg.Name == pkgbase.Name: if len(packages) == 1 and pkg.Name == pkgbase.Name:
return RedirectResponse(f"/packages/{pkg.Name}", return RedirectResponse(
status_code=int(HTTPStatus.SEE_OTHER)) f"/packages/{pkg.Name}", status_code=int(HTTPStatus.SEE_OTHER)
)
# Add our base information. # Add our base information.
context = pkgbaseutil.make_context(request, pkgbase) context = pkgbaseutil.make_context(request, pkgbase)
@ -69,8 +68,7 @@ async def pkgbase_voters(request: Request, name: str) -> Response:
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
if not request.user.has_credential(creds.PKGBASE_LIST_VOTERS): if not request.user.has_credential(creds.PKGBASE_LIST_VOTERS):
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
context = templates.make_context(request, "Voters") context = templates.make_context(request, "Voters")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
@ -82,8 +80,7 @@ async def pkgbase_flag_comment(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
if pkgbase.OutOfDateTS is None: if pkgbase.OutOfDateTS is None:
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
context = templates.make_context(request, "Flag Comment") context = templates.make_context(request, "Flag Comment")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
@ -92,13 +89,15 @@ async def pkgbase_flag_comment(request: Request, name: str):
@router.post("/pkgbase/{name}/keywords") @router.post("/pkgbase/{name}/keywords")
@handle_form_exceptions @handle_form_exceptions
async def pkgbase_keywords(request: Request, name: str, async def pkgbase_keywords(
keywords: str = Form(default=str())): request: Request, name: str, keywords: str = Form(default=str())
):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
approved = [pkgbase.Maintainer] + [c.User for c in pkgbase.comaintainers] approved = [pkgbase.Maintainer] + [c.User for c in pkgbase.comaintainers]
has_cred = creds.has_credential(request.user, creds.PKGBASE_SET_KEYWORDS, has_cred = creds.has_credential(
approved=approved) request.user, creds.PKGBASE_SET_KEYWORDS, approved=approved
)
if not has_cred: if not has_cred:
return Response(status_code=HTTPStatus.UNAUTHORIZED) return Response(status_code=HTTPStatus.UNAUTHORIZED)
@ -108,15 +107,14 @@ async def pkgbase_keywords(request: Request, name: str,
# Delete all keywords which are not supplied by the user. # Delete all keywords which are not supplied by the user.
with db.begin(): with db.begin():
other_keywords = pkgbase.keywords.filter( other_keywords = pkgbase.keywords.filter(~PackageKeyword.Keyword.in_(keywords))
~PackageKeyword.Keyword.in_(keywords)) other_keyword_strings = set(kwd.Keyword.lower() for kwd in other_keywords)
other_keyword_strings = set(
kwd.Keyword.lower() for kwd in other_keywords)
existing_keywords = set( existing_keywords = set(
kwd.Keyword.lower() for kwd in kwd.Keyword.lower()
pkgbase.keywords.filter( for kwd in pkgbase.keywords.filter(
~PackageKeyword.Keyword.in_(other_keyword_strings)) ~PackageKeyword.Keyword.in_(other_keyword_strings)
)
) )
db.delete_all(other_keywords) db.delete_all(other_keywords)
@ -124,8 +122,7 @@ async def pkgbase_keywords(request: Request, name: str,
for keyword in new_keywords: for keyword in new_keywords:
db.create(PackageKeyword, PackageBase=pkgbase, Keyword=keyword) db.create(PackageKeyword, PackageBase=pkgbase, Keyword=keyword)
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
@router.get("/pkgbase/{name}/flag") @router.get("/pkgbase/{name}/flag")
@ -135,8 +132,7 @@ async def pkgbase_flag_get(request: Request, name: str):
has_cred = request.user.has_credential(creds.PKGBASE_FLAG) has_cred = request.user.has_credential(creds.PKGBASE_FLAG)
if not has_cred or pkgbase.OutOfDateTS is not None: if not has_cred or pkgbase.OutOfDateTS is not None:
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
context = templates.make_context(request, "Flag Package Out-Of-Date") context = templates.make_context(request, "Flag Package Out-Of-Date")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
@ -146,17 +142,20 @@ async def pkgbase_flag_get(request: Request, name: str):
@router.post("/pkgbase/{name}/flag") @router.post("/pkgbase/{name}/flag")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_flag_post(request: Request, name: str, async def pkgbase_flag_post(
comments: str = Form(default=str())): request: Request, name: str, comments: str = Form(default=str())
):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
if not comments: if not comments:
context = templates.make_context(request, "Flag Package Out-Of-Date") context = templates.make_context(request, "Flag Package Out-Of-Date")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
context["errors"] = ["The selected packages have not been flagged, " context["errors"] = [
"please enter a comment."] "The selected packages have not been flagged, " "please enter a comment."
return render_template(request, "pkgbase/flag.html", context, ]
status_code=HTTPStatus.BAD_REQUEST) return render_template(
request, "pkgbase/flag.html", context, status_code=HTTPStatus.BAD_REQUEST
)
has_cred = request.user.has_credential(creds.PKGBASE_FLAG) has_cred = request.user.has_credential(creds.PKGBASE_FLAG)
if has_cred and not pkgbase.OutOfDateTS: if has_cred and not pkgbase.OutOfDateTS:
@ -168,18 +167,19 @@ async def pkgbase_flag_post(request: Request, name: str,
notify.FlagNotification(request.user.ID, pkgbase.ID).send() notify.FlagNotification(request.user.ID, pkgbase.ID).send()
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
@router.post("/pkgbase/{name}/comments") @router.post("/pkgbase/{name}/comments")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_comments_post( async def pkgbase_comments_post(
request: Request, name: str, request: Request,
comment: str = Form(default=str()), name: str,
enable_notifications: bool = Form(default=False)): comment: str = Form(default=str()),
""" Add a new comment via POST request. """ enable_notifications: bool = Form(default=False),
):
"""Add a new comment via POST request."""
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
if not comment: if not comment:
@ -189,29 +189,34 @@ async def pkgbase_comments_post(
# update the db record. # update the db record.
now = time.utcnow() now = time.utcnow()
with db.begin(): with db.begin():
comment = db.create(PackageComment, User=request.user, comment = db.create(
PackageBase=pkgbase, PackageComment,
Comments=comment, RenderedComment=str(), User=request.user,
CommentTS=now) PackageBase=pkgbase,
Comments=comment,
RenderedComment=str(),
CommentTS=now,
)
if enable_notifications and not request.user.notified(pkgbase): if enable_notifications and not request.user.notified(pkgbase):
db.create(PackageNotification, db.create(PackageNotification, User=request.user, PackageBase=pkgbase)
User=request.user,
PackageBase=pkgbase)
update_comment_render_fastapi(comment) update_comment_render_fastapi(comment)
notif = notify.CommentNotification(request.user.ID, pkgbase.ID, comment.ID) notif = notify.CommentNotification(request.user.ID, pkgbase.ID, comment.ID)
notif.send() notif.send()
# Redirect to the pkgbase page. # Redirect to the pkgbase page.
return RedirectResponse(f"/pkgbase/{pkgbase.Name}#comment-{comment.ID}", return RedirectResponse(
status_code=HTTPStatus.SEE_OTHER) f"/pkgbase/{pkgbase.Name}#comment-{comment.ID}",
status_code=HTTPStatus.SEE_OTHER,
)
@router.get("/pkgbase/{name}/comments/{id}/form") @router.get("/pkgbase/{name}/comments/{id}/form")
@requires_auth @requires_auth
async def pkgbase_comment_form(request: Request, name: str, id: int, async def pkgbase_comment_form(
next: str = Query(default=None)): request: Request, name: str, id: int, next: str = Query(default=None)
):
""" """
Produce a comment form for comment {id}. Produce a comment form for comment {id}.
@ -244,14 +249,16 @@ async def pkgbase_comment_form(request: Request, name: str, id: int,
context["next"] = next context["next"] = next
form = templates.render_raw_template( form = templates.render_raw_template(
request, "partials/packages/comment_form.html", context) request, "partials/packages/comment_form.html", context
)
return JSONResponse({"form": form}) return JSONResponse({"form": form})
@router.get("/pkgbase/{name}/comments/{id}/edit") @router.get("/pkgbase/{name}/comments/{id}/edit")
@requires_auth @requires_auth
async def pkgbase_comment_edit(request: Request, name: str, id: int, async def pkgbase_comment_edit(
next: str = Form(default=None)): request: Request, name: str, id: int, next: str = Form(default=None)
):
""" """
Render the non-javascript edit form. Render the non-javascript edit form.
@ -276,11 +283,14 @@ async def pkgbase_comment_edit(request: Request, name: str, id: int,
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_comment_post( async def pkgbase_comment_post(
request: Request, name: str, id: int, request: Request,
comment: str = Form(default=str()), name: str,
enable_notifications: bool = Form(default=False), id: int,
next: str = Form(default=None)): comment: str = Form(default=str()),
""" Edit an existing comment. """ enable_notifications: bool = Form(default=False),
next: str = Form(default=None),
):
"""Edit an existing comment."""
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
db_comment = get_pkgbase_comment(pkgbase, id) db_comment = get_pkgbase_comment(pkgbase, id)
@ -302,24 +312,24 @@ async def pkgbase_comment_post(
PackageNotification.PackageBaseID == pkgbase.ID PackageNotification.PackageBaseID == pkgbase.ID
).first() ).first()
if enable_notifications and not db_notif: if enable_notifications and not db_notif:
db.create(PackageNotification, db.create(PackageNotification, User=request.user, PackageBase=pkgbase)
User=request.user,
PackageBase=pkgbase)
update_comment_render_fastapi(db_comment) update_comment_render_fastapi(db_comment)
if not next: if not next:
next = f"/pkgbase/{pkgbase.Name}" next = f"/pkgbase/{pkgbase.Name}"
# Redirect to the pkgbase page anchored to the updated comment. # Redirect to the pkgbase page anchored to the updated comment.
return RedirectResponse(f"{next}#comment-{db_comment.ID}", return RedirectResponse(
status_code=HTTPStatus.SEE_OTHER) f"{next}#comment-{db_comment.ID}", status_code=HTTPStatus.SEE_OTHER
)
@router.post("/pkgbase/{name}/comments/{id}/pin") @router.post("/pkgbase/{name}/comments/{id}/pin")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_comment_pin(request: Request, name: str, id: int, async def pkgbase_comment_pin(
next: str = Form(default=None)): request: Request, name: str, id: int, next: str = Form(default=None)
):
""" """
Pin a comment. Pin a comment.
@ -332,13 +342,15 @@ async def pkgbase_comment_pin(request: Request, name: str, id: int,
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
comment = get_pkgbase_comment(pkgbase, id) comment = get_pkgbase_comment(pkgbase, id)
has_cred = request.user.has_credential(creds.COMMENT_PIN, has_cred = request.user.has_credential(
approved=comment.maintainers()) creds.COMMENT_PIN, approved=comment.maintainers()
)
if not has_cred: if not has_cred:
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, status_code=HTTPStatus.UNAUTHORIZED,
detail=_("You are not allowed to pin this comment.")) detail=_("You are not allowed to pin this comment."),
)
now = time.utcnow() now = time.utcnow()
with db.begin(): with db.begin():
@ -353,8 +365,9 @@ async def pkgbase_comment_pin(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/unpin") @router.post("/pkgbase/{name}/comments/{id}/unpin")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_comment_unpin(request: Request, name: str, id: int, async def pkgbase_comment_unpin(
next: str = Form(default=None)): request: Request, name: str, id: int, next: str = Form(default=None)
):
""" """
Unpin a comment. Unpin a comment.
@ -367,13 +380,15 @@ async def pkgbase_comment_unpin(request: Request, name: str, id: int,
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
comment = get_pkgbase_comment(pkgbase, id) comment = get_pkgbase_comment(pkgbase, id)
has_cred = request.user.has_credential(creds.COMMENT_PIN, has_cred = request.user.has_credential(
approved=comment.maintainers()) creds.COMMENT_PIN, approved=comment.maintainers()
)
if not has_cred: if not has_cred:
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, status_code=HTTPStatus.UNAUTHORIZED,
detail=_("You are not allowed to unpin this comment.")) detail=_("You are not allowed to unpin this comment."),
)
with db.begin(): with db.begin():
comment.PinnedTS = 0 comment.PinnedTS = 0
@ -387,8 +402,9 @@ async def pkgbase_comment_unpin(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/delete") @router.post("/pkgbase/{name}/comments/{id}/delete")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_comment_delete(request: Request, name: str, id: int, async def pkgbase_comment_delete(
next: str = Form(default=None)): request: Request, name: str, id: int, next: str = Form(default=None)
):
""" """
Delete a comment. Delete a comment.
@ -405,13 +421,13 @@ async def pkgbase_comment_delete(request: Request, name: str, id: int,
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
comment = get_pkgbase_comment(pkgbase, id) comment = get_pkgbase_comment(pkgbase, id)
authorized = request.user.has_credential(creds.COMMENT_DELETE, authorized = request.user.has_credential(creds.COMMENT_DELETE, [comment.User])
[comment.User])
if not authorized: if not authorized:
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, status_code=HTTPStatus.UNAUTHORIZED,
detail=_("You are not allowed to delete this comment.")) detail=_("You are not allowed to delete this comment."),
)
now = time.utcnow() now = time.utcnow()
with db.begin(): with db.begin():
@ -427,8 +443,9 @@ async def pkgbase_comment_delete(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/undelete") @router.post("/pkgbase/{name}/comments/{id}/undelete")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_comment_undelete(request: Request, name: str, id: int, async def pkgbase_comment_undelete(
next: str = Form(default=None)): request: Request, name: str, id: int, next: str = Form(default=None)
):
""" """
Undelete a comment. Undelete a comment.
@ -445,13 +462,15 @@ async def pkgbase_comment_undelete(request: Request, name: str, id: int,
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
comment = get_pkgbase_comment(pkgbase, id) comment = get_pkgbase_comment(pkgbase, id)
has_cred = request.user.has_credential(creds.COMMENT_UNDELETE, has_cred = request.user.has_credential(
approved=[comment.User]) creds.COMMENT_UNDELETE, approved=[comment.User]
)
if not has_cred: if not has_cred:
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, status_code=HTTPStatus.UNAUTHORIZED,
detail=_("You are not allowed to undelete this comment.")) detail=_("You are not allowed to undelete this comment."),
)
with db.begin(): with db.begin():
comment.Deleter = None comment.Deleter = None
@ -469,23 +488,17 @@ async def pkgbase_comment_undelete(request: Request, name: str, id: int,
async def pkgbase_vote(request: Request, name: str): async def pkgbase_vote(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
vote = pkgbase.package_votes.filter( vote = pkgbase.package_votes.filter(PackageVote.UsersID == request.user.ID).first()
PackageVote.UsersID == request.user.ID
).first()
has_cred = request.user.has_credential(creds.PKGBASE_VOTE) has_cred = request.user.has_credential(creds.PKGBASE_VOTE)
if has_cred and not vote: if has_cred and not vote:
now = time.utcnow() now = time.utcnow()
with db.begin(): with db.begin():
db.create(PackageVote, db.create(PackageVote, User=request.user, PackageBase=pkgbase, VoteTS=now)
User=request.user,
PackageBase=pkgbase,
VoteTS=now)
# Update NumVotes/Popularity. # Update NumVotes/Popularity.
popupdate.run_single(pkgbase) popupdate.run_single(pkgbase)
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
@router.post("/pkgbase/{name}/unvote") @router.post("/pkgbase/{name}/unvote")
@ -494,9 +507,7 @@ async def pkgbase_vote(request: Request, name: str):
async def pkgbase_unvote(request: Request, name: str): async def pkgbase_unvote(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
vote = pkgbase.package_votes.filter( vote = pkgbase.package_votes.filter(PackageVote.UsersID == request.user.ID).first()
PackageVote.UsersID == request.user.ID
).first()
has_cred = request.user.has_credential(creds.PKGBASE_VOTE) has_cred = request.user.has_credential(creds.PKGBASE_VOTE)
if has_cred and vote: if has_cred and vote:
with db.begin(): with db.begin():
@ -505,8 +516,7 @@ async def pkgbase_unvote(request: Request, name: str):
# Update NumVotes/Popularity. # Update NumVotes/Popularity.
popupdate.run_single(pkgbase) popupdate.run_single(pkgbase)
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
@router.post("/pkgbase/{name}/notify") @router.post("/pkgbase/{name}/notify")
@ -515,8 +525,7 @@ async def pkgbase_unvote(request: Request, name: str):
async def pkgbase_notify(request: Request, name: str): async def pkgbase_notify(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_notify_instance(request, pkgbase) actions.pkgbase_notify_instance(request, pkgbase)
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
@router.post("/pkgbase/{name}/unnotify") @router.post("/pkgbase/{name}/unnotify")
@ -525,8 +534,7 @@ async def pkgbase_notify(request: Request, name: str):
async def pkgbase_unnotify(request: Request, name: str): async def pkgbase_unnotify(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_unnotify_instance(request, pkgbase) actions.pkgbase_unnotify_instance(request, pkgbase)
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
@router.post("/pkgbase/{name}/unflag") @router.post("/pkgbase/{name}/unflag")
@ -535,20 +543,19 @@ async def pkgbase_unnotify(request: Request, name: str):
async def pkgbase_unflag(request: Request, name: str): async def pkgbase_unflag(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_unflag_instance(request, pkgbase) actions.pkgbase_unflag_instance(request, pkgbase)
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
@router.get("/pkgbase/{name}/disown") @router.get("/pkgbase/{name}/disown")
@requires_auth @requires_auth
async def pkgbase_disown_get(request: Request, name: str, async def pkgbase_disown_get(
next: str = Query(default=str())): request: Request, name: str, next: str = Query(default=str())
):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
comaints = {c.User for c in pkgbase.comaintainers} comaints = {c.User for c in pkgbase.comaintainers}
approved = [pkgbase.Maintainer] + list(comaints) approved = [pkgbase.Maintainer] + list(comaints)
has_cred = request.user.has_credential(creds.PKGBASE_DISOWN, has_cred = request.user.has_credential(creds.PKGBASE_DISOWN, approved=approved)
approved=approved)
if not has_cred: if not has_cred:
return RedirectResponse(f"/pkgbase/{name}", HTTPStatus.SEE_OTHER) return RedirectResponse(f"/pkgbase/{name}", HTTPStatus.SEE_OTHER)
@ -563,27 +570,33 @@ async def pkgbase_disown_get(request: Request, name: str,
@router.post("/pkgbase/{name}/disown") @router.post("/pkgbase/{name}/disown")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_disown_post(request: Request, name: str, async def pkgbase_disown_post(
comments: str = Form(default=str()), request: Request,
confirm: bool = Form(default=False), name: str,
next: str = Form(default=str())): comments: str = Form(default=str()),
confirm: bool = Form(default=False),
next: str = Form(default=str()),
):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
comaints = {c.User for c in pkgbase.comaintainers} comaints = {c.User for c in pkgbase.comaintainers}
approved = [pkgbase.Maintainer] + list(comaints) approved = [pkgbase.Maintainer] + list(comaints)
has_cred = request.user.has_credential(creds.PKGBASE_DISOWN, has_cred = request.user.has_credential(creds.PKGBASE_DISOWN, approved=approved)
approved=approved)
if not has_cred: if not has_cred:
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", HTTPStatus.SEE_OTHER)
HTTPStatus.SEE_OTHER)
context = templates.make_context(request, "Disown Package") context = templates.make_context(request, "Disown Package")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
if not confirm: if not confirm:
context["errors"] = [("The selected packages have not been disowned, " context["errors"] = [
"check the confirmation checkbox.")] (
return render_template(request, "pkgbase/disown.html", context, "The selected packages have not been disowned, "
status_code=HTTPStatus.BAD_REQUEST) "check the confirmation checkbox."
)
]
return render_template(
request, "pkgbase/disown.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if request.user != pkgbase.Maintainer and request.user not in comaints: if request.user != pkgbase.Maintainer and request.user not in comaints:
with db.begin(): with db.begin():
@ -593,8 +606,9 @@ async def pkgbase_disown_post(request: Request, name: str,
actions.pkgbase_disown_instance(request, pkgbase) actions.pkgbase_disown_instance(request, pkgbase)
except InvariantError as exc: except InvariantError as exc:
context["errors"] = [str(exc)] context["errors"] = [str(exc)]
return render_template(request, "pkgbase/disown.html", context, return render_template(
status_code=HTTPStatus.BAD_REQUEST) request, "pkgbase/disown.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if not next: if not next:
next = f"/pkgbase/{name}" next = f"/pkgbase/{name}"
@ -615,8 +629,7 @@ async def pkgbase_adopt_post(request: Request, name: str):
# if no maintainer currently exists. # if no maintainer currently exists.
actions.pkgbase_adopt_instance(request, pkgbase) actions.pkgbase_adopt_instance(request, pkgbase)
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
@router.get("/pkgbase/{name}/comaintainers") @router.get("/pkgbase/{name}/comaintainers")
@ -627,20 +640,20 @@ async def pkgbase_comaintainers(request: Request, name: str) -> Response:
# Unauthorized users (Non-TU/Dev and not the pkgbase maintainer) # Unauthorized users (Non-TU/Dev and not the pkgbase maintainer)
# get redirected to the package base's page. # get redirected to the package base's page.
has_creds = request.user.has_credential(creds.PKGBASE_EDIT_COMAINTAINERS, has_creds = request.user.has_credential(
approved=[pkgbase.Maintainer]) creds.PKGBASE_EDIT_COMAINTAINERS, approved=[pkgbase.Maintainer]
)
if not has_creds: if not has_creds:
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
# Add our base information. # Add our base information.
context = templates.make_context(request, "Manage Co-maintainers") context = templates.make_context(request, "Manage Co-maintainers")
context.update({ context.update(
"pkgbase": pkgbase, {
"comaintainers": [ "pkgbase": pkgbase,
c.User.Username for c in pkgbase.comaintainers "comaintainers": [c.User.Username for c in pkgbase.comaintainers],
] }
}) )
return render_template(request, "pkgbase/comaintainers.html", context) return render_template(request, "pkgbase/comaintainers.html", context)
@ -648,50 +661,52 @@ async def pkgbase_comaintainers(request: Request, name: str) -> Response:
@router.post("/pkgbase/{name}/comaintainers") @router.post("/pkgbase/{name}/comaintainers")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_comaintainers_post(request: Request, name: str, async def pkgbase_comaintainers_post(
users: str = Form(default=str())) \ request: Request, name: str, users: str = Form(default=str())
-> Response: ) -> Response:
# Get the PackageBase. # Get the PackageBase.
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
# Unauthorized users (Non-TU/Dev and not the pkgbase maintainer) # Unauthorized users (Non-TU/Dev and not the pkgbase maintainer)
# get redirected to the package base's page. # get redirected to the package base's page.
has_creds = request.user.has_credential(creds.PKGBASE_EDIT_COMAINTAINERS, has_creds = request.user.has_credential(
approved=[pkgbase.Maintainer]) creds.PKGBASE_EDIT_COMAINTAINERS, approved=[pkgbase.Maintainer]
)
if not has_creds: if not has_creds:
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
users = {e.strip() for e in users.split("\n") if bool(e.strip())} users = {e.strip() for e in users.split("\n") if bool(e.strip())}
records = {c.User.Username for c in pkgbase.comaintainers} records = {c.User.Username for c in pkgbase.comaintainers}
users_to_rm = records.difference(users) users_to_rm = records.difference(users)
pkgbaseutil.remove_comaintainers(pkgbase, users_to_rm) pkgbaseutil.remove_comaintainers(pkgbase, users_to_rm)
logger.debug(f"{request.user} removed comaintainers from " logger.debug(
f"{pkgbase.Name}: {users_to_rm}") f"{request.user} removed comaintainers from " f"{pkgbase.Name}: {users_to_rm}"
)
users_to_add = users.difference(records) users_to_add = users.difference(records)
error = pkgbaseutil.add_comaintainers(request, pkgbase, users_to_add) error = pkgbaseutil.add_comaintainers(request, pkgbase, users_to_add)
if error: if error:
context = templates.make_context(request, "Manage Co-maintainers") context = templates.make_context(request, "Manage Co-maintainers")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
context["comaintainers"] = [ context["comaintainers"] = [c.User.Username for c in pkgbase.comaintainers]
c.User.Username for c in pkgbase.comaintainers
]
context["errors"] = [error] context["errors"] = [error]
return render_template(request, "pkgbase/comaintainers.html", context) return render_template(request, "pkgbase/comaintainers.html", context)
logger.debug(f"{request.user} added comaintainers to " logger.debug(
f"{pkgbase.Name}: {users_to_add}") f"{request.user} added comaintainers to " f"{pkgbase.Name}: {users_to_add}"
)
return RedirectResponse(f"/pkgbase/{pkgbase.Name}", return RedirectResponse(
status_code=HTTPStatus.SEE_OTHER) f"/pkgbase/{pkgbase.Name}", status_code=HTTPStatus.SEE_OTHER
)
@router.get("/pkgbase/{name}/request") @router.get("/pkgbase/{name}/request")
@requires_auth @requires_auth
async def pkgbase_request(request: Request, name: str, async def pkgbase_request(
next: str = Query(default=str())): request: Request, name: str, next: str = Query(default=str())
):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
context = await make_variable_context(request, "Submit Request") context = await make_variable_context(request, "Submit Request")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
@ -702,28 +717,28 @@ async def pkgbase_request(request: Request, name: str,
@router.post("/pkgbase/{name}/request") @router.post("/pkgbase/{name}/request")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_request_post(request: Request, name: str, async def pkgbase_request_post(
type: str = Form(...), request: Request,
merge_into: str = Form(default=None), name: str,
comments: str = Form(default=str()), type: str = Form(...),
next: str = Form(default=str())): merge_into: str = Form(default=None),
comments: str = Form(default=str()),
next: str = Form(default=str()),
):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
# Create our render context. # Create our render context.
context = await make_variable_context(request, "Submit Request") context = await make_variable_context(request, "Submit Request")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
types = { types = {"deletion": DELETION_ID, "merge": MERGE_ID, "orphan": ORPHAN_ID}
"deletion": DELETION_ID,
"merge": MERGE_ID,
"orphan": ORPHAN_ID
}
if type not in types: if type not in types:
# In the case that someone crafted a POST request with an invalid # In the case that someone crafted a POST request with an invalid
# type, just return them to the request form with BAD_REQUEST status. # type, just return them to the request form with BAD_REQUEST status.
return render_template(request, "pkgbase/request.html", context, return render_template(
status_code=HTTPStatus.BAD_REQUEST) request, "pkgbase/request.html", context, status_code=HTTPStatus.BAD_REQUEST
)
try: try:
validate.request(pkgbase, type, comments, merge_into, context) validate.request(pkgbase, type, comments, merge_into, context)
@ -735,20 +750,26 @@ async def pkgbase_request_post(request: Request, name: str,
# All good. Create a new PackageRequest based on the given type. # All good. Create a new PackageRequest based on the given type.
now = time.utcnow() now = time.utcnow()
with db.begin(): with db.begin():
pkgreq = db.create(PackageRequest, pkgreq = db.create(
ReqTypeID=types.get(type), PackageRequest,
User=request.user, ReqTypeID=types.get(type),
RequestTS=now, User=request.user,
PackageBase=pkgbase, RequestTS=now,
PackageBaseName=pkgbase.Name, PackageBase=pkgbase,
MergeBaseName=merge_into, PackageBaseName=pkgbase.Name,
Comments=comments, MergeBaseName=merge_into,
ClosureComment=str()) Comments=comments,
ClosureComment=str(),
)
# Prepare notification object. # Prepare notification object.
notif = notify.RequestOpenNotification( notif = notify.RequestOpenNotification(
request.user.ID, pkgreq.ID, type, request.user.ID,
pkgreq.PackageBase.ID, merge_into=merge_into or None) pkgreq.ID,
type,
pkgreq.PackageBase.ID,
merge_into=merge_into or None,
)
# Send the notification now that we're out of the DB scope. # Send the notification now that we're out of the DB scope.
notif.send() notif.send()
@ -767,13 +788,13 @@ async def pkgbase_request_post(request: Request, name: str,
pkgbase.Maintainer = None pkgbase.Maintainer = None
pkgreq.Status = ACCEPTED_ID pkgreq.Status = ACCEPTED_ID
notif = notify.RequestCloseNotification( notif = notify.RequestCloseNotification(
request.user.ID, pkgreq.ID, pkgreq.status_display()) request.user.ID, pkgreq.ID, pkgreq.status_display()
)
notif.send() notif.send()
logger.debug(f"New request #{pkgreq.ID} is marked for auto-orphan.") logger.debug(f"New request #{pkgreq.ID} is marked for auto-orphan.")
elif type == "deletion" and is_maintainer and outdated: elif type == "deletion" and is_maintainer and outdated:
# This request should be auto-accepted. # This request should be auto-accepted.
notifs = actions.pkgbase_delete_instance( notifs = actions.pkgbase_delete_instance(request, pkgbase, comments=comments)
request, pkgbase, comments=comments)
util.apply_all(notifs, lambda n: n.send()) util.apply_all(notifs, lambda n: n.send())
logger.debug(f"New request #{pkgreq.ID} is marked for auto-deletion.") logger.debug(f"New request #{pkgreq.ID} is marked for auto-deletion.")
@ -783,11 +804,11 @@ async def pkgbase_request_post(request: Request, name: str,
@router.get("/pkgbase/{name}/delete") @router.get("/pkgbase/{name}/delete")
@requires_auth @requires_auth
async def pkgbase_delete_get(request: Request, name: str, async def pkgbase_delete_get(
next: str = Query(default=str())): request: Request, name: str, next: str = Query(default=str())
):
if not request.user.has_credential(creds.PKGBASE_DELETE): if not request.user.has_credential(creds.PKGBASE_DELETE):
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
context = templates.make_context(request, "Package Deletion") context = templates.make_context(request, "Package Deletion")
context["pkgbase"] = get_pkg_or_base(name, PackageBase) context["pkgbase"] = get_pkg_or_base(name, PackageBase)
@ -798,53 +819,60 @@ async def pkgbase_delete_get(request: Request, name: str,
@router.post("/pkgbase/{name}/delete") @router.post("/pkgbase/{name}/delete")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_delete_post(request: Request, name: str, async def pkgbase_delete_post(
confirm: bool = Form(default=False), request: Request,
comments: str = Form(default=str()), name: str,
next: str = Form(default="/packages")): confirm: bool = Form(default=False),
comments: str = Form(default=str()),
next: str = Form(default="/packages"),
):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
if not request.user.has_credential(creds.PKGBASE_DELETE): if not request.user.has_credential(creds.PKGBASE_DELETE):
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
status_code=HTTPStatus.SEE_OTHER)
if not confirm: if not confirm:
context = templates.make_context(request, "Package Deletion") context = templates.make_context(request, "Package Deletion")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
context["errors"] = [("The selected packages have not been deleted, " context["errors"] = [
"check the confirmation checkbox.")] (
return render_template(request, "pkgbase/delete.html", context, "The selected packages have not been deleted, "
status_code=HTTPStatus.BAD_REQUEST) "check the confirmation checkbox."
)
]
return render_template(
request, "pkgbase/delete.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if comments: if comments:
# Update any existing deletion requests' ClosureComment. # Update any existing deletion requests' ClosureComment.
with db.begin(): with db.begin():
requests = pkgbase.requests.filter( requests = pkgbase.requests.filter(
and_(PackageRequest.Status == PENDING_ID, and_(
PackageRequest.ReqTypeID == DELETION_ID) PackageRequest.Status == PENDING_ID,
PackageRequest.ReqTypeID == DELETION_ID,
)
) )
for pkgreq in requests: for pkgreq in requests:
pkgreq.ClosureComment = comments pkgreq.ClosureComment = comments
notifs = actions.pkgbase_delete_instance( notifs = actions.pkgbase_delete_instance(request, pkgbase, comments=comments)
request, pkgbase, comments=comments)
util.apply_all(notifs, lambda n: n.send()) util.apply_all(notifs, lambda n: n.send())
return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER) return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER)
@router.get("/pkgbase/{name}/merge") @router.get("/pkgbase/{name}/merge")
@requires_auth @requires_auth
async def pkgbase_merge_get(request: Request, name: str, async def pkgbase_merge_get(
into: str = Query(default=str()), request: Request,
next: str = Query(default=str())): name: str,
into: str = Query(default=str()),
next: str = Query(default=str()),
):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
context = templates.make_context(request, "Package Merging") context = templates.make_context(request, "Package Merging")
context.update({ context.update({"pkgbase": pkgbase, "into": into, "next": next})
"pkgbase": pkgbase,
"into": into,
"next": next
})
status_code = HTTPStatus.OK status_code = HTTPStatus.OK
# TODO: Lookup errors from credential instead of hardcoding them. # TODO: Lookup errors from credential instead of hardcoding them.
@ -852,51 +880,58 @@ async def pkgbase_merge_get(request: Request, name: str,
# Perhaps additionally: bad_credential_status_code(creds.PKGBASE_MERGE). # Perhaps additionally: bad_credential_status_code(creds.PKGBASE_MERGE).
# Don't take these examples verbatim. We should find good naming. # Don't take these examples verbatim. We should find good naming.
if not request.user.has_credential(creds.PKGBASE_MERGE): if not request.user.has_credential(creds.PKGBASE_MERGE):
context["errors"] = [ context["errors"] = ["Only Trusted Users and Developers can merge packages."]
"Only Trusted Users and Developers can merge packages."]
status_code = HTTPStatus.UNAUTHORIZED status_code = HTTPStatus.UNAUTHORIZED
return render_template(request, "pkgbase/merge.html", context, return render_template(
status_code=status_code) request, "pkgbase/merge.html", context, status_code=status_code
)
@router.post("/pkgbase/{name}/merge") @router.post("/pkgbase/{name}/merge")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def pkgbase_merge_post(request: Request, name: str, async def pkgbase_merge_post(
into: str = Form(default=str()), request: Request,
comments: str = Form(default=str()), name: str,
confirm: bool = Form(default=False), into: str = Form(default=str()),
next: str = Form(default=str())): comments: str = Form(default=str()),
confirm: bool = Form(default=False),
next: str = Form(default=str()),
):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, PackageBase)
context = await make_variable_context(request, "Package Merging") context = await make_variable_context(request, "Package Merging")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
# TODO: Lookup errors from credential instead of hardcoding them. # TODO: Lookup errors from credential instead of hardcoding them.
if not request.user.has_credential(creds.PKGBASE_MERGE): if not request.user.has_credential(creds.PKGBASE_MERGE):
context["errors"] = [ context["errors"] = ["Only Trusted Users and Developers can merge packages."]
"Only Trusted Users and Developers can merge packages."] return render_template(
return render_template(request, "pkgbase/merge.html", context, request, "pkgbase/merge.html", context, status_code=HTTPStatus.UNAUTHORIZED
status_code=HTTPStatus.UNAUTHORIZED) )
if not confirm: if not confirm:
context["errors"] = ["The selected packages have not been deleted, " context["errors"] = [
"check the confirmation checkbox."] "The selected packages have not been deleted, "
return render_template(request, "pkgbase/merge.html", context, "check the confirmation checkbox."
status_code=HTTPStatus.BAD_REQUEST) ]
return render_template(
request, "pkgbase/merge.html", context, status_code=HTTPStatus.BAD_REQUEST
)
try: try:
target = get_pkg_or_base(into, PackageBase) target = get_pkg_or_base(into, PackageBase)
except HTTPException: except HTTPException:
context["errors"] = [ context["errors"] = ["Cannot find package to merge votes and comments into."]
"Cannot find package to merge votes and comments into."] return render_template(
return render_template(request, "pkgbase/merge.html", context, request, "pkgbase/merge.html", context, status_code=HTTPStatus.BAD_REQUEST
status_code=HTTPStatus.BAD_REQUEST) )
if pkgbase == target: if pkgbase == target:
context["errors"] = ["Cannot merge a package base with itself."] context["errors"] = ["Cannot merge a package base with itself."]
return render_template(request, "pkgbase/merge.html", context, return render_template(
status_code=HTTPStatus.BAD_REQUEST) request, "pkgbase/merge.html", context, status_code=HTTPStatus.BAD_REQUEST
)
with db.begin(): with db.begin():
update_closure_comment(pkgbase, MERGE_ID, comments, target=target) update_closure_comment(pkgbase, MERGE_ID, comments, target=target)

View file

@ -18,9 +18,11 @@ router = APIRouter()
@router.get("/requests") @router.get("/requests")
@requires_auth @requires_auth
async def requests(request: Request, async def requests(
O: int = Query(default=defaults.O), request: Request,
PP: int = Query(default=defaults.PP)): O: int = Query(default=defaults.O),
PP: int = Query(default=defaults.PP),
):
context = make_context(request, "Requests") context = make_context(request, "Requests")
context["q"] = dict(request.query_params) context["q"] = dict(request.query_params)
@ -30,8 +32,7 @@ async def requests(request: Request,
context["PP"] = PP context["PP"] = PP
# A PackageRequest query, with left inner joined User and RequestType. # A PackageRequest query, with left inner joined User and RequestType.
query = db.query(PackageRequest).join( query = db.query(PackageRequest).join(User, User.ID == PackageRequest.UsersID)
User, User.ID == PackageRequest.UsersID)
# If the request user is not elevated (TU or Dev), then # If the request user is not elevated (TU or Dev), then
# filter PackageRequests which are owned by the request user. # filter PackageRequests which are owned by the request user.
@ -39,12 +40,17 @@ async def requests(request: Request,
query = query.filter(PackageRequest.UsersID == request.user.ID) query = query.filter(PackageRequest.UsersID == request.user.ID)
context["total"] = query.count() context["total"] = query.count()
context["results"] = query.order_by( context["results"] = (
# Order primarily by the Status column being PENDING_ID, query.order_by(
# and secondarily by RequestTS; both in descending order. # Order primarily by the Status column being PENDING_ID,
case([(PackageRequest.Status == PENDING_ID, 1)], else_=0).desc(), # and secondarily by RequestTS; both in descending order.
PackageRequest.RequestTS.desc() case([(PackageRequest.Status == PENDING_ID, 1)], else_=0).desc(),
).limit(PP).offset(O).all() PackageRequest.RequestTS.desc(),
)
.limit(PP)
.offset(O)
.all()
)
return render_template(request, "requests.html", context) return render_template(request, "requests.html", context)
@ -66,8 +72,9 @@ async def request_close(request: Request, id: int):
@router.post("/requests/{id}/close") @router.post("/requests/{id}/close")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def request_close_post(request: Request, id: int, async def request_close_post(
comments: str = Form(default=str())): request: Request, id: int, comments: str = Form(default=str())
):
pkgreq = get_pkgreq_by_id(id) pkgreq = get_pkgreq_by_id(id)
# `pkgreq`.User can close their own request. # `pkgreq`.User can close their own request.
@ -87,7 +94,8 @@ async def request_close_post(request: Request, id: int,
pkgreq.Status = REJECTED_ID pkgreq.Status = REJECTED_ID
notify_ = notify.RequestCloseNotification( notify_ = notify.RequestCloseNotification(
request.user.ID, pkgreq.ID, pkgreq.status_display()) request.user.ID, pkgreq.ID, pkgreq.status_display()
)
notify_.send() notify_.send()
return RedirectResponse("/requests", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/requests", status_code=HTTPStatus.SEE_OTHER)

View file

@ -1,12 +1,10 @@
import hashlib import hashlib
import re import re
from http import HTTPStatus from http import HTTPStatus
from typing import Optional from typing import Optional
from urllib.parse import unquote from urllib.parse import unquote
import orjson import orjson
from fastapi import APIRouter, Form, Query, Request, Response from fastapi import APIRouter, Form, Query, Request, Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -19,7 +17,7 @@ router = APIRouter()
def parse_args(request: Request): def parse_args(request: Request):
""" Handle legacy logic of 'arg' and 'arg[]' query parameter handling. """Handle legacy logic of 'arg' and 'arg[]' query parameter handling.
When 'arg' appears as the last argument given to the query string, When 'arg' appears as the last argument given to the query string,
that argument is used by itself as one single argument, regardless that argument is used by itself as one single argument, regardless
@ -39,9 +37,7 @@ def parse_args(request: Request):
# Create a list of (key, value) pairs of the given 'arg' and 'arg[]' # Create a list of (key, value) pairs of the given 'arg' and 'arg[]'
# query parameters from last to first. # query parameters from last to first.
query = list(reversed(unquote(request.url.query).split("&"))) query = list(reversed(unquote(request.url.query).split("&")))
parts = [ parts = [e.split("=", 1) for e in query if e.startswith(("arg=", "arg[]="))]
e.split("=", 1) for e in query if e.startswith(("arg=", "arg[]="))
]
args = [] args = []
if parts: if parts:
@ -63,24 +59,28 @@ def parse_args(request: Request):
return args return args
JSONP_EXPR = re.compile(r'^[a-zA-Z0-9()_.]{1,128}$') JSONP_EXPR = re.compile(r"^[a-zA-Z0-9()_.]{1,128}$")
async def rpc_request(request: Request, async def rpc_request(
v: Optional[int] = None, request: Request,
type: Optional[str] = None, v: Optional[int] = None,
by: Optional[str] = defaults.RPC_SEARCH_BY, type: Optional[str] = None,
arg: Optional[str] = None, by: Optional[str] = defaults.RPC_SEARCH_BY,
args: Optional[list[str]] = [], arg: Optional[str] = None,
callback: Optional[str] = None): args: Optional[list[str]] = [],
callback: Optional[str] = None,
):
# Create a handle to our RPC class. # Create a handle to our RPC class.
rpc = RPC(version=v, type=type) rpc = RPC(version=v, type=type)
# If ratelimit was exceeded, return a 429 Too Many Requests. # If ratelimit was exceeded, return a 429 Too Many Requests.
if check_ratelimit(request): if check_ratelimit(request):
return JSONResponse(rpc.error("Rate limit reached"), return JSONResponse(
status_code=int(HTTPStatus.TOO_MANY_REQUESTS)) rpc.error("Rate limit reached"),
status_code=int(HTTPStatus.TOO_MANY_REQUESTS),
)
# If `callback` was provided, produce a text/javascript response # If `callback` was provided, produce a text/javascript response
# valid for the jsonp callback. Otherwise, by default, return # valid for the jsonp callback. Otherwise, by default, return
@ -115,15 +115,11 @@ async def rpc_request(request: Request,
# The ETag header expects quotes to surround any identifier. # The ETag header expects quotes to surround any identifier.
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag
headers = { headers = {"Content-Type": content_type, "ETag": f'"{etag}"'}
"Content-Type": content_type,
"ETag": f'"{etag}"'
}
if_none_match = request.headers.get("If-None-Match", str()) if_none_match = request.headers.get("If-None-Match", str())
if if_none_match and if_none_match.strip("\t\n\r\" ") == etag: if if_none_match and if_none_match.strip('\t\n\r" ') == etag:
return Response(headers=headers, return Response(headers=headers, status_code=int(HTTPStatus.NOT_MODIFIED))
status_code=int(HTTPStatus.NOT_MODIFIED))
if callback: if callback:
content = f"/**/{callback}({content.decode()})" content = f"/**/{callback}({content.decode()})"
@ -135,13 +131,15 @@ async def rpc_request(request: Request,
@router.get("/rpc.php") # Temporary! Remove on 03/04 @router.get("/rpc.php") # Temporary! Remove on 03/04
@router.get("/rpc/") @router.get("/rpc/")
@router.get("/rpc") @router.get("/rpc")
async def rpc(request: Request, async def rpc(
v: Optional[int] = Query(default=None), request: Request,
type: Optional[str] = Query(default=None), v: Optional[int] = Query(default=None),
by: Optional[str] = Query(default=defaults.RPC_SEARCH_BY), type: Optional[str] = Query(default=None),
arg: Optional[str] = Query(default=None), by: Optional[str] = Query(default=defaults.RPC_SEARCH_BY),
args: Optional[list[str]] = Query(default=[], alias="arg[]"), arg: Optional[str] = Query(default=None),
callback: Optional[str] = Query(default=None)): args: Optional[list[str]] = Query(default=[], alias="arg[]"),
callback: Optional[str] = Query(default=None),
):
if not request.url.query: if not request.url.query:
return documentation() return documentation()
return await rpc_request(request, v, type, by, arg, args, callback) return await rpc_request(request, v, type, by, arg, args, callback)
@ -152,11 +150,13 @@ async def rpc(request: Request,
@router.post("/rpc/") @router.post("/rpc/")
@router.post("/rpc") @router.post("/rpc")
@handle_form_exceptions @handle_form_exceptions
async def rpc_post(request: Request, async def rpc_post(
v: Optional[int] = Form(default=None), request: Request,
type: Optional[str] = Form(default=None), v: Optional[int] = Form(default=None),
by: Optional[str] = Form(default=defaults.RPC_SEARCH_BY), type: Optional[str] = Form(default=None),
arg: Optional[str] = Form(default=None), by: Optional[str] = Form(default=defaults.RPC_SEARCH_BY),
args: Optional[list[str]] = Form(default=[], alias="arg[]"), arg: Optional[str] = Form(default=None),
callback: Optional[str] = Form(default=None)): args: Optional[list[str]] = Form(default=[], alias="arg[]"),
callback: Optional[str] = Form(default=None),
):
return await rpc_request(request, v, type, by, arg, args, callback) return await rpc_request(request, v, type, by, arg, args, callback)

View file

@ -10,9 +10,8 @@ from aurweb.models import Package, PackageBase
router = APIRouter() router = APIRouter()
def make_rss_feed(request: Request, packages: list, def make_rss_feed(request: Request, packages: list, date_attr: str):
date_attr: str): """Create an RSS Feed string for some packages.
""" Create an RSS Feed string for some packages.
:param request: A FastAPI request :param request: A FastAPI request
:param packages: A list of packages to add to the RSS feed :param packages: A list of packages to add to the RSS feed
@ -26,10 +25,12 @@ def make_rss_feed(request: Request, packages: list,
base = f"{request.url.scheme}://{request.url.netloc}" base = f"{request.url.scheme}://{request.url.netloc}"
feed.link(href=base, rel="alternate") feed.link(href=base, rel="alternate")
feed.link(href=f"{base}/rss", rel="self") feed.link(href=f"{base}/rss", rel="self")
feed.image(title="AUR Newest Packages", feed.image(
url=f"{base}/static/css/archnavbar/aurlogo.png", title="AUR Newest Packages",
link=base, url=f"{base}/static/css/archnavbar/aurlogo.png",
description="AUR Newest Packages Feed") link=base,
description="AUR Newest Packages Feed",
)
for pkg in packages: for pkg in packages:
entry = feed.add_entry(order="append") entry = feed.add_entry(order="append")
@ -53,8 +54,12 @@ def make_rss_feed(request: Request, packages: list,
@router.get("/rss/") @router.get("/rss/")
async def rss(request: Request): async def rss(request: Request):
packages = db.query(Package).join(PackageBase).order_by( packages = (
PackageBase.SubmittedTS.desc()).limit(100) db.query(Package)
.join(PackageBase)
.order_by(PackageBase.SubmittedTS.desc())
.limit(100)
)
feed = make_rss_feed(request, packages, "SubmittedTS") feed = make_rss_feed(request, packages, "SubmittedTS")
response = Response(feed, media_type="application/rss+xml") response = Response(feed, media_type="application/rss+xml")
@ -69,8 +74,12 @@ async def rss(request: Request):
@router.get("/rss/modified") @router.get("/rss/modified")
async def rss_modified(request: Request): async def rss_modified(request: Request):
packages = db.query(Package).join(PackageBase).order_by( packages = (
PackageBase.ModifiedTS.desc()).limit(100) db.query(Package)
.join(PackageBase)
.order_by(PackageBase.ModifiedTS.desc())
.limit(100)
)
feed = make_rss_feed(request, packages, "ModifiedTS") feed = make_rss_feed(request, packages, "ModifiedTS")
response = Response(feed, media_type="application/rss+xml") response = Response(feed, media_type="application/rss+xml")

View file

@ -1,11 +1,9 @@
import time import time
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from urllib.parse import urlencode from urllib.parse import urlencode
import fastapi import fastapi
from authlib.integrations.starlette_client import OAuth, OAuthError from authlib.integrations.starlette_client import OAuth, OAuthError
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
@ -14,7 +12,6 @@ from starlette.requests import Request
import aurweb.config import aurweb.config
import aurweb.db import aurweb.db
from aurweb import util from aurweb import util
from aurweb.l10n import get_translator_for_request from aurweb.l10n import get_translator_for_request
from aurweb.schema import Bans, Sessions, Users from aurweb.schema import Bans, Sessions, Users
@ -43,14 +40,18 @@ async def login(request: Request, redirect: str = None):
The `redirect` argument is a query parameter specifying the post-login The `redirect` argument is a query parameter specifying the post-login
redirect URL. redirect URL.
""" """
authenticate_url = aurweb.config.get("options", "aur_location") + "/sso/authenticate" authenticate_url = (
aurweb.config.get("options", "aur_location") + "/sso/authenticate"
)
if redirect: if redirect:
authenticate_url = authenticate_url + "?" + urlencode([("redirect", redirect)]) authenticate_url = authenticate_url + "?" + urlencode([("redirect", redirect)])
return await oauth.sso.authorize_redirect(request, authenticate_url, prompt="login") return await oauth.sso.authorize_redirect(request, authenticate_url, prompt="login")
def is_account_suspended(conn, user_id): def is_account_suspended(conn, user_id):
row = conn.execute(select([Users.c.Suspended]).where(Users.c.ID == user_id)).fetchone() row = conn.execute(
select([Users.c.Suspended]).where(Users.c.ID == user_id)
).fetchone()
return row is not None and bool(row[0]) return row is not None and bool(row[0])
@ -60,23 +61,27 @@ def open_session(request, conn, user_id):
""" """
if is_account_suspended(conn, user_id): if is_account_suspended(conn, user_id):
_ = get_translator_for_request(request) _ = get_translator_for_request(request)
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, raise HTTPException(
detail=_('Account suspended')) status_code=HTTPStatus.FORBIDDEN, detail=_("Account suspended")
)
# TODO This is a terrible message because it could imply the attempt at # TODO This is a terrible message because it could imply the attempt at
# logging in just caused the suspension. # logging in just caused the suspension.
sid = uuid.uuid4().hex sid = uuid.uuid4().hex
conn.execute(Sessions.insert().values( conn.execute(
UsersID=user_id, Sessions.insert().values(
SessionID=sid, UsersID=user_id,
LastUpdateTS=time.time(), SessionID=sid,
)) LastUpdateTS=time.time(),
)
)
# Update users last login information. # Update users last login information.
conn.execute(Users.update() conn.execute(
.where(Users.c.ID == user_id) Users.update()
.values(LastLogin=int(time.time()), .where(Users.c.ID == user_id)
LastLoginIPAddress=request.client.host)) .values(LastLogin=int(time.time()), LastLoginIPAddress=request.client.host)
)
return sid return sid
@ -98,7 +103,9 @@ def is_aur_url(url):
@router.get("/sso/authenticate") @router.get("/sso/authenticate")
async def authenticate(request: Request, redirect: str = None, conn=Depends(aurweb.db.connect)): async def authenticate(
request: Request, redirect: str = None, conn=Depends(aurweb.db.connect)
):
""" """
Receive an OpenID Connect ID token, validate it, then process it to create Receive an OpenID Connect ID token, validate it, then process it to create
an new AUR session. an new AUR session.
@ -107,9 +114,12 @@ async def authenticate(request: Request, redirect: str = None, conn=Depends(aurw
_ = get_translator_for_request(request) _ = get_translator_for_request(request)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, status_code=HTTPStatus.FORBIDDEN,
detail=_('The login form is currently disabled for your IP address, ' detail=_(
'probably due to sustained spam attacks. Sorry for the ' "The login form is currently disabled for your IP address, "
'inconvenience.')) "probably due to sustained spam attacks. Sorry for the "
"inconvenience."
),
)
try: try:
token = await oauth.sso.authorize_access_token(request) token = await oauth.sso.authorize_access_token(request)
@ -120,30 +130,41 @@ async def authenticate(request: Request, redirect: str = None, conn=Depends(aurw
_ = get_translator_for_request(request) _ = get_translator_for_request(request)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=_('Bad OAuth token. Please retry logging in from the start.')) detail=_("Bad OAuth token. Please retry logging in from the start."),
)
sub = user.get("sub") # this is the SSO account ID in JWT terminology sub = user.get("sub") # this is the SSO account ID in JWT terminology
if not sub: if not sub:
_ = get_translator_for_request(request) _ = get_translator_for_request(request)
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, raise HTTPException(
detail=_("JWT is missing its `sub` field.")) status_code=HTTPStatus.BAD_REQUEST,
detail=_("JWT is missing its `sub` field."),
)
aur_accounts = conn.execute(select([Users.c.ID]).where(Users.c.SSOAccountID == sub)) \ aur_accounts = conn.execute(
.fetchall() select([Users.c.ID]).where(Users.c.SSOAccountID == sub)
).fetchall()
if not aur_accounts: if not aur_accounts:
return "Sorry, we dont seem to know you Sir " + sub return "Sorry, we dont seem to know you Sir " + sub
elif len(aur_accounts) == 1: elif len(aur_accounts) == 1:
sid = open_session(request, conn, aur_accounts[0][Users.c.ID]) sid = open_session(request, conn, aur_accounts[0][Users.c.ID])
response = RedirectResponse(redirect if redirect and is_aur_url(redirect) else "/") response = RedirectResponse(
redirect if redirect and is_aur_url(redirect) else "/"
)
secure_cookies = aurweb.config.getboolean("options", "disable_http_login") secure_cookies = aurweb.config.getboolean("options", "disable_http_login")
response.set_cookie(key="AURSID", value=sid, httponly=True, response.set_cookie(
secure=secure_cookies) key="AURSID", value=sid, httponly=True, secure=secure_cookies
)
if "id_token" in token: if "id_token" in token:
# We save the id_token for the SSO logout. Its not too important # We save the id_token for the SSO logout. Its not too important
# though, so if we cant find it, we can live without it. # though, so if we cant find it, we can live without it.
response.set_cookie(key="SSO_ID_TOKEN", value=token["id_token"], response.set_cookie(
path="/sso/", httponly=True, key="SSO_ID_TOKEN",
secure=secure_cookies) value=token["id_token"],
path="/sso/",
httponly=True,
secure=secure_cookies,
)
return util.add_samesite_fields(response, "strict") return util.add_samesite_fields(response, "strict")
else: else:
# Weve got a severe integrity violation. # Weve got a severe integrity violation.
@ -165,8 +186,12 @@ async def logout(request: Request):
return RedirectResponse("/") return RedirectResponse("/")
metadata = await oauth.sso.load_server_metadata() metadata = await oauth.sso.load_server_metadata()
query = urlencode({'post_logout_redirect_uri': aurweb.config.get('options', 'aur_location'), query = urlencode(
'id_token_hint': id_token}) {
response = RedirectResponse(metadata["end_session_endpoint"] + '?' + query) "post_logout_redirect_uri": aurweb.config.get("options", "aur_location"),
"id_token_hint": id_token,
}
)
response = RedirectResponse(metadata["end_session_endpoint"] + "?" + query)
response.delete_cookie("SSO_ID_TOKEN", path="/sso/") response.delete_cookie("SSO_ID_TOKEN", path="/sso/")
return response return response

View file

@ -1,6 +1,5 @@
import html import html
import typing import typing
from http import HTTPStatus from http import HTTPStatus
from typing import Any from typing import Any
@ -30,33 +29,36 @@ ADDVOTE_SPECIFICS = {
"add_tu": (7 * 24 * 60 * 60, 0.66), "add_tu": (7 * 24 * 60 * 60, 0.66),
"remove_tu": (7 * 24 * 60 * 60, 0.75), "remove_tu": (7 * 24 * 60 * 60, 0.75),
"remove_inactive_tu": (5 * 24 * 60 * 60, 0.66), "remove_inactive_tu": (5 * 24 * 60 * 60, 0.66),
"bylaws": (7 * 24 * 60 * 60, 0.75) "bylaws": (7 * 24 * 60 * 60, 0.75),
} }
def populate_trusted_user_counts(context: dict[str, Any]) -> None: def populate_trusted_user_counts(context: dict[str, Any]) -> None:
tu_query = db.query(User).filter( tu_query = db.query(User).filter(
or_(User.AccountTypeID == TRUSTED_USER_ID, or_(
User.AccountTypeID == TRUSTED_USER_AND_DEV_ID) User.AccountTypeID == TRUSTED_USER_ID,
User.AccountTypeID == TRUSTED_USER_AND_DEV_ID,
)
) )
context["trusted_user_count"] = tu_query.count() context["trusted_user_count"] = tu_query.count()
# In case any records have a None InactivityTS. # In case any records have a None InactivityTS.
active_tu_query = tu_query.filter( active_tu_query = tu_query.filter(
or_(User.InactivityTS.is_(None), or_(User.InactivityTS.is_(None), User.InactivityTS == 0)
User.InactivityTS == 0)
) )
context["active_trusted_user_count"] = active_tu_query.count() context["active_trusted_user_count"] = active_tu_query.count()
@router.get("/tu") @router.get("/tu")
@requires_auth @requires_auth
async def trusted_user(request: Request, async def trusted_user(
coff: int = 0, # current offset request: Request,
cby: str = "desc", # current by coff: int = 0, # current offset
poff: int = 0, # past offset cby: str = "desc", # current by
pby: str = "desc"): # past by poff: int = 0, # past offset
""" Proposal listings. """ pby: str = "desc",
): # past by
"""Proposal listings."""
if not request.user.has_credential(creds.TU_LIST_VOTES): if not request.user.has_credential(creds.TU_LIST_VOTES):
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
@ -81,40 +83,47 @@ async def trusted_user(request: Request,
past_by = "desc" past_by = "desc"
context["past_by"] = past_by context["past_by"] = past_by
current_votes = db.query(models.TUVoteInfo).filter( current_votes = (
models.TUVoteInfo.End > ts).order_by( db.query(models.TUVoteInfo)
models.TUVoteInfo.Submitted.desc()) .filter(models.TUVoteInfo.End > ts)
.order_by(models.TUVoteInfo.Submitted.desc())
)
context["current_votes_count"] = current_votes.count() context["current_votes_count"] = current_votes.count()
current_votes = current_votes.limit(pp).offset(current_off) current_votes = current_votes.limit(pp).offset(current_off)
context["current_votes"] = reversed(current_votes.all()) \ context["current_votes"] = (
if current_by == "asc" else current_votes.all() reversed(current_votes.all()) if current_by == "asc" else current_votes.all()
)
context["current_off"] = current_off context["current_off"] = current_off
past_votes = db.query(models.TUVoteInfo).filter( past_votes = (
models.TUVoteInfo.End <= ts).order_by( db.query(models.TUVoteInfo)
models.TUVoteInfo.Submitted.desc()) .filter(models.TUVoteInfo.End <= ts)
.order_by(models.TUVoteInfo.Submitted.desc())
)
context["past_votes_count"] = past_votes.count() context["past_votes_count"] = past_votes.count()
past_votes = past_votes.limit(pp).offset(past_off) past_votes = past_votes.limit(pp).offset(past_off)
context["past_votes"] = reversed(past_votes.all()) \ context["past_votes"] = (
if past_by == "asc" else past_votes.all() reversed(past_votes.all()) if past_by == "asc" else past_votes.all()
)
context["past_off"] = past_off context["past_off"] = past_off
last_vote = func.max(models.TUVote.VoteID).label("LastVote") last_vote = func.max(models.TUVote.VoteID).label("LastVote")
last_votes_by_tu = db.query(models.TUVote).join(models.User).join( last_votes_by_tu = (
models.TUVoteInfo, db.query(models.TUVote)
models.TUVoteInfo.ID == models.TUVote.VoteID .join(models.User)
).filter( .join(models.TUVoteInfo, models.TUVoteInfo.ID == models.TUVote.VoteID)
and_(models.TUVote.VoteID == models.TUVoteInfo.ID, .filter(
models.User.ID == models.TUVote.UserID, and_(
models.TUVoteInfo.End < ts, models.TUVote.VoteID == models.TUVoteInfo.ID,
or_(models.User.AccountTypeID == 2, models.User.ID == models.TUVote.UserID,
models.User.AccountTypeID == 4)) models.TUVoteInfo.End < ts,
).with_entities( or_(models.User.AccountTypeID == 2, models.User.AccountTypeID == 4),
models.TUVote.UserID, )
last_vote, )
models.User.Username .with_entities(models.TUVote.UserID, last_vote, models.User.Username)
).group_by(models.TUVote.UserID).order_by( .group_by(models.TUVote.UserID)
last_vote.desc(), models.User.Username.asc()) .order_by(last_vote.desc(), models.User.Username.asc())
)
context["last_votes_by_tu"] = last_votes_by_tu.all() context["last_votes_by_tu"] = last_votes_by_tu.all()
context["current_by_next"] = "asc" if current_by == "desc" else "desc" context["current_by_next"] = "asc" if current_by == "desc" else "desc"
@ -126,18 +135,22 @@ async def trusted_user(request: Request,
"coff": current_off, "coff": current_off,
"cby": current_by, "cby": current_by,
"poff": past_off, "poff": past_off,
"pby": past_by "pby": past_by,
} }
return render_template(request, "tu/index.html", context) return render_template(request, "tu/index.html", context)
def render_proposal(request: Request, context: dict, proposal: int, def render_proposal(
voteinfo: models.TUVoteInfo, request: Request,
voters: typing.Iterable[models.User], context: dict,
vote: models.TUVote, proposal: int,
status_code: HTTPStatus = HTTPStatus.OK): voteinfo: models.TUVoteInfo,
""" Render a single TU proposal. """ voters: typing.Iterable[models.User],
vote: models.TUVote,
status_code: HTTPStatus = HTTPStatus.OK,
):
"""Render a single TU proposal."""
context["proposal"] = proposal context["proposal"] = proposal
context["voteinfo"] = voteinfo context["voteinfo"] = voteinfo
context["voters"] = voters.all() context["voters"] = voters.all()
@ -146,8 +159,9 @@ def render_proposal(request: Request, context: dict, proposal: int,
participation = (total / voteinfo.ActiveTUs) if voteinfo.ActiveTUs else 0 participation = (total / voteinfo.ActiveTUs) if voteinfo.ActiveTUs else 0
context["participation"] = participation context["participation"] = participation
accepted = (voteinfo.Yes > voteinfo.ActiveTUs / 2) or \ accepted = (voteinfo.Yes > voteinfo.ActiveTUs / 2) or (
(participation > voteinfo.Quorum and voteinfo.Yes > voteinfo.No) participation > voteinfo.Quorum and voteinfo.Yes > voteinfo.No
)
context["accepted"] = accepted context["accepted"] = accepted
can_vote = voters.filter(models.TUVote.User == request.user).first() is None can_vote = voters.filter(models.TUVote.User == request.user).first() is None
@ -159,8 +173,7 @@ def render_proposal(request: Request, context: dict, proposal: int,
context["vote"] = vote context["vote"] = vote
context["has_voted"] = vote is not None context["has_voted"] = vote is not None
return render_template(request, "tu/show.html", context, return render_template(request, "tu/show.html", context, status_code=status_code)
status_code=status_code)
@router.get("/tu/{proposal}") @router.get("/tu/{proposal}")
@ -172,16 +185,27 @@ async def trusted_user_proposal(request: Request, proposal: int):
context = await make_variable_context(request, "Trusted User") context = await make_variable_context(request, "Trusted User")
proposal = int(proposal) proposal = int(proposal)
voteinfo = db.query(models.TUVoteInfo).filter( voteinfo = (
models.TUVoteInfo.ID == proposal).first() db.query(models.TUVoteInfo).filter(models.TUVoteInfo.ID == proposal).first()
)
if not voteinfo: if not voteinfo:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
voters = db.query(models.User).join(models.TUVote).filter( voters = (
models.TUVote.VoteID == voteinfo.ID) db.query(models.User)
vote = db.query(models.TUVote).filter( .join(models.TUVote)
and_(models.TUVote.UserID == request.user.ID, .filter(models.TUVote.VoteID == voteinfo.ID)
models.TUVote.VoteID == voteinfo.ID)).first() )
vote = (
db.query(models.TUVote)
.filter(
and_(
models.TUVote.UserID == request.user.ID,
models.TUVote.VoteID == voteinfo.ID,
)
)
.first()
)
if not request.user.has_credential(creds.TU_VOTE): if not request.user.has_credential(creds.TU_VOTE):
context["error"] = "Only Trusted Users are allowed to vote." context["error"] = "Only Trusted Users are allowed to vote."
if voteinfo.User == request.user.Username: if voteinfo.User == request.user.Username:
@ -196,24 +220,36 @@ async def trusted_user_proposal(request: Request, proposal: int):
@router.post("/tu/{proposal}") @router.post("/tu/{proposal}")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def trusted_user_proposal_post(request: Request, proposal: int, async def trusted_user_proposal_post(
decision: str = Form(...)): request: Request, proposal: int, decision: str = Form(...)
):
if not request.user.has_credential(creds.TU_LIST_VOTES): if not request.user.has_credential(creds.TU_LIST_VOTES):
return RedirectResponse("/tu", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/tu", status_code=HTTPStatus.SEE_OTHER)
context = await make_variable_context(request, "Trusted User") context = await make_variable_context(request, "Trusted User")
proposal = int(proposal) # Make sure it's an int. proposal = int(proposal) # Make sure it's an int.
voteinfo = db.query(models.TUVoteInfo).filter( voteinfo = (
models.TUVoteInfo.ID == proposal).first() db.query(models.TUVoteInfo).filter(models.TUVoteInfo.ID == proposal).first()
)
if not voteinfo: if not voteinfo:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
voters = db.query(models.User).join(models.TUVote).filter( voters = (
models.TUVote.VoteID == voteinfo.ID) db.query(models.User)
vote = db.query(models.TUVote).filter( .join(models.TUVote)
and_(models.TUVote.UserID == request.user.ID, .filter(models.TUVote.VoteID == voteinfo.ID)
models.TUVote.VoteID == voteinfo.ID)).first() )
vote = (
db.query(models.TUVote)
.filter(
and_(
models.TUVote.UserID == request.user.ID,
models.TUVote.VoteID == voteinfo.ID,
)
)
.first()
)
status_code = HTTPStatus.OK status_code = HTTPStatus.OK
if not request.user.has_credential(creds.TU_VOTE): if not request.user.has_credential(creds.TU_VOTE):
@ -227,16 +263,15 @@ async def trusted_user_proposal_post(request: Request, proposal: int,
status_code = HTTPStatus.BAD_REQUEST status_code = HTTPStatus.BAD_REQUEST
if status_code != HTTPStatus.OK: if status_code != HTTPStatus.OK:
return render_proposal(request, context, proposal, return render_proposal(
voteinfo, voters, vote, request, context, proposal, voteinfo, voters, vote, status_code=status_code
status_code=status_code) )
if decision in {"Yes", "No", "Abstain"}: if decision in {"Yes", "No", "Abstain"}:
# Increment whichever decision was given to us. # Increment whichever decision was given to us.
setattr(voteinfo, decision, getattr(voteinfo, decision) + 1) setattr(voteinfo, decision, getattr(voteinfo, decision) + 1)
else: else:
return Response("Invalid 'decision' value.", return Response("Invalid 'decision' value.", status_code=HTTPStatus.BAD_REQUEST)
status_code=HTTPStatus.BAD_REQUEST)
with db.begin(): with db.begin():
vote = db.create(models.TUVote, User=request.user, VoteInfo=voteinfo) vote = db.create(models.TUVote, User=request.user, VoteInfo=voteinfo)
@ -247,8 +282,9 @@ async def trusted_user_proposal_post(request: Request, proposal: int,
@router.get("/addvote") @router.get("/addvote")
@requires_auth @requires_auth
async def trusted_user_addvote(request: Request, user: str = str(), async def trusted_user_addvote(
type: str = "add_tu", agenda: str = str()): request: Request, user: str = str(), type: str = "add_tu", agenda: str = str()
):
if not request.user.has_credential(creds.TU_ADD_VOTE): if not request.user.has_credential(creds.TU_ADD_VOTE):
return RedirectResponse("/tu", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/tu", status_code=HTTPStatus.SEE_OTHER)
@ -268,10 +304,12 @@ async def trusted_user_addvote(request: Request, user: str = str(),
@router.post("/addvote") @router.post("/addvote")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def trusted_user_addvote_post(request: Request, async def trusted_user_addvote_post(
user: str = Form(default=str()), request: Request,
type: str = Form(default=str()), user: str = Form(default=str()),
agenda: str = Form(default=str())): type: str = Form(default=str()),
agenda: str = Form(default=str()),
):
if not request.user.has_credential(creds.TU_ADD_VOTE): if not request.user.has_credential(creds.TU_ADD_VOTE):
return RedirectResponse("/tu", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/tu", status_code=HTTPStatus.SEE_OTHER)
@ -283,26 +321,29 @@ async def trusted_user_addvote_post(request: Request,
context["agenda"] = agenda context["agenda"] = agenda
def render_addvote(context, status_code): def render_addvote(context, status_code):
""" Simplify render_template a bit for this test. """ """Simplify render_template a bit for this test."""
return render_template(request, "addvote.html", context, status_code) return render_template(request, "addvote.html", context, status_code)
# Alright, get some database records, if we can. # Alright, get some database records, if we can.
if type != "bylaws": if type != "bylaws":
user_record = db.query(models.User).filter( user_record = db.query(models.User).filter(models.User.Username == user).first()
models.User.Username == user).first()
if user_record is None: if user_record is None:
context["error"] = "Username does not exist." context["error"] = "Username does not exist."
return render_addvote(context, HTTPStatus.NOT_FOUND) return render_addvote(context, HTTPStatus.NOT_FOUND)
utcnow = time.utcnow() utcnow = time.utcnow()
voteinfo = db.query(models.TUVoteInfo).filter( voteinfo = (
and_(models.TUVoteInfo.User == user, db.query(models.TUVoteInfo)
models.TUVoteInfo.End > utcnow)).count() .filter(
and_(models.TUVoteInfo.User == user, models.TUVoteInfo.End > utcnow)
)
.count()
)
if voteinfo: if voteinfo:
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
context["error"] = _( context["error"] = _("%s already has proposal running for them.") % (
"%s already has proposal running for them.") % ( html.escape(user),
html.escape(user),) )
return render_addvote(context, HTTPStatus.BAD_REQUEST) return render_addvote(context, HTTPStatus.BAD_REQUEST)
if type not in ADDVOTE_SPECIFICS: if type not in ADDVOTE_SPECIFICS:
@ -323,16 +364,27 @@ async def trusted_user_addvote_post(request: Request,
# Create a new TUVoteInfo (proposal)! # Create a new TUVoteInfo (proposal)!
with db.begin(): with db.begin():
active_tus = db.query(User).filter( active_tus = (
and_(User.Suspended == 0, db.query(User)
User.InactivityTS.isnot(None), .filter(
User.AccountTypeID.in_(types)) and_(
).count() User.Suspended == 0,
voteinfo = db.create(models.TUVoteInfo, User=user, User.InactivityTS.isnot(None),
Agenda=html.escape(agenda), User.AccountTypeID.in_(types),
Submitted=timestamp, End=(timestamp + duration), )
Quorum=quorum, ActiveTUs=active_tus, )
Submitter=request.user) .count()
)
voteinfo = db.create(
models.TUVoteInfo,
User=user,
Agenda=html.escape(agenda),
Submitted=timestamp,
End=(timestamp + duration),
Quorum=quorum,
ActiveTUs=active_tus,
Submitter=request.user,
)
# Redirect to the new proposal. # Redirect to the new proposal.
endpoint = f"/tu/{voteinfo.ID}" endpoint = f"/tu/{voteinfo.ID}"

View file

@ -1,5 +1,4 @@
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, NewType, Union from typing import Any, Callable, NewType, Union
@ -7,7 +6,6 @@ from fastapi.responses import HTMLResponse
from sqlalchemy import and_, literal, orm from sqlalchemy import and_, literal, orm
import aurweb.config as config import aurweb.config as config
from aurweb import db, defaults, models from aurweb import db, defaults, models
from aurweb.exceptions import RPCError from aurweb.exceptions import RPCError
from aurweb.filters import number_format from aurweb.filters import number_format
@ -23,8 +21,7 @@ TYPE_MAPPING = {
"replaces": "Replaces", "replaces": "Replaces",
} }
DataGenerator = NewType("DataGenerator", DataGenerator = NewType("DataGenerator", Callable[[models.Package], dict[str, Any]])
Callable[[models.Package], dict[str, Any]])
def documentation(): def documentation():
@ -40,7 +37,7 @@ def documentation():
class RPC: class RPC:
""" RPC API handler class. """RPC API handler class.
There are various pieces to RPC's process, and encapsulating them There are various pieces to RPC's process, and encapsulating them
inside of a class means that external users do not abuse the inside of a class means that external users do not abuse the
@ -66,17 +63,25 @@ class RPC:
# A set of RPC types supported by this API. # A set of RPC types supported by this API.
EXPOSED_TYPES = { EXPOSED_TYPES = {
"info", "multiinfo", "info",
"search", "msearch", "multiinfo",
"suggest", "suggest-pkgbase" "search",
"msearch",
"suggest",
"suggest-pkgbase",
} }
# A mapping of type aliases. # A mapping of type aliases.
TYPE_ALIASES = {"info": "multiinfo"} TYPE_ALIASES = {"info": "multiinfo"}
EXPOSED_BYS = { EXPOSED_BYS = {
"name-desc", "name", "maintainer", "name-desc",
"depends", "makedepends", "optdepends", "checkdepends" "name",
"maintainer",
"depends",
"makedepends",
"optdepends",
"checkdepends",
} }
# A mapping of by aliases. # A mapping of by aliases.
@ -92,7 +97,7 @@ class RPC:
"results": [], "results": [],
"resultcount": 0, "resultcount": 0,
"type": "error", "type": "error",
"error": message "error": message,
} }
def _verify_inputs(self, by: str = [], args: list[str] = []) -> None: def _verify_inputs(self, by: str = [], args: list[str] = []) -> None:
@ -116,7 +121,7 @@ class RPC:
raise RPCError("No request type/data specified.") raise RPCError("No request type/data specified.")
def _get_json_data(self, package: models.Package) -> dict[str, Any]: def _get_json_data(self, package: models.Package) -> dict[str, Any]:
""" Produce dictionary data of one Package that can be JSON-serialized. """Produce dictionary data of one Package that can be JSON-serialized.
:param package: Package instance :param package: Package instance
:returns: JSON-serializable dictionary :returns: JSON-serializable dictionary
@ -143,7 +148,7 @@ class RPC:
"Popularity": pop, "Popularity": pop,
"OutOfDate": package.OutOfDateTS, "OutOfDate": package.OutOfDateTS,
"FirstSubmitted": package.SubmittedTS, "FirstSubmitted": package.SubmittedTS,
"LastModified": package.ModifiedTS "LastModified": package.ModifiedTS,
} }
def _get_info_json_data(self, package: models.Package) -> dict[str, Any]: def _get_info_json_data(self, package: models.Package) -> dict[str, Any]:
@ -151,10 +156,7 @@ class RPC:
# All info results have _at least_ an empty list of # All info results have _at least_ an empty list of
# License and Keywords. # License and Keywords.
data.update({ data.update({"License": [], "Keywords": []})
"License": [],
"Keywords": []
})
# If we actually got extra_info records, update data with # If we actually got extra_info records, update data with
# them for this particular package. # them for this particular package.
@ -163,9 +165,9 @@ class RPC:
return data return data
def _assemble_json_data(self, packages: list[models.Package], def _assemble_json_data(
data_generator: DataGenerator) \ self, packages: list[models.Package], data_generator: DataGenerator
-> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
""" """
Assemble JSON data out of a list of packages. Assemble JSON data out of a list of packages.
@ -175,7 +177,7 @@ class RPC:
return [data_generator(pkg) for pkg in packages] return [data_generator(pkg) for pkg in packages]
def _entities(self, query: orm.Query) -> orm.Query: def _entities(self, query: orm.Query) -> orm.Query:
""" Select specific RPC columns on `query`. """ """Select specific RPC columns on `query`."""
return query.with_entities( return query.with_entities(
models.Package.ID, models.Package.ID,
models.Package.Name, models.Package.Name,
@ -192,16 +194,22 @@ class RPC:
models.User.Username.label("Maintainer"), models.User.Username.label("Maintainer"),
).group_by(models.Package.ID) ).group_by(models.Package.ID)
def _handle_multiinfo_type(self, args: list[str] = [], **kwargs) \ def _handle_multiinfo_type(
-> list[dict[str, Any]]: self, args: list[str] = [], **kwargs
) -> list[dict[str, Any]]:
self._enforce_args(args) self._enforce_args(args)
args = set(args) args = set(args)
packages = db.query(models.Package).join(models.PackageBase).join( packages = (
models.User, db.query(models.Package)
models.User.ID == models.PackageBase.MaintainerUID, .join(models.PackageBase)
isouter=True .join(
).filter(models.Package.Name.in_(args)) models.User,
models.User.ID == models.PackageBase.MaintainerUID,
isouter=True,
)
.filter(models.Package.Name.in_(args))
)
max_results = config.getint("options", "max_rpc_results") max_results = config.getint("options", "max_rpc_results")
packages = self._entities(packages).limit(max_results + 1) packages = self._entities(packages).limit(max_results + 1)
@ -217,65 +225,75 @@ class RPC:
subqueries = [ subqueries = [
# PackageDependency # PackageDependency
db.query( db.query(models.PackageDependency)
models.PackageDependency .join(models.DependencyType)
).join(models.DependencyType).filter( .filter(models.PackageDependency.PackageID.in_(ids))
models.PackageDependency.PackageID.in_(ids) .with_entities(
).with_entities(
models.PackageDependency.PackageID.label("ID"), models.PackageDependency.PackageID.label("ID"),
models.DependencyType.Name.label("Type"), models.DependencyType.Name.label("Type"),
models.PackageDependency.DepName.label("Name"), models.PackageDependency.DepName.label("Name"),
models.PackageDependency.DepCondition.label("Cond") models.PackageDependency.DepCondition.label("Cond"),
).distinct().order_by("Name"), )
.distinct()
.order_by("Name"),
# PackageRelation # PackageRelation
db.query( db.query(models.PackageRelation)
models.PackageRelation .join(models.RelationType)
).join(models.RelationType).filter( .filter(models.PackageRelation.PackageID.in_(ids))
models.PackageRelation.PackageID.in_(ids) .with_entities(
).with_entities(
models.PackageRelation.PackageID.label("ID"), models.PackageRelation.PackageID.label("ID"),
models.RelationType.Name.label("Type"), models.RelationType.Name.label("Type"),
models.PackageRelation.RelName.label("Name"), models.PackageRelation.RelName.label("Name"),
models.PackageRelation.RelCondition.label("Cond") models.PackageRelation.RelCondition.label("Cond"),
).distinct().order_by("Name"), )
.distinct()
.order_by("Name"),
# Groups # Groups
db.query(models.PackageGroup).join( db.query(models.PackageGroup)
.join(
models.Group, models.Group,
and_(models.PackageGroup.GroupID == models.Group.ID, and_(
models.PackageGroup.PackageID.in_(ids)) models.PackageGroup.GroupID == models.Group.ID,
).with_entities( models.PackageGroup.PackageID.in_(ids),
),
)
.with_entities(
models.PackageGroup.PackageID.label("ID"), models.PackageGroup.PackageID.label("ID"),
literal("Groups").label("Type"), literal("Groups").label("Type"),
models.Group.Name.label("Name"), models.Group.Name.label("Name"),
literal(str()).label("Cond") literal(str()).label("Cond"),
).distinct().order_by("Name"), )
.distinct()
.order_by("Name"),
# Licenses # Licenses
db.query(models.PackageLicense).join( db.query(models.PackageLicense)
models.License, .join(models.License, models.PackageLicense.LicenseID == models.License.ID)
models.PackageLicense.LicenseID == models.License.ID .filter(models.PackageLicense.PackageID.in_(ids))
).filter( .with_entities(
models.PackageLicense.PackageID.in_(ids)
).with_entities(
models.PackageLicense.PackageID.label("ID"), models.PackageLicense.PackageID.label("ID"),
literal("License").label("Type"), literal("License").label("Type"),
models.License.Name.label("Name"), models.License.Name.label("Name"),
literal(str()).label("Cond") literal(str()).label("Cond"),
).distinct().order_by("Name"), )
.distinct()
.order_by("Name"),
# Keywords # Keywords
db.query(models.PackageKeyword).join( db.query(models.PackageKeyword)
.join(
models.Package, models.Package,
and_(Package.PackageBaseID == PackageKeyword.PackageBaseID, and_(
Package.ID.in_(ids)) Package.PackageBaseID == PackageKeyword.PackageBaseID,
).with_entities( Package.ID.in_(ids),
),
)
.with_entities(
models.Package.ID.label("ID"), models.Package.ID.label("ID"),
literal("Keywords").label("Type"), literal("Keywords").label("Type"),
models.PackageKeyword.Keyword.label("Name"), models.PackageKeyword.Keyword.label("Name"),
literal(str()).label("Cond") literal(str()).label("Cond"),
).distinct().order_by("Name") )
.distinct()
.order_by("Name"),
] ]
# Union all subqueries together. # Union all subqueries together.
@ -295,8 +313,9 @@ class RPC:
return self._assemble_json_data(packages, self._get_info_json_data) return self._assemble_json_data(packages, self._get_info_json_data)
def _handle_search_type(self, by: str = defaults.RPC_SEARCH_BY, def _handle_search_type(
args: list[str] = []) -> list[dict[str, Any]]: self, by: str = defaults.RPC_SEARCH_BY, args: list[str] = []
) -> list[dict[str, Any]]:
# If `by` isn't maintainer and we don't have any args, raise an error. # If `by` isn't maintainer and we don't have any args, raise an error.
# In maintainer's case, return all orphans if there are no args, # In maintainer's case, return all orphans if there are no args,
# so we need args to pass through to the handler without errors. # so we need args to pass through to the handler without errors.
@ -318,50 +337,64 @@ class RPC:
return self._assemble_json_data(results, self._get_json_data) return self._assemble_json_data(results, self._get_json_data)
def _handle_msearch_type(self, args: list[str] = [], **kwargs)\ def _handle_msearch_type(
-> list[dict[str, Any]]: self, args: list[str] = [], **kwargs
) -> list[dict[str, Any]]:
return self._handle_search_type(by="m", args=args) return self._handle_search_type(by="m", args=args)
def _handle_suggest_type(self, args: list[str] = [], **kwargs)\ def _handle_suggest_type(self, args: list[str] = [], **kwargs) -> list[str]:
-> list[str]:
if not args: if not args:
return [] return []
arg = args[0] arg = args[0]
packages = db.query(models.Package.Name).join( packages = (
models.PackageBase db.query(models.Package.Name)
).filter( .join(models.PackageBase)
and_(models.PackageBase.PackagerUID.isnot(None), .filter(
models.Package.Name.like(f"{arg}%")) and_(
).order_by(models.Package.Name.asc()).limit(20) models.PackageBase.PackagerUID.isnot(None),
models.Package.Name.like(f"{arg}%"),
)
)
.order_by(models.Package.Name.asc())
.limit(20)
)
return [pkg.Name for pkg in packages] return [pkg.Name for pkg in packages]
def _handle_suggest_pkgbase_type(self, args: list[str] = [], **kwargs)\ def _handle_suggest_pkgbase_type(self, args: list[str] = [], **kwargs) -> list[str]:
-> list[str]:
if not args: if not args:
return [] return []
arg = args[0] arg = args[0]
packages = db.query(models.PackageBase.Name).filter( packages = (
and_(models.PackageBase.PackagerUID.isnot(None), db.query(models.PackageBase.Name)
models.PackageBase.Name.like(f"{arg}%")) .filter(
).order_by(models.PackageBase.Name.asc()).limit(20) and_(
models.PackageBase.PackagerUID.isnot(None),
models.PackageBase.Name.like(f"{arg}%"),
)
)
.order_by(models.PackageBase.Name.asc())
.limit(20)
)
return [pkg.Name for pkg in packages] return [pkg.Name for pkg in packages]
def _is_suggestion(self) -> bool: def _is_suggestion(self) -> bool:
return self.type.startswith("suggest") return self.type.startswith("suggest")
def _handle_callback(self, by: str, args: list[str])\ def _handle_callback(
-> Union[list[dict[str, Any]], list[str]]: self, by: str, args: list[str]
) -> Union[list[dict[str, Any]], list[str]]:
# Get a handle to our callback and trap an RPCError with # Get a handle to our callback and trap an RPCError with
# an empty list of results based on callback's execution. # an empty list of results based on callback's execution.
callback = getattr(self, f"_handle_{self.type.replace('-', '_')}_type") callback = getattr(self, f"_handle_{self.type.replace('-', '_')}_type")
results = callback(by=by, args=args) results = callback(by=by, args=args)
return results return results
def handle(self, by: str = defaults.RPC_SEARCH_BY, args: list[str] = [])\ def handle(
-> Union[list[dict[str, Any]], dict[str, Any]]: self, by: str = defaults.RPC_SEARCH_BY, args: list[str] = []
""" Request entrypoint. A router should pass v, type and args ) -> Union[list[dict[str, Any]], dict[str, Any]]:
"""Request entrypoint. A router should pass v, type and args
to this function and expect an output dictionary to be returned. to this function and expect an output dictionary to be returned.
:param v: RPC version argument :param v: RPC version argument
@ -392,8 +425,5 @@ class RPC:
return results return results
# Return JSON output. # Return JSON output.
data.update({ data.update({"resultcount": len(results), "results": results})
"resultcount": len(results),
"results": results
})
return data return data

View file

@ -6,7 +6,18 @@ usually be automatically generated. See `migrations/README` for details.
""" """
from sqlalchemy import CHAR, TIMESTAMP, Column, ForeignKey, Index, MetaData, String, Table, Text, text from sqlalchemy import (
CHAR,
TIMESTAMP,
Column,
ForeignKey,
Index,
MetaData,
String,
Table,
Text,
text,
)
from sqlalchemy.dialects.mysql import BIGINT, DECIMAL, INTEGER, TINYINT from sqlalchemy.dialects.mysql import BIGINT, DECIMAL, INTEGER, TINYINT
from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.compiler import compiles
@ -15,13 +26,13 @@ import aurweb.config
db_backend = aurweb.config.get("database", "backend") db_backend = aurweb.config.get("database", "backend")
@compiles(TINYINT, 'sqlite') @compiles(TINYINT, "sqlite")
def compile_tinyint_sqlite(type_, compiler, **kw): # pragma: no cover def compile_tinyint_sqlite(type_, compiler, **kw): # pragma: no cover
"""TINYINT is not supported on SQLite. Substitute it with INTEGER.""" """TINYINT is not supported on SQLite. Substitute it with INTEGER."""
return 'INTEGER' return "INTEGER"
@compiles(BIGINT, 'sqlite') @compiles(BIGINT, "sqlite")
def compile_bigint_sqlite(type_, compiler, **kw): # pragma: no cover def compile_bigint_sqlite(type_, compiler, **kw): # pragma: no cover
""" """
For SQLite's AUTOINCREMENT to work on BIGINT columns, we need to map BIGINT For SQLite's AUTOINCREMENT to work on BIGINT columns, we need to map BIGINT
@ -29,429 +40,567 @@ def compile_bigint_sqlite(type_, compiler, **kw): # pragma: no cover
See https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#allowing-autoincrement-behavior-sqlalchemy-types-other-than-integer-integer See https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#allowing-autoincrement-behavior-sqlalchemy-types-other-than-integer-integer
""" # noqa: E501 """ # noqa: E501
return 'INTEGER' return "INTEGER"
metadata = MetaData() metadata = MetaData()
# Define the Account Types for the AUR. # Define the Account Types for the AUR.
AccountTypes = Table( AccountTypes = Table(
'AccountTypes', metadata, "AccountTypes",
Column('ID', TINYINT(unsigned=True), primary_key=True), metadata,
Column('AccountType', String(32), nullable=False, server_default=text("''")), Column("ID", TINYINT(unsigned=True), primary_key=True),
mysql_engine='InnoDB', Column("AccountType", String(32), nullable=False, server_default=text("''")),
mysql_charset='utf8mb4', mysql_engine="InnoDB",
mysql_collate='utf8mb4_general_ci' mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# User information for each user regardless of type. # User information for each user regardless of type.
Users = Table( Users = Table(
'Users', metadata, "Users",
Column('ID', INTEGER(unsigned=True), primary_key=True), metadata,
Column('AccountTypeID', ForeignKey('AccountTypes.ID', ondelete="NO ACTION"), nullable=False, server_default=text("1")), Column("ID", INTEGER(unsigned=True), primary_key=True),
Column('Suspended', TINYINT(unsigned=True), nullable=False, server_default=text("0")), Column(
Column('Username', String(32), nullable=False, unique=True), "AccountTypeID",
Column('Email', String(254), nullable=False, unique=True), ForeignKey("AccountTypes.ID", ondelete="NO ACTION"),
Column('BackupEmail', String(254)), nullable=False,
Column('HideEmail', TINYINT(unsigned=True), nullable=False, server_default=text("0")), server_default=text("1"),
Column('Passwd', String(255), nullable=False), ),
Column('Salt', CHAR(32), nullable=False, server_default=text("''")), Column(
Column('ResetKey', CHAR(32), nullable=False, server_default=text("''")), "Suspended", TINYINT(unsigned=True), nullable=False, server_default=text("0")
Column('RealName', String(64), nullable=False, server_default=text("''")), ),
Column('LangPreference', String(6), nullable=False, server_default=text("'en'")), Column("Username", String(32), nullable=False, unique=True),
Column('Timezone', String(32), nullable=False, server_default=text("'UTC'")), Column("Email", String(254), nullable=False, unique=True),
Column('Homepage', Text), Column("BackupEmail", String(254)),
Column('IRCNick', String(32), nullable=False, server_default=text("''")), Column(
Column('PGPKey', String(40)), "HideEmail", TINYINT(unsigned=True), nullable=False, server_default=text("0")
Column('LastLogin', BIGINT(unsigned=True), nullable=False, server_default=text("0")), ),
Column('LastLoginIPAddress', String(45)), Column("Passwd", String(255), nullable=False),
Column('LastSSHLogin', BIGINT(unsigned=True), nullable=False, server_default=text("0")), Column("Salt", CHAR(32), nullable=False, server_default=text("''")),
Column('LastSSHLoginIPAddress', String(45)), Column("ResetKey", CHAR(32), nullable=False, server_default=text("''")),
Column('InactivityTS', BIGINT(unsigned=True), nullable=False, server_default=text("0")), Column("RealName", String(64), nullable=False, server_default=text("''")),
Column('RegistrationTS', TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")), Column("LangPreference", String(6), nullable=False, server_default=text("'en'")),
Column('CommentNotify', TINYINT(1), nullable=False, server_default=text("1")), Column("Timezone", String(32), nullable=False, server_default=text("'UTC'")),
Column('UpdateNotify', TINYINT(1), nullable=False, server_default=text("0")), Column("Homepage", Text),
Column('OwnershipNotify', TINYINT(1), nullable=False, server_default=text("1")), Column("IRCNick", String(32), nullable=False, server_default=text("''")),
Column('SSOAccountID', String(255), nullable=True, unique=True), Column("PGPKey", String(40)),
Index('UsersAccountTypeID', 'AccountTypeID'), Column(
mysql_engine='InnoDB', "LastLogin", BIGINT(unsigned=True), nullable=False, server_default=text("0")
mysql_charset='utf8mb4', ),
mysql_collate='utf8mb4_general_ci', Column("LastLoginIPAddress", String(45)),
Column(
"LastSSHLogin", BIGINT(unsigned=True), nullable=False, server_default=text("0")
),
Column("LastSSHLoginIPAddress", String(45)),
Column(
"InactivityTS", BIGINT(unsigned=True), nullable=False, server_default=text("0")
),
Column(
"RegistrationTS",
TIMESTAMP,
nullable=False,
server_default=text("CURRENT_TIMESTAMP"),
),
Column("CommentNotify", TINYINT(1), nullable=False, server_default=text("1")),
Column("UpdateNotify", TINYINT(1), nullable=False, server_default=text("0")),
Column("OwnershipNotify", TINYINT(1), nullable=False, server_default=text("1")),
Column("SSOAccountID", String(255), nullable=True, unique=True),
Index("UsersAccountTypeID", "AccountTypeID"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# SSH public keys used for the aurweb SSH/Git interface. # SSH public keys used for the aurweb SSH/Git interface.
SSHPubKeys = Table( SSHPubKeys = Table(
'SSHPubKeys', metadata, "SSHPubKeys",
Column('UserID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False), metadata,
Column('Fingerprint', String(44), primary_key=True), Column("UserID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column('PubKey', String(4096), nullable=False), Column("Fingerprint", String(44), primary_key=True),
mysql_engine='InnoDB', mysql_charset='utf8mb4', mysql_collate='utf8mb4_bin', Column("PubKey", String(4096), nullable=False),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_bin",
) )
# Track Users logging in/out of AUR web site. # Track Users logging in/out of AUR web site.
Sessions = Table( Sessions = Table(
'Sessions', metadata, "Sessions",
Column('UsersID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False), metadata,
Column('SessionID', CHAR(32), nullable=False, unique=True), Column("UsersID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column('LastUpdateTS', BIGINT(unsigned=True), nullable=False), Column("SessionID", CHAR(32), nullable=False, unique=True),
mysql_engine='InnoDB', mysql_charset='utf8mb4', mysql_collate='utf8mb4_bin', Column("LastUpdateTS", BIGINT(unsigned=True), nullable=False),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_bin",
) )
# Information on package bases # Information on package bases
PackageBases = Table( PackageBases = Table(
'PackageBases', metadata, "PackageBases",
Column('ID', INTEGER(unsigned=True), primary_key=True), metadata,
Column('Name', String(255), nullable=False, unique=True), Column("ID", INTEGER(unsigned=True), primary_key=True),
Column('NumVotes', INTEGER(unsigned=True), nullable=False, server_default=text("0")), Column("Name", String(255), nullable=False, unique=True),
Column('Popularity', Column(
DECIMAL(10, 6, unsigned=True) "NumVotes", INTEGER(unsigned=True), nullable=False, server_default=text("0")
if db_backend == "mysql" else String(17), ),
nullable=False, server_default=text("0")), Column(
Column('OutOfDateTS', BIGINT(unsigned=True)), "Popularity",
Column('FlaggerComment', Text, nullable=False), DECIMAL(10, 6, unsigned=True) if db_backend == "mysql" else String(17),
Column('SubmittedTS', BIGINT(unsigned=True), nullable=False), nullable=False,
Column('ModifiedTS', BIGINT(unsigned=True), nullable=False), server_default=text("0"),
Column('FlaggerUID', ForeignKey('Users.ID', ondelete='SET NULL')), # who flagged the package out-of-date? ),
Column("OutOfDateTS", BIGINT(unsigned=True)),
Column("FlaggerComment", Text, nullable=False),
Column("SubmittedTS", BIGINT(unsigned=True), nullable=False),
Column("ModifiedTS", BIGINT(unsigned=True), nullable=False),
Column(
"FlaggerUID", ForeignKey("Users.ID", ondelete="SET NULL")
), # who flagged the package out-of-date?
# deleting a user will cause packages to be orphaned, not deleted # deleting a user will cause packages to be orphaned, not deleted
Column('SubmitterUID', ForeignKey('Users.ID', ondelete='SET NULL')), # who submitted it? Column(
Column('MaintainerUID', ForeignKey('Users.ID', ondelete='SET NULL')), # User "SubmitterUID", ForeignKey("Users.ID", ondelete="SET NULL")
Column('PackagerUID', ForeignKey('Users.ID', ondelete='SET NULL')), # Last packager ), # who submitted it?
Index('BasesMaintainerUID', 'MaintainerUID'), Column("MaintainerUID", ForeignKey("Users.ID", ondelete="SET NULL")), # User
Index('BasesNumVotes', 'NumVotes'), Column("PackagerUID", ForeignKey("Users.ID", ondelete="SET NULL")), # Last packager
Index('BasesPackagerUID', 'PackagerUID'), Index("BasesMaintainerUID", "MaintainerUID"),
Index('BasesSubmitterUID', 'SubmitterUID'), Index("BasesNumVotes", "NumVotes"),
mysql_engine='InnoDB', Index("BasesPackagerUID", "PackagerUID"),
mysql_charset='utf8mb4', Index("BasesSubmitterUID", "SubmitterUID"),
mysql_collate='utf8mb4_general_ci', mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Keywords of package bases # Keywords of package bases
PackageKeywords = Table( PackageKeywords = Table(
'PackageKeywords', metadata, "PackageKeywords",
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), primary_key=True, nullable=True), metadata,
Column('Keyword', String(255), primary_key=True, nullable=False, server_default=text("''")), Column(
mysql_engine='InnoDB', "PackageBaseID",
mysql_charset='utf8mb4', ForeignKey("PackageBases.ID", ondelete="CASCADE"),
mysql_collate='utf8mb4_general_ci', primary_key=True,
nullable=True,
),
Column(
"Keyword",
String(255),
primary_key=True,
nullable=False,
server_default=text("''"),
),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Information about the actual packages # Information about the actual packages
Packages = Table( Packages = Table(
'Packages', metadata, "Packages",
Column('ID', INTEGER(unsigned=True), primary_key=True), metadata,
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), nullable=False), Column("ID", INTEGER(unsigned=True), primary_key=True),
Column('Name', String(255), nullable=False, unique=True), Column(
Column('Version', String(255), nullable=False, server_default=text("''")), "PackageBaseID",
Column('Description', String(255)), ForeignKey("PackageBases.ID", ondelete="CASCADE"),
Column('URL', String(8000)), nullable=False,
mysql_engine='InnoDB', ),
mysql_charset='utf8mb4', Column("Name", String(255), nullable=False, unique=True),
mysql_collate='utf8mb4_general_ci', Column("Version", String(255), nullable=False, server_default=text("''")),
Column("Description", String(255)),
Column("URL", String(8000)),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Information about licenses # Information about licenses
Licenses = Table( Licenses = Table(
'Licenses', metadata, "Licenses",
Column('ID', INTEGER(unsigned=True), primary_key=True), metadata,
Column('Name', String(255), nullable=False, unique=True), Column("ID", INTEGER(unsigned=True), primary_key=True),
mysql_engine='InnoDB', Column("Name", String(255), nullable=False, unique=True),
mysql_charset='utf8mb4', mysql_engine="InnoDB",
mysql_collate='utf8mb4_general_ci', mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Information about package-license-relations # Information about package-license-relations
PackageLicenses = Table( PackageLicenses = Table(
'PackageLicenses', metadata, "PackageLicenses",
Column('PackageID', ForeignKey('Packages.ID', ondelete='CASCADE'), primary_key=True, nullable=True), metadata,
Column('LicenseID', ForeignKey('Licenses.ID', ondelete='CASCADE'), primary_key=True, nullable=True), Column(
mysql_engine='InnoDB', "PackageID",
ForeignKey("Packages.ID", ondelete="CASCADE"),
primary_key=True,
nullable=True,
),
Column(
"LicenseID",
ForeignKey("Licenses.ID", ondelete="CASCADE"),
primary_key=True,
nullable=True,
),
mysql_engine="InnoDB",
) )
# Information about groups # Information about groups
Groups = Table( Groups = Table(
'Groups', metadata, "Groups",
Column('ID', INTEGER(unsigned=True), primary_key=True), metadata,
Column('Name', String(255), nullable=False, unique=True), Column("ID", INTEGER(unsigned=True), primary_key=True),
mysql_engine='InnoDB', Column("Name", String(255), nullable=False, unique=True),
mysql_charset='utf8mb4', mysql_engine="InnoDB",
mysql_collate='utf8mb4_general_ci', mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Information about package-group-relations # Information about package-group-relations
PackageGroups = Table( PackageGroups = Table(
'PackageGroups', metadata, "PackageGroups",
Column('PackageID', ForeignKey('Packages.ID', ondelete='CASCADE'), primary_key=True, nullable=True), metadata,
Column('GroupID', ForeignKey('Groups.ID', ondelete='CASCADE'), primary_key=True, nullable=True), Column(
mysql_engine='InnoDB', "PackageID",
ForeignKey("Packages.ID", ondelete="CASCADE"),
primary_key=True,
nullable=True,
),
Column(
"GroupID",
ForeignKey("Groups.ID", ondelete="CASCADE"),
primary_key=True,
nullable=True,
),
mysql_engine="InnoDB",
) )
# Define the package dependency types # Define the package dependency types
DependencyTypes = Table( DependencyTypes = Table(
'DependencyTypes', metadata, "DependencyTypes",
Column('ID', TINYINT(unsigned=True), primary_key=True), metadata,
Column('Name', String(32), nullable=False, server_default=text("''")), Column("ID", TINYINT(unsigned=True), primary_key=True),
mysql_engine='InnoDB', Column("Name", String(32), nullable=False, server_default=text("''")),
mysql_charset='utf8mb4', mysql_engine="InnoDB",
mysql_collate='utf8mb4_general_ci', mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Track which dependencies a package has # Track which dependencies a package has
PackageDepends = Table( PackageDepends = Table(
'PackageDepends', metadata, "PackageDepends",
Column('PackageID', ForeignKey('Packages.ID', ondelete='CASCADE'), nullable=False), metadata,
Column('DepTypeID', ForeignKey('DependencyTypes.ID', ondelete="NO ACTION"), nullable=False), Column("PackageID", ForeignKey("Packages.ID", ondelete="CASCADE"), nullable=False),
Column('DepName', String(255), nullable=False), Column(
Column('DepDesc', String(255)), "DepTypeID",
Column('DepCondition', String(255)), ForeignKey("DependencyTypes.ID", ondelete="NO ACTION"),
Column('DepArch', String(255)), nullable=False,
Index('DependsDepName', 'DepName'), ),
Index('DependsPackageID', 'PackageID'), Column("DepName", String(255), nullable=False),
mysql_engine='InnoDB', Column("DepDesc", String(255)),
mysql_charset='utf8mb4', Column("DepCondition", String(255)),
mysql_collate='utf8mb4_general_ci', Column("DepArch", String(255)),
Index("DependsDepName", "DepName"),
Index("DependsPackageID", "PackageID"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Define the package relation types # Define the package relation types
RelationTypes = Table( RelationTypes = Table(
'RelationTypes', metadata, "RelationTypes",
Column('ID', TINYINT(unsigned=True), primary_key=True), metadata,
Column('Name', String(32), nullable=False, server_default=text("''")), Column("ID", TINYINT(unsigned=True), primary_key=True),
mysql_engine='InnoDB', Column("Name", String(32), nullable=False, server_default=text("''")),
mysql_charset='utf8mb4', mysql_engine="InnoDB",
mysql_collate='utf8mb4_general_ci', mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Track which conflicts, provides and replaces a package has # Track which conflicts, provides and replaces a package has
PackageRelations = Table( PackageRelations = Table(
'PackageRelations', metadata, "PackageRelations",
Column('PackageID', ForeignKey('Packages.ID', ondelete='CASCADE'), nullable=False), metadata,
Column('RelTypeID', ForeignKey('RelationTypes.ID', ondelete="NO ACTION"), nullable=False), Column("PackageID", ForeignKey("Packages.ID", ondelete="CASCADE"), nullable=False),
Column('RelName', String(255), nullable=False), Column(
Column('RelCondition', String(255)), "RelTypeID",
Column('RelArch', String(255)), ForeignKey("RelationTypes.ID", ondelete="NO ACTION"),
Index('RelationsPackageID', 'PackageID'), nullable=False,
Index('RelationsRelName', 'RelName'), ),
mysql_engine='InnoDB', Column("RelName", String(255), nullable=False),
mysql_charset='utf8mb4', Column("RelCondition", String(255)),
mysql_collate='utf8mb4_general_ci', Column("RelArch", String(255)),
Index("RelationsPackageID", "PackageID"),
Index("RelationsRelName", "RelName"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Track which sources a package has # Track which sources a package has
PackageSources = Table( PackageSources = Table(
'PackageSources', metadata, "PackageSources",
Column('PackageID', ForeignKey('Packages.ID', ondelete='CASCADE'), nullable=False), metadata,
Column('Source', String(8000), nullable=False, server_default=text("'/dev/null'")), Column("PackageID", ForeignKey("Packages.ID", ondelete="CASCADE"), nullable=False),
Column('SourceArch', String(255)), Column("Source", String(8000), nullable=False, server_default=text("'/dev/null'")),
Index('SourcesPackageID', 'PackageID'), Column("SourceArch", String(255)),
mysql_engine='InnoDB', Index("SourcesPackageID", "PackageID"),
mysql_charset='utf8mb4', mysql_engine="InnoDB",
mysql_collate='utf8mb4_general_ci', mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Track votes for packages # Track votes for packages
PackageVotes = Table( PackageVotes = Table(
'PackageVotes', metadata, "PackageVotes",
Column('UsersID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False), metadata,
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), nullable=False), Column("UsersID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column('VoteTS', BIGINT(unsigned=True), nullable=False), Column(
Index('VoteUsersIDPackageID', 'UsersID', 'PackageBaseID', unique=True), "PackageBaseID",
Index('VotesPackageBaseID', 'PackageBaseID'), ForeignKey("PackageBases.ID", ondelete="CASCADE"),
Index('VotesUsersID', 'UsersID'), nullable=False,
mysql_engine='InnoDB', ),
Column("VoteTS", BIGINT(unsigned=True), nullable=False),
Index("VoteUsersIDPackageID", "UsersID", "PackageBaseID", unique=True),
Index("VotesPackageBaseID", "PackageBaseID"),
Index("VotesUsersID", "UsersID"),
mysql_engine="InnoDB",
) )
# Record comments for packages # Record comments for packages
PackageComments = Table( PackageComments = Table(
'PackageComments', metadata, "PackageComments",
Column('ID', BIGINT(unsigned=True), primary_key=True), metadata,
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), nullable=False), Column("ID", BIGINT(unsigned=True), primary_key=True),
Column('UsersID', ForeignKey('Users.ID', ondelete='SET NULL')), Column(
Column('Comments', Text, nullable=False), "PackageBaseID",
Column('RenderedComment', Text, nullable=False), ForeignKey("PackageBases.ID", ondelete="CASCADE"),
Column('CommentTS', BIGINT(unsigned=True), nullable=False, server_default=text("0")), nullable=False,
Column('EditedTS', BIGINT(unsigned=True)), ),
Column('EditedUsersID', ForeignKey('Users.ID', ondelete='SET NULL')), Column("UsersID", ForeignKey("Users.ID", ondelete="SET NULL")),
Column('DelTS', BIGINT(unsigned=True)), Column("Comments", Text, nullable=False),
Column('DelUsersID', ForeignKey('Users.ID', ondelete='CASCADE')), Column("RenderedComment", Text, nullable=False),
Column('PinnedTS', BIGINT(unsigned=True), nullable=False, server_default=text("0")), Column(
Index('CommentsPackageBaseID', 'PackageBaseID'), "CommentTS", BIGINT(unsigned=True), nullable=False, server_default=text("0")
Index('CommentsUsersID', 'UsersID'), ),
mysql_engine='InnoDB', Column("EditedTS", BIGINT(unsigned=True)),
mysql_charset='utf8mb4', Column("EditedUsersID", ForeignKey("Users.ID", ondelete="SET NULL")),
mysql_collate='utf8mb4_general_ci', Column("DelTS", BIGINT(unsigned=True)),
Column("DelUsersID", ForeignKey("Users.ID", ondelete="CASCADE")),
Column("PinnedTS", BIGINT(unsigned=True), nullable=False, server_default=text("0")),
Index("CommentsPackageBaseID", "PackageBaseID"),
Index("CommentsUsersID", "UsersID"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Package base co-maintainers # Package base co-maintainers
PackageComaintainers = Table( PackageComaintainers = Table(
'PackageComaintainers', metadata, "PackageComaintainers",
Column('UsersID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False), metadata,
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), nullable=False), Column("UsersID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column('Priority', INTEGER(unsigned=True), nullable=False), Column(
Index('ComaintainersPackageBaseID', 'PackageBaseID'), "PackageBaseID",
Index('ComaintainersUsersID', 'UsersID'), ForeignKey("PackageBases.ID", ondelete="CASCADE"),
mysql_engine='InnoDB', nullable=False,
),
Column("Priority", INTEGER(unsigned=True), nullable=False),
Index("ComaintainersPackageBaseID", "PackageBaseID"),
Index("ComaintainersUsersID", "UsersID"),
mysql_engine="InnoDB",
) )
# Package base notifications # Package base notifications
PackageNotifications = Table( PackageNotifications = Table(
'PackageNotifications', metadata, "PackageNotifications",
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), nullable=False), metadata,
Column('UserID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False), Column(
Index('NotifyUserIDPkgID', 'UserID', 'PackageBaseID', unique=True), "PackageBaseID",
mysql_engine='InnoDB', ForeignKey("PackageBases.ID", ondelete="CASCADE"),
nullable=False,
),
Column("UserID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Index("NotifyUserIDPkgID", "UserID", "PackageBaseID", unique=True),
mysql_engine="InnoDB",
) )
# Package name blacklist # Package name blacklist
PackageBlacklist = Table( PackageBlacklist = Table(
'PackageBlacklist', metadata, "PackageBlacklist",
Column('ID', INTEGER(unsigned=True), primary_key=True), metadata,
Column('Name', String(64), nullable=False, unique=True), Column("ID", INTEGER(unsigned=True), primary_key=True),
mysql_engine='InnoDB', Column("Name", String(64), nullable=False, unique=True),
mysql_charset='utf8mb4', mysql_engine="InnoDB",
mysql_collate='utf8mb4_general_ci', mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Providers in the official repositories # Providers in the official repositories
OfficialProviders = Table( OfficialProviders = Table(
'OfficialProviders', metadata, "OfficialProviders",
Column('ID', INTEGER(unsigned=True), primary_key=True), metadata,
Column('Name', String(64), nullable=False), Column("ID", INTEGER(unsigned=True), primary_key=True),
Column('Repo', String(64), nullable=False), Column("Name", String(64), nullable=False),
Column('Provides', String(64), nullable=False), Column("Repo", String(64), nullable=False),
Index('ProviderNameProvides', 'Name', 'Provides', unique=True), Column("Provides", String(64), nullable=False),
mysql_engine='InnoDB', mysql_charset='utf8mb4', mysql_collate='utf8mb4_bin', Index("ProviderNameProvides", "Name", "Provides", unique=True),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_bin",
) )
# Define package request types # Define package request types
RequestTypes = Table( RequestTypes = Table(
'RequestTypes', metadata, "RequestTypes",
Column('ID', TINYINT(unsigned=True), primary_key=True), metadata,
Column('Name', String(32), nullable=False, server_default=text("''")), Column("ID", TINYINT(unsigned=True), primary_key=True),
mysql_engine='InnoDB', Column("Name", String(32), nullable=False, server_default=text("''")),
mysql_charset='utf8mb4', mysql_engine="InnoDB",
mysql_collate='utf8mb4_general_ci', mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Package requests # Package requests
PackageRequests = Table( PackageRequests = Table(
'PackageRequests', metadata, "PackageRequests",
Column('ID', BIGINT(unsigned=True), primary_key=True), metadata,
Column('ReqTypeID', ForeignKey('RequestTypes.ID', ondelete="NO ACTION"), nullable=False), Column("ID", BIGINT(unsigned=True), primary_key=True),
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='SET NULL')), Column(
Column('PackageBaseName', String(255), nullable=False), "ReqTypeID", ForeignKey("RequestTypes.ID", ondelete="NO ACTION"), nullable=False
Column('MergeBaseName', String(255)), ),
Column('UsersID', ForeignKey('Users.ID', ondelete='SET NULL')), Column("PackageBaseID", ForeignKey("PackageBases.ID", ondelete="SET NULL")),
Column('Comments', Text, nullable=False), Column("PackageBaseName", String(255), nullable=False),
Column('ClosureComment', Text, nullable=False), Column("MergeBaseName", String(255)),
Column('RequestTS', BIGINT(unsigned=True), nullable=False, server_default=text("0")), Column("UsersID", ForeignKey("Users.ID", ondelete="SET NULL")),
Column('ClosedTS', BIGINT(unsigned=True)), Column("Comments", Text, nullable=False),
Column('ClosedUID', ForeignKey('Users.ID', ondelete='SET NULL')), Column("ClosureComment", Text, nullable=False),
Column('Status', TINYINT(unsigned=True), nullable=False, server_default=text("0")), Column(
Index('RequestsPackageBaseID', 'PackageBaseID'), "RequestTS", BIGINT(unsigned=True), nullable=False, server_default=text("0")
Index('RequestsUsersID', 'UsersID'), ),
mysql_engine='InnoDB', Column("ClosedTS", BIGINT(unsigned=True)),
mysql_charset='utf8mb4', Column("ClosedUID", ForeignKey("Users.ID", ondelete="SET NULL")),
mysql_collate='utf8mb4_general_ci', Column("Status", TINYINT(unsigned=True), nullable=False, server_default=text("0")),
Index("RequestsPackageBaseID", "PackageBaseID"),
Index("RequestsUsersID", "UsersID"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Vote information # Vote information
TU_VoteInfo = Table( TU_VoteInfo = Table(
'TU_VoteInfo', metadata, "TU_VoteInfo",
Column('ID', INTEGER(unsigned=True), primary_key=True), metadata,
Column('Agenda', Text, nullable=False), Column("ID", INTEGER(unsigned=True), primary_key=True),
Column('User', String(32), nullable=False), Column("Agenda", Text, nullable=False),
Column('Submitted', BIGINT(unsigned=True), nullable=False), Column("User", String(32), nullable=False),
Column('End', BIGINT(unsigned=True), nullable=False), Column("Submitted", BIGINT(unsigned=True), nullable=False),
Column('Quorum', Column("End", BIGINT(unsigned=True), nullable=False),
DECIMAL(2, 2, unsigned=True) Column(
if db_backend == "mysql" else String(5), "Quorum",
nullable=False), DECIMAL(2, 2, unsigned=True) if db_backend == "mysql" else String(5),
Column('SubmitterID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False), nullable=False,
Column('Yes', INTEGER(unsigned=True), nullable=False, server_default=text("'0'")), ),
Column('No', INTEGER(unsigned=True), nullable=False, server_default=text("'0'")), Column("SubmitterID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column('Abstain', INTEGER(unsigned=True), nullable=False, server_default=text("'0'")), Column("Yes", INTEGER(unsigned=True), nullable=False, server_default=text("'0'")),
Column('ActiveTUs', INTEGER(unsigned=True), nullable=False, server_default=text("'0'")), Column("No", INTEGER(unsigned=True), nullable=False, server_default=text("'0'")),
mysql_engine='InnoDB', Column(
mysql_charset='utf8mb4', "Abstain", INTEGER(unsigned=True), nullable=False, server_default=text("'0'")
mysql_collate='utf8mb4_general_ci', ),
Column(
"ActiveTUs", INTEGER(unsigned=True), nullable=False, server_default=text("'0'")
),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Individual vote records # Individual vote records
TU_Votes = Table( TU_Votes = Table(
'TU_Votes', metadata, "TU_Votes",
Column('VoteID', ForeignKey('TU_VoteInfo.ID', ondelete='CASCADE'), nullable=False), metadata,
Column('UserID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False), Column("VoteID", ForeignKey("TU_VoteInfo.ID", ondelete="CASCADE"), nullable=False),
mysql_engine='InnoDB', Column("UserID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
mysql_engine="InnoDB",
) )
# Malicious user banning # Malicious user banning
Bans = Table( Bans = Table(
'Bans', metadata, "Bans",
Column('IPAddress', String(45), primary_key=True), metadata,
Column('BanTS', TIMESTAMP, nullable=False), Column("IPAddress", String(45), primary_key=True),
mysql_engine='InnoDB', Column("BanTS", TIMESTAMP, nullable=False),
mysql_charset='utf8mb4', mysql_engine="InnoDB",
mysql_collate='utf8mb4_general_ci', mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Terms and Conditions # Terms and Conditions
Terms = Table( Terms = Table(
'Terms', metadata, "Terms",
Column('ID', INTEGER(unsigned=True), primary_key=True), metadata,
Column('Description', String(255), nullable=False), Column("ID", INTEGER(unsigned=True), primary_key=True),
Column('URL', String(8000), nullable=False), Column("Description", String(255), nullable=False),
Column('Revision', INTEGER(unsigned=True), nullable=False, server_default=text("1")), Column("URL", String(8000), nullable=False),
mysql_engine='InnoDB', Column(
mysql_charset='utf8mb4', "Revision", INTEGER(unsigned=True), nullable=False, server_default=text("1")
mysql_collate='utf8mb4_general_ci', ),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )
# Terms and Conditions accepted by users # Terms and Conditions accepted by users
AcceptedTerms = Table( AcceptedTerms = Table(
'AcceptedTerms', metadata, "AcceptedTerms",
Column('UsersID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False), metadata,
Column('TermsID', ForeignKey('Terms.ID', ondelete='CASCADE'), nullable=False), Column("UsersID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column('Revision', INTEGER(unsigned=True), nullable=False, server_default=text("0")), Column("TermsID", ForeignKey("Terms.ID", ondelete="CASCADE"), nullable=False),
mysql_engine='InnoDB', Column(
"Revision", INTEGER(unsigned=True), nullable=False, server_default=text("0")
),
mysql_engine="InnoDB",
) )
# Rate limits for API # Rate limits for API
ApiRateLimit = Table( ApiRateLimit = Table(
'ApiRateLimit', metadata, "ApiRateLimit",
Column('IP', String(45), primary_key=True, unique=True, default=str()), metadata,
Column('Requests', INTEGER(11), nullable=False), Column("IP", String(45), primary_key=True, unique=True, default=str()),
Column('WindowStart', BIGINT(20), nullable=False), Column("Requests", INTEGER(11), nullable=False),
Index('ApiRateLimitWindowStart', 'WindowStart'), Column("WindowStart", BIGINT(20), nullable=False),
mysql_engine='InnoDB', Index("ApiRateLimitWindowStart", "WindowStart"),
mysql_charset='utf8mb4', mysql_engine="InnoDB",
mysql_collate='utf8mb4_general_ci', mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
) )

View file

@ -11,7 +11,6 @@ import sys
import traceback import traceback
import aurweb.models.account_type as at import aurweb.models.account_type as at
from aurweb import db from aurweb import db
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
@ -30,8 +29,9 @@ def parse_args():
parser.add_argument("--ssh-pubkey", help="SSH PubKey") parser.add_argument("--ssh-pubkey", help="SSH PubKey")
choices = at.ACCOUNT_TYPE_NAME.values() choices = at.ACCOUNT_TYPE_NAME.values()
parser.add_argument("-t", "--type", help="Account Type", parser.add_argument(
choices=choices, default=at.USER) "-t", "--type", help="Account Type", choices=choices, default=at.USER
)
return parser.parse_args() return parser.parse_args()
@ -40,25 +40,29 @@ def main():
args = parse_args() args = parse_args()
db.get_engine() db.get_engine()
type = db.query(AccountType, type = db.query(AccountType, AccountType.AccountType == args.type).first()
AccountType.AccountType == args.type).first()
with db.begin(): with db.begin():
user = db.create(User, Username=args.username, user = db.create(
Email=args.email, Passwd=args.password, User,
RealName=args.realname, IRCNick=args.ircnick, Username=args.username,
PGPKey=args.pgp_key, AccountType=type) Email=args.email,
Passwd=args.password,
RealName=args.realname,
IRCNick=args.ircnick,
PGPKey=args.pgp_key,
AccountType=type,
)
if args.ssh_pubkey: if args.ssh_pubkey:
pubkey = args.ssh_pubkey.strip() pubkey = args.ssh_pubkey.strip()
# Remove host from the pubkey if it's there. # Remove host from the pubkey if it's there.
pubkey = ' '.join(pubkey.split(' ')[:2]) pubkey = " ".join(pubkey.split(" ")[:2])
with db.begin(): with db.begin():
db.create(SSHPubKey, db.create(
User=user, SSHPubKey, User=user, PubKey=pubkey, Fingerprint=get_fingerprint(pubkey)
PubKey=pubkey, )
Fingerprint=get_fingerprint(pubkey))
print(user.json()) print(user.json())
return 0 return 0

View file

@ -3,11 +3,9 @@
import re import re
import pyalpm import pyalpm
from sqlalchemy import and_ from sqlalchemy import and_
import aurweb.config import aurweb.config
from aurweb import db, util from aurweb import db, util
from aurweb.models import OfficialProvider from aurweb.models import OfficialProvider
@ -18,8 +16,8 @@ def _main(force: bool = False):
repomap = dict() repomap = dict()
db_path = aurweb.config.get("aurblup", "db-path") db_path = aurweb.config.get("aurblup", "db-path")
sync_dbs = aurweb.config.get('aurblup', 'sync-dbs').split(' ') sync_dbs = aurweb.config.get("aurblup", "sync-dbs").split(" ")
server = aurweb.config.get('aurblup', 'server') server = aurweb.config.get("aurblup", "server")
h = pyalpm.Handle("/", db_path) h = pyalpm.Handle("/", db_path)
for sync_db in sync_dbs: for sync_db in sync_dbs:
@ -35,28 +33,35 @@ def _main(force: bool = False):
providers.add((pkg.name, pkg.name)) providers.add((pkg.name, pkg.name))
repomap[(pkg.name, pkg.name)] = repo.name repomap[(pkg.name, pkg.name)] = repo.name
for provision in pkg.provides: for provision in pkg.provides:
provisionname = re.sub(r'(<|=|>).*', '', provision) provisionname = re.sub(r"(<|=|>).*", "", provision)
providers.add((pkg.name, provisionname)) providers.add((pkg.name, provisionname))
repomap[(pkg.name, provisionname)] = repo.name repomap[(pkg.name, provisionname)] = repo.name
with db.begin(): with db.begin():
old_providers = set( old_providers = set(
db.query(OfficialProvider).with_entities( db.query(OfficialProvider)
.with_entities(
OfficialProvider.Name.label("Name"), OfficialProvider.Name.label("Name"),
OfficialProvider.Provides.label("Provides") OfficialProvider.Provides.label("Provides"),
).distinct().order_by("Name").all() )
.distinct()
.order_by("Name")
.all()
) )
for name, provides in old_providers.difference(providers): for name, provides in old_providers.difference(providers):
db.delete_all(db.query(OfficialProvider).filter( db.delete_all(
and_(OfficialProvider.Name == name, db.query(OfficialProvider).filter(
OfficialProvider.Provides == provides) and_(
)) OfficialProvider.Name == name,
OfficialProvider.Provides == provides,
)
)
)
for name, provides in providers.difference(old_providers): for name, provides in providers.difference(old_providers):
repo = repomap.get((name, provides)) repo = repomap.get((name, provides))
db.create(OfficialProvider, Name=name, db.create(OfficialProvider, Name=name, Repo=repo, Provides=provides)
Repo=repo, Provides=provides)
def main(force: bool = False): def main(force: bool = False):
@ -64,5 +69,5 @@ def main(force: bool = False):
_main(force) _main(force)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -50,12 +50,12 @@ def parse_args():
actions = ["get", "set", "unset"] actions = ["get", "set", "unset"]
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="aurweb configuration tool", description="aurweb configuration tool",
formatter_class=lambda prog: fmt_cls(prog=prog, max_help_position=80)) formatter_class=lambda prog: fmt_cls(prog=prog, max_help_position=80),
)
parser.add_argument("action", choices=actions, help="script action") parser.add_argument("action", choices=actions, help="script action")
parser.add_argument("section", help="config section") parser.add_argument("section", help="config section")
parser.add_argument("option", help="config option") parser.add_argument("option", help="config option")
parser.add_argument("value", nargs="?", default=0, parser.add_argument("value", nargs="?", default=0, help="config option value")
help="config option value")
return parser.parse_args() return parser.parse_args()

View file

@ -25,16 +25,13 @@ import os
import shutil import shutil
import sys import sys
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
import orjson import orjson
from sqlalchemy import literal, orm from sqlalchemy import literal, orm
import aurweb.config import aurweb.config
from aurweb import db, filters, logging, models, util from aurweb import db, filters, logging, models, util
from aurweb.benchmark import Benchmark from aurweb.benchmark import Benchmark
from aurweb.models import Package, PackageBase, User from aurweb.models import Package, PackageBase, User
@ -90,65 +87,68 @@ def get_extended_dict(query: orm.Query):
def get_extended_fields(): def get_extended_fields():
subqueries = [ subqueries = [
# PackageDependency # PackageDependency
db.query( db.query(models.PackageDependency)
models.PackageDependency .join(models.DependencyType)
).join(models.DependencyType).with_entities( .with_entities(
models.PackageDependency.PackageID.label("ID"), models.PackageDependency.PackageID.label("ID"),
models.DependencyType.Name.label("Type"), models.DependencyType.Name.label("Type"),
models.PackageDependency.DepName.label("Name"), models.PackageDependency.DepName.label("Name"),
models.PackageDependency.DepCondition.label("Cond") models.PackageDependency.DepCondition.label("Cond"),
).distinct().order_by("Name"), )
.distinct()
.order_by("Name"),
# PackageRelation # PackageRelation
db.query( db.query(models.PackageRelation)
models.PackageRelation .join(models.RelationType)
).join(models.RelationType).with_entities( .with_entities(
models.PackageRelation.PackageID.label("ID"), models.PackageRelation.PackageID.label("ID"),
models.RelationType.Name.label("Type"), models.RelationType.Name.label("Type"),
models.PackageRelation.RelName.label("Name"), models.PackageRelation.RelName.label("Name"),
models.PackageRelation.RelCondition.label("Cond") models.PackageRelation.RelCondition.label("Cond"),
).distinct().order_by("Name"), )
.distinct()
.order_by("Name"),
# Groups # Groups
db.query(models.PackageGroup).join( db.query(models.PackageGroup)
models.Group, .join(models.Group, models.PackageGroup.GroupID == models.Group.ID)
models.PackageGroup.GroupID == models.Group.ID .with_entities(
).with_entities(
models.PackageGroup.PackageID.label("ID"), models.PackageGroup.PackageID.label("ID"),
literal("Groups").label("Type"), literal("Groups").label("Type"),
models.Group.Name.label("Name"), models.Group.Name.label("Name"),
literal(str()).label("Cond") literal(str()).label("Cond"),
).distinct().order_by("Name"), )
.distinct()
.order_by("Name"),
# Licenses # Licenses
db.query(models.PackageLicense).join( db.query(models.PackageLicense)
models.License, .join(models.License, models.PackageLicense.LicenseID == models.License.ID)
models.PackageLicense.LicenseID == models.License.ID .with_entities(
).with_entities(
models.PackageLicense.PackageID.label("ID"), models.PackageLicense.PackageID.label("ID"),
literal("License").label("Type"), literal("License").label("Type"),
models.License.Name.label("Name"), models.License.Name.label("Name"),
literal(str()).label("Cond") literal(str()).label("Cond"),
).distinct().order_by("Name"), )
.distinct()
.order_by("Name"),
# Keywords # Keywords
db.query(models.PackageKeyword).join( db.query(models.PackageKeyword)
models.Package, .join(
Package.PackageBaseID == models.PackageKeyword.PackageBaseID models.Package, Package.PackageBaseID == models.PackageKeyword.PackageBaseID
).with_entities( )
.with_entities(
models.Package.ID.label("ID"), models.Package.ID.label("ID"),
literal("Keywords").label("Type"), literal("Keywords").label("Type"),
models.PackageKeyword.Keyword.label("Name"), models.PackageKeyword.Keyword.label("Name"),
literal(str()).label("Cond") literal(str()).label("Cond"),
).distinct().order_by("Name") )
.distinct()
.order_by("Name"),
] ]
query = subqueries[0].union_all(*subqueries[1:]) query = subqueries[0].union_all(*subqueries[1:])
return get_extended_dict(query) return get_extended_dict(query)
EXTENDED_FIELD_HANDLERS = { EXTENDED_FIELD_HANDLERS = {"--extended": get_extended_fields}
"--extended": get_extended_fields
}
def as_dict(package: Package) -> dict[str, Any]: def as_dict(package: Package) -> dict[str, Any]:
@ -181,37 +181,38 @@ def _main():
archivedir = aurweb.config.get("mkpkglists", "archivedir") archivedir = aurweb.config.get("mkpkglists", "archivedir")
os.makedirs(archivedir, exist_ok=True) os.makedirs(archivedir, exist_ok=True)
PACKAGES = aurweb.config.get('mkpkglists', 'packagesfile') PACKAGES = aurweb.config.get("mkpkglists", "packagesfile")
META = aurweb.config.get('mkpkglists', 'packagesmetafile') META = aurweb.config.get("mkpkglists", "packagesmetafile")
META_EXT = aurweb.config.get('mkpkglists', 'packagesmetaextfile') META_EXT = aurweb.config.get("mkpkglists", "packagesmetaextfile")
PKGBASE = aurweb.config.get('mkpkglists', 'pkgbasefile') PKGBASE = aurweb.config.get("mkpkglists", "pkgbasefile")
USERS = aurweb.config.get('mkpkglists', 'userfile') USERS = aurweb.config.get("mkpkglists", "userfile")
bench = Benchmark() bench = Benchmark()
logger.info("Started re-creating archives, wait a while...") logger.info("Started re-creating archives, wait a while...")
query = db.query(Package).join( query = (
PackageBase, db.query(Package)
PackageBase.ID == Package.PackageBaseID .join(PackageBase, PackageBase.ID == Package.PackageBaseID)
).join( .join(User, PackageBase.MaintainerUID == User.ID, isouter=True)
User, .filter(PackageBase.PackagerUID.isnot(None))
PackageBase.MaintainerUID == User.ID, .with_entities(
isouter=True Package.ID,
).filter(PackageBase.PackagerUID.isnot(None)).with_entities( Package.Name,
Package.ID, PackageBase.ID.label("PackageBaseID"),
Package.Name, PackageBase.Name.label("PackageBase"),
PackageBase.ID.label("PackageBaseID"), Package.Version,
PackageBase.Name.label("PackageBase"), Package.Description,
Package.Version, Package.URL,
Package.Description, PackageBase.NumVotes,
Package.URL, PackageBase.Popularity,
PackageBase.NumVotes, PackageBase.OutOfDateTS.label("OutOfDate"),
PackageBase.Popularity, User.Username.label("Maintainer"),
PackageBase.OutOfDateTS.label("OutOfDate"), PackageBase.SubmittedTS.label("FirstSubmitted"),
User.Username.label("Maintainer"), PackageBase.ModifiedTS.label("LastModified"),
PackageBase.SubmittedTS.label("FirstSubmitted"), )
PackageBase.ModifiedTS.label("LastModified") .distinct()
).distinct().order_by("Name") .order_by("Name")
)
# Produce packages-meta-v1.json.gz # Produce packages-meta-v1.json.gz
output = list() output = list()
@ -252,7 +253,7 @@ def _main():
# We stream out package json objects line per line, so # We stream out package json objects line per line, so
# we also need to include the ',' character at the end # we also need to include the ',' character at the end
# of package lines (excluding the last package). # of package lines (excluding the last package).
suffix = b",\n" if i < n else b'\n' suffix = b",\n" if i < n else b"\n"
# Write out to packagesmetafile # Write out to packagesmetafile
output.append(item) output.append(item)
@ -273,8 +274,7 @@ def _main():
util.apply_all(gzips.values(), lambda gz: gz.close()) util.apply_all(gzips.values(), lambda gz: gz.close())
# Produce pkgbase.gz # Produce pkgbase.gz
query = db.query(PackageBase.Name).filter( query = db.query(PackageBase.Name).filter(PackageBase.PackagerUID.isnot(None)).all()
PackageBase.PackagerUID.isnot(None)).all()
tmp_pkgbase = os.path.join(tmpdir, os.path.basename(PKGBASE)) tmp_pkgbase = os.path.join(tmpdir, os.path.basename(PKGBASE))
with gzip.open(tmp_pkgbase, "wt") as f: with gzip.open(tmp_pkgbase, "wt") as f:
f.writelines([f"{base.Name}\n" for i, base in enumerate(query)]) f.writelines([f"{base.Name}\n" for i, base in enumerate(query)])
@ -317,5 +317,5 @@ def main():
_main() _main()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

File diff suppressed because it is too large Load diff

View file

@ -11,8 +11,8 @@ def _main():
limit_to = time.utcnow() - 86400 limit_to = time.utcnow() - 86400
query = db.query(PackageBase).filter( query = db.query(PackageBase).filter(
and_(PackageBase.SubmittedTS < limit_to, and_(PackageBase.SubmittedTS < limit_to, PackageBase.PackagerUID.is_(None))
PackageBase.PackagerUID.is_(None))) )
db.delete_all(query) db.delete_all(query)
@ -22,5 +22,5 @@ def main():
_main() _main()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -1,8 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from sqlalchemy import and_, func from sqlalchemy import and_, func
from sqlalchemy.sql.functions import coalesce from sqlalchemy.sql.functions import coalesce, sum as _sum
from sqlalchemy.sql.functions import sum as _sum
from aurweb import db, time from aurweb import db, time
from aurweb.models import PackageBase, PackageVote from aurweb.models import PackageBase, PackageVote
@ -20,18 +19,26 @@ def run_variable(pkgbases: list[PackageBase] = []) -> None:
now = time.utcnow() now = time.utcnow()
# NumVotes subquery. # NumVotes subquery.
votes_subq = db.get_session().query( votes_subq = (
func.count("*") db.get_session()
).select_from(PackageVote).filter( .query(func.count("*"))
PackageVote.PackageBaseID == PackageBase.ID .select_from(PackageVote)
.filter(PackageVote.PackageBaseID == PackageBase.ID)
) )
# Popularity subquery. # Popularity subquery.
pop_subq = db.get_session().query( pop_subq = (
coalesce(_sum(func.pow(0.98, (now - PackageVote.VoteTS) / 86400)), 0.0), db.get_session()
).select_from(PackageVote).filter( .query(
and_(PackageVote.PackageBaseID == PackageBase.ID, coalesce(_sum(func.pow(0.98, (now - PackageVote.VoteTS) / 86400)), 0.0),
PackageVote.VoteTS.isnot(None)) )
.select_from(PackageVote)
.filter(
and_(
PackageVote.PackageBaseID == PackageBase.ID,
PackageVote.VoteTS.isnot(None),
)
)
) )
with db.begin(): with db.begin():
@ -42,14 +49,16 @@ def run_variable(pkgbases: list[PackageBase] = []) -> None:
ids = {pkgbase.ID for pkgbase in pkgbases} ids = {pkgbase.ID for pkgbase in pkgbases}
query = query.filter(PackageBase.ID.in_(ids)) query = query.filter(PackageBase.ID.in_(ids))
query.update({ query.update(
"NumVotes": votes_subq.scalar_subquery(), {
"Popularity": pop_subq.scalar_subquery() "NumVotes": votes_subq.scalar_subquery(),
}) "Popularity": pop_subq.scalar_subquery(),
}
)
def run_single(pkgbase: PackageBase) -> None: def run_single(pkgbase: PackageBase) -> None:
""" A single popupdate. The given pkgbase instance will be """A single popupdate. The given pkgbase instance will be
refreshed after the database update is done. refreshed after the database update is done.
NOTE: This function is compatible only with aurweb FastAPI. NOTE: This function is compatible only with aurweb FastAPI.
@ -65,5 +74,5 @@ def main():
run_variable() run_variable()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -1,7 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys import sys
from urllib.parse import quote_plus from urllib.parse import quote_plus
from xml.etree.ElementTree import Element from xml.etree.ElementTree import Element
@ -10,7 +9,6 @@ import markdown
import pygit2 import pygit2
import aurweb.config import aurweb.config
from aurweb import db, logging, util from aurweb import db, logging, util
from aurweb.models import PackageComment from aurweb.models import PackageComment
@ -25,13 +23,15 @@ class LinkifyExtension(markdown.extensions.Extension):
# Captures http(s) and ftp URLs until the first non URL-ish character. # Captures http(s) and ftp URLs until the first non URL-ish character.
# Excludes trailing punctuation. # Excludes trailing punctuation.
_urlre = (r'(\b(?:https?|ftp):\/\/[\w\/\#~:.?+=&%@!\-;,]+?' _urlre = (
r'(?=[.:?\-;,]*(?:[^\w\/\#~:.?+=&%@!\-;,]|$)))') r"(\b(?:https?|ftp):\/\/[\w\/\#~:.?+=&%@!\-;,]+?"
r"(?=[.:?\-;,]*(?:[^\w\/\#~:.?+=&%@!\-;,]|$)))"
)
def extendMarkdown(self, md): def extendMarkdown(self, md):
processor = markdown.inlinepatterns.AutolinkInlineProcessor(self._urlre, md) processor = markdown.inlinepatterns.AutolinkInlineProcessor(self._urlre, md)
# Register it right after the default <>-link processor (priority 120). # Register it right after the default <>-link processor (priority 120).
md.inlinePatterns.register(processor, 'linkify', 119) md.inlinePatterns.register(processor, "linkify", 119)
class FlysprayLinksInlineProcessor(markdown.inlinepatterns.InlineProcessor): class FlysprayLinksInlineProcessor(markdown.inlinepatterns.InlineProcessor):
@ -43,16 +43,16 @@ class FlysprayLinksInlineProcessor(markdown.inlinepatterns.InlineProcessor):
""" """
def handleMatch(self, m, data): def handleMatch(self, m, data):
el = Element('a') el = Element("a")
el.set('href', f'https://bugs.archlinux.org/task/{m.group(1)}') el.set("href", f"https://bugs.archlinux.org/task/{m.group(1)}")
el.text = markdown.util.AtomicString(m.group(0)) el.text = markdown.util.AtomicString(m.group(0))
return (el, m.start(0), m.end(0)) return (el, m.start(0), m.end(0))
class FlysprayLinksExtension(markdown.extensions.Extension): class FlysprayLinksExtension(markdown.extensions.Extension):
def extendMarkdown(self, md): def extendMarkdown(self, md):
processor = FlysprayLinksInlineProcessor(r'\bFS#(\d+)\b', md) processor = FlysprayLinksInlineProcessor(r"\bFS#(\d+)\b", md)
md.inlinePatterns.register(processor, 'flyspray-links', 118) md.inlinePatterns.register(processor, "flyspray-links", 118)
class GitCommitsInlineProcessor(markdown.inlinepatterns.InlineProcessor): class GitCommitsInlineProcessor(markdown.inlinepatterns.InlineProcessor):
@ -65,10 +65,10 @@ class GitCommitsInlineProcessor(markdown.inlinepatterns.InlineProcessor):
""" """
def __init__(self, md, head): def __init__(self, md, head):
repo_path = aurweb.config.get('serve', 'repo-path') repo_path = aurweb.config.get("serve", "repo-path")
self._repo = pygit2.Repository(repo_path) self._repo = pygit2.Repository(repo_path)
self._head = head self._head = head
super().__init__(r'\b([0-9a-f]{7,40})\b', md) super().__init__(r"\b([0-9a-f]{7,40})\b", md)
def handleMatch(self, m, data): def handleMatch(self, m, data):
oid = m.group(1) oid = m.group(1)
@ -76,13 +76,12 @@ class GitCommitsInlineProcessor(markdown.inlinepatterns.InlineProcessor):
# Unknown OID; preserve the orginal text. # Unknown OID; preserve the orginal text.
return (None, None, None) return (None, None, None)
el = Element('a') el = Element("a")
commit_uri = aurweb.config.get("options", "commit_uri") commit_uri = aurweb.config.get("options", "commit_uri")
prefixlen = util.git_search(self._repo, oid) prefixlen = util.git_search(self._repo, oid)
el.set('href', commit_uri % ( el.set(
quote_plus(self._head), "href", commit_uri % (quote_plus(self._head), quote_plus(oid[:prefixlen]))
quote_plus(oid[:prefixlen]) )
))
el.text = markdown.util.AtomicString(oid[:prefixlen]) el.text = markdown.util.AtomicString(oid[:prefixlen])
return (el, m.start(0), m.end(0)) return (el, m.start(0), m.end(0))
@ -97,7 +96,7 @@ class GitCommitsExtension(markdown.extensions.Extension):
def extendMarkdown(self, md): def extendMarkdown(self, md):
try: try:
processor = GitCommitsInlineProcessor(md, self._head) processor = GitCommitsInlineProcessor(md, self._head)
md.inlinePatterns.register(processor, 'git-commits', 117) md.inlinePatterns.register(processor, "git-commits", 117)
except pygit2.GitError: except pygit2.GitError:
logger.error(f"No git repository found for '{self._head}'.") logger.error(f"No git repository found for '{self._head}'.")
@ -105,16 +104,16 @@ class GitCommitsExtension(markdown.extensions.Extension):
class HeadingTreeprocessor(markdown.treeprocessors.Treeprocessor): class HeadingTreeprocessor(markdown.treeprocessors.Treeprocessor):
def run(self, doc): def run(self, doc):
for elem in doc: for elem in doc:
if elem.tag == 'h1': if elem.tag == "h1":
elem.tag = 'h5' elem.tag = "h5"
elif elem.tag in ['h2', 'h3', 'h4', 'h5']: elif elem.tag in ["h2", "h3", "h4", "h5"]:
elem.tag = 'h6' elem.tag = "h6"
class HeadingExtension(markdown.extensions.Extension): class HeadingExtension(markdown.extensions.Extension):
def extendMarkdown(self, md): def extendMarkdown(self, md):
# Priority doesn't matter since we don't conflict with other processors. # Priority doesn't matter since we don't conflict with other processors.
md.treeprocessors.register(HeadingTreeprocessor(md), 'heading', 30) md.treeprocessors.register(HeadingTreeprocessor(md), "heading", 30)
def save_rendered_comment(comment: PackageComment, html: str): def save_rendered_comment(comment: PackageComment, html: str):
@ -130,16 +129,26 @@ def update_comment_render(comment: PackageComment) -> None:
text = comment.Comments text = comment.Comments
pkgbasename = comment.PackageBase.Name pkgbasename = comment.PackageBase.Name
html = markdown.markdown(text, extensions=[ html = markdown.markdown(
'fenced_code', text,
LinkifyExtension(), extensions=[
FlysprayLinksExtension(), "fenced_code",
GitCommitsExtension(pkgbasename), LinkifyExtension(),
HeadingExtension() FlysprayLinksExtension(),
]) GitCommitsExtension(pkgbasename),
HeadingExtension(),
],
)
allowed_tags = (bleach.sanitizer.ALLOWED_TAGS allowed_tags = bleach.sanitizer.ALLOWED_TAGS + [
+ ['p', 'pre', 'h4', 'h5', 'h6', 'br', 'hr']) "p",
"pre",
"h4",
"h5",
"h6",
"br",
"hr",
]
html = bleach.clean(html, tags=allowed_tags) html = bleach.clean(html, tags=allowed_tags)
save_rendered_comment(comment, html) save_rendered_comment(comment, html)
db.refresh(comment) db.refresh(comment)
@ -148,11 +157,9 @@ def update_comment_render(comment: PackageComment) -> None:
def main(): def main():
db.get_engine() db.get_engine()
comment_id = int(sys.argv[1]) comment_id = int(sys.argv[1])
comment = db.query(PackageComment).filter( comment = db.query(PackageComment).filter(PackageComment.ID == comment_id).first()
PackageComment.ID == comment_id
).first()
update_comment_render(comment) update_comment_render(comment)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -3,12 +3,11 @@
from sqlalchemy import and_ from sqlalchemy import and_
import aurweb.config import aurweb.config
from aurweb import db, time from aurweb import db, time
from aurweb.models import TUVoteInfo from aurweb.models import TUVoteInfo
from aurweb.scripts import notify from aurweb.scripts import notify
notify_cmd = aurweb.config.get('notifications', 'notify-cmd') notify_cmd = aurweb.config.get("notifications", "notify-cmd")
def main(): def main():
@ -23,13 +22,12 @@ def main():
filter_to = now + end filter_to = now + end
query = db.query(TUVoteInfo.ID).filter( query = db.query(TUVoteInfo.ID).filter(
and_(TUVoteInfo.End >= filter_from, and_(TUVoteInfo.End >= filter_from, TUVoteInfo.End <= filter_to)
TUVoteInfo.End <= filter_to)
) )
for voteinfo in query: for voteinfo in query:
notif = notify.TUVoteReminderNotification(voteinfo.ID) notif = notify.TUVoteReminderNotification(voteinfo.ID)
notif.send() notif.send()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -9,14 +9,16 @@ from aurweb.models import User
def _main(): def _main():
limit_to = time.utcnow() - 86400 * 7 limit_to = time.utcnow() - 86400 * 7
update_ = update(User).where( update_ = (
User.LastLogin < limit_to update(User).where(User.LastLogin < limit_to).values(LastLoginIPAddress=None)
).values(LastLoginIPAddress=None) )
db.get_session().execute(update_) db.get_session().execute(update_)
update_ = update(User).where( update_ = (
User.LastSSHLogin < limit_to update(User)
).values(LastSSHLoginIPAddress=None) .where(User.LastSSHLogin < limit_to)
.values(LastSSHLoginIPAddress=None)
)
db.get_session().execute(update_) db.get_session().execute(update_)
@ -26,5 +28,5 @@ def main():
_main() _main()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -16,18 +16,16 @@ import subprocess
import sys import sys
import tempfile import tempfile
import time import time
from typing import Iterable from typing import Iterable
import aurweb.config import aurweb.config
import aurweb.schema import aurweb.schema
from aurweb.exceptions import AurwebException from aurweb.exceptions import AurwebException
children = [] children = []
temporary_dir = None temporary_dir = None
verbosity = 0 verbosity = 0
asgi_backend = '' asgi_backend = ""
workers = 1 workers = 1
PHP_BINARY = os.environ.get("PHP_BINARY", "php") PHP_BINARY = os.environ.get("PHP_BINARY", "php")
@ -60,22 +58,21 @@ def validate_php_config() -> None:
:return: None :return: None
""" """
try: try:
proc = subprocess.Popen([PHP_BINARY, "-m"], proc = subprocess.Popen(
stdout=subprocess.PIPE, [PHP_BINARY, "-m"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
stderr=subprocess.PIPE) )
out, _ = proc.communicate() out, _ = proc.communicate()
except FileNotFoundError: except FileNotFoundError:
raise AurwebException(f"Unable to locate the '{PHP_BINARY}' " raise AurwebException(f"Unable to locate the '{PHP_BINARY}' " "executable.")
"executable.")
assert proc.returncode == 0, ("Received non-zero error code " assert proc.returncode == 0, (
f"{proc.returncode} from '{PHP_BINARY}'.") "Received non-zero error code " f"{proc.returncode} from '{PHP_BINARY}'."
)
modules = out.decode().splitlines() modules = out.decode().splitlines()
for module in PHP_MODULES: for module in PHP_MODULES:
if module not in modules: if module not in modules:
raise AurwebException( raise AurwebException(f"PHP does not have the '{module}' module enabled.")
f"PHP does not have the '{module}' module enabled.")
def generate_nginx_config(): def generate_nginx_config():
@ -91,7 +88,8 @@ def generate_nginx_config():
config_path = os.path.join(temporary_dir, "nginx.conf") config_path = os.path.join(temporary_dir, "nginx.conf")
config = open(config_path, "w") config = open(config_path, "w")
# We double nginx's braces because they conflict with Python's f-strings. # We double nginx's braces because they conflict with Python's f-strings.
config.write(f""" config.write(
f"""
events {{}} events {{}}
daemon off; daemon off;
error_log /dev/stderr info; error_log /dev/stderr info;
@ -124,7 +122,8 @@ def generate_nginx_config():
}} }}
}} }}
}} }}
""") """
)
return config_path return config_path
@ -146,20 +145,23 @@ def start():
return return
atexit.register(stop) atexit.register(stop)
if 'AUR_CONFIG' in os.environ: if "AUR_CONFIG" in os.environ:
os.environ['AUR_CONFIG'] = os.path.realpath(os.environ['AUR_CONFIG']) os.environ["AUR_CONFIG"] = os.path.realpath(os.environ["AUR_CONFIG"])
try: try:
terminal_width = os.get_terminal_size().columns terminal_width = os.get_terminal_size().columns
except OSError: except OSError:
terminal_width = 80 terminal_width = 80
print("{ruler}\n" print(
"Spawing PHP and FastAPI, then nginx as a reverse proxy.\n" "{ruler}\n"
"Check out {aur_location}\n" "Spawing PHP and FastAPI, then nginx as a reverse proxy.\n"
"Hit ^C to terminate everything.\n" "Check out {aur_location}\n"
"{ruler}" "Hit ^C to terminate everything.\n"
.format(ruler=("-" * terminal_width), "{ruler}".format(
aur_location=aurweb.config.get('options', 'aur_location'))) ruler=("-" * terminal_width),
aur_location=aurweb.config.get("options", "aur_location"),
)
)
# PHP # PHP
php_address = aurweb.config.get("php", "bind_address") php_address = aurweb.config.get("php", "bind_address")
@ -168,8 +170,9 @@ def start():
spawn_child(["php", "-S", php_address, "-t", htmldir]) spawn_child(["php", "-S", php_address, "-t", htmldir])
# FastAPI # FastAPI
fastapi_host, fastapi_port = aurweb.config.get( fastapi_host, fastapi_port = aurweb.config.get("fastapi", "bind_address").rsplit(
"fastapi", "bind_address").rsplit(":", 1) ":", 1
)
# Logging config. # Logging config.
aurwebdir = aurweb.config.get("options", "aurwebdir") aurwebdir = aurweb.config.get("options", "aurwebdir")
@ -178,20 +181,33 @@ def start():
backend_args = { backend_args = {
"hypercorn": ["-b", f"{fastapi_host}:{fastapi_port}"], "hypercorn": ["-b", f"{fastapi_host}:{fastapi_port}"],
"uvicorn": ["--host", fastapi_host, "--port", fastapi_port], "uvicorn": ["--host", fastapi_host, "--port", fastapi_port],
"gunicorn": ["--bind", f"{fastapi_host}:{fastapi_port}", "gunicorn": [
"-k", "uvicorn.workers.UvicornWorker", "--bind",
"-w", str(workers)] f"{fastapi_host}:{fastapi_port}",
"-k",
"uvicorn.workers.UvicornWorker",
"-w",
str(workers),
],
} }
backend_args = backend_args.get(asgi_backend) backend_args = backend_args.get(asgi_backend)
spawn_child([ spawn_child(
"python", "-m", asgi_backend, [
"--log-config", fastapi_log_config, "python",
] + backend_args + ["aurweb.asgi:app"]) "-m",
asgi_backend,
"--log-config",
fastapi_log_config,
]
+ backend_args
+ ["aurweb.asgi:app"]
)
# nginx # nginx
spawn_child(["nginx", "-p", temporary_dir, "-c", generate_nginx_config()]) spawn_child(["nginx", "-p", temporary_dir, "-c", generate_nginx_config()])
print(f""" print(
f"""
> Started nginx. > Started nginx.
> >
> PHP backend: http://{php_address} > PHP backend: http://{php_address}
@ -201,11 +217,13 @@ def start():
> FastAPI frontend: http://{fastapi_host}:{FASTAPI_NGINX_PORT} > FastAPI frontend: http://{fastapi_host}:{FASTAPI_NGINX_PORT}
> >
> Frontends are hosted via nginx and should be preferred. > Frontends are hosted via nginx and should be preferred.
""") """
)
def _kill_children(children: Iterable, exceptions: list[Exception] = []) \ def _kill_children(
-> list[Exception]: children: Iterable, exceptions: list[Exception] = []
) -> list[Exception]:
""" """
Kill each process found in `children`. Kill each process found in `children`.
@ -223,8 +241,9 @@ def _kill_children(children: Iterable, exceptions: list[Exception] = []) \
return exceptions return exceptions
def _wait_for_children(children: Iterable, exceptions: list[Exception] = []) \ def _wait_for_children(
-> list[Exception]: children: Iterable, exceptions: list[Exception] = []
) -> list[Exception]:
""" """
Wait for each process to end found in `children`. Wait for each process to end found in `children`.
@ -261,21 +280,31 @@ def stop() -> None:
exceptions = _wait_for_children(children, exceptions) exceptions = _wait_for_children(children, exceptions)
children = [] children = []
if exceptions: if exceptions:
raise ProcessExceptions("Errors terminating the child processes:", raise ProcessExceptions("Errors terminating the child processes:", exceptions)
exceptions)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='python -m aurweb.spawn', prog="python -m aurweb.spawn", description="Start aurweb's test server."
description='Start aurweb\'s test server.') )
parser.add_argument('-v', '--verbose', action='count', default=0, parser.add_argument(
help='increase verbosity') "-v", "--verbose", action="count", default=0, help="increase verbosity"
choices = ['hypercorn', 'gunicorn', 'uvicorn'] )
parser.add_argument('-b', '--backend', choices=choices, default='uvicorn', choices = ["hypercorn", "gunicorn", "uvicorn"]
help='asgi backend used to launch the python server') parser.add_argument(
parser.add_argument("-w", "--workers", default=1, type=int, "-b",
help="number of workers to use in gunicorn") "--backend",
choices=choices,
default="uvicorn",
help="asgi backend used to launch the python server",
)
parser.add_argument(
"-w",
"--workers",
default=1,
type=int,
help="number of workers to use in gunicorn",
)
args = parser.parse_args() args = parser.parse_args()
try: try:

View file

@ -1,28 +1,27 @@
import copy import copy
import functools import functools
import os import os
from http import HTTPStatus from http import HTTPStatus
from typing import Callable from typing import Callable
import jinja2 import jinja2
from fastapi import Request from fastapi import Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
import aurweb.config import aurweb.config
from aurweb import cookies, l10n, time from aurweb import cookies, l10n, time
# Prepare jinja2 objects. # Prepare jinja2 objects.
_loader = jinja2.FileSystemLoader(os.path.join( _loader = jinja2.FileSystemLoader(
aurweb.config.get("options", "aurwebdir"), "templates")) os.path.join(aurweb.config.get("options", "aurwebdir"), "templates")
_env = jinja2.Environment(loader=_loader, autoescape=True, )
extensions=["jinja2.ext.i18n"]) _env = jinja2.Environment(
loader=_loader, autoescape=True, extensions=["jinja2.ext.i18n"]
)
def register_filter(name: str) -> Callable: def register_filter(name: str) -> Callable:
""" A decorator that can be used to register a filter. """A decorator that can be used to register a filter.
Example Example
@register_filter("some_filter") @register_filter("some_filter")
@ -35,31 +34,36 @@ def register_filter(name: str) -> Callable:
:param name: Filter name :param name: Filter name
:return: Callable used for filter :return: Callable used for filter
""" """
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
_env.filters[name] = wrapper _env.filters[name] = wrapper
return wrapper return wrapper
return decorator return decorator
def register_function(name: str) -> Callable: def register_function(name: str) -> Callable:
""" A decorator that can be used to register a function. """A decorator that can be used to register a function."""
"""
def decorator(func): def decorator(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
if name in _env.globals: if name in _env.globals:
raise KeyError(f"Jinja already has a function named '{name}'") raise KeyError(f"Jinja already has a function named '{name}'")
_env.globals[name] = wrapper _env.globals[name] = wrapper
return wrapper return wrapper
return decorator return decorator
def make_context(request: Request, title: str, next: str = None): def make_context(request: Request, title: str, next: str = None):
""" Create a context for a jinja2 TemplateResponse. """ """Create a context for a jinja2 TemplateResponse."""
import aurweb.auth.creds import aurweb.auth.creds
commit_url = aurweb.config.get_with_fallback("devel", "commit_url", None) commit_url = aurweb.config.get_with_fallback("devel", "commit_url", None)
@ -85,17 +89,19 @@ def make_context(request: Request, title: str, next: str = None):
"config": aurweb.config, "config": aurweb.config,
"creds": aurweb.auth.creds, "creds": aurweb.auth.creds,
"next": next if next else request.url.path, "next": next if next else request.url.path,
"version": os.environ.get("COMMIT_HASH", aurweb.config.AURWEB_VERSION) "version": os.environ.get("COMMIT_HASH", aurweb.config.AURWEB_VERSION),
} }
async def make_variable_context(request: Request, title: str, next: str = None): async def make_variable_context(request: Request, title: str, next: str = None):
""" Make a context with variables provided by the user """Make a context with variables provided by the user
(query params via GET or form data via POST). """ (query params via GET or form data via POST)."""
context = make_context(request, title, next) context = make_context(request, title, next)
to_copy = dict(request.query_params) \ to_copy = (
if request.method.lower() == "get" \ dict(request.query_params)
if request.method.lower() == "get"
else dict(await request.form()) else dict(await request.form())
)
for k, v in to_copy.items(): for k, v in to_copy.items():
context[k] = v context[k] = v
@ -111,7 +117,7 @@ def base_template(path: str):
def render_raw_template(request: Request, path: str, context: dict): def render_raw_template(request: Request, path: str, context: dict):
""" Render a Jinja2 multi-lingual template with some context. """ """Render a Jinja2 multi-lingual template with some context."""
# Create a deep copy of our jinja2 _environment. The _environment in # Create a deep copy of our jinja2 _environment. The _environment in
# total by itself is 48 bytes large (according to sys.getsizeof). # total by itself is 48 bytes large (according to sys.getsizeof).
# This is done so we can install gettext translations on the template # This is done so we can install gettext translations on the template
@ -126,11 +132,10 @@ def render_raw_template(request: Request, path: str, context: dict):
return template.render(context) return template.render(context)
def render_template(request: Request, def render_template(
path: str, request: Request, path: str, context: dict, status_code: HTTPStatus = HTTPStatus.OK
context: dict, ):
status_code: HTTPStatus = HTTPStatus.OK): """Render a template as an HTMLResponse."""
""" Render a template as an HTMLResponse. """
rendered = render_raw_template(request, path, context) rendered = render_raw_template(request, path, context)
response = HTMLResponse(rendered, status_code=int(status_code)) response = HTMLResponse(rendered, status_code=int(status_code))

View file

@ -1,10 +1,9 @@
import aurweb.db import aurweb.db
from aurweb import models from aurweb import models
def setup_test_db(*args): def setup_test_db(*args):
""" This function is to be used to setup a test database before """This function is to be used to setup a test database before
using it. It takes a variable number of table strings, and for using it. It takes a variable number of table strings, and for
each table in that set of table strings, it deletes all records. each table in that set of table strings, it deletes all records.

View file

@ -17,6 +17,7 @@ class AlpmDatabase:
This class can be used to add or remove packages from a This class can be used to add or remove packages from a
test repository. test repository.
""" """
repo = "test" repo = "test"
def __init__(self, database_root: str): def __init__(self, database_root: str):
@ -35,13 +36,14 @@ class AlpmDatabase:
os.makedirs(pkgdir) os.makedirs(pkgdir)
return pkgdir return pkgdir
def add(self, pkgname: str, pkgver: str, arch: str, def add(
provides: list[str] = []) -> None: self, pkgname: str, pkgver: str, arch: str, provides: list[str] = []
) -> None:
context = { context = {
"pkgname": pkgname, "pkgname": pkgname,
"pkgver": pkgver, "pkgver": pkgver,
"arch": arch, "arch": arch,
"provides": provides "provides": provides,
} }
template = base_template("testing/alpm_package.j2") template = base_template("testing/alpm_package.j2")
pkgdir = self._get_pkgdir(pkgname, pkgver, self.repo) pkgdir = self._get_pkgdir(pkgname, pkgver, self.repo)
@ -76,8 +78,9 @@ class AlpmDatabase:
self.clean() self.clean()
cmdline = ["bash", "-c", "bsdtar -czvf ../test.db *"] cmdline = ["bash", "-c", "bsdtar -czvf ../test.db *"]
proc = subprocess.run(cmdline, cwd=self.repopath) proc = subprocess.run(cmdline, cwd=self.repopath)
assert proc.returncode == 0, \ assert (
f"Bad return code while creating alpm database: {proc.returncode}" proc.returncode == 0
), f"Bad return code while creating alpm database: {proc.returncode}"
# Print out the md5 hash value of the new test.db. # Print out the md5 hash value of the new test.db.
test_db = os.path.join(self.remote, "test.db") test_db = os.path.join(self.remote, "test.db")

View file

@ -5,7 +5,6 @@ import email
import os import os
import re import re
import sys import sys
from typing import TextIO from typing import TextIO
@ -28,6 +27,7 @@ class Email:
print(email.headers) print(email.headers)
""" """
TEST_DIR = "test-emails" TEST_DIR = "test-emails"
def __init__(self, serial: int = 1, autoparse: bool = True): def __init__(self, serial: int = 1, autoparse: bool = True):
@ -61,7 +61,7 @@ class Email:
value = os.environ.get("PYTEST_CURRENT_TEST", "email").split(" ")[0] value = os.environ.get("PYTEST_CURRENT_TEST", "email").split(" ")[0]
if suite: if suite:
value = value.split(":")[0] value = value.split(":")[0]
return re.sub(r'(\/|\.|,|:)', "_", value) return re.sub(r"(\/|\.|,|:)", "_", value)
@staticmethod @staticmethod
def count() -> int: def count() -> int:
@ -159,6 +159,6 @@ class Email:
lines += [ lines += [
f"== Email #{i + 1} ==", f"== Email #{i + 1} ==",
email.glue(), email.glue(),
f"== End of Email #{i + 1}" f"== End of Email #{i + 1}",
] ]
print("\n".join(lines), file=file) print("\n".join(lines), file=file)

View file

@ -1,6 +1,5 @@
import hashlib import hashlib
import os import os
from typing import Callable from typing import Callable
from posix_ipc import O_CREAT, Semaphore from posix_ipc import O_CREAT, Semaphore

View file

@ -1,6 +1,5 @@
import os import os
import shlex import shlex
from subprocess import PIPE, Popen from subprocess import PIPE, Popen
from typing import Tuple from typing import Tuple

View file

@ -6,7 +6,7 @@ parser = etree.HTMLParser()
def parse_root(html: str) -> etree.Element: def parse_root(html: str) -> etree.Element:
""" Parse an lxml.etree.ElementTree root from html content. """Parse an lxml.etree.ElementTree root from html content.
:param html: HTML markup :param html: HTML markup
:return: etree.Element :return: etree.Element

View file

@ -2,7 +2,8 @@ import aurweb.config
class User: class User:
""" A fake User model. """ """A fake User model."""
# Fake columns. # Fake columns.
LangPreference = aurweb.config.get("options", "default_lang") LangPreference = aurweb.config.get("options", "default_lang")
Timezone = aurweb.config.get("options", "default_timezone") Timezone = aurweb.config.get("options", "default_timezone")
@ -15,7 +16,8 @@ class User:
class Client: class Client:
""" A fake FastAPI Request.client object. """ """A fake FastAPI Request.client object."""
# A fake host. # A fake host.
host = "127.0.0.1" host = "127.0.0.1"
@ -25,16 +27,19 @@ class URL:
class Request: class Request:
""" A fake Request object which mimics a FastAPI Request for tests. """ """A fake Request object which mimics a FastAPI Request for tests."""
client = Client() client = Client()
url = URL() url = URL()
def __init__(self, def __init__(
user: User = User(), self,
authenticated: bool = False, user: User = User(),
method: str = "GET", authenticated: bool = False,
headers: dict[str, str] = dict(), method: str = "GET",
cookies: dict[str, str] = dict()) -> "Request": headers: dict[str, str] = dict(),
cookies: dict[str, str] = dict(),
) -> "Request":
self.user = user self.user = user
self.user.authenticated = authenticated self.user.authenticated = authenticated

View file

@ -2,7 +2,7 @@
class FakeSMTP: class FakeSMTP:
""" A fake version of smtplib.SMTP used for testing. """ """A fake version of smtplib.SMTP used for testing."""
starttls_enabled = False starttls_enabled = False
use_ssl = False use_ssl = False
@ -41,5 +41,6 @@ class FakeSMTP:
class FakeSMTP_SSL(FakeSMTP): class FakeSMTP_SSL(FakeSMTP):
""" A fake version of smtplib.SMTP_SSL used for testing. """ """A fake version of smtplib.SMTP_SSL used for testing."""
use_ssl = True use_ssl = True

View file

@ -1,5 +1,4 @@
import zoneinfo import zoneinfo
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime from datetime import datetime
from urllib.parse import unquote from urllib.parse import unquote
@ -11,7 +10,7 @@ import aurweb.config
def tz_offset(name: str): def tz_offset(name: str):
""" Get a timezone offset in the form "+00:00" by its name. """Get a timezone offset in the form "+00:00" by its name.
Example: tz_offset('America/Los_Angeles') Example: tz_offset('America/Los_Angeles')
@ -24,7 +23,7 @@ def tz_offset(name: str):
offset = dt.utcoffset().total_seconds() / 60 / 60 offset = dt.utcoffset().total_seconds() / 60 / 60
# Prefix the offset string with a - or +. # Prefix the offset string with a - or +.
offset_string = '-' if offset < 0 else '+' offset_string = "-" if offset < 0 else "+"
# Remove any negativity from the offset. We want a good offset. :) # Remove any negativity from the offset. We want a good offset. :)
offset = abs(offset) offset = abs(offset)
@ -42,19 +41,25 @@ def tz_offset(name: str):
return offset_string return offset_string
SUPPORTED_TIMEZONES = OrderedDict({ SUPPORTED_TIMEZONES = OrderedDict(
# Flatten out the list of tuples into an OrderedDict. {
timezone: offset for timezone, offset in sorted([ # Flatten out the list of tuples into an OrderedDict.
# Comprehend a list of tuples (timezone, offset display string) timezone: offset
# and sort them by (offset, timezone). for timezone, offset in sorted(
(tz, "(UTC%s) %s" % (tz_offset(tz), tz)) [
for tz in zoneinfo.available_timezones() # Comprehend a list of tuples (timezone, offset display string)
], key=lambda element: (tz_offset(element[0]), element[0])) # and sort them by (offset, timezone).
}) (tz, "(UTC%s) %s" % (tz_offset(tz), tz))
for tz in zoneinfo.available_timezones()
],
key=lambda element: (tz_offset(element[0]), element[0]),
)
}
)
def get_request_timezone(request: Request): def get_request_timezone(request: Request):
""" Get a request's timezone by its AURTZ cookie. We use the """Get a request's timezone by its AURTZ cookie. We use the
configuration's [options] default_timezone otherwise. configuration's [options] default_timezone otherwise.
@param request FastAPI request @param request FastAPI request

View file

@ -8,12 +8,23 @@ from aurweb.models.ssh_pub_key import get_fingerprint
from aurweb.util import strtobool from aurweb.util import strtobool
def simple(U: str = str(), E: str = str(), H: bool = False, def simple(
BE: str = str(), R: str = str(), HP: str = str(), U: str = str(),
I: str = str(), K: str = str(), J: bool = False, E: str = str(),
CN: bool = False, UN: bool = False, ON: bool = False, H: bool = False,
S: bool = False, user: models.User = None, BE: str = str(),
**kwargs) -> None: R: str = str(),
HP: str = str(),
I: str = str(),
K: str = str(),
J: bool = False,
CN: bool = False,
UN: bool = False,
ON: bool = False,
S: bool = False,
user: models.User = None,
**kwargs,
) -> None:
now = time.utcnow() now = time.utcnow()
with db.begin(): with db.begin():
user.Username = U or user.Username user.Username = U or user.Username
@ -31,22 +42,26 @@ def simple(U: str = str(), E: str = str(), H: bool = False,
user.OwnershipNotify = strtobool(ON) user.OwnershipNotify = strtobool(ON)
def language(L: str = str(), def language(
request: Request = None, L: str = str(),
user: models.User = None, request: Request = None,
context: dict[str, Any] = {}, user: models.User = None,
**kwargs) -> None: context: dict[str, Any] = {},
**kwargs,
) -> None:
if L and L != user.LangPreference: if L and L != user.LangPreference:
with db.begin(): with db.begin():
user.LangPreference = L user.LangPreference = L
context["language"] = L context["language"] = L
def timezone(TZ: str = str(), def timezone(
request: Request = None, TZ: str = str(),
user: models.User = None, request: Request = None,
context: dict[str, Any] = {}, user: models.User = None,
**kwargs) -> None: context: dict[str, Any] = {},
**kwargs,
) -> None:
if TZ and TZ != user.Timezone: if TZ and TZ != user.Timezone:
with db.begin(): with db.begin():
user.Timezone = TZ user.Timezone = TZ
@ -67,8 +82,7 @@ def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None:
with db.begin(): with db.begin():
# Delete any existing keys we can't find. # Delete any existing keys we can't find.
to_remove = user.ssh_pub_keys.filter( to_remove = user.ssh_pub_keys.filter(~SSHPubKey.Fingerprint.in_(fprints))
~SSHPubKey.Fingerprint.in_(fprints))
db.delete_all(to_remove) db.delete_all(to_remove)
# For each key, if it does not yet exist, create it. # For each key, if it does not yet exist, create it.
@ -79,24 +93,27 @@ def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None:
).exists() ).exists()
if not db.query(exists).scalar(): if not db.query(exists).scalar():
# No public key exists, create one. # No public key exists, create one.
db.create(models.SSHPubKey, UserID=user.ID, db.create(
PubKey=" ".join([prefix, key]), models.SSHPubKey,
Fingerprint=fprints[i]) UserID=user.ID,
PubKey=" ".join([prefix, key]),
Fingerprint=fprints[i],
)
def account_type(T: int = None, def account_type(T: int = None, user: models.User = None, **kwargs) -> None:
user: models.User = None,
**kwargs) -> None:
if T is not None and (T := int(T)) != user.AccountTypeID: if T is not None and (T := int(T)) != user.AccountTypeID:
with db.begin(): with db.begin():
user.AccountTypeID = T user.AccountTypeID = T
def password(P: str = str(), def password(
request: Request = None, P: str = str(),
user: models.User = None, request: Request = None,
context: dict[str, Any] = {}, user: models.User = None,
**kwargs) -> None: context: dict[str, Any] = {},
**kwargs,
) -> None:
if P and not user.valid_password(P): if P and not user.valid_password(P):
# Remove the fields we consumed for passwords. # Remove the fields we consumed for passwords.
context["P"] = context["C"] = str() context["P"] = context["C"] = str()

View file

@ -25,42 +25,44 @@ def invalid_fields(E: str = str(), U: str = str(), **kwargs) -> None:
raise ValidationError(["Missing a required field."]) raise ValidationError(["Missing a required field."])
def invalid_suspend_permission(request: Request = None, def invalid_suspend_permission(
user: models.User = None, request: Request = None, user: models.User = None, S: str = "False", **kwargs
S: str = "False", ) -> None:
**kwargs) -> None:
if not request.user.is_elevated() and strtobool(S) != bool(user.Suspended): if not request.user.is_elevated() and strtobool(S) != bool(user.Suspended):
raise ValidationError([ raise ValidationError(["You do not have permission to suspend accounts."])
"You do not have permission to suspend accounts."])
def invalid_username(request: Request = None, U: str = str(), def invalid_username(
_: l10n.Translator = None, request: Request = None, U: str = str(), _: l10n.Translator = None, **kwargs
**kwargs) -> None: ) -> None:
if not util.valid_username(U): if not util.valid_username(U):
username_min_len = config.getint("options", "username_min_len") username_min_len = config.getint("options", "username_min_len")
username_max_len = config.getint("options", "username_max_len") username_max_len = config.getint("options", "username_max_len")
raise ValidationError([ raise ValidationError(
"The username is invalid.",
[ [
_("It must be between %s and %s characters long") % ( "The username is invalid.",
username_min_len, username_max_len), [
"Start and end with a letter or number", _("It must be between %s and %s characters long")
"Can contain only one period, underscore or hyphen.", % (username_min_len, username_max_len),
"Start and end with a letter or number",
"Can contain only one period, underscore or hyphen.",
],
] ]
]) )
def invalid_password(P: str = str(), C: str = str(), def invalid_password(
_: l10n.Translator = None, **kwargs) -> None: P: str = str(), C: str = str(), _: l10n.Translator = None, **kwargs
) -> None:
if P: if P:
if not util.valid_password(P): if not util.valid_password(P):
username_min_len = config.getint( username_min_len = config.getint("options", "username_min_len")
"options", "username_min_len") raise ValidationError(
raise ValidationError([ [
_("Your password must be at least %s characters.") % ( _("Your password must be at least %s characters.")
username_min_len) % (username_min_len)
]) ]
)
elif not C: elif not C:
raise ValidationError(["Please confirm your new password."]) raise ValidationError(["Please confirm your new password."])
elif P != C: elif P != C:
@ -71,15 +73,18 @@ def is_banned(request: Request = None, **kwargs) -> None:
host = request.client.host host = request.client.host
exists = db.query(models.Ban, models.Ban.IPAddress == host).exists() exists = db.query(models.Ban, models.Ban.IPAddress == host).exists()
if db.query(exists).scalar(): if db.query(exists).scalar():
raise ValidationError([ raise ValidationError(
"Account registration has been disabled for your " [
"IP address, probably due to sustained spam attacks. " "Account registration has been disabled for your "
"Sorry for the inconvenience." "IP address, probably due to sustained spam attacks. "
]) "Sorry for the inconvenience."
]
)
def invalid_user_password(request: Request = None, passwd: str = str(), def invalid_user_password(
**kwargs) -> None: request: Request = None, passwd: str = str(), **kwargs
) -> None:
if request.user.is_authenticated(): if request.user.is_authenticated():
if not request.user.valid_password(passwd): if not request.user.valid_password(passwd):
raise ValidationError(["Invalid password."]) raise ValidationError(["Invalid password."])
@ -97,8 +102,9 @@ def invalid_backup_email(BE: str = str(), **kwargs) -> None:
def invalid_homepage(HP: str = str(), **kwargs) -> None: def invalid_homepage(HP: str = str(), **kwargs) -> None:
if HP and not util.valid_homepage(HP): if HP and not util.valid_homepage(HP):
raise ValidationError([ raise ValidationError(
"The home page is invalid, please specify the full HTTP(s) URL."]) ["The home page is invalid, please specify the full HTTP(s) URL."]
)
def invalid_pgp_key(K: str = str(), **kwargs) -> None: def invalid_pgp_key(K: str = str(), **kwargs) -> None:
@ -106,8 +112,9 @@ def invalid_pgp_key(K: str = str(), **kwargs) -> None:
raise ValidationError(["The PGP key fingerprint is invalid."]) raise ValidationError(["The PGP key fingerprint is invalid."])
def invalid_ssh_pubkey(PK: str = str(), user: models.User = None, def invalid_ssh_pubkey(
_: l10n.Translator = None, **kwargs) -> None: PK: str = str(), user: models.User = None, _: l10n.Translator = None, **kwargs
) -> None:
if not PK: if not PK:
return return
@ -119,15 +126,23 @@ def invalid_ssh_pubkey(PK: str = str(), user: models.User = None,
for prefix, key in keys: for prefix, key in keys:
fingerprint = get_fingerprint(f"{prefix} {key}") fingerprint = get_fingerprint(f"{prefix} {key}")
exists = db.query(models.SSHPubKey).filter( exists = (
and_(models.SSHPubKey.UserID != user.ID, db.query(models.SSHPubKey)
models.SSHPubKey.Fingerprint == fingerprint) .filter(
).exists() and_(
models.SSHPubKey.UserID != user.ID,
models.SSHPubKey.Fingerprint == fingerprint,
)
)
.exists()
)
if db.query(exists).scalar(): if db.query(exists).scalar():
raise ValidationError([ raise ValidationError(
_("The SSH public key, %s%s%s, is already in use.") % ( [
"<strong>", fingerprint, "</strong>") _("The SSH public key, %s%s%s, is already in use.")
]) % ("<strong>", fingerprint, "</strong>")
]
)
def invalid_language(L: str = str(), **kwargs) -> None: def invalid_language(L: str = str(), **kwargs) -> None:
@ -140,60 +155,78 @@ def invalid_timezone(TZ: str = str(), **kwargs) -> None:
raise ValidationError(["Timezone is not currently supported."]) raise ValidationError(["Timezone is not currently supported."])
def username_in_use(U: str = str(), user: models.User = None, def username_in_use(
_: l10n.Translator = None, **kwargs) -> None: U: str = str(), user: models.User = None, _: l10n.Translator = None, **kwargs
exists = db.query(models.User).filter( ) -> None:
and_(models.User.ID != user.ID, exists = (
models.User.Username == U) db.query(models.User)
).exists() .filter(and_(models.User.ID != user.ID, models.User.Username == U))
.exists()
)
if db.query(exists).scalar(): if db.query(exists).scalar():
# If the username already exists... # If the username already exists...
raise ValidationError([ raise ValidationError(
_("The username, %s%s%s, is already in use.") % ( [
"<strong>", U, "</strong>") _("The username, %s%s%s, is already in use.")
]) % ("<strong>", U, "</strong>")
]
)
def email_in_use(E: str = str(), user: models.User = None, def email_in_use(
_: l10n.Translator = None, **kwargs) -> None: E: str = str(), user: models.User = None, _: l10n.Translator = None, **kwargs
exists = db.query(models.User).filter( ) -> None:
and_(models.User.ID != user.ID, exists = (
models.User.Email == E) db.query(models.User)
).exists() .filter(and_(models.User.ID != user.ID, models.User.Email == E))
.exists()
)
if db.query(exists).scalar(): if db.query(exists).scalar():
# If the email already exists... # If the email already exists...
raise ValidationError([ raise ValidationError(
_("The address, %s%s%s, is already in use.") % ( [
"<strong>", E, "</strong>") _("The address, %s%s%s, is already in use.")
]) % ("<strong>", E, "</strong>")
]
)
def invalid_account_type(T: int = None, request: Request = None, def invalid_account_type(
user: models.User = None, T: int = None,
_: l10n.Translator = None, request: Request = None,
**kwargs) -> None: user: models.User = None,
_: l10n.Translator = None,
**kwargs,
) -> None:
if T is not None and (T := int(T)) != user.AccountTypeID: if T is not None and (T := int(T)) != user.AccountTypeID:
name = ACCOUNT_TYPE_NAME.get(T, None) name = ACCOUNT_TYPE_NAME.get(T, None)
has_cred = request.user.has_credential(creds.ACCOUNT_CHANGE_TYPE) has_cred = request.user.has_credential(creds.ACCOUNT_CHANGE_TYPE)
if name is None: if name is None:
raise ValidationError(["Invalid account type provided."]) raise ValidationError(["Invalid account type provided."])
elif not has_cred: elif not has_cred:
raise ValidationError([ raise ValidationError(
"You do not have permission to change account types."]) ["You do not have permission to change account types."]
)
elif T > request.user.AccountTypeID: elif T > request.user.AccountTypeID:
# If the chosen account type is higher than the editor's account # If the chosen account type is higher than the editor's account
# type, the editor doesn't have permission to set the new type. # type, the editor doesn't have permission to set the new type.
error = _("You do not have permission to change " error = (
"this user's account type to %s.") % name _(
"You do not have permission to change "
"this user's account type to %s."
)
% name
)
raise ValidationError([error]) raise ValidationError([error])
logger.debug(f"Trusted User '{request.user.Username}' has " logger.debug(
f"modified '{user.Username}' account's type to" f"Trusted User '{request.user.Username}' has "
f" {name}.") f"modified '{user.Username}' account's type to"
f" {name}."
)
def invalid_captcha(captcha_salt: str = None, captcha: str = None, def invalid_captcha(captcha_salt: str = None, captcha: str = None, **kwargs) -> None:
**kwargs) -> None:
if captcha_salt and captcha_salt not in get_captcha_salts(): if captcha_salt and captcha_salt not in get_captcha_salts():
raise ValidationError(["This CAPTCHA has expired. Please try again."]) raise ValidationError(["This CAPTCHA has expired. Please try again."])

View file

@ -2,7 +2,6 @@ import math
import re import re
import secrets import secrets
import string import string
from datetime import datetime from datetime import datetime
from http import HTTPStatus from http import HTTPStatus
from subprocess import PIPE, Popen from subprocess import PIPE, Popen
@ -11,12 +10,10 @@ from urllib.parse import urlparse
import fastapi import fastapi
import pygit2 import pygit2
from email_validator import EmailSyntaxError, validate_email from email_validator import EmailSyntaxError, validate_email
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
import aurweb.config import aurweb.config
from aurweb import defaults, logging from aurweb import defaults, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -24,15 +21,15 @@ logger = logging.get_logger(__name__)
def make_random_string(length: int) -> str: def make_random_string(length: int) -> str:
alphanumerics = string.ascii_lowercase + string.digits alphanumerics = string.ascii_lowercase + string.digits
return ''.join([secrets.choice(alphanumerics) for i in range(length)]) return "".join([secrets.choice(alphanumerics) for i in range(length)])
def make_nonce(length: int = 8): def make_nonce(length: int = 8):
""" Generate a single random nonce. Here, token_hex generates a hex """Generate a single random nonce. Here, token_hex generates a hex
string of 2 hex characters per byte, where the length give is string of 2 hex characters per byte, where the length give is
nbytes. This means that to get our proper string length, we need to nbytes. This means that to get our proper string length, we need to
cut it in half and truncate off any remaining (in the case that cut it in half and truncate off any remaining (in the case that
length was uneven). """ length was uneven)."""
return secrets.token_hex(math.ceil(length / 2))[:length] return secrets.token_hex(math.ceil(length / 2))[:length]
@ -45,7 +42,7 @@ def valid_username(username):
# Check that username contains: one or more alphanumeric # Check that username contains: one or more alphanumeric
# characters, an optional separator of '.', '-' or '_', followed # characters, an optional separator of '.', '-' or '_', followed
# by alphanumeric characters. # by alphanumeric characters.
return re.match(r'^[a-zA-Z0-9]+[.\-_]?[a-zA-Z0-9]+$', username) return re.match(r"^[a-zA-Z0-9]+[.\-_]?[a-zA-Z0-9]+$", username)
def valid_email(email): def valid_email(email):
@ -82,7 +79,7 @@ def valid_pgp_fingerprint(fp):
def jsonify(obj): def jsonify(obj):
""" Perform a conversion on obj if it's needed. """ """Perform a conversion on obj if it's needed."""
if isinstance(obj, datetime): if isinstance(obj, datetime):
obj = int(obj.timestamp()) obj = int(obj.timestamp())
return obj return obj
@ -151,8 +148,7 @@ def git_search(repo: pygit2.Repository, commit_hash: str) -> int:
return prefixlen return prefixlen
async def error_or_result(next: Callable, *args, **kwargs) \ async def error_or_result(next: Callable, *args, **kwargs) -> fastapi.Response:
-> fastapi.Response:
""" """
Try to return a response from `next`. Try to return a response from `next`.
@ -174,9 +170,9 @@ async def error_or_result(next: Callable, *args, **kwargs) \
def parse_ssh_key(string: str) -> Tuple[str, str]: def parse_ssh_key(string: str) -> Tuple[str, str]:
""" Parse an SSH public key. """ """Parse an SSH public key."""
invalid_exc = ValueError("The SSH public key is invalid.") invalid_exc = ValueError("The SSH public key is invalid.")
parts = re.sub(r'\s\s+', ' ', string.strip()).split() parts = re.sub(r"\s\s+", " ", string.strip()).split()
if len(parts) < 2: if len(parts) < 2:
raise invalid_exc raise invalid_exc
@ -185,8 +181,7 @@ def parse_ssh_key(string: str) -> Tuple[str, str]:
if prefix not in prefixes: if prefix not in prefixes:
raise invalid_exc raise invalid_exc
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE, proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE, stderr=PIPE)
stderr=PIPE)
out, _ = proc.communicate(f"{prefix} {key}".encode()) out, _ = proc.communicate(f"{prefix} {key}".encode())
if proc.returncode: if proc.returncode:
raise invalid_exc raise invalid_exc
@ -195,5 +190,5 @@ def parse_ssh_key(string: str) -> Tuple[str, str]:
def parse_ssh_keys(string: str) -> list[Tuple[str, str]]: def parse_ssh_keys(string: str) -> list[Tuple[str, str]]:
""" Parse a list of SSH public keys. """ """Parse a list of SSH public keys."""
return [parse_ssh_key(e) for e in string.splitlines()] return [parse_ssh_key(e) for e in string.splitlines()]

View file

@ -108,4 +108,3 @@ The following list of steps describes exactly how this verification works:
- `options.disable_http_login: 1` - `options.disable_http_login: 1`
- `options.login_timeout: <default_provided_in_config.defaults>` - `options.login_timeout: <default_provided_in_config.defaults>`
- `options.persistent_cookie_timeout: <default_provided_in_config.defaults>` - `options.persistent_cookie_timeout: <default_provided_in_config.defaults>`

View file

@ -147,4 +147,3 @@ http {
'' close; '' close;
} }
} }

View file

@ -2,7 +2,6 @@ import logging
import logging.config import logging.config
import sqlalchemy import sqlalchemy
from alembic import context from alembic import context
import aurweb.db import aurweb.db
@ -69,9 +68,7 @@ def run_migrations_online():
) )
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(connection=connection, target_metadata=target_metadata)
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

View file

@ -15,27 +15,26 @@ import os
import random import random
import sys import sys
import time import time
from datetime import datetime from datetime import datetime
import bcrypt import bcrypt
LOG_LEVEL = logging.DEBUG # logging level. set to logging.INFO to reduce output LOG_LEVEL = logging.DEBUG # logging level. set to logging.INFO to reduce output
SEED_FILE = "/usr/share/dict/words" SEED_FILE = "/usr/share/dict/words"
USER_ID = 5 # Users.ID of first bogus user USER_ID = 5 # Users.ID of first bogus user
PKG_ID = 1 # Packages.ID of first package PKG_ID = 1 # Packages.ID of first package
# how many users to 'register' # how many users to 'register'
MAX_USERS = int(os.environ.get("MAX_USERS", 38000)) MAX_USERS = int(os.environ.get("MAX_USERS", 38000))
MAX_DEVS = .1 # what percentage of MAX_USERS are Developers MAX_DEVS = 0.1 # what percentage of MAX_USERS are Developers
MAX_TUS = .2 # what percentage of MAX_USERS are Trusted Users MAX_TUS = 0.2 # what percentage of MAX_USERS are Trusted Users
# how many packages to load # how many packages to load
MAX_PKGS = int(os.environ.get("MAX_PKGS", 32000)) MAX_PKGS = int(os.environ.get("MAX_PKGS", 32000))
PKG_DEPS = (1, 15) # min/max depends a package has PKG_DEPS = (1, 15) # min/max depends a package has
PKG_RELS = (1, 5) # min/max relations a package has PKG_RELS = (1, 5) # min/max relations a package has
PKG_SRC = (1, 3) # min/max sources a package has PKG_SRC = (1, 3) # min/max sources a package has
PKG_CMNTS = (1, 5) # min/max number of comments a package has PKG_CMNTS = (1, 5) # min/max number of comments a package has
CATEGORIES_COUNT = 17 # the number of categories from aur-schema CATEGORIES_COUNT = 17 # the number of categories from aur-schema
VOTING = (0, .001) # percentage range for package voting VOTING = (0, 0.001) # percentage range for package voting
# number of open trusted user proposals # number of open trusted user proposals
OPEN_PROPOSALS = int(os.environ.get("OPEN_PROPOSALS", 15)) OPEN_PROPOSALS = int(os.environ.get("OPEN_PROPOSALS", 15))
# number of closed trusted user proposals # number of closed trusted user proposals
@ -113,10 +112,10 @@ if not len(contents) - MAX_USERS > MAX_PKGS:
def normalize(unicode_data): def normalize(unicode_data):
""" We only accept ascii for usernames. Also use this to normalize """We only accept ascii for usernames. Also use this to normalize
package names; our database utf8mb4 collations compare with Unicode package names; our database utf8mb4 collations compare with Unicode
Equivalence. """ Equivalence."""
return unicode_data.encode('ascii', 'ignore').decode('ascii') return unicode_data.encode("ascii", "ignore").decode("ascii")
# select random usernames # select random usernames
@ -196,10 +195,12 @@ for u in user_keys:
# "{salt}{username}" # "{salt}{username}"
to_hash = f"{salt}{u}" to_hash = f"{salt}{u}"
h = hashlib.new('md5') h = hashlib.new("md5")
h.update(to_hash.encode()) h.update(to_hash.encode())
s = ("INSERT INTO Users (ID, AccountTypeID, Username, Email, Passwd, Salt)" s = (
" VALUES (%d, %d, '%s', '%s@example.com', '%s', '%s');\n") "INSERT INTO Users (ID, AccountTypeID, Username, Email, Passwd, Salt)"
" VALUES (%d, %d, '%s', '%s@example.com', '%s', '%s');\n"
)
s = s % (seen_users[u], account_type, u, u, h.hexdigest(), salt) s = s % (seen_users[u], account_type, u, u, h.hexdigest(), salt)
out.write(s) out.write(s)
@ -230,13 +231,17 @@ for p in list(seen_pkgs.keys()):
uuid = genUID() # the submitter/user uuid = genUID() # the submitter/user
s = ("INSERT INTO PackageBases (ID, Name, FlaggerComment, SubmittedTS, ModifiedTS, " s = (
"SubmitterUID, MaintainerUID, PackagerUID) VALUES (%d, '%s', '', %d, %d, %d, %s, %s);\n") "INSERT INTO PackageBases (ID, Name, FlaggerComment, SubmittedTS, ModifiedTS, "
"SubmitterUID, MaintainerUID, PackagerUID) VALUES (%d, '%s', '', %d, %d, %d, %s, %s);\n"
)
s = s % (seen_pkgs[p], p, NOW, NOW, uuid, muid, puid) s = s % (seen_pkgs[p], p, NOW, NOW, uuid, muid, puid)
out.write(s) out.write(s)
s = ("INSERT INTO Packages (ID, PackageBaseID, Name, Version) VALUES " s = (
"(%d, %d, '%s', '%s');\n") "INSERT INTO Packages (ID, PackageBaseID, Name, Version) VALUES "
"(%d, %d, '%s', '%s');\n"
)
s = s % (seen_pkgs[p], seen_pkgs[p], p, genVersion()) s = s % (seen_pkgs[p], seen_pkgs[p], p, genVersion())
out.write(s) out.write(s)
@ -247,8 +252,10 @@ for p in list(seen_pkgs.keys()):
num_comments = random.randrange(PKG_CMNTS[0], PKG_CMNTS[1]) num_comments = random.randrange(PKG_CMNTS[0], PKG_CMNTS[1])
for i in range(0, num_comments): for i in range(0, num_comments):
now = NOW + random.randrange(400, 86400 * 3) now = NOW + random.randrange(400, 86400 * 3)
s = ("INSERT INTO PackageComments (PackageBaseID, UsersID," s = (
" Comments, RenderedComment, CommentTS) VALUES (%d, %d, '%s', '', %d);\n") "INSERT INTO PackageComments (PackageBaseID, UsersID,"
" Comments, RenderedComment, CommentTS) VALUES (%d, %d, '%s', '', %d);\n"
)
s = s % (seen_pkgs[p], genUID(), genFortune(), now) s = s % (seen_pkgs[p], genUID(), genFortune(), now)
out.write(s) out.write(s)
@ -258,14 +265,17 @@ utcnow = int(datetime.utcnow().timestamp())
track_votes = {} track_votes = {}
log.debug("Casting votes for packages.") log.debug("Casting votes for packages.")
for u in user_keys: for u in user_keys:
num_votes = random.randrange(int(len(seen_pkgs) * VOTING[0]), num_votes = random.randrange(
int(len(seen_pkgs) * VOTING[1])) int(len(seen_pkgs) * VOTING[0]), int(len(seen_pkgs) * VOTING[1])
)
pkgvote = {} pkgvote = {}
for v in range(num_votes): for v in range(num_votes):
pkg = random.randrange(1, len(seen_pkgs) + 1) pkg = random.randrange(1, len(seen_pkgs) + 1)
if pkg not in pkgvote: if pkg not in pkgvote:
s = ("INSERT INTO PackageVotes (UsersID, PackageBaseID, VoteTS)" s = (
" VALUES (%d, %d, %d);\n") "INSERT INTO PackageVotes (UsersID, PackageBaseID, VoteTS)"
" VALUES (%d, %d, %d);\n"
)
s = s % (seen_users[u], pkg, utcnow) s = s % (seen_users[u], pkg, utcnow)
pkgvote[pkg] = 1 pkgvote[pkg] = 1
if pkg not in track_votes: if pkg not in track_votes:
@ -310,9 +320,12 @@ for p in seen_pkgs_keys:
src_file = user_keys[random.randrange(0, len(user_keys))] src_file = user_keys[random.randrange(0, len(user_keys))]
src = "%s%s.%s/%s/%s-%s.tar.gz" % ( src = "%s%s.%s/%s/%s-%s.tar.gz" % (
RANDOM_URL[random.randrange(0, len(RANDOM_URL))], RANDOM_URL[random.randrange(0, len(RANDOM_URL))],
p, RANDOM_TLDS[random.randrange(0, len(RANDOM_TLDS))], p,
RANDOM_TLDS[random.randrange(0, len(RANDOM_TLDS))],
RANDOM_LOCS[random.randrange(0, len(RANDOM_LOCS))], RANDOM_LOCS[random.randrange(0, len(RANDOM_LOCS))],
src_file, genVersion()) src_file,
genVersion(),
)
s = "INSERT INTO PackageSources(PackageID, Source) VALUES (%d, '%s');\n" s = "INSERT INTO PackageSources(PackageID, Source) VALUES (%d, '%s');\n"
s = s % (seen_pkgs[p], src) s = s % (seen_pkgs[p], src)
out.write(s) out.write(s)
@ -334,8 +347,10 @@ for t in range(0, OPEN_PROPOSALS + CLOSE_PROPOSALS):
else: else:
user = user_keys[random.randrange(0, len(user_keys))] user = user_keys[random.randrange(0, len(user_keys))]
suid = trustedusers[random.randrange(0, len(trustedusers))] suid = trustedusers[random.randrange(0, len(trustedusers))]
s = ("INSERT INTO TU_VoteInfo (Agenda, User, Submitted, End," s = (
" Quorum, SubmitterID) VALUES ('%s', '%s', %d, %d, 0.0, %d);\n") "INSERT INTO TU_VoteInfo (Agenda, User, Submitted, End,"
" Quorum, SubmitterID) VALUES ('%s', '%s', %d, %d, 0.0, %d);\n"
)
s = s % (genFortune(), user, start, end, suid) s = s % (genFortune(), user, start, end, suid)
out.write(s) out.write(s)
count += 1 count += 1

View file

@ -65,4 +65,3 @@
</form> </form>
</div> </div>
{% endblock %} {% endblock %}

View file

@ -79,4 +79,3 @@
</td> </td>
</tr> </tr>
</table> </table>

Some files were not shown because too many files have changed in this diff Show more