style: Run pre-commit

This commit is contained in:
Joakim Saario 2022-08-21 22:08:29 +02:00
parent b47882b114
commit 9c6c13b78a
No known key found for this signature in database
GPG key ID: D8B76D271B7BD453
235 changed files with 7180 additions and 5628 deletions

1
.gitignore vendored
View file

@ -1,3 +1,4 @@
/data/
__pycache__/
*.py[cod]
.vim/

View file

@ -5,4 +5,3 @@ host = https://www.transifex.com
file_filter = po/<lang>.po
source_file = po/aurweb.pot
source_lang = en

View file

@ -6,11 +6,9 @@ import re
import sys
import traceback
import typing
from urllib.parse import quote_plus
import requests
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
@ -26,7 +24,6 @@ import aurweb.config
import aurweb.filters # noqa: F401
import aurweb.logging
import aurweb.pkgbase.util as pkgbaseutil
from aurweb import logging, prometheus, util
from aurweb.auth import BasicAuthBackend
from aurweb.db import get_engine, query
@ -60,33 +57,33 @@ async def app_startup():
# provided by the user. Docker uses .env's TEST_RECURSION_LIMIT
# when running test suites.
# TODO: Find a proper fix to this issue.
recursion_limit = int(os.environ.get(
"TEST_RECURSION_LIMIT", sys.getrecursionlimit() + 1000))
recursion_limit = int(
os.environ.get("TEST_RECURSION_LIMIT", sys.getrecursionlimit() + 1000)
)
sys.setrecursionlimit(recursion_limit)
backend = aurweb.config.get("database", "backend")
if backend not in aurweb.db.DRIVERS:
raise ValueError(
f"The configured database backend ({backend}) is unsupported. "
f"Supported backends: {str(aurweb.db.DRIVERS.keys())}")
f"Supported backends: {str(aurweb.db.DRIVERS.keys())}"
)
session_secret = aurweb.config.get("fastapi", "session_secret")
if not session_secret:
raise Exception("[fastapi] session_secret must not be empty")
if not os.environ.get("PROMETHEUS_MULTIPROC_DIR", None):
logger.warning("$PROMETHEUS_MULTIPROC_DIR is not set, the /metrics "
"endpoint is disabled.")
logger.warning(
"$PROMETHEUS_MULTIPROC_DIR is not set, the /metrics "
"endpoint is disabled."
)
app.mount("/static/css",
StaticFiles(directory="web/html/css"),
name="static_css")
app.mount("/static/js",
StaticFiles(directory="web/html/js"),
name="static_js")
app.mount("/static/images",
StaticFiles(directory="web/html/images"),
name="static_images")
app.mount("/static/css", StaticFiles(directory="web/html/css"), name="static_css")
app.mount("/static/js", StaticFiles(directory="web/html/js"), name="static_js")
app.mount(
"/static/images", StaticFiles(directory="web/html/images"), name="static_images"
)
# Add application middlewares.
app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend())
@ -95,6 +92,7 @@ async def app_startup():
# Add application routes.
def add_router(module):
app.include_router(module.router)
util.apply_all(APP_ROUTES, add_router)
# Initialize the database engine and ORM.
@ -177,9 +175,7 @@ async def internal_server_error(request: Request, exc: Exception) -> Response:
else:
# post
form_data = str(dict(request.state.form_data))
desc = desc + [
f"- Data: `{form_data}`"
] + ["", f"```{tb}```"]
desc = desc + [f"- Data: `{form_data}`"] + ["", f"```{tb}```"]
headers = {"Authorization": f"Bearer {token}"}
data = {
@ -191,11 +187,12 @@ async def internal_server_error(request: Request, exc: Exception) -> Response:
logger.info(endp)
resp = requests.post(endp, json=data, headers=headers)
if resp.status_code != http.HTTPStatus.CREATED:
logger.error(
f"Unable to report exception to {repo}: {resp.text}")
logger.error(f"Unable to report exception to {repo}: {resp.text}")
else:
logger.warning("Unable to report an exception found due to "
"unset notifications.error-{{project,token}}")
logger.warning(
"Unable to report an exception found due to "
"unset notifications.error-{{project,token}}"
)
# Log details about the exception traceback.
logger.error(f"FATAL[{tb_id}]: An unexpected exception has occurred.")
@ -203,13 +200,16 @@ async def internal_server_error(request: Request, exc: Exception) -> Response:
else:
retval = retval.decode()
return render_template(request, "errors/500.html", context,
status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR)
return render_template(
request,
"errors/500.html",
context,
status_code=http.HTTPStatus.INTERNAL_SERVER_ERROR,
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: HTTPException) \
-> Response:
async def http_exception_handler(request: Request, exc: HTTPException) -> Response:
"""Handle an HTTPException thrown in a route."""
phrase = http.HTTPStatus(exc.status_code).phrase
context = make_context(request, phrase)
@ -228,11 +228,11 @@ async def http_exception_handler(request: Request, exc: HTTPException) \
pass
try:
return render_template(request, f"errors/{exc.status_code}.html",
context, exc.status_code)
return render_template(
request, f"errors/{exc.status_code}.html", context, exc.status_code
)
except TemplateNotFound:
return render_template(request, "errors/detail.html",
context, exc.status_code)
return render_template(request, "errors/detail.html", context, exc.status_code)
@app.middleware("http")
@ -254,7 +254,7 @@ async def add_security_headers(request: Request, call_next: typing.Callable):
nonce = request.user.nonce
csp = "default-src 'self'; "
script_hosts = []
csp += f"script-src 'self' 'nonce-{nonce}' " + ' '.join(script_hosts)
csp += f"script-src 'self' 'nonce-{nonce}' " + " ".join(script_hosts)
# It's fine if css is inlined.
csp += "; style-src 'self' 'unsafe-inline'"
response.headers["Content-Security-Policy"] = csp
@ -279,14 +279,22 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable):
"""This middleware function redirects authenticated users if they
have any outstanding Terms to agree to."""
if request.user.is_authenticated() and request.url.path != "/tos":
unaccepted = query(Term).join(AcceptedTerm).filter(
or_(AcceptedTerm.UsersID != request.user.ID,
and_(AcceptedTerm.UsersID == request.user.ID,
unaccepted = (
query(Term)
.join(AcceptedTerm)
.filter(
or_(
AcceptedTerm.UsersID != request.user.ID,
and_(
AcceptedTerm.UsersID == request.user.ID,
AcceptedTerm.TermsID == Term.ID,
AcceptedTerm.Revision < Term.Revision)))
AcceptedTerm.Revision < Term.Revision,
),
)
)
)
if query(Term).count() > unaccepted.count():
return RedirectResponse(
"/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
return RedirectResponse("/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
return await util.error_or_result(call_next, request)
@ -301,9 +309,9 @@ async def id_redirect_middleware(request: Request, call_next: typing.Callable):
for k, v in request.query_params.items():
if k != "id":
qs.append(f"{k}={quote_plus(str(v))}")
qs = str() if not qs else '?' + '&'.join(qs)
qs = str() if not qs else "?" + "&".join(qs)
path = request.url.path.rstrip('/')
path = request.url.path.rstrip("/")
return RedirectResponse(f"{path}/{id}{qs}")
return await util.error_or_result(call_next, request)

View file

@ -1,17 +1,14 @@
import functools
from http import HTTPStatus
from typing import Callable
import fastapi
from fastapi import HTTPException
from fastapi.responses import RedirectResponse
from starlette.authentication import AuthCredentials, AuthenticationBackend
from starlette.requests import HTTPConnection
import aurweb.config
from aurweb import db, filters, l10n, time, util
from aurweb.models import Session, User
from aurweb.models.account_type import ACCOUNT_TYPE_ID
@ -31,6 +28,7 @@ class StubQuery:
class AnonymousUser:
"""A stubbed User class used when an unauthenticated User
makes a request against FastAPI."""
# Stub attributes used to mimic a real user.
ID = 0
Username = "N/A"
@ -42,6 +40,7 @@ class AnonymousUser:
All records primary keys (AccountType.ID) should be non-zero,
so using a zero here means that we'll never match against a
real AccountType."""
ID = 0
AccountType = "Anonymous"
@ -104,11 +103,11 @@ class BasicAuthBackend(AuthenticationBackend):
return unauthenticated
timeout = aurweb.config.getint("options", "login_timeout")
remembered = ("AURREMEMBER" in conn.cookies
and bool(conn.cookies.get("AURREMEMBER")))
remembered = "AURREMEMBER" in conn.cookies and bool(
conn.cookies.get("AURREMEMBER")
)
if remembered:
timeout = aurweb.config.getint("options",
"persistent_cookie_timeout")
timeout = aurweb.config.getint("options", "persistent_cookie_timeout")
# If no session with sid and a LastUpdateTS now or later exists.
now_ts = time.utcnow()
@ -160,15 +159,18 @@ def _auth_required(auth_goal: bool = True):
# page itself is not directly possible (e.g. submitting a form).
if request.method in ("GET", "HEAD"):
url = request.url.path
elif (referer := request.headers.get("Referer")):
elif referer := request.headers.get("Referer"):
aur = aurweb.config.get("options", "aur_location") + "/"
if not referer.startswith(aur):
_ = l10n.get_translator_for_request(request)
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST,
detail=_("Bad Referer header."))
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=_("Bad Referer header."),
)
url = referer[len(aur) - 1 :]
url = "/login?" + filters.urlencode({"next": url})
return RedirectResponse(url, status_code=int(HTTPStatus.SEE_OTHER))
return wrapper
return decorator
@ -180,6 +182,7 @@ def requires_auth(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await _auth_required(True)(func)(*args, **kwargs)
return wrapper
@ -189,6 +192,7 @@ def requires_guest(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await _auth_required(False)(func)(*args, **kwargs)
return wrapper
@ -211,18 +215,15 @@ def account_type_required(one_of: set):
:return: Return the FastAPI function this decorator wraps.
"""
# Convert any account type string constants to their integer IDs.
one_of = {
ACCOUNT_TYPE_ID[atype]
for atype in one_of
if isinstance(atype, str)
}
one_of = {ACCOUNT_TYPE_ID[atype] for atype in one_of if isinstance(atype, str)}
def decorator(func):
@functools.wraps(func)
async def wrapper(request: fastapi.Request, *args, **kwargs):
if request.user.AccountTypeID not in one_of:
return RedirectResponse("/",
status_code=int(HTTPStatus.SEE_OTHER))
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))
return await func(request, *args, **kwargs)
return wrapper
return decorator

View file

@ -1,4 +1,9 @@
from aurweb.models.account_type import DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, USER_ID
from aurweb.models.account_type import (
DEVELOPER_ID,
TRUSTED_USER_AND_DEV_ID,
TRUSTED_USER_ID,
USER_ID,
)
from aurweb.models.user import User
ACCOUNT_CHANGE_TYPE = 1
@ -30,7 +35,9 @@ TU_LIST_VOTES = 20
TU_VOTE = 21
PKGBASE_MERGE = 29
user_developer_or_trusted_user = set([USER_ID, TRUSTED_USER_ID, DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID])
user_developer_or_trusted_user = set(
[USER_ID, TRUSTED_USER_ID, DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID]
)
trusted_user_or_dev = set([TRUSTED_USER_ID, DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID])
developer = set([DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID])
trusted_user = set([TRUSTED_USER_ID, TRUSTED_USER_AND_DEV_ID])
@ -67,9 +74,7 @@ cred_filters = {
}
def has_credential(user: User,
credential: int,
approved: list = tuple()):
def has_credential(user: User, credential: int, approved: list = tuple()):
if user in approved:
return True

View file

@ -2,8 +2,9 @@ from redis import Redis
from sqlalchemy import orm
async def db_count_cache(redis: Redis, key: str, query: orm.Query,
expire: int = None) -> int:
async def db_count_cache(
redis: Redis, key: str, query: orm.Query, expire: int = None
) -> int:
"""Store and retrieve a query.count() via redis cache.
:param redis: Redis handle

View file

@ -38,7 +38,9 @@ def get_captcha_answer(token):
'--'
This program may be freely redistributed under
the terms of the GNU General Public License.
""" % tuple([token] * 10)
""" % tuple(
[token] * 10
)
return hashlib.md5((text + "\n").encode()).hexdigest()[:6]

View file

@ -1,6 +1,5 @@
import configparser
import os
from typing import Any
# Publicly visible version of aurweb. This is used to display
@ -15,8 +14,8 @@ def _get_parser():
global _parser
if not _parser:
path = os.environ.get('AUR_CONFIG', '/etc/aurweb/config')
defaults = os.environ.get('AUR_CONFIG_DEFAULTS', path + '.defaults')
path = os.environ.get("AUR_CONFIG", "/etc/aurweb/config")
defaults = os.environ.get("AUR_CONFIG_DEFAULTS", path + ".defaults")
_parser = configparser.RawConfigParser()
_parser.optionxform = lambda option: option

View file

@ -35,9 +35,13 @@ def timeout(extended: bool) -> int:
return timeout
def update_response_cookies(request: Request, response: Response,
aurtz: str = None, aurlang: str = None,
aursid: str = None) -> Response:
def update_response_cookies(
request: Request,
response: Response,
aurtz: str = None,
aurlang: str = None,
aursid: str = None,
) -> Response:
"""Update session cookies. This method is particularly useful
when updating a cookie which was already set.
@ -53,14 +57,21 @@ def update_response_cookies(request: Request, response: Response,
"""
secure = config.getboolean("options", "disable_http_login")
if aurtz:
response.set_cookie("AURTZ", aurtz, secure=secure, httponly=secure,
samesite=samesite())
response.set_cookie(
"AURTZ", aurtz, secure=secure, httponly=secure, samesite=samesite()
)
if aurlang:
response.set_cookie("AURLANG", aurlang, secure=secure, httponly=secure,
samesite=samesite())
response.set_cookie(
"AURLANG", aurlang, secure=secure, httponly=secure, samesite=samesite()
)
if aursid:
remember_me = bool(request.cookies.get("AURREMEMBER", False))
response.set_cookie("AURSID", aursid, secure=secure, httponly=secure,
response.set_cookie(
"AURSID",
aursid,
secure=secure,
httponly=secure,
max_age=timeout(remember_me),
samesite=samesite())
samesite=samesite(),
)
return response

View file

@ -1,7 +1,5 @@
# Supported database drivers.
DRIVERS = {
"mysql": "mysql+mysqldb"
}
DRIVERS = {"mysql": "mysql+mysqldb"}
def make_random_value(table: str, column: str, length: int):
@ -10,6 +8,7 @@ def make_random_value(table: str, column: str, length: int):
:return: A unique string that is not in the database
"""
import aurweb.util
string = aurweb.util.make_random_string(length)
while query(table).filter(column == string).first():
string = aurweb.util.make_random_string(length)
@ -37,8 +36,7 @@ def test_name() -> str:
import aurweb.config
db = os.environ.get("PYTEST_CURRENT_TEST",
aurweb.config.get("database", "name"))
db = os.environ.get("PYTEST_CURRENT_TEST", aurweb.config.get("database", "name"))
return db.split(":")[0]
@ -57,6 +55,7 @@ def name() -> str:
return dbname
import hashlib
sha1 = hashlib.sha1(dbname.encode()).hexdigest()
return "db" + sha1
@ -78,7 +77,8 @@ def get_session(engine=None):
engine = get_engine()
Session = scoped_session(
sessionmaker(autocommit=True, autoflush=False, bind=engine))
sessionmaker(autocommit=True, autoflush=False, bind=engine)
)
_sessions[dbname] = Session()
return _sessions.get(dbname)
@ -140,6 +140,7 @@ def delete(model) -> None:
def delete_all(iterable) -> None:
"""Delete each instance found in `iterable`."""
import aurweb.util
session_ = get_session()
aurweb.util.apply_all(iterable, session_.delete)
@ -167,49 +168,49 @@ def get_sqlalchemy_url():
:return: sqlalchemy.engine.url.URL
"""
import sqlalchemy
from sqlalchemy.engine.url import URL
import aurweb.config
constructor = URL
parts = sqlalchemy.__version__.split('.')
parts = sqlalchemy.__version__.split(".")
major = int(parts[0])
minor = int(parts[1])
if major == 1 and minor >= 4: # pragma: no cover
constructor = URL.create
aur_db_backend = aurweb.config.get('database', 'backend')
if aur_db_backend == 'mysql':
aur_db_backend = aurweb.config.get("database", "backend")
if aur_db_backend == "mysql":
param_query = {}
port = aurweb.config.get_with_fallback("database", "port", None)
if not port:
param_query["unix_socket"] = aurweb.config.get(
"database", "socket")
param_query["unix_socket"] = aurweb.config.get("database", "socket")
return constructor(
DRIVERS.get(aur_db_backend),
username=aurweb.config.get('database', 'user'),
password=aurweb.config.get_with_fallback('database', 'password',
fallback=None),
host=aurweb.config.get('database', 'host'),
username=aurweb.config.get("database", "user"),
password=aurweb.config.get_with_fallback(
"database", "password", fallback=None
),
host=aurweb.config.get("database", "host"),
database=name(),
port=port,
query=param_query
query=param_query,
)
elif aur_db_backend == 'sqlite':
elif aur_db_backend == "sqlite":
return constructor(
'sqlite',
database=aurweb.config.get('database', 'name'),
"sqlite",
database=aurweb.config.get("database", "name"),
)
else:
raise ValueError('unsupported database backend')
raise ValueError("unsupported database backend")
def sqlite_regexp(regex, item) -> bool: # pragma: no cover
"""Method which mimics SQL's REGEXP for SQLite."""
import re
return bool(re.search(regex, str(item)))
@ -220,9 +221,9 @@ def setup_sqlite(engine) -> None: # pragma: no cover
@event.listens_for(engine, "connect")
def do_begin(conn, record):
import functools
create_deterministic_function = functools.partial(
conn.create_function,
deterministic=True
conn.create_function, deterministic=True
)
create_deterministic_function("REGEXP", 2, sqlite_regexp)
@ -256,11 +257,9 @@ def get_engine(dbname: str = None, echo: bool = False):
if is_sqlite: # pragma: no cover
connect_args["check_same_thread"] = False
kwargs = {
"echo": echo,
"connect_args": connect_args
}
kwargs = {"echo": echo, "connect_args": connect_args}
from sqlalchemy import create_engine
_engines[dbname] = create_engine(get_sqlalchemy_url(), **kwargs)
if is_sqlite: # pragma: no cover
@ -317,6 +316,7 @@ class ConnectionExecutor:
self._paramstyle = "format"
elif backend == "sqlite":
import sqlite3
self._paramstyle = sqlite3.paramstyle
def paramstyle(self):
@ -325,12 +325,12 @@ class ConnectionExecutor:
def execute(self, query, params=()): # pragma: no cover
# TODO: SQLite support has been removed in FastAPI. It remains
# here to fund its support for PHP until it is removed.
if self._paramstyle in ('format', 'pyformat'):
query = query.replace('%', '%%').replace('?', '%s')
elif self._paramstyle == 'qmark':
if self._paramstyle in ("format", "pyformat"):
query = query.replace("%", "%%").replace("?", "%s")
elif self._paramstyle == "qmark":
pass
else:
raise ValueError('unsupported paramstyle')
raise ValueError("unsupported paramstyle")
cur = self._conn.cursor()
cur.execute(query, params)
@ -350,32 +350,35 @@ class Connection:
def __init__(self):
import aurweb.config
aur_db_backend = aurweb.config.get('database', 'backend')
if aur_db_backend == 'mysql':
aur_db_backend = aurweb.config.get("database", "backend")
if aur_db_backend == "mysql":
import MySQLdb
aur_db_host = aurweb.config.get('database', 'host')
aur_db_host = aurweb.config.get("database", "host")
aur_db_name = name()
aur_db_user = aurweb.config.get('database', 'user')
aur_db_pass = aurweb.config.get_with_fallback(
'database', 'password', str())
aur_db_socket = aurweb.config.get('database', 'socket')
self._conn = MySQLdb.connect(host=aur_db_host,
aur_db_user = aurweb.config.get("database", "user")
aur_db_pass = aurweb.config.get_with_fallback("database", "password", str())
aur_db_socket = aurweb.config.get("database", "socket")
self._conn = MySQLdb.connect(
host=aur_db_host,
user=aur_db_user,
passwd=aur_db_pass,
db=aur_db_name,
unix_socket=aur_db_socket)
elif aur_db_backend == 'sqlite': # pragma: no cover
unix_socket=aur_db_socket,
)
elif aur_db_backend == "sqlite": # pragma: no cover
# TODO: SQLite support has been removed in FastAPI. It remains
# here to fund its support for PHP until it is removed.
import math
import sqlite3
aur_db_name = aurweb.config.get('database', 'name')
aur_db_name = aurweb.config.get("database", "name")
self._conn = sqlite3.connect(aur_db_name)
self._conn.create_function("POWER", 2, math.pow)
else:
raise ValueError('unsupported database backend')
raise ValueError("unsupported database backend")
self._conn = ConnectionExecutor(self._conn, aur_db_backend)

View file

@ -1,5 +1,4 @@
import functools
from typing import Any, Callable
import fastapi
@ -19,61 +18,61 @@ class BannedException(AurwebException):
class PermissionDeniedException(AurwebException):
def __init__(self, user):
msg = 'permission denied: {:s}'.format(user)
msg = "permission denied: {:s}".format(user)
super(PermissionDeniedException, self).__init__(msg)
class BrokenUpdateHookException(AurwebException):
def __init__(self, cmd):
msg = 'broken update hook: {:s}'.format(cmd)
msg = "broken update hook: {:s}".format(cmd)
super(BrokenUpdateHookException, self).__init__(msg)
class InvalidUserException(AurwebException):
def __init__(self, user):
msg = 'unknown user: {:s}'.format(user)
msg = "unknown user: {:s}".format(user)
super(InvalidUserException, self).__init__(msg)
class InvalidPackageBaseException(AurwebException):
def __init__(self, pkgbase):
msg = 'package base not found: {:s}'.format(pkgbase)
msg = "package base not found: {:s}".format(pkgbase)
super(InvalidPackageBaseException, self).__init__(msg)
class InvalidRepositoryNameException(AurwebException):
def __init__(self, pkgbase):
msg = 'invalid repository name: {:s}'.format(pkgbase)
msg = "invalid repository name: {:s}".format(pkgbase)
super(InvalidRepositoryNameException, self).__init__(msg)
class PackageBaseExistsException(AurwebException):
def __init__(self, pkgbase):
msg = 'package base already exists: {:s}'.format(pkgbase)
msg = "package base already exists: {:s}".format(pkgbase)
super(PackageBaseExistsException, self).__init__(msg)
class InvalidReasonException(AurwebException):
def __init__(self, reason):
msg = 'invalid reason: {:s}'.format(reason)
msg = "invalid reason: {:s}".format(reason)
super(InvalidReasonException, self).__init__(msg)
class InvalidCommentException(AurwebException):
def __init__(self, comment):
msg = 'comment is too short: {:s}'.format(comment)
msg = "comment is too short: {:s}".format(comment)
super(InvalidCommentException, self).__init__(msg)
class AlreadyVotedException(AurwebException):
def __init__(self, comment):
msg = 'already voted for package base: {:s}'.format(comment)
msg = "already voted for package base: {:s}".format(comment)
super(AlreadyVotedException, self).__init__(msg)
class NotVotedException(AurwebException):
def __init__(self, comment):
msg = 'missing vote for package base: {:s}'.format(comment)
msg = "missing vote for package base: {:s}".format(comment)
super(NotVotedException, self).__init__(msg)
@ -109,4 +108,5 @@ def handle_form_exceptions(route: Callable) -> fastapi.Response:
async def wrapper(request: fastapi.Request, *args, **kwargs):
request.state.form_data = await request.form()
return await route(request, *args, **kwargs)
return wrapper

View file

@ -1,6 +1,5 @@
import copy
import math
from datetime import datetime
from typing import Any, Union
from urllib.parse import quote_plus, urlencode
@ -8,19 +7,16 @@ from zoneinfo import ZoneInfo
import fastapi
import paginate
from jinja2 import pass_context
import aurweb.models
from aurweb import config, l10n
from aurweb.templates import register_filter, register_function
@register_filter("pager_nav")
@pass_context
def pager_nav(context: dict[str, Any],
page: int, total: int, prefix: str) -> str:
def pager_nav(context: dict[str, Any], page: int, total: int, prefix: str) -> str:
page = int(page) # Make sure this is an int.
pp = context.get("PP", 50)
@ -43,10 +39,9 @@ def pager_nav(context: dict[str, Any],
return f"{prefix}?{qs}"
# Use the paginate module to produce our linkage.
pager = paginate.Page([], page=page + 1,
items_per_page=pp,
item_count=total,
url_maker=create_url)
pager = paginate.Page(
[], page=page + 1, items_per_page=pp, item_count=total, url_maker=create_url
)
return pager.pager(
link_attr={"class": "page"},
@ -56,7 +51,8 @@ def pager_nav(context: dict[str, Any],
symbol_first="« First",
symbol_previous=" Previous",
symbol_next="Next ",
symbol_last="Last »")
symbol_last="Last »",
)
@register_function("config_getint")
@ -79,8 +75,7 @@ def tr(context: dict[str, Any], value: str):
@register_filter("tn")
@pass_context
def tn(context: dict[str, Any], count: int,
singular: str, plural: str) -> str:
def tn(context: dict[str, Any], count: int, singular: str, plural: str) -> str:
"""A singular and plural translation filter.
Example:
@ -123,6 +118,7 @@ def to_qs(query: dict[str, Any]) -> str:
@register_filter("get_vote")
def get_vote(voteinfo, request: fastapi.Request):
from aurweb.models import TUVote
return voteinfo.tu_votes.filter(TUVote.User == request.user).first()
@ -134,8 +130,7 @@ def number_format(value: float, places: int):
@register_filter("account_url")
@pass_context
def account_url(context: dict[str, Any],
user: "aurweb.models.user.User") -> str:
def account_url(context: dict[str, Any], user: "aurweb.models.user.User") -> str:
base = aurweb.config.get("options", "aur_location")
return f"{base}/account/{user.Username}"
@ -152,8 +147,7 @@ def ceil(*args, **kwargs) -> int:
@register_function("date_strftime")
@pass_context
def date_strftime(context: dict[str, Any], dt: Union[int, datetime], fmt: str) \
-> str:
def date_strftime(context: dict[str, Any], dt: Union[int, datetime], fmt: str) -> str:
if isinstance(dt, int):
dt = timestamp_to_datetime(dt)
tz = context.get("timezone")

View file

@ -9,12 +9,12 @@ import aurweb.db
def format_command(env_vars, command, ssh_opts, ssh_key):
environment = ''
environment = ""
for key, var in env_vars.items():
environment += '{}={} '.format(key, shlex.quote(var))
environment += "{}={} ".format(key, shlex.quote(var))
command = shlex.quote(command)
command = '{}{}'.format(environment, command)
command = "{}{}".format(environment, command)
# The command is being substituted into an authorized_keys line below,
# so we need to escape the double quotes.
@ -24,10 +24,10 @@ def format_command(env_vars, command, ssh_opts, ssh_key):
def main():
valid_keytypes = aurweb.config.get('auth', 'valid-keytypes').split()
username_regex = aurweb.config.get('auth', 'username-regex')
git_serve_cmd = aurweb.config.get('auth', 'git-serve-cmd')
ssh_opts = aurweb.config.get('auth', 'ssh-options')
valid_keytypes = aurweb.config.get("auth", "valid-keytypes").split()
username_regex = aurweb.config.get("auth", "username-regex")
git_serve_cmd = aurweb.config.get("auth", "git-serve-cmd")
ssh_opts = aurweb.config.get("auth", "ssh-options")
keytype = sys.argv[1]
keytext = sys.argv[2]
@ -36,11 +36,13 @@ def main():
conn = aurweb.db.Connection()
cur = conn.execute("SELECT Users.Username, Users.AccountTypeID FROM Users "
cur = conn.execute(
"SELECT Users.Username, Users.AccountTypeID FROM Users "
"INNER JOIN SSHPubKeys ON SSHPubKeys.UserID = Users.ID "
"WHERE SSHPubKeys.PubKey = ? AND Users.Suspended = 0 "
"AND NOT Users.Passwd = ''",
(keytype + " " + keytext,))
(keytype + " " + keytext,),
)
row = cur.fetchone()
if not row or cur.fetchone():
@ -51,13 +53,13 @@ def main():
exit(1)
env_vars = {
'AUR_USER': user,
'AUR_PRIVILEGED': '1' if account_type > 1 else '0',
"AUR_USER": user,
"AUR_PRIVILEGED": "1" if account_type > 1 else "0",
}
key = keytype + ' ' + keytext
key = keytype + " " + keytext
print(format_command(env_vars, git_serve_cmd, ssh_opts, key))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -11,16 +11,16 @@ import aurweb.config
import aurweb.db
import aurweb.exceptions
notify_cmd = aurweb.config.get('notifications', 'notify-cmd')
notify_cmd = aurweb.config.get("notifications", "notify-cmd")
repo_path = aurweb.config.get('serve', 'repo-path')
repo_regex = aurweb.config.get('serve', 'repo-regex')
git_shell_cmd = aurweb.config.get('serve', 'git-shell-cmd')
git_update_cmd = aurweb.config.get('serve', 'git-update-cmd')
ssh_cmdline = aurweb.config.get('serve', 'ssh-cmdline')
repo_path = aurweb.config.get("serve", "repo-path")
repo_regex = aurweb.config.get("serve", "repo-regex")
git_shell_cmd = aurweb.config.get("serve", "git-shell-cmd")
git_update_cmd = aurweb.config.get("serve", "git-update-cmd")
ssh_cmdline = aurweb.config.get("serve", "ssh-cmdline")
enable_maintenance = aurweb.config.getboolean('options', 'enable-maintenance')
maintenance_exc = aurweb.config.get('options', 'maintenance-exceptions').split()
enable_maintenance = aurweb.config.getboolean("options", "enable-maintenance")
maintenance_exc = aurweb.config.get("options", "maintenance-exceptions").split()
def pkgbase_from_name(pkgbase):
@ -43,10 +43,12 @@ def list_repos(user):
if userid == 0:
raise aurweb.exceptions.InvalidUserException(user)
cur = conn.execute("SELECT Name, PackagerUID FROM PackageBases " +
"WHERE MaintainerUID = ?", [userid])
cur = conn.execute(
"SELECT Name, PackagerUID FROM PackageBases " + "WHERE MaintainerUID = ?",
[userid],
)
for row in cur:
print((' ' if row[1] else '*') + row[0])
print((" " if row[1] else "*") + row[0])
conn.close()
@ -64,15 +66,18 @@ def create_pkgbase(pkgbase, user):
raise aurweb.exceptions.InvalidUserException(user)
now = int(time.time())
cur = conn.execute("INSERT INTO PackageBases (Name, SubmittedTS, " +
"ModifiedTS, SubmitterUID, MaintainerUID, " +
"FlaggerComment) VALUES (?, ?, ?, ?, ?, '')",
[pkgbase, now, now, userid, userid])
cur = conn.execute(
"INSERT INTO PackageBases (Name, SubmittedTS, "
+ "ModifiedTS, SubmitterUID, MaintainerUID, "
+ "FlaggerComment) VALUES (?, ?, ?, ?, ?, '')",
[pkgbase, now, now, userid, userid],
)
pkgbase_id = cur.lastrowid
cur = conn.execute("INSERT INTO PackageNotifications " +
"(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, userid])
cur = conn.execute(
"INSERT INTO PackageNotifications " + "(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, userid],
)
conn.commit()
conn.close()
@ -85,8 +90,10 @@ def pkgbase_adopt(pkgbase, user, privileged):
conn = aurweb.db.Connection()
cur = conn.execute("SELECT ID FROM PackageBases WHERE ID = ? AND " +
"MaintainerUID IS NULL", [pkgbase_id])
cur = conn.execute(
"SELECT ID FROM PackageBases WHERE ID = ? AND " + "MaintainerUID IS NULL",
[pkgbase_id],
)
if not privileged and not cur.fetchone():
raise aurweb.exceptions.PermissionDeniedException(user)
@ -95,19 +102,25 @@ def pkgbase_adopt(pkgbase, user, privileged):
if userid == 0:
raise aurweb.exceptions.InvalidUserException(user)
cur = conn.execute("UPDATE PackageBases SET MaintainerUID = ? " +
"WHERE ID = ?", [userid, pkgbase_id])
cur = conn.execute(
"UPDATE PackageBases SET MaintainerUID = ? " + "WHERE ID = ?",
[userid, pkgbase_id],
)
cur = conn.execute("SELECT COUNT(*) FROM PackageNotifications WHERE " +
"PackageBaseID = ? AND UserID = ?",
[pkgbase_id, userid])
cur = conn.execute(
"SELECT COUNT(*) FROM PackageNotifications WHERE "
+ "PackageBaseID = ? AND UserID = ?",
[pkgbase_id, userid],
)
if cur.fetchone()[0] == 0:
cur = conn.execute("INSERT INTO PackageNotifications " +
"(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, userid])
cur = conn.execute(
"INSERT INTO PackageNotifications "
+ "(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, userid],
)
conn.commit()
subprocess.Popen((notify_cmd, 'adopt', str(userid), str(pkgbase_id)))
subprocess.Popen((notify_cmd, "adopt", str(userid), str(pkgbase_id)))
conn.close()
@ -115,13 +128,16 @@ def pkgbase_adopt(pkgbase, user, privileged):
def pkgbase_get_comaintainers(pkgbase):
conn = aurweb.db.Connection()
cur = conn.execute("SELECT UserName FROM PackageComaintainers " +
"INNER JOIN Users " +
"ON Users.ID = PackageComaintainers.UsersID " +
"INNER JOIN PackageBases " +
"ON PackageBases.ID = PackageComaintainers.PackageBaseID " +
"WHERE PackageBases.Name = ? " +
"ORDER BY Priority ASC", [pkgbase])
cur = conn.execute(
"SELECT UserName FROM PackageComaintainers "
+ "INNER JOIN Users "
+ "ON Users.ID = PackageComaintainers.UsersID "
+ "INNER JOIN PackageBases "
+ "ON PackageBases.ID = PackageComaintainers.PackageBaseID "
+ "WHERE PackageBases.Name = ? "
+ "ORDER BY Priority ASC",
[pkgbase],
)
return [row[0] for row in cur.fetchall()]
@ -140,8 +156,7 @@ def pkgbase_set_comaintainers(pkgbase, userlist, user, privileged):
uids_old = set()
for olduser in userlist_old:
cur = conn.execute("SELECT ID FROM Users WHERE Username = ?",
[olduser])
cur = conn.execute("SELECT ID FROM Users WHERE Username = ?", [olduser])
userid = cur.fetchone()[0]
if userid == 0:
raise aurweb.exceptions.InvalidUserException(user)
@ -149,8 +164,7 @@ def pkgbase_set_comaintainers(pkgbase, userlist, user, privileged):
uids_new = set()
for newuser in userlist:
cur = conn.execute("SELECT ID FROM Users WHERE Username = ?",
[newuser])
cur = conn.execute("SELECT ID FROM Users WHERE Username = ?", [newuser])
userid = cur.fetchone()[0]
if userid == 0:
raise aurweb.exceptions.InvalidUserException(user)
@ -162,24 +176,33 @@ def pkgbase_set_comaintainers(pkgbase, userlist, user, privileged):
i = 1
for userid in uids_new:
if userid in uids_add:
cur = conn.execute("INSERT INTO PackageComaintainers " +
"(PackageBaseID, UsersID, Priority) " +
"VALUES (?, ?, ?)", [pkgbase_id, userid, i])
subprocess.Popen((notify_cmd, 'comaintainer-add', str(userid),
str(pkgbase_id)))
cur = conn.execute(
"INSERT INTO PackageComaintainers "
+ "(PackageBaseID, UsersID, Priority) "
+ "VALUES (?, ?, ?)",
[pkgbase_id, userid, i],
)
subprocess.Popen(
(notify_cmd, "comaintainer-add", str(userid), str(pkgbase_id))
)
else:
cur = conn.execute("UPDATE PackageComaintainers " +
"SET Priority = ? " +
"WHERE PackageBaseID = ? AND UsersID = ?",
[i, pkgbase_id, userid])
cur = conn.execute(
"UPDATE PackageComaintainers "
+ "SET Priority = ? "
+ "WHERE PackageBaseID = ? AND UsersID = ?",
[i, pkgbase_id, userid],
)
i += 1
for userid in uids_rem:
cur = conn.execute("DELETE FROM PackageComaintainers " +
"WHERE PackageBaseID = ? AND UsersID = ?",
[pkgbase_id, userid])
subprocess.Popen((notify_cmd, 'comaintainer-remove',
str(userid), str(pkgbase_id)))
cur = conn.execute(
"DELETE FROM PackageComaintainers "
+ "WHERE PackageBaseID = ? AND UsersID = ?",
[pkgbase_id, userid],
)
subprocess.Popen(
(notify_cmd, "comaintainer-remove", str(userid), str(pkgbase_id))
)
conn.commit()
conn.close()
@ -188,18 +211,21 @@ def pkgbase_set_comaintainers(pkgbase, userlist, user, privileged):
def pkgreq_by_pkgbase(pkgbase_id, reqtype):
conn = aurweb.db.Connection()
cur = conn.execute("SELECT PackageRequests.ID FROM PackageRequests " +
"INNER JOIN RequestTypes ON " +
"RequestTypes.ID = PackageRequests.ReqTypeID " +
"WHERE PackageRequests.Status = 0 " +
"AND PackageRequests.PackageBaseID = ? " +
"AND RequestTypes.Name = ?", [pkgbase_id, reqtype])
cur = conn.execute(
"SELECT PackageRequests.ID FROM PackageRequests "
+ "INNER JOIN RequestTypes ON "
+ "RequestTypes.ID = PackageRequests.ReqTypeID "
+ "WHERE PackageRequests.Status = 0 "
+ "AND PackageRequests.PackageBaseID = ? "
+ "AND RequestTypes.Name = ?",
[pkgbase_id, reqtype],
)
return [row[0] for row in cur.fetchall()]
def pkgreq_close(reqid, user, reason, comments, autoclose=False):
statusmap = {'accepted': 2, 'rejected': 3}
statusmap = {"accepted": 2, "rejected": 3}
if reason not in statusmap:
raise aurweb.exceptions.InvalidReasonException(reason)
status = statusmap[reason]
@ -215,16 +241,20 @@ def pkgreq_close(reqid, user, reason, comments, autoclose=False):
raise aurweb.exceptions.InvalidUserException(user)
now = int(time.time())
conn.execute("UPDATE PackageRequests SET Status = ?, ClosedTS = ?, " +
"ClosedUID = ?, ClosureComment = ? " +
"WHERE ID = ?", [status, now, userid, comments, reqid])
conn.execute(
"UPDATE PackageRequests SET Status = ?, ClosedTS = ?, "
+ "ClosedUID = ?, ClosureComment = ? "
+ "WHERE ID = ?",
[status, now, userid, comments, reqid],
)
conn.commit()
conn.close()
if not userid:
userid = 0
subprocess.Popen((notify_cmd, 'request-close', str(userid), str(reqid),
reason)).wait()
subprocess.Popen(
(notify_cmd, "request-close", str(userid), str(reqid), reason)
).wait()
def pkgbase_disown(pkgbase, user, privileged):
@ -239,9 +269,9 @@ def pkgbase_disown(pkgbase, user, privileged):
# TODO: Support disowning package bases via package request.
# Scan through pending orphan requests and close them.
comment = 'The user {:s} disowned the package.'.format(user)
for reqid in pkgreq_by_pkgbase(pkgbase_id, 'orphan'):
pkgreq_close(reqid, user, 'accepted', comment, True)
comment = "The user {:s} disowned the package.".format(user)
for reqid in pkgreq_by_pkgbase(pkgbase_id, "orphan"):
pkgreq_close(reqid, user, "accepted", comment, True)
comaintainers = []
new_maintainer_userid = None
@ -254,14 +284,17 @@ def pkgbase_disown(pkgbase, user, privileged):
comaintainers = pkgbase_get_comaintainers(pkgbase)
if len(comaintainers) > 0:
new_maintainer = comaintainers[0]
cur = conn.execute("SELECT ID FROM Users WHERE Username = ?",
[new_maintainer])
cur = conn.execute(
"SELECT ID FROM Users WHERE Username = ?", [new_maintainer]
)
new_maintainer_userid = cur.fetchone()[0]
comaintainers.remove(new_maintainer)
pkgbase_set_comaintainers(pkgbase, comaintainers, user, privileged)
cur = conn.execute("UPDATE PackageBases SET MaintainerUID = ? " +
"WHERE ID = ?", [new_maintainer_userid, pkgbase_id])
cur = conn.execute(
"UPDATE PackageBases SET MaintainerUID = ? " + "WHERE ID = ?",
[new_maintainer_userid, pkgbase_id],
)
conn.commit()
@ -270,7 +303,7 @@ def pkgbase_disown(pkgbase, user, privileged):
if userid == 0:
raise aurweb.exceptions.InvalidUserException(user)
subprocess.Popen((notify_cmd, 'disown', str(userid), str(pkgbase_id)))
subprocess.Popen((notify_cmd, "disown", str(userid), str(pkgbase_id)))
conn.close()
@ -290,14 +323,16 @@ def pkgbase_flag(pkgbase, user, comment):
raise aurweb.exceptions.InvalidUserException(user)
now = int(time.time())
conn.execute("UPDATE PackageBases SET " +
"OutOfDateTS = ?, FlaggerUID = ?, FlaggerComment = ? " +
"WHERE ID = ? AND OutOfDateTS IS NULL",
[now, userid, comment, pkgbase_id])
conn.execute(
"UPDATE PackageBases SET "
+ "OutOfDateTS = ?, FlaggerUID = ?, FlaggerComment = ? "
+ "WHERE ID = ? AND OutOfDateTS IS NULL",
[now, userid, comment, pkgbase_id],
)
conn.commit()
subprocess.Popen((notify_cmd, 'flag', str(userid), str(pkgbase_id)))
subprocess.Popen((notify_cmd, "flag", str(userid), str(pkgbase_id)))
def pkgbase_unflag(pkgbase, user):
@ -313,12 +348,15 @@ def pkgbase_unflag(pkgbase, user):
raise aurweb.exceptions.InvalidUserException(user)
if user in pkgbase_get_comaintainers(pkgbase):
conn.execute("UPDATE PackageBases SET OutOfDateTS = NULL " +
"WHERE ID = ?", [pkgbase_id])
conn.execute(
"UPDATE PackageBases SET OutOfDateTS = NULL " + "WHERE ID = ?", [pkgbase_id]
)
else:
conn.execute("UPDATE PackageBases SET OutOfDateTS = NULL " +
"WHERE ID = ? AND (MaintainerUID = ? OR FlaggerUID = ?)",
[pkgbase_id, userid, userid])
conn.execute(
"UPDATE PackageBases SET OutOfDateTS = NULL "
+ "WHERE ID = ? AND (MaintainerUID = ? OR FlaggerUID = ?)",
[pkgbase_id, userid, userid],
)
conn.commit()
@ -335,17 +373,24 @@ def pkgbase_vote(pkgbase, user):
if userid == 0:
raise aurweb.exceptions.InvalidUserException(user)
cur = conn.execute("SELECT COUNT(*) FROM PackageVotes " +
"WHERE UsersID = ? AND PackageBaseID = ?",
[userid, pkgbase_id])
cur = conn.execute(
"SELECT COUNT(*) FROM PackageVotes "
+ "WHERE UsersID = ? AND PackageBaseID = ?",
[userid, pkgbase_id],
)
if cur.fetchone()[0] > 0:
raise aurweb.exceptions.AlreadyVotedException(pkgbase)
now = int(time.time())
conn.execute("INSERT INTO PackageVotes (UsersID, PackageBaseID, VoteTS) " +
"VALUES (?, ?, ?)", [userid, pkgbase_id, now])
conn.execute("UPDATE PackageBases SET NumVotes = NumVotes + 1 " +
"WHERE ID = ?", [pkgbase_id])
conn.execute(
"INSERT INTO PackageVotes (UsersID, PackageBaseID, VoteTS) "
+ "VALUES (?, ?, ?)",
[userid, pkgbase_id, now],
)
conn.execute(
"UPDATE PackageBases SET NumVotes = NumVotes + 1 " + "WHERE ID = ?",
[pkgbase_id],
)
conn.commit()
@ -361,16 +406,22 @@ def pkgbase_unvote(pkgbase, user):
if userid == 0:
raise aurweb.exceptions.InvalidUserException(user)
cur = conn.execute("SELECT COUNT(*) FROM PackageVotes " +
"WHERE UsersID = ? AND PackageBaseID = ?",
[userid, pkgbase_id])
cur = conn.execute(
"SELECT COUNT(*) FROM PackageVotes "
+ "WHERE UsersID = ? AND PackageBaseID = ?",
[userid, pkgbase_id],
)
if cur.fetchone()[0] == 0:
raise aurweb.exceptions.NotVotedException(pkgbase)
conn.execute("DELETE FROM PackageVotes WHERE UsersID = ? AND " +
"PackageBaseID = ?", [userid, pkgbase_id])
conn.execute("UPDATE PackageBases SET NumVotes = NumVotes - 1 " +
"WHERE ID = ?", [pkgbase_id])
conn.execute(
"DELETE FROM PackageVotes WHERE UsersID = ? AND " + "PackageBaseID = ?",
[userid, pkgbase_id],
)
conn.execute(
"UPDATE PackageBases SET NumVotes = NumVotes - 1 " + "WHERE ID = ?",
[pkgbase_id],
)
conn.commit()
@ -381,11 +432,12 @@ def pkgbase_set_keywords(pkgbase, keywords):
conn = aurweb.db.Connection()
conn.execute("DELETE FROM PackageKeywords WHERE PackageBaseID = ?",
[pkgbase_id])
conn.execute("DELETE FROM PackageKeywords WHERE PackageBaseID = ?", [pkgbase_id])
for keyword in keywords:
conn.execute("INSERT INTO PackageKeywords (PackageBaseID, Keyword) " +
"VALUES (?, ?)", [pkgbase_id, keyword])
conn.execute(
"INSERT INTO PackageKeywords (PackageBaseID, Keyword) " + "VALUES (?, ?)",
[pkgbase_id, keyword],
)
conn.commit()
conn.close()
@ -394,24 +446,30 @@ def pkgbase_set_keywords(pkgbase, keywords):
def pkgbase_has_write_access(pkgbase, user):
conn = aurweb.db.Connection()
cur = conn.execute("SELECT COUNT(*) FROM PackageBases " +
"LEFT JOIN PackageComaintainers " +
"ON PackageComaintainers.PackageBaseID = PackageBases.ID " +
"INNER JOIN Users " +
"ON Users.ID = PackageBases.MaintainerUID " +
"OR PackageBases.MaintainerUID IS NULL " +
"OR Users.ID = PackageComaintainers.UsersID " +
"WHERE Name = ? AND Username = ?", [pkgbase, user])
cur = conn.execute(
"SELECT COUNT(*) FROM PackageBases "
+ "LEFT JOIN PackageComaintainers "
+ "ON PackageComaintainers.PackageBaseID = PackageBases.ID "
+ "INNER JOIN Users "
+ "ON Users.ID = PackageBases.MaintainerUID "
+ "OR PackageBases.MaintainerUID IS NULL "
+ "OR Users.ID = PackageComaintainers.UsersID "
+ "WHERE Name = ? AND Username = ?",
[pkgbase, user],
)
return cur.fetchone()[0] > 0
def pkgbase_has_full_access(pkgbase, user):
conn = aurweb.db.Connection()
cur = conn.execute("SELECT COUNT(*) FROM PackageBases " +
"INNER JOIN Users " +
"ON Users.ID = PackageBases.MaintainerUID " +
"WHERE Name = ? AND Username = ?", [pkgbase, user])
cur = conn.execute(
"SELECT COUNT(*) FROM PackageBases "
+ "INNER JOIN Users "
+ "ON Users.ID = PackageBases.MaintainerUID "
+ "WHERE Name = ? AND Username = ?",
[pkgbase, user],
)
return cur.fetchone()[0] > 0
@ -419,9 +477,11 @@ def log_ssh_login(user, remote_addr):
conn = aurweb.db.Connection()
now = int(time.time())
conn.execute("UPDATE Users SET LastSSHLogin = ?, " +
"LastSSHLoginIPAddress = ? WHERE Username = ?",
[now, remote_addr, user])
conn.execute(
"UPDATE Users SET LastSSHLogin = ?, "
+ "LastSSHLoginIPAddress = ? WHERE Username = ?",
[now, remote_addr, user],
)
conn.commit()
conn.close()
@ -430,8 +490,7 @@ def log_ssh_login(user, remote_addr):
def bans_match(remote_addr):
conn = aurweb.db.Connection()
cur = conn.execute("SELECT COUNT(*) FROM Bans WHERE IPAddress = ?",
[remote_addr])
cur = conn.execute("SELECT COUNT(*) FROM Bans WHERE IPAddress = ?", [remote_addr])
return cur.fetchone()[0] > 0
@ -458,13 +517,13 @@ def usage(cmds):
def checkarg_atleast(cmdargv, *argdesc):
if len(cmdargv) - 1 < len(argdesc):
msg = 'missing {:s}'.format(argdesc[len(cmdargv) - 1])
msg = "missing {:s}".format(argdesc[len(cmdargv) - 1])
raise aurweb.exceptions.InvalidArgumentsException(msg)
def checkarg_atmost(cmdargv, *argdesc):
if len(cmdargv) - 1 > len(argdesc):
raise aurweb.exceptions.InvalidArgumentsException('too many arguments')
raise aurweb.exceptions.InvalidArgumentsException("too many arguments")
def checkarg(cmdargv, *argdesc):
@ -480,23 +539,23 @@ def serve(action, cmdargv, user, privileged, remote_addr): # noqa: C901
raise aurweb.exceptions.BannedException
log_ssh_login(user, remote_addr)
if action == 'git' and cmdargv[1] in ('upload-pack', 'receive-pack'):
action = action + '-' + cmdargv[1]
if action == "git" and cmdargv[1] in ("upload-pack", "receive-pack"):
action = action + "-" + cmdargv[1]
del cmdargv[1]
if action == 'git-upload-pack' or action == 'git-receive-pack':
checkarg(cmdargv, 'path')
if action == "git-upload-pack" or action == "git-receive-pack":
checkarg(cmdargv, "path")
path = cmdargv[1].rstrip('/')
if not path.startswith('/'):
path = '/' + path
if not path.endswith('.git'):
path = path + '.git'
path = cmdargv[1].rstrip("/")
if not path.startswith("/"):
path = "/" + path
if not path.endswith(".git"):
path = path + ".git"
pkgbase = path[1:-4]
if not re.match(repo_regex, pkgbase):
raise aurweb.exceptions.InvalidRepositoryNameException(pkgbase)
if action == 'git-receive-pack' and pkgbase_exists(pkgbase):
if action == "git-receive-pack" and pkgbase_exists(pkgbase):
if not privileged and not pkgbase_has_write_access(pkgbase, user):
raise aurweb.exceptions.PermissionDeniedException(user)
@ -507,65 +566,67 @@ def serve(action, cmdargv, user, privileged, remote_addr): # noqa: C901
os.environ["AUR_PKGBASE"] = pkgbase
os.environ["GIT_NAMESPACE"] = pkgbase
cmd = action + " '" + repo_path + "'"
os.execl(git_shell_cmd, git_shell_cmd, '-c', cmd)
elif action == 'set-keywords':
checkarg_atleast(cmdargv, 'repository name')
os.execl(git_shell_cmd, git_shell_cmd, "-c", cmd)
elif action == "set-keywords":
checkarg_atleast(cmdargv, "repository name")
pkgbase_set_keywords(cmdargv[1], cmdargv[2:])
elif action == 'list-repos':
elif action == "list-repos":
checkarg(cmdargv)
list_repos(user)
elif action == 'setup-repo':
checkarg(cmdargv, 'repository name')
warn('{:s} is deprecated. '
'Use `git push` to create new repositories.'.format(action))
elif action == "setup-repo":
checkarg(cmdargv, "repository name")
warn(
"{:s} is deprecated. "
"Use `git push` to create new repositories.".format(action)
)
create_pkgbase(cmdargv[1], user)
elif action == 'restore':
checkarg(cmdargv, 'repository name')
elif action == "restore":
checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1]
create_pkgbase(pkgbase, user)
os.environ["AUR_USER"] = user
os.environ["AUR_PKGBASE"] = pkgbase
os.execl(git_update_cmd, git_update_cmd, 'restore')
elif action == 'adopt':
checkarg(cmdargv, 'repository name')
os.execl(git_update_cmd, git_update_cmd, "restore")
elif action == "adopt":
checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1]
pkgbase_adopt(pkgbase, user, privileged)
elif action == 'disown':
checkarg(cmdargv, 'repository name')
elif action == "disown":
checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1]
pkgbase_disown(pkgbase, user, privileged)
elif action == 'flag':
checkarg(cmdargv, 'repository name', 'comment')
elif action == "flag":
checkarg(cmdargv, "repository name", "comment")
pkgbase = cmdargv[1]
comment = cmdargv[2]
pkgbase_flag(pkgbase, user, comment)
elif action == 'unflag':
checkarg(cmdargv, 'repository name')
elif action == "unflag":
checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1]
pkgbase_unflag(pkgbase, user)
elif action == 'vote':
checkarg(cmdargv, 'repository name')
elif action == "vote":
checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1]
pkgbase_vote(pkgbase, user)
elif action == 'unvote':
checkarg(cmdargv, 'repository name')
elif action == "unvote":
checkarg(cmdargv, "repository name")
pkgbase = cmdargv[1]
pkgbase_unvote(pkgbase, user)
elif action == 'set-comaintainers':
checkarg_atleast(cmdargv, 'repository name')
elif action == "set-comaintainers":
checkarg_atleast(cmdargv, "repository name")
pkgbase = cmdargv[1]
userlist = cmdargv[2:]
pkgbase_set_comaintainers(pkgbase, userlist, user, privileged)
elif action == 'help':
elif action == "help":
cmds = {
"adopt <name>": "Adopt a package base.",
"disown <name>": "Disown a package base.",
@ -584,21 +645,21 @@ def serve(action, cmdargv, user, privileged, remote_addr): # noqa: C901
}
usage(cmds)
else:
msg = 'invalid command: {:s}'.format(action)
msg = "invalid command: {:s}".format(action)
raise aurweb.exceptions.InvalidArgumentsException(msg)
def main():
user = os.environ.get('AUR_USER')
privileged = (os.environ.get('AUR_PRIVILEGED', '0') == '1')
ssh_cmd = os.environ.get('SSH_ORIGINAL_COMMAND')
ssh_client = os.environ.get('SSH_CLIENT')
user = os.environ.get("AUR_USER")
privileged = os.environ.get("AUR_PRIVILEGED", "0") == "1"
ssh_cmd = os.environ.get("SSH_ORIGINAL_COMMAND")
ssh_client = os.environ.get("SSH_CLIENT")
if not ssh_cmd:
die_with_help("Interactive shell is disabled.")
cmdargv = shlex.split(ssh_cmd)
action = cmdargv[0]
remote_addr = ssh_client.split(' ')[0] if ssh_client else None
remote_addr = ssh_client.split(" ")[0] if ssh_client else None
try:
serve(action, cmdargv, user, privileged, remote_addr)
@ -607,10 +668,10 @@ def main():
except aurweb.exceptions.BannedException:
die("The SSH interface is disabled for your IP address.")
except aurweb.exceptions.InvalidArgumentsException as e:
die_with_help('{:s}: {}'.format(action, e))
die_with_help("{:s}: {}".format(action, e))
except aurweb.exceptions.AurwebException as e:
die('{:s}: {}'.format(action, e))
die("{:s}: {}".format(action, e))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -13,23 +13,23 @@ import srcinfo.utils
import aurweb.config
import aurweb.db
notify_cmd = aurweb.config.get('notifications', 'notify-cmd')
notify_cmd = aurweb.config.get("notifications", "notify-cmd")
repo_path = aurweb.config.get('serve', 'repo-path')
repo_regex = aurweb.config.get('serve', 'repo-regex')
repo_path = aurweb.config.get("serve", "repo-path")
repo_regex = aurweb.config.get("serve", "repo-regex")
max_blob_size = aurweb.config.getint('update', 'max-blob-size')
max_blob_size = aurweb.config.getint("update", "max-blob-size")
def size_humanize(num):
for unit in ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB']:
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB"]:
if abs(num) < 2048.0:
if isinstance(num, int):
return "{}{}".format(num, unit)
else:
return "{:.2f}{}".format(num, unit)
num /= 1024.0
return "{:.2f}{}".format(num, 'YiB')
return "{:.2f}{}".format(num, "YiB")
def extract_arch_fields(pkginfo, field):
@ -39,17 +39,17 @@ def extract_arch_fields(pkginfo, field):
for val in pkginfo[field]:
values.append({"value": val, "arch": None})
for arch in pkginfo['arch']:
if field + '_' + arch in pkginfo:
for val in pkginfo[field + '_' + arch]:
for arch in pkginfo["arch"]:
if field + "_" + arch in pkginfo:
for val in pkginfo[field + "_" + arch]:
values.append({"value": val, "arch": arch})
return values
def parse_dep(depstring):
dep, _, desc = depstring.partition(': ')
depname = re.sub(r'(<|=|>).*', '', dep)
dep, _, desc = depstring.partition(": ")
depname = re.sub(r"(<|=|>).*", "", dep)
depcond = dep[len(depname) :]
return (depname, desc, depcond)
@ -60,15 +60,18 @@ def create_pkgbase(conn, pkgbase, user):
userid = cur.fetchone()[0]
now = int(time.time())
cur = conn.execute("INSERT INTO PackageBases (Name, SubmittedTS, " +
"ModifiedTS, SubmitterUID, MaintainerUID, " +
"FlaggerComment) VALUES (?, ?, ?, ?, ?, '')",
[pkgbase, now, now, userid, userid])
cur = conn.execute(
"INSERT INTO PackageBases (Name, SubmittedTS, "
+ "ModifiedTS, SubmitterUID, MaintainerUID, "
+ "FlaggerComment) VALUES (?, ?, ?, ?, ?, '')",
[pkgbase, now, now, userid, userid],
)
pkgbase_id = cur.lastrowid
cur = conn.execute("INSERT INTO PackageNotifications " +
"(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, userid])
cur = conn.execute(
"INSERT INTO PackageNotifications " + "(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, userid],
)
conn.commit()
@ -77,9 +80,10 @@ def create_pkgbase(conn, pkgbase, user):
def save_metadata(metadata, conn, user): # noqa: C901
# Obtain package base ID and previous maintainer.
pkgbase = metadata['pkgbase']
cur = conn.execute("SELECT ID, MaintainerUID FROM PackageBases "
"WHERE Name = ?", [pkgbase])
pkgbase = metadata["pkgbase"]
cur = conn.execute(
"SELECT ID, MaintainerUID FROM PackageBases " "WHERE Name = ?", [pkgbase]
)
(pkgbase_id, maintainer_uid) = cur.fetchone()
was_orphan = not maintainer_uid
@ -89,119 +93,142 @@ def save_metadata(metadata, conn, user): # noqa: C901
# Update package base details and delete current packages.
now = int(time.time())
conn.execute("UPDATE PackageBases SET ModifiedTS = ?, " +
"PackagerUID = ?, OutOfDateTS = NULL WHERE ID = ?",
[now, user_id, pkgbase_id])
conn.execute("UPDATE PackageBases SET MaintainerUID = ? " +
"WHERE ID = ? AND MaintainerUID IS NULL",
[user_id, pkgbase_id])
for table in ('Sources', 'Depends', 'Relations', 'Licenses', 'Groups'):
conn.execute("DELETE FROM Package" + table + " WHERE EXISTS (" +
"SELECT * FROM Packages " +
"WHERE Packages.PackageBaseID = ? AND " +
"Package" + table + ".PackageID = Packages.ID)",
[pkgbase_id])
conn.execute(
"UPDATE PackageBases SET ModifiedTS = ?, "
+ "PackagerUID = ?, OutOfDateTS = NULL WHERE ID = ?",
[now, user_id, pkgbase_id],
)
conn.execute(
"UPDATE PackageBases SET MaintainerUID = ? "
+ "WHERE ID = ? AND MaintainerUID IS NULL",
[user_id, pkgbase_id],
)
for table in ("Sources", "Depends", "Relations", "Licenses", "Groups"):
conn.execute(
"DELETE FROM Package"
+ table
+ " WHERE EXISTS ("
+ "SELECT * FROM Packages "
+ "WHERE Packages.PackageBaseID = ? AND "
+ "Package"
+ table
+ ".PackageID = Packages.ID)",
[pkgbase_id],
)
conn.execute("DELETE FROM Packages WHERE PackageBaseID = ?", [pkgbase_id])
for pkgname in srcinfo.utils.get_package_names(metadata):
pkginfo = srcinfo.utils.get_merged_package(pkgname, metadata)
if 'epoch' in pkginfo and int(pkginfo['epoch']) > 0:
ver = '{:d}:{:s}-{:s}'.format(int(pkginfo['epoch']),
pkginfo['pkgver'],
pkginfo['pkgrel'])
if "epoch" in pkginfo and int(pkginfo["epoch"]) > 0:
ver = "{:d}:{:s}-{:s}".format(
int(pkginfo["epoch"]), pkginfo["pkgver"], pkginfo["pkgrel"]
)
else:
ver = '{:s}-{:s}'.format(pkginfo['pkgver'], pkginfo['pkgrel'])
ver = "{:s}-{:s}".format(pkginfo["pkgver"], pkginfo["pkgrel"])
for field in ('pkgdesc', 'url'):
for field in ("pkgdesc", "url"):
if field not in pkginfo:
pkginfo[field] = None
# Create a new package.
cur = conn.execute("INSERT INTO Packages (PackageBaseID, Name, " +
"Version, Description, URL) " +
"VALUES (?, ?, ?, ?, ?)",
[pkgbase_id, pkginfo['pkgname'], ver,
pkginfo['pkgdesc'], pkginfo['url']])
cur = conn.execute(
"INSERT INTO Packages (PackageBaseID, Name, "
+ "Version, Description, URL) "
+ "VALUES (?, ?, ?, ?, ?)",
[pkgbase_id, pkginfo["pkgname"], ver, pkginfo["pkgdesc"], pkginfo["url"]],
)
conn.commit()
pkgid = cur.lastrowid
# Add package sources.
for source_info in extract_arch_fields(pkginfo, 'source'):
conn.execute("INSERT INTO PackageSources (PackageID, Source, " +
"SourceArch) VALUES (?, ?, ?)",
[pkgid, source_info['value'], source_info['arch']])
for source_info in extract_arch_fields(pkginfo, "source"):
conn.execute(
"INSERT INTO PackageSources (PackageID, Source, "
+ "SourceArch) VALUES (?, ?, ?)",
[pkgid, source_info["value"], source_info["arch"]],
)
# Add package dependencies.
for deptype in ('depends', 'makedepends',
'checkdepends', 'optdepends'):
cur = conn.execute("SELECT ID FROM DependencyTypes WHERE Name = ?",
[deptype])
for deptype in ("depends", "makedepends", "checkdepends", "optdepends"):
cur = conn.execute(
"SELECT ID FROM DependencyTypes WHERE Name = ?", [deptype]
)
deptypeid = cur.fetchone()[0]
for dep_info in extract_arch_fields(pkginfo, deptype):
depname, depdesc, depcond = parse_dep(dep_info['value'])
deparch = dep_info['arch']
conn.execute("INSERT INTO PackageDepends (PackageID, " +
"DepTypeID, DepName, DepDesc, DepCondition, " +
"DepArch) VALUES (?, ?, ?, ?, ?, ?)",
[pkgid, deptypeid, depname, depdesc, depcond,
deparch])
depname, depdesc, depcond = parse_dep(dep_info["value"])
deparch = dep_info["arch"]
conn.execute(
"INSERT INTO PackageDepends (PackageID, "
+ "DepTypeID, DepName, DepDesc, DepCondition, "
+ "DepArch) VALUES (?, ?, ?, ?, ?, ?)",
[pkgid, deptypeid, depname, depdesc, depcond, deparch],
)
# Add package relations (conflicts, provides, replaces).
for reltype in ('conflicts', 'provides', 'replaces'):
cur = conn.execute("SELECT ID FROM RelationTypes WHERE Name = ?",
[reltype])
for reltype in ("conflicts", "provides", "replaces"):
cur = conn.execute("SELECT ID FROM RelationTypes WHERE Name = ?", [reltype])
reltypeid = cur.fetchone()[0]
for rel_info in extract_arch_fields(pkginfo, reltype):
relname, _, relcond = parse_dep(rel_info['value'])
relarch = rel_info['arch']
conn.execute("INSERT INTO PackageRelations (PackageID, " +
"RelTypeID, RelName, RelCondition, RelArch) " +
"VALUES (?, ?, ?, ?, ?)",
[pkgid, reltypeid, relname, relcond, relarch])
relname, _, relcond = parse_dep(rel_info["value"])
relarch = rel_info["arch"]
conn.execute(
"INSERT INTO PackageRelations (PackageID, "
+ "RelTypeID, RelName, RelCondition, RelArch) "
+ "VALUES (?, ?, ?, ?, ?)",
[pkgid, reltypeid, relname, relcond, relarch],
)
# Add package licenses.
if 'license' in pkginfo:
for license in pkginfo['license']:
cur = conn.execute("SELECT ID FROM Licenses WHERE Name = ?",
[license])
if "license" in pkginfo:
for license in pkginfo["license"]:
cur = conn.execute("SELECT ID FROM Licenses WHERE Name = ?", [license])
row = cur.fetchone()
if row:
licenseid = row[0]
else:
cur = conn.execute("INSERT INTO Licenses (Name) " +
"VALUES (?)", [license])
cur = conn.execute(
"INSERT INTO Licenses (Name) " + "VALUES (?)", [license]
)
conn.commit()
licenseid = cur.lastrowid
conn.execute("INSERT INTO PackageLicenses (PackageID, " +
"LicenseID) VALUES (?, ?)",
[pkgid, licenseid])
conn.execute(
"INSERT INTO PackageLicenses (PackageID, "
+ "LicenseID) VALUES (?, ?)",
[pkgid, licenseid],
)
# Add package groups.
if 'groups' in pkginfo:
for group in pkginfo['groups']:
cur = conn.execute("SELECT ID FROM `Groups` WHERE Name = ?",
[group])
if "groups" in pkginfo:
for group in pkginfo["groups"]:
cur = conn.execute("SELECT ID FROM `Groups` WHERE Name = ?", [group])
row = cur.fetchone()
if row:
groupid = row[0]
else:
cur = conn.execute("INSERT INTO `Groups` (Name) VALUES (?)",
[group])
cur = conn.execute(
"INSERT INTO `Groups` (Name) VALUES (?)", [group]
)
conn.commit()
groupid = cur.lastrowid
conn.execute("INSERT INTO PackageGroups (PackageID, "
"GroupID) VALUES (?, ?)", [pkgid, groupid])
conn.execute(
"INSERT INTO PackageGroups (PackageID, " "GroupID) VALUES (?, ?)",
[pkgid, groupid],
)
# Add user to notification list on adoption.
if was_orphan:
cur = conn.execute("SELECT COUNT(*) FROM PackageNotifications WHERE " +
"PackageBaseID = ? AND UserID = ?",
[pkgbase_id, user_id])
cur = conn.execute(
"SELECT COUNT(*) FROM PackageNotifications WHERE "
+ "PackageBaseID = ? AND UserID = ?",
[pkgbase_id, user_id],
)
if cur.fetchone()[0] == 0:
conn.execute("INSERT INTO PackageNotifications " +
"(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, user_id])
conn.execute(
"INSERT INTO PackageNotifications "
+ "(PackageBaseID, UserID) VALUES (?, ?)",
[pkgbase_id, user_id],
)
conn.commit()
@ -212,7 +239,7 @@ def update_notify(conn, user, pkgbase_id):
user_id = int(cur.fetchone()[0])
# Execute the notification script.
subprocess.Popen((notify_cmd, 'update', str(user_id), str(pkgbase_id)))
subprocess.Popen((notify_cmd, "update", str(user_id), str(pkgbase_id)))
def die(msg):
@ -225,8 +252,7 @@ def warn(msg):
def die_commit(msg, commit):
sys.stderr.write("error: The following error " +
"occurred when parsing commit\n")
sys.stderr.write("error: The following error " + "occurred when parsing commit\n")
sys.stderr.write("error: {:s}:\n".format(commit))
sys.stderr.write("error: {:s}\n".format(msg))
exit(1)
@ -237,16 +263,15 @@ def main(): # noqa: C901
user = os.environ.get("AUR_USER")
pkgbase = os.environ.get("AUR_PKGBASE")
privileged = (os.environ.get("AUR_PRIVILEGED", '0') == '1')
allow_overwrite = (os.environ.get("AUR_OVERWRITE", '0') == '1') and privileged
privileged = os.environ.get("AUR_PRIVILEGED", "0") == "1"
allow_overwrite = (os.environ.get("AUR_OVERWRITE", "0") == "1") and privileged
warn_or_die = warn if privileged else die
if len(sys.argv) == 2 and sys.argv[1] == "restore":
if 'refs/heads/' + pkgbase not in repo.listall_references():
die('{:s}: repository not found: {:s}'.format(sys.argv[1],
pkgbase))
if "refs/heads/" + pkgbase not in repo.listall_references():
die("{:s}: repository not found: {:s}".format(sys.argv[1], pkgbase))
refname = "refs/heads/master"
branchref = 'refs/heads/' + pkgbase
branchref = "refs/heads/" + pkgbase
sha1_old = sha1_new = repo.lookup_reference(branchref).target
elif len(sys.argv) == 4:
refname, sha1_old, sha1_new = sys.argv[1:4]
@ -272,7 +297,7 @@ def main(): # noqa: C901
# Validate all new commits.
for commit in walker:
for fname in ('.SRCINFO', 'PKGBUILD'):
for fname in (".SRCINFO", "PKGBUILD"):
if fname not in commit.tree:
die_commit("missing {:s}".format(fname), str(commit.id))
@ -280,99 +305,115 @@ def main(): # noqa: C901
blob = repo[treeobj.id]
if isinstance(blob, pygit2.Tree):
die_commit("the repository must not contain subdirectories",
str(commit.id))
die_commit(
"the repository must not contain subdirectories", str(commit.id)
)
if not isinstance(blob, pygit2.Blob):
die_commit("not a blob object: {:s}".format(treeobj),
str(commit.id))
die_commit("not a blob object: {:s}".format(treeobj), str(commit.id))
if blob.size > max_blob_size:
die_commit("maximum blob size ({:s}) exceeded".format(
size_humanize(max_blob_size)), str(commit.id))
die_commit(
"maximum blob size ({:s}) exceeded".format(
size_humanize(max_blob_size)
),
str(commit.id),
)
metadata_raw = repo[commit.tree['.SRCINFO'].id].data.decode()
metadata_raw = repo[commit.tree[".SRCINFO"].id].data.decode()
(metadata, errors) = srcinfo.parse.parse_srcinfo(metadata_raw)
if errors:
sys.stderr.write("error: The following errors occurred "
"when parsing .SRCINFO in commit\n")
sys.stderr.write(
"error: The following errors occurred "
"when parsing .SRCINFO in commit\n"
)
sys.stderr.write("error: {:s}:\n".format(str(commit.id)))
for error in errors:
for err in error['error']:
sys.stderr.write("error: line {:d}: {:s}\n".format(
error['line'], err))
for err in error["error"]:
sys.stderr.write(
"error: line {:d}: {:s}\n".format(error["line"], err)
)
exit(1)
try:
metadata_pkgbase = metadata['pkgbase']
metadata_pkgbase = metadata["pkgbase"]
except KeyError:
die_commit('invalid .SRCINFO, does not contain a pkgbase (is the file empty?)',
str(commit.id))
die_commit(
"invalid .SRCINFO, does not contain a pkgbase (is the file empty?)",
str(commit.id),
)
if not re.match(repo_regex, metadata_pkgbase):
die_commit('invalid pkgbase: {:s}'.format(metadata_pkgbase),
str(commit.id))
die_commit("invalid pkgbase: {:s}".format(metadata_pkgbase), str(commit.id))
if not metadata['packages']:
die_commit('missing pkgname entry', str(commit.id))
if not metadata["packages"]:
die_commit("missing pkgname entry", str(commit.id))
for pkgname in set(metadata['packages'].keys()):
for pkgname in set(metadata["packages"].keys()):
pkginfo = srcinfo.utils.get_merged_package(pkgname, metadata)
for field in ('pkgver', 'pkgrel', 'pkgname'):
for field in ("pkgver", "pkgrel", "pkgname"):
if field not in pkginfo:
die_commit('missing mandatory field: {:s}'.format(field),
str(commit.id))
die_commit(
"missing mandatory field: {:s}".format(field), str(commit.id)
)
if 'epoch' in pkginfo and not pkginfo['epoch'].isdigit():
die_commit('invalid epoch: {:s}'.format(pkginfo['epoch']),
str(commit.id))
if "epoch" in pkginfo and not pkginfo["epoch"].isdigit():
die_commit(
"invalid epoch: {:s}".format(pkginfo["epoch"]), str(commit.id)
)
if not re.match(r'[a-z0-9][a-z0-9\.+_-]*$', pkginfo['pkgname']):
die_commit('invalid package name: {:s}'.format(
pkginfo['pkgname']), str(commit.id))
if not re.match(r"[a-z0-9][a-z0-9\.+_-]*$", pkginfo["pkgname"]):
die_commit(
"invalid package name: {:s}".format(pkginfo["pkgname"]),
str(commit.id),
)
max_len = {'pkgname': 255, 'pkgdesc': 255, 'url': 8000}
max_len = {"pkgname": 255, "pkgdesc": 255, "url": 8000}
for field in max_len.keys():
if field in pkginfo and len(pkginfo[field]) > max_len[field]:
die_commit('{:s} field too long: {:s}'.format(field,
pkginfo[field]), str(commit.id))
die_commit(
"{:s} field too long: {:s}".format(field, pkginfo[field]),
str(commit.id),
)
for field in ('install', 'changelog'):
for field in ("install", "changelog"):
if field in pkginfo and not pkginfo[field] in commit.tree:
die_commit('missing {:s} file: {:s}'.format(field,
pkginfo[field]), str(commit.id))
die_commit(
"missing {:s} file: {:s}".format(field, pkginfo[field]),
str(commit.id),
)
for field in extract_arch_fields(pkginfo, 'source'):
fname = field['value']
for field in extract_arch_fields(pkginfo, "source"):
fname = field["value"]
if len(fname) > 8000:
die_commit('source entry too long: {:s}'.format(fname),
str(commit.id))
die_commit(
"source entry too long: {:s}".format(fname), str(commit.id)
)
if "://" in fname or "lp:" in fname:
continue
if fname not in commit.tree:
die_commit('missing source file: {:s}'.format(fname),
str(commit.id))
die_commit(
"missing source file: {:s}".format(fname), str(commit.id)
)
# Display a warning if .SRCINFO is unchanged.
if sha1_old not in ("0000000000000000000000000000000000000000", sha1_new):
srcinfo_id_old = repo[sha1_old].tree['.SRCINFO'].id
srcinfo_id_new = repo[sha1_new].tree['.SRCINFO'].id
srcinfo_id_old = repo[sha1_old].tree[".SRCINFO"].id
srcinfo_id_new = repo[sha1_new].tree[".SRCINFO"].id
if srcinfo_id_old == srcinfo_id_new:
warn(".SRCINFO unchanged. "
"The package database will not be updated!")
warn(".SRCINFO unchanged. " "The package database will not be updated!")
# Read .SRCINFO from the HEAD commit.
metadata_raw = repo[repo[sha1_new].tree['.SRCINFO'].id].data.decode()
metadata_raw = repo[repo[sha1_new].tree[".SRCINFO"].id].data.decode()
(metadata, errors) = srcinfo.parse.parse_srcinfo(metadata_raw)
# Ensure that the package base name matches the repository name.
metadata_pkgbase = metadata['pkgbase']
metadata_pkgbase = metadata["pkgbase"]
if metadata_pkgbase != pkgbase:
die('invalid pkgbase: {:s}, expected {:s}'.format(metadata_pkgbase,
pkgbase))
die("invalid pkgbase: {:s}, expected {:s}".format(metadata_pkgbase, pkgbase))
# Ensure that packages are neither blacklisted nor overwritten.
pkgbase = metadata['pkgbase']
pkgbase = metadata["pkgbase"]
cur = conn.execute("SELECT ID FROM PackageBases WHERE Name = ?", [pkgbase])
row = cur.fetchone()
pkgbase_id = row[0] if row else 0
@ -385,18 +426,23 @@ def main(): # noqa: C901
for pkgname in srcinfo.utils.get_package_names(metadata):
pkginfo = srcinfo.utils.get_merged_package(pkgname, metadata)
pkgname = pkginfo['pkgname']
pkgname = pkginfo["pkgname"]
if pkgname in blacklist:
warn_or_die('package is blacklisted: {:s}'.format(pkgname))
warn_or_die("package is blacklisted: {:s}".format(pkgname))
if pkgname in providers:
warn_or_die('package already provided by [{:s}]: {:s}'.format(
providers[pkgname], pkgname))
warn_or_die(
"package already provided by [{:s}]: {:s}".format(
providers[pkgname], pkgname
)
)
cur = conn.execute("SELECT COUNT(*) FROM Packages WHERE Name = ? " +
"AND PackageBaseID <> ?", [pkgname, pkgbase_id])
cur = conn.execute(
"SELECT COUNT(*) FROM Packages WHERE Name = ? " + "AND PackageBaseID <> ?",
[pkgname, pkgbase_id],
)
if cur.fetchone()[0] > 0:
die('cannot overwrite package: {:s}'.format(pkgname))
die("cannot overwrite package: {:s}".format(pkgname))
# Create a new package base if it does not exist yet.
if pkgbase_id == 0:
@ -407,7 +453,7 @@ def main(): # noqa: C901
# Create (or update) a branch with the name of the package base for better
# accessibility.
branchref = 'refs/heads/' + pkgbase
branchref = "refs/heads/" + pkgbase
repo.create_reference(branchref, sha1_new, True)
# Work around a Git bug: The HEAD ref is not updated when using
@ -415,7 +461,7 @@ def main(): # noqa: C901
# mainline. See
# http://git.661346.n2.nabble.com/PATCH-receive-pack-Create-a-HEAD-ref-for-ref-namespace-td7632149.html
# for details.
headref = 'refs/namespaces/' + pkgbase + '/HEAD'
headref = "refs/namespaces/" + pkgbase + "/HEAD"
repo.create_reference(headref, sha1_new, True)
# Send package update notifications.
@ -426,5 +472,5 @@ def main(): # noqa: C901
conn.close()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -9,28 +9,40 @@ import aurweb.schema
def feed_initial_data(conn):
conn.execute(aurweb.schema.AccountTypes.insert(), [
{'ID': 1, 'AccountType': 'User'},
{'ID': 2, 'AccountType': 'Trusted User'},
{'ID': 3, 'AccountType': 'Developer'},
{'ID': 4, 'AccountType': 'Trusted User & Developer'},
])
conn.execute(aurweb.schema.DependencyTypes.insert(), [
{'ID': 1, 'Name': 'depends'},
{'ID': 2, 'Name': 'makedepends'},
{'ID': 3, 'Name': 'checkdepends'},
{'ID': 4, 'Name': 'optdepends'},
])
conn.execute(aurweb.schema.RelationTypes.insert(), [
{'ID': 1, 'Name': 'conflicts'},
{'ID': 2, 'Name': 'provides'},
{'ID': 3, 'Name': 'replaces'},
])
conn.execute(aurweb.schema.RequestTypes.insert(), [
{'ID': 1, 'Name': 'deletion'},
{'ID': 2, 'Name': 'orphan'},
{'ID': 3, 'Name': 'merge'},
])
conn.execute(
aurweb.schema.AccountTypes.insert(),
[
{"ID": 1, "AccountType": "User"},
{"ID": 2, "AccountType": "Trusted User"},
{"ID": 3, "AccountType": "Developer"},
{"ID": 4, "AccountType": "Trusted User & Developer"},
],
)
conn.execute(
aurweb.schema.DependencyTypes.insert(),
[
{"ID": 1, "Name": "depends"},
{"ID": 2, "Name": "makedepends"},
{"ID": 3, "Name": "checkdepends"},
{"ID": 4, "Name": "optdepends"},
],
)
conn.execute(
aurweb.schema.RelationTypes.insert(),
[
{"ID": 1, "Name": "conflicts"},
{"ID": 2, "Name": "provides"},
{"ID": 3, "Name": "replaces"},
],
)
conn.execute(
aurweb.schema.RequestTypes.insert(),
[
{"ID": 1, "Name": "deletion"},
{"ID": 2, "Name": "orphan"},
{"ID": 3, "Name": "merge"},
],
)
def run(args):
@ -40,8 +52,8 @@ def run(args):
# the last step and leave the database in an inconsistent state. The
# configuration is loaded lazily, so we query it to force its loading.
if args.use_alembic:
alembic_config = alembic.config.Config('alembic.ini')
alembic_config.get_main_option('script_location')
alembic_config = alembic.config.Config("alembic.ini")
alembic_config.get_main_option("script_location")
alembic_config.attributes["configure_logger"] = False
engine = aurweb.db.get_engine(echo=(args.verbose >= 1))
@ -51,17 +63,21 @@ def run(args):
conn.close()
if args.use_alembic:
alembic.command.stamp(alembic_config, 'head')
alembic.command.stamp(alembic_config, "head")
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='python -m aurweb.initdb',
description='Initialize the aurweb database.')
parser.add_argument('-v', '--verbose', action='count', default=0,
help='increase verbosity')
parser.add_argument('--no-alembic',
help='disable Alembic migrations support',
dest='use_alembic', action='store_false')
prog="python -m aurweb.initdb", description="Initialize the aurweb database."
)
parser.add_argument(
"-v", "--verbose", action="count", default=0, help="increase verbosity"
)
parser.add_argument(
"--no-alembic",
help="disable Alembic migrations support",
dest="use_alembic",
action="store_false",
)
args = parser.parse_args()
run(args)

View file

@ -1,12 +1,12 @@
import gettext
from collections import OrderedDict
from fastapi import Request
import aurweb.config
SUPPORTED_LANGUAGES = OrderedDict({
SUPPORTED_LANGUAGES = OrderedDict(
{
"ar": "العربية",
"ast": "Asturianu",
"ca": "Català",
@ -36,8 +36,9 @@ SUPPORTED_LANGUAGES = OrderedDict({
"tr": "Türkçe",
"uk": "Українська",
"zh_CN": "简体中文",
"zh_TW": "正體中文"
})
"zh_TW": "正體中文",
}
)
RIGHT_TO_LEFT_LANGUAGES = ("he", "ar")
@ -45,15 +46,14 @@ RIGHT_TO_LEFT_LANGUAGES = ("he", "ar")
class Translator:
def __init__(self):
self._localedir = aurweb.config.get('options', 'localedir')
self._localedir = aurweb.config.get("options", "localedir")
self._translator = {}
def get_translator(self, lang: str):
if lang not in self._translator:
self._translator[lang] = gettext.translation("aurweb",
self._localedir,
languages=[lang],
fallback=True)
self._translator[lang] = gettext.translation(
"aurweb", self._localedir, languages=[lang], fallback=True
)
return self._translator.get(lang)
def translate(self, s: str, lang: str):

View file

@ -13,12 +13,16 @@ class AcceptedTerm(Base):
__mapper_args__ = {"primary_key": [__table__.c.TermsID]}
User = relationship(
_User, backref=backref("accepted_terms", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID])
_User,
backref=backref("accepted_terms", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID],
)
Term = relationship(
_Term, backref=backref("accepted_terms", lazy="dynamic"),
foreign_keys=[__table__.c.TermsID])
_Term,
backref=backref("accepted_terms", lazy="dynamic"),
foreign_keys=[__table__.c.TermsID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -27,10 +31,12 @@ class AcceptedTerm(Base):
raise IntegrityError(
statement="Foreign key UsersID cannot be null.",
orig="AcceptedTerms.UserID",
params=("NULL"))
params=("NULL"),
)
if not self.Term and not self.TermsID:
raise IntegrityError(
statement="Foreign key TermID cannot be null.",
orig="AcceptedTerms.TermID",
params=("NULL"))
params=("NULL"),
)

View file

@ -16,7 +16,7 @@ ACCOUNT_TYPE_ID = {
USER: USER_ID,
TRUSTED_USER: TRUSTED_USER_ID,
DEVELOPER: DEVELOPER_ID,
TRUSTED_USER_AND_DEV: TRUSTED_USER_AND_DEV_ID
TRUSTED_USER_AND_DEV: TRUSTED_USER_AND_DEV_ID,
}
# Reversed ACCOUNT_TYPE_ID mapping.
@ -25,6 +25,7 @@ ACCOUNT_TYPE_NAME = {v: k for k, v in ACCOUNT_TYPE_ID.items()}
class AccountType(Base):
"""An ORM model of a single AccountTypes record."""
__table__ = schema.AccountTypes
__tablename__ = __table__.name
__mapper_args__ = {"primary_key": [__table__.c.ID]}
@ -36,5 +37,4 @@ class AccountType(Base):
return str(self.AccountType)
def __repr__(self):
return "<AccountType(ID='%s', AccountType='%s')>" % (
self.ID, str(self))
return "<AccountType(ID='%s', AccountType='%s')>" % (self.ID, str(self))

View file

@ -16,10 +16,12 @@ class ApiRateLimit(Base):
raise IntegrityError(
statement="Column Requests cannot be null.",
orig="ApiRateLimit.Requests",
params=("NULL"))
params=("NULL"),
)
if self.WindowStart is None:
raise IntegrityError(
statement="Column WindowStart cannot be null.",
orig="ApiRateLimit.WindowStart",
params=("NULL"))
params=("NULL"),
)

View file

@ -6,26 +6,19 @@ from aurweb import util
def to_dict(model):
return {
c.name: getattr(model, c.name)
for c in model.__table__.columns
}
return {c.name: getattr(model, c.name) for c in model.__table__.columns}
def to_json(model, indent: int = None):
return json.dumps({
k: util.jsonify(v)
for k, v in to_dict(model).items()
}, indent=indent)
return json.dumps(
{k: util.jsonify(v) for k, v in to_dict(model).items()}, indent=indent
)
Base = declarative_base()
# Setup __table_args__ applicable to every table.
Base.__table_args__ = {
"autoload": False,
"extend_existing": True
}
Base.__table_args__ = {"autoload": False, "extend_existing": True}
# Setup Base.as_dict and Base.json.
#

View file

@ -15,4 +15,5 @@ class Group(Base):
raise IntegrityError(
statement="Column Name cannot be null.",
orig="Groups.Name",
params=("NULL"))
params=("NULL"),
)

View file

@ -16,4 +16,5 @@ class License(Base):
raise IntegrityError(
statement="Column Name cannot be null.",
orig="Licenses.Name",
params=("NULL"))
params=("NULL"),
)

View file

@ -21,16 +21,19 @@ class OfficialProvider(Base):
raise IntegrityError(
statement="Column Name cannot be null.",
orig="OfficialProviders.Name",
params=("NULL"))
params=("NULL"),
)
if not self.Repo:
raise IntegrityError(
statement="Column Repo cannot be null.",
orig="OfficialProviders.Repo",
params=("NULL"))
params=("NULL"),
)
if not self.Provides:
raise IntegrityError(
statement="Column Provides cannot be null.",
orig="OfficialProviders.Provides",
params=("NULL"))
params=("NULL"),
)

View file

@ -12,9 +12,10 @@ class Package(Base):
__mapper_args__ = {"primary_key": [__table__.c.ID]}
PackageBase = relationship(
_PackageBase, backref=backref("packages", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID])
_PackageBase,
backref=backref("packages", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID],
)
# No Package instances are official packages.
is_official = False
@ -26,10 +27,12 @@ class Package(Base):
raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.",
orig="Packages.PackageBaseID",
params=("NULL"))
params=("NULL"),
)
if self.Name is None:
raise IntegrityError(
statement="Column Name cannot be null.",
orig="Packages.Name",
params=("NULL"))
params=("NULL"),
)

View file

@ -12,20 +12,28 @@ class PackageBase(Base):
__mapper_args__ = {"primary_key": [__table__.c.ID]}
Flagger = relationship(
_User, backref=backref("flagged_bases", lazy="dynamic"),
foreign_keys=[__table__.c.FlaggerUID])
_User,
backref=backref("flagged_bases", lazy="dynamic"),
foreign_keys=[__table__.c.FlaggerUID],
)
Submitter = relationship(
_User, backref=backref("submitted_bases", lazy="dynamic"),
foreign_keys=[__table__.c.SubmitterUID])
_User,
backref=backref("submitted_bases", lazy="dynamic"),
foreign_keys=[__table__.c.SubmitterUID],
)
Maintainer = relationship(
_User, backref=backref("maintained_bases", lazy="dynamic"),
foreign_keys=[__table__.c.MaintainerUID])
_User,
backref=backref("maintained_bases", lazy="dynamic"),
foreign_keys=[__table__.c.MaintainerUID],
)
Packager = relationship(
_User, backref=backref("package_bases", lazy="dynamic"),
foreign_keys=[__table__.c.PackagerUID])
_User,
backref=backref("package_bases", lazy="dynamic"),
foreign_keys=[__table__.c.PackagerUID],
)
# A set used to check for floatable values.
TO_FLOAT = {"Popularity"}
@ -37,7 +45,8 @@ class PackageBase(Base):
raise IntegrityError(
statement="Column Name cannot be null.",
orig="PackageBases.Name",
params=("NULL"))
params=("NULL"),
)
# If no SubmittedTS/ModifiedTS is provided on creation, set them
# here to the current utc timestamp.

View file

@ -16,4 +16,5 @@ class PackageBlacklist(Base):
raise IntegrityError(
statement="Column Name cannot be null.",
orig="PackageBlacklist.Name",
params=("NULL"))
params=("NULL"),
)

View file

@ -10,19 +10,19 @@ from aurweb.models.user import User as _User
class PackageComaintainer(Base):
__table__ = schema.PackageComaintainers
__tablename__ = __table__.name
__mapper_args__ = {
"primary_key": [__table__.c.UsersID, __table__.c.PackageBaseID]
}
__mapper_args__ = {"primary_key": [__table__.c.UsersID, __table__.c.PackageBaseID]}
User = relationship(
_User, backref=backref("comaintained", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.UsersID])
_User,
backref=backref("comaintained", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.UsersID],
)
PackageBase = relationship(
_PackageBase, backref=backref("comaintainers", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID])
_PackageBase,
backref=backref("comaintainers", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -31,16 +31,19 @@ class PackageComaintainer(Base):
raise IntegrityError(
statement="Foreign key UsersID cannot be null.",
orig="PackageComaintainers.UsersID",
params=("NULL"))
params=("NULL"),
)
if not self.PackageBase and not self.PackageBaseID:
raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.",
orig="PackageComaintainers.PackageBaseID",
params=("NULL"))
params=("NULL"),
)
if not self.Priority:
raise IntegrityError(
statement="Column Priority cannot be null.",
orig="PackageComaintainers.Priority",
params=("NULL"))
params=("NULL"),
)

View file

@ -13,21 +13,28 @@ class PackageComment(Base):
__mapper_args__ = {"primary_key": [__table__.c.ID]}
PackageBase = relationship(
_PackageBase, backref=backref("comments", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID])
_PackageBase,
backref=backref("comments", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID],
)
User = relationship(
_User, backref=backref("package_comments", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID])
_User,
backref=backref("package_comments", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID],
)
Editor = relationship(
_User, backref=backref("edited_comments", lazy="dynamic"),
foreign_keys=[__table__.c.EditedUsersID])
_User,
backref=backref("edited_comments", lazy="dynamic"),
foreign_keys=[__table__.c.EditedUsersID],
)
Deleter = relationship(
_User, backref=backref("deleted_comments", lazy="dynamic"),
foreign_keys=[__table__.c.DelUsersID])
_User,
backref=backref("deleted_comments", lazy="dynamic"),
foreign_keys=[__table__.c.DelUsersID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -36,27 +43,31 @@ class PackageComment(Base):
raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.",
orig="PackageComments.PackageBaseID",
params=("NULL"))
params=("NULL"),
)
if not self.User and not self.UsersID:
raise IntegrityError(
statement="Foreign key UsersID cannot be null.",
orig="PackageComments.UsersID",
params=("NULL"))
params=("NULL"),
)
if self.Comments is None:
raise IntegrityError(
statement="Column Comments cannot be null.",
orig="PackageComments.Comments",
params=("NULL"))
params=("NULL"),
)
if self.RenderedComment is None:
self.RenderedComment = str()
def maintainers(self):
return list(filter(
return list(
filter(
lambda e: e is not None,
[self.PackageBase.Maintainer] + [
c.User for c in self.PackageBase.comaintainers
]
))
[self.PackageBase.Maintainer]
+ [c.User for c in self.PackageBase.comaintainers],
)
)

View file

@ -22,14 +22,16 @@ class PackageDependency(Base):
}
Package = relationship(
_Package, backref=backref("package_dependencies", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.PackageID])
_Package,
backref=backref("package_dependencies", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageID],
)
DependencyType = relationship(
_DependencyType,
backref=backref("package_dependencies", lazy="dynamic"),
foreign_keys=[__table__.c.DepTypeID])
foreign_keys=[__table__.c.DepTypeID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -38,43 +40,58 @@ class PackageDependency(Base):
raise IntegrityError(
statement="Foreign key PackageID cannot be null.",
orig="PackageDependencies.PackageID",
params=("NULL"))
params=("NULL"),
)
if not self.DependencyType and not self.DepTypeID:
raise IntegrityError(
statement="Foreign key DepTypeID cannot be null.",
orig="PackageDependencies.DepTypeID",
params=("NULL"))
params=("NULL"),
)
if self.DepName is None:
raise IntegrityError(
statement="Column DepName cannot be null.",
orig="PackageDependencies.DepName",
params=("NULL"))
params=("NULL"),
)
def is_package(self) -> bool:
pkg = db.query(_Package).filter(_Package.Name == self.DepName).exists()
official = db.query(_OfficialProvider).filter(
_OfficialProvider.Name == self.DepName).exists()
official = (
db.query(_OfficialProvider)
.filter(_OfficialProvider.Name == self.DepName)
.exists()
)
return db.query(pkg).scalar() or db.query(official).scalar()
def provides(self) -> list[PackageRelation]:
from aurweb.models.relation_type import PROVIDES_ID
rels = db.query(PackageRelation).join(_Package).filter(
and_(PackageRelation.RelTypeID == PROVIDES_ID,
PackageRelation.RelName == self.DepName)
).with_entities(
_Package.Name,
literal(False).label("is_official")
).order_by(_Package.Name.asc())
rels = (
db.query(PackageRelation)
.join(_Package)
.filter(
and_(
PackageRelation.RelTypeID == PROVIDES_ID,
PackageRelation.RelName == self.DepName,
)
)
.with_entities(_Package.Name, literal(False).label("is_official"))
.order_by(_Package.Name.asc())
)
official_rels = db.query(_OfficialProvider).filter(
and_(_OfficialProvider.Provides == self.DepName,
_OfficialProvider.Name != self.DepName)
).with_entities(
_OfficialProvider.Name,
literal(True).label("is_official")
).order_by(_OfficialProvider.Name.asc())
official_rels = (
db.query(_OfficialProvider)
.filter(
and_(
_OfficialProvider.Provides == self.DepName,
_OfficialProvider.Name != self.DepName,
)
)
.with_entities(_OfficialProvider.Name, literal(True).label("is_official"))
.order_by(_OfficialProvider.Name.asc())
)
return rels.union(official_rels).all()

View file

@ -10,19 +10,19 @@ from aurweb.models.package import Package as _Package
class PackageGroup(Base):
__table__ = schema.PackageGroups
__tablename__ = __table__.name
__mapper_args__ = {
"primary_key": [__table__.c.PackageID, __table__.c.GroupID]
}
__mapper_args__ = {"primary_key": [__table__.c.PackageID, __table__.c.GroupID]}
Package = relationship(
_Package, backref=backref("package_groups", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.PackageID])
_Package,
backref=backref("package_groups", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageID],
)
Group = relationship(
_Group, backref=backref("package_groups", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.GroupID])
_Group,
backref=backref("package_groups", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.GroupID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -31,10 +31,12 @@ class PackageGroup(Base):
raise IntegrityError(
statement="Primary key PackageID cannot be null.",
orig="PackageGroups.PackageID",
params=("NULL"))
params=("NULL"),
)
if not self.Group and not self.GroupID:
raise IntegrityError(
statement="Primary key GroupID cannot be null.",
orig="PackageGroups.GroupID",
params=("NULL"))
params=("NULL"),
)

View file

@ -9,14 +9,13 @@ from aurweb.models.package_base import PackageBase as _PackageBase
class PackageKeyword(Base):
__table__ = schema.PackageKeywords
__tablename__ = __table__.name
__mapper_args__ = {
"primary_key": [__table__.c.PackageBaseID, __table__.c.Keyword]
}
__mapper_args__ = {"primary_key": [__table__.c.PackageBaseID, __table__.c.Keyword]}
PackageBase = relationship(
_PackageBase, backref=backref("keywords", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID])
_PackageBase,
backref=backref("keywords", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -25,4 +24,5 @@ class PackageKeyword(Base):
raise IntegrityError(
statement="Primary key PackageBaseID cannot be null.",
orig="PackageKeywords.PackageBaseID",
params=("NULL"))
params=("NULL"),
)

View file

@ -10,19 +10,19 @@ from aurweb.models.package import Package as _Package
class PackageLicense(Base):
__table__ = schema.PackageLicenses
__tablename__ = __table__.name
__mapper_args__ = {
"primary_key": [__table__.c.PackageID, __table__.c.LicenseID]
}
__mapper_args__ = {"primary_key": [__table__.c.PackageID, __table__.c.LicenseID]}
Package = relationship(
_Package, backref=backref("package_licenses", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.PackageID])
_Package,
backref=backref("package_licenses", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageID],
)
License = relationship(
_License, backref=backref("package_licenses", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.LicenseID])
_License,
backref=backref("package_licenses", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.LicenseID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -31,10 +31,12 @@ class PackageLicense(Base):
raise IntegrityError(
statement="Primary key PackageID cannot be null.",
orig="PackageLicenses.PackageID",
params=("NULL"))
params=("NULL"),
)
if not self.License and not self.LicenseID:
raise IntegrityError(
statement="Primary key LicenseID cannot be null.",
orig="PackageLicenses.LicenseID",
params=("NULL"))
params=("NULL"),
)

View file

@ -10,20 +10,19 @@ from aurweb.models.user import User as _User
class PackageNotification(Base):
__table__ = schema.PackageNotifications
__tablename__ = __table__.name
__mapper_args__ = {
"primary_key": [__table__.c.UserID, __table__.c.PackageBaseID]
}
__mapper_args__ = {"primary_key": [__table__.c.UserID, __table__.c.PackageBaseID]}
User = relationship(
_User, backref=backref("notifications", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.UserID])
_User,
backref=backref("notifications", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.UserID],
)
PackageBase = relationship(
_PackageBase,
backref=backref("notifications", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID])
backref=backref("notifications", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -32,10 +31,12 @@ class PackageNotification(Base):
raise IntegrityError(
statement="Foreign key UserID cannot be null.",
orig="PackageNotifications.UserID",
params=("NULL"))
params=("NULL"),
)
if not self.PackageBase and not self.PackageBaseID:
raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.",
orig="PackageNotifications.PackageBaseID",
params=("NULL"))
params=("NULL"),
)

View file

@ -19,13 +19,16 @@ class PackageRelation(Base):
}
Package = relationship(
_Package, backref=backref("package_relations", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.PackageID])
_Package,
backref=backref("package_relations", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageID],
)
RelationType = relationship(
_RelationType, backref=backref("package_relations", lazy="dynamic"),
foreign_keys=[__table__.c.RelTypeID])
_RelationType,
backref=backref("package_relations", lazy="dynamic"),
foreign_keys=[__table__.c.RelTypeID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -34,16 +37,19 @@ class PackageRelation(Base):
raise IntegrityError(
statement="Foreign key PackageID cannot be null.",
orig="PackageRelations.PackageID",
params=("NULL"))
params=("NULL"),
)
if not self.RelationType and not self.RelTypeID:
raise IntegrityError(
statement="Foreign key RelTypeID cannot be null.",
orig="PackageRelations.RelTypeID",
params=("NULL"))
params=("NULL"),
)
if not self.RelName:
raise IntegrityError(
statement="Column RelName cannot be null.",
orig="PackageRelations.RelName",
params=("NULL"))
params=("NULL"),
)

View file

@ -25,26 +25,34 @@ class PackageRequest(Base):
__mapper_args__ = {"primary_key": [__table__.c.ID]}
RequestType = relationship(
_RequestType, backref=backref("package_requests", lazy="dynamic"),
foreign_keys=[__table__.c.ReqTypeID])
_RequestType,
backref=backref("package_requests", lazy="dynamic"),
foreign_keys=[__table__.c.ReqTypeID],
)
User = relationship(
_User, backref=backref("package_requests", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID])
_User,
backref=backref("package_requests", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID],
)
PackageBase = relationship(
_PackageBase, backref=backref("requests", lazy="dynamic"),
foreign_keys=[__table__.c.PackageBaseID])
_PackageBase,
backref=backref("requests", lazy="dynamic"),
foreign_keys=[__table__.c.PackageBaseID],
)
Closer = relationship(
_User, backref=backref("closed_requests", lazy="dynamic"),
foreign_keys=[__table__.c.ClosedUID])
_User,
backref=backref("closed_requests", lazy="dynamic"),
foreign_keys=[__table__.c.ClosedUID],
)
STATUS_DISPLAY = {
PENDING_ID: PENDING,
CLOSED_ID: CLOSED,
ACCEPTED_ID: ACCEPTED,
REJECTED_ID: REJECTED
REJECTED_ID: REJECTED,
}
def __init__(self, **kwargs):
@ -54,37 +62,43 @@ class PackageRequest(Base):
raise IntegrityError(
statement="Foreign key ReqTypeID cannot be null.",
orig="PackageRequests.ReqTypeID",
params=("NULL"))
params=("NULL"),
)
if not self.PackageBase and not self.PackageBaseID:
raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.",
orig="PackageRequests.PackageBaseID",
params=("NULL"))
params=("NULL"),
)
if not self.PackageBaseName:
raise IntegrityError(
statement="Column PackageBaseName cannot be null.",
orig="PackageRequests.PackageBaseName",
params=("NULL"))
params=("NULL"),
)
if not self.User and not self.UsersID:
raise IntegrityError(
statement="Foreign key UsersID cannot be null.",
orig="PackageRequests.UsersID",
params=("NULL"))
params=("NULL"),
)
if self.Comments is None:
raise IntegrityError(
statement="Column Comments cannot be null.",
orig="PackageRequests.Comments",
params=("NULL"))
params=("NULL"),
)
if self.ClosureComment is None:
raise IntegrityError(
statement="Column ClosureComment cannot be null.",
orig="PackageRequests.ClosureComment",
params=("NULL"))
params=("NULL"),
)
def status_display(self) -> str:
"""Return a display string for the Status column."""

View file

@ -9,17 +9,13 @@ from aurweb.models.package import Package as _Package
class PackageSource(Base):
__table__ = schema.PackageSources
__tablename__ = __table__.name
__mapper_args__ = {
"primary_key": [
__table__.c.PackageID,
__table__.c.Source
]
}
__mapper_args__ = {"primary_key": [__table__.c.PackageID, __table__.c.Source]}
Package = relationship(
_Package, backref=backref("package_sources", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.PackageID])
_Package,
backref=backref("package_sources", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -28,7 +24,8 @@ class PackageSource(Base):
raise IntegrityError(
statement="Foreign key PackageID cannot be null.",
orig="PackageSources.PackageID",
params=("NULL"))
params=("NULL"),
)
if not self.Source:
self.Source = "/dev/null"

View file

@ -10,18 +10,19 @@ from aurweb.models.user import User as _User
class PackageVote(Base):
__table__ = schema.PackageVotes
__tablename__ = __table__.name
__mapper_args__ = {
"primary_key": [__table__.c.UsersID, __table__.c.PackageBaseID]
}
__mapper_args__ = {"primary_key": [__table__.c.UsersID, __table__.c.PackageBaseID]}
User = relationship(
_User, backref=backref("package_votes", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID])
_User,
backref=backref("package_votes", lazy="dynamic"),
foreign_keys=[__table__.c.UsersID],
)
PackageBase = relationship(
_PackageBase, backref=backref("package_votes", lazy="dynamic",
cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID])
_PackageBase,
backref=backref("package_votes", lazy="dynamic", cascade="all, delete"),
foreign_keys=[__table__.c.PackageBaseID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -30,16 +31,19 @@ class PackageVote(Base):
raise IntegrityError(
statement="Foreign key UsersID cannot be null.",
orig="PackageVotes.UsersID",
params=("NULL"))
params=("NULL"),
)
if not self.PackageBase and not self.PackageBaseID:
raise IntegrityError(
statement="Foreign key PackageBaseID cannot be null.",
orig="PackageVotes.PackageBaseID",
params=("NULL"))
params=("NULL"),
)
if not self.VoteTS:
raise IntegrityError(
statement="Column VoteTS cannot be null.",
orig="PackageVotes.VoteTS",
params=("NULL"))
params=("NULL"),
)

View file

@ -12,8 +12,10 @@ class Session(Base):
__mapper_args__ = {"primary_key": [__table__.c.UsersID]}
User = relationship(
_User, backref=backref("session", uselist=False),
foreign_keys=[__table__.c.UsersID])
_User,
backref=backref("session", uselist=False),
foreign_keys=[__table__.c.UsersID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -29,10 +31,13 @@ class Session(Base):
user_exists = db.query(_User).filter(_User.ID == uid).exists()
if not db.query(user_exists).scalar():
raise IntegrityError(
statement=("Foreign key UsersID cannot be null and "
"must be a valid user's ID."),
statement=(
"Foreign key UsersID cannot be null and "
"must be a valid user's ID."
),
orig="Sessions.UsersID",
params=("NULL"))
params=("NULL"),
)
def generate_unique_sid():

View file

@ -12,16 +12,17 @@ class SSHPubKey(Base):
__mapper_args__ = {"primary_key": [__table__.c.Fingerprint]}
User = relationship(
"User", backref=backref("ssh_pub_keys", lazy="dynamic"),
foreign_keys=[__table__.c.UserID])
"User",
backref=backref("ssh_pub_keys", lazy="dynamic"),
foreign_keys=[__table__.c.UserID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_fingerprint(pubkey: str) -> str:
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE,
stderr=PIPE)
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE, stderr=PIPE)
out, _ = proc.communicate(pubkey.encode())
if proc.returncode:
raise ValueError("The SSH public key is invalid.")

View file

@ -16,10 +16,12 @@ class Term(Base):
raise IntegrityError(
statement="Column Description cannot be null.",
orig="Terms.Description",
params=("NULL"))
params=("NULL"),
)
if not self.URL:
raise IntegrityError(
statement="Column URL cannot be null.",
orig="Terms.URL",
params=("NULL"))
params=("NULL"),
)

View file

@ -10,17 +10,19 @@ from aurweb.models.user import User as _User
class TUVote(Base):
__table__ = schema.TU_Votes
__tablename__ = __table__.name
__mapper_args__ = {
"primary_key": [__table__.c.VoteID, __table__.c.UserID]
}
__mapper_args__ = {"primary_key": [__table__.c.VoteID, __table__.c.UserID]}
VoteInfo = relationship(
_TUVoteInfo, backref=backref("tu_votes", lazy="dynamic"),
foreign_keys=[__table__.c.VoteID])
_TUVoteInfo,
backref=backref("tu_votes", lazy="dynamic"),
foreign_keys=[__table__.c.VoteID],
)
User = relationship(
_User, backref=backref("tu_votes", lazy="dynamic"),
foreign_keys=[__table__.c.UserID])
_User,
backref=backref("tu_votes", lazy="dynamic"),
foreign_keys=[__table__.c.UserID],
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -29,10 +31,12 @@ class TUVote(Base):
raise IntegrityError(
statement="Foreign key VoteID cannot be null.",
orig="TU_Votes.VoteID",
params=("NULL"))
params=("NULL"),
)
if not self.User and not self.UserID:
raise IntegrityError(
statement="Foreign key UserID cannot be null.",
orig="TU_Votes.UserID",
params=("NULL"))
params=("NULL"),
)

View file

@ -14,8 +14,10 @@ class TUVoteInfo(Base):
__mapper_args__ = {"primary_key": [__table__.c.ID]}
Submitter = relationship(
_User, backref=backref("tu_voteinfo_set", lazy="dynamic"),
foreign_keys=[__table__.c.SubmitterID])
_User,
backref=backref("tu_voteinfo_set", lazy="dynamic"),
foreign_keys=[__table__.c.SubmitterID],
)
def __init__(self, **kwargs):
# Default Quorum, Yes, No and Abstain columns to 0.
@ -29,31 +31,36 @@ class TUVoteInfo(Base):
raise IntegrityError(
statement="Column Agenda cannot be null.",
orig="TU_VoteInfo.Agenda",
params=("NULL"))
params=("NULL"),
)
if self.User is None:
raise IntegrityError(
statement="Column User cannot be null.",
orig="TU_VoteInfo.User",
params=("NULL"))
params=("NULL"),
)
if self.Submitted is None:
raise IntegrityError(
statement="Column Submitted cannot be null.",
orig="TU_VoteInfo.Submitted",
params=("NULL"))
params=("NULL"),
)
if self.End is None:
raise IntegrityError(
statement="Column End cannot be null.",
orig="TU_VoteInfo.End",
params=("NULL"))
params=("NULL"),
)
if not self.Submitter:
raise IntegrityError(
statement="Foreign key SubmitterID cannot be null.",
orig="TU_VoteInfo.SubmitterID",
params=("NULL"))
params=("NULL"),
)
def __setattr__(self, key: str, value: typing.Any):
"""Customize setattr to stringify any Quorum keys given."""

View file

@ -1,9 +1,7 @@
import hashlib
from typing import Set
import bcrypt
from fastapi import Request
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
@ -12,7 +10,6 @@ from sqlalchemy.orm import backref, relationship
import aurweb.config
import aurweb.models.account_type
import aurweb.schema
from aurweb import db, logging, schema, time, util
from aurweb.models.account_type import AccountType as _AccountType
from aurweb.models.ban import is_banned
@ -25,6 +22,7 @@ SALT_ROUNDS_DEFAULT = 12
class User(Base):
"""An ORM model of a single Users record."""
__table__ = schema.Users
__tablename__ = __table__.name
__mapper_args__ = {"primary_key": [__table__.c.ID]}
@ -33,7 +31,8 @@ class User(Base):
_AccountType,
backref=backref("users", lazy="dynamic"),
foreign_keys=[__table__.c.AccountTypeID],
uselist=False)
uselist=False,
)
# High-level variables used to track authentication (not in DB).
authenticated = False
@ -41,22 +40,22 @@ class User(Base):
# Make this static to the class just in case SQLAlchemy ever
# does something to bypass our constructor.
salt_rounds = aurweb.config.getint("options", "salt_rounds",
SALT_ROUNDS_DEFAULT)
salt_rounds = aurweb.config.getint("options", "salt_rounds", SALT_ROUNDS_DEFAULT)
def __init__(self, Passwd: str = str(), **kwargs):
super().__init__(**kwargs, Passwd=str())
# Run this again in the constructor in case we rehashed config.
self.salt_rounds = aurweb.config.getint("options", "salt_rounds",
SALT_ROUNDS_DEFAULT)
self.salt_rounds = aurweb.config.getint(
"options", "salt_rounds", SALT_ROUNDS_DEFAULT
)
if Passwd:
self.update_password(Passwd)
def update_password(self, password):
self.Passwd = bcrypt.hashpw(
password.encode(),
bcrypt.gensalt(rounds=self.salt_rounds)).decode()
password.encode(), bcrypt.gensalt(rounds=self.salt_rounds)
).decode()
@staticmethod
def minimum_passwd_length():
@ -74,17 +73,17 @@ class User(Base):
password_is_valid = False
try:
password_is_valid = bcrypt.checkpw(password.encode(),
self.Passwd.encode())
password_is_valid = bcrypt.checkpw(password.encode(), self.Passwd.encode())
except ValueError:
pass
# If our Salt column is not empty, we're using a legacy password.
if not password_is_valid and self.Salt != str():
# Try to login with legacy method.
password_is_valid = hashlib.md5(
f"{self.Salt}{password}".encode()
).hexdigest() == self.Passwd
password_is_valid = (
hashlib.md5(f"{self.Salt}{password}".encode()).hexdigest()
== self.Passwd
)
# We got here, we passed the legacy authentication.
# Update the password to our modern hash style.
@ -96,8 +95,7 @@ class User(Base):
def _login_approved(self, request: Request):
return not is_banned(request) and not self.Suspended
def login(self, request: Request, password: str,
session_time: int = 0) -> str:
def login(self, request: Request, password: str, session_time: int = 0) -> str:
"""Login and authenticate a request."""
from aurweb import db
@ -127,9 +125,9 @@ class User(Base):
self.LastLoginIPAddress = request.client.host
if not self.session:
sid = generate_unique_sid()
self.session = db.create(Session, User=self,
SessionID=sid,
LastUpdateTS=now_ts)
self.session = db.create(
Session, User=self, SessionID=sid, LastUpdateTS=now_ts
)
else:
last_updated = self.session.LastUpdateTS
if last_updated and last_updated < now_ts:
@ -148,9 +146,9 @@ class User(Base):
return self.session.SessionID
def has_credential(self, credential: Set[int],
approved: list["User"] = list()):
def has_credential(self, credential: Set[int], approved: list["User"] = list()):
from aurweb.auth.creds import has_credential
return has_credential(self, credential, approved)
def logout(self, request: Request):
@ -162,13 +160,13 @@ class User(Base):
def is_trusted_user(self):
return self.AccountType.ID in {
aurweb.models.account_type.TRUSTED_USER_ID,
aurweb.models.account_type.TRUSTED_USER_AND_DEV_ID
aurweb.models.account_type.TRUSTED_USER_AND_DEV_ID,
}
def is_developer(self):
return self.AccountType.ID in {
aurweb.models.account_type.DEVELOPER_ID,
aurweb.models.account_type.TRUSTED_USER_AND_DEV_ID
aurweb.models.account_type.TRUSTED_USER_AND_DEV_ID,
}
def is_elevated(self):
@ -196,15 +194,19 @@ class User(Base):
:return: Boolean indicating whether `self` can edit `target`
"""
from aurweb.auth import creds
has_cred = self.has_credential(creds.ACCOUNT_EDIT, approved=[target])
return has_cred and self.AccountTypeID >= target.AccountTypeID
def voted_for(self, package) -> bool:
"""Has this User voted for package?"""
from aurweb.models.package_vote import PackageVote
return bool(package.PackageBase.package_votes.filter(
return bool(
package.PackageBase.package_votes.filter(
PackageVote.UsersID == self.ID
).scalar())
).scalar()
)
def notified(self, package) -> bool:
"""Is this User being notified about package (or package base)?
@ -225,9 +227,11 @@ class User(Base):
# Run an exists() query where a pkgbase-related
# PackageNotification exists for self (a user).
return bool(db.query(
return bool(
db.query(
query.filter(PackageNotification.UserID == self.ID).exists()
).scalar())
).scalar()
)
def packages(self):
"""Returns an ORM query to Package objects owned by this user.
@ -241,16 +245,24 @@ class User(Base):
"""
from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase
return db.query(Package).join(PackageBase).filter(
return (
db.query(Package)
.join(PackageBase)
.filter(
or_(
PackageBase.PackagerUID == self.ID,
PackageBase.MaintainerUID == self.ID
PackageBase.MaintainerUID == self.ID,
)
)
)
def __repr__(self):
return "<User(ID='%s', AccountType='%s', Username='%s')>" % (
self.ID, str(self.AccountType), self.Username)
self.ID,
str(self.AccountType),
self.Username,
)
def __str__(self) -> str:
return self.Username

View file

@ -7,46 +7,55 @@ from aurweb import config, db, l10n, time, util
from aurweb.exceptions import InvariantError
from aurweb.models import PackageBase, PackageRequest, User
from aurweb.models.package_request import ACCEPTED_ID, PENDING_ID, REJECTED_ID
from aurweb.models.request_type import DELETION, DELETION_ID, MERGE, MERGE_ID, ORPHAN, ORPHAN_ID
from aurweb.models.request_type import (
DELETION,
DELETION_ID,
MERGE,
MERGE_ID,
ORPHAN,
ORPHAN_ID,
)
from aurweb.scripts import notify
class ClosureFactory:
"""A factory class used to autogenerate closure comments."""
REQTYPE_NAMES = {
DELETION_ID: DELETION,
MERGE_ID: MERGE,
ORPHAN_ID: ORPHAN
}
REQTYPE_NAMES = {DELETION_ID: DELETION, MERGE_ID: MERGE, ORPHAN_ID: ORPHAN}
def _deletion_closure(self, requester: User,
pkgbase: PackageBase,
target: PackageBase = None):
return (f"[Autogenerated] Accepted deletion for {pkgbase.Name}.")
def _deletion_closure(
self, requester: User, pkgbase: PackageBase, target: PackageBase = None
):
return f"[Autogenerated] Accepted deletion for {pkgbase.Name}."
def _merge_closure(self, requester: User,
pkgbase: PackageBase,
target: PackageBase = None):
return (f"[Autogenerated] Accepted merge for {pkgbase.Name} "
f"into {target.Name}.")
def _merge_closure(
self, requester: User, pkgbase: PackageBase, target: PackageBase = None
):
return (
f"[Autogenerated] Accepted merge for {pkgbase.Name} " f"into {target.Name}."
)
def _orphan_closure(self, requester: User,
pkgbase: PackageBase,
target: PackageBase = None):
return (f"[Autogenerated] Accepted orphan for {pkgbase.Name}.")
def _orphan_closure(
self, requester: User, pkgbase: PackageBase, target: PackageBase = None
):
return f"[Autogenerated] Accepted orphan for {pkgbase.Name}."
def _rejected_merge_closure(self, requester: User,
pkgbase: PackageBase,
target: PackageBase = None):
return (f"[Autogenerated] Another request to merge {pkgbase.Name} "
f"into {target.Name} has rendered this request invalid.")
def _rejected_merge_closure(
self, requester: User, pkgbase: PackageBase, target: PackageBase = None
):
return (
f"[Autogenerated] Another request to merge {pkgbase.Name} "
f"into {target.Name} has rendered this request invalid."
)
def get_closure(self, reqtype_id: int,
def get_closure(
self,
reqtype_id: int,
requester: User,
pkgbase: PackageBase,
target: PackageBase = None,
status: int = ACCEPTED_ID) -> str:
status: int = ACCEPTED_ID,
) -> str:
"""
Return a closure comment handled by this class.
@ -69,8 +78,9 @@ class ClosureFactory:
return handler(requester, pkgbase, target)
def update_closure_comment(pkgbase: PackageBase, reqtype_id: int,
comments: str, target: PackageBase = None) -> None:
def update_closure_comment(
pkgbase: PackageBase, reqtype_id: int, comments: str, target: PackageBase = None
) -> None:
"""
Update all pending requests related to `pkgbase` with a closure comment.
@ -90,8 +100,10 @@ def update_closure_comment(pkgbase: PackageBase, reqtype_id: int,
return
query = pkgbase.requests.filter(
and_(PackageRequest.ReqTypeID == reqtype_id,
PackageRequest.Status == PENDING_ID))
and_(
PackageRequest.ReqTypeID == reqtype_id, PackageRequest.Status == PENDING_ID
)
)
if reqtype_id == MERGE_ID:
query = query.filter(PackageRequest.MergeBaseName == target.Name)
@ -101,8 +113,7 @@ def update_closure_comment(pkgbase: PackageBase, reqtype_id: int,
def verify_orphan_request(user: User, pkgbase: PackageBase):
"""Verify that an undue orphan request exists in `requests`."""
requests = pkgbase.requests.filter(
PackageRequest.ReqTypeID == ORPHAN_ID)
requests = pkgbase.requests.filter(PackageRequest.ReqTypeID == ORPHAN_ID)
for pkgreq in requests:
idle_time = config.getint("options", "request_idle_time")
time_delta = time.utcnow() - pkgreq.RequestTS
@ -115,9 +126,13 @@ def verify_orphan_request(user: User, pkgbase: PackageBase):
return False
def close_pkgreq(pkgreq: PackageRequest, closer: User,
pkgbase: PackageBase, target: Optional[PackageBase],
status: int) -> None:
def close_pkgreq(
pkgreq: PackageRequest,
closer: User,
pkgbase: PackageBase,
target: Optional[PackageBase],
status: int,
) -> None:
"""
Close a package request with `pkgreq`.Status == `status`.
@ -130,16 +145,15 @@ def close_pkgreq(pkgreq: PackageRequest, closer: User,
now = time.utcnow()
pkgreq.Status = status
pkgreq.Closer = closer
pkgreq.ClosureComment = (
pkgreq.ClosureComment or ClosureFactory().get_closure(
pkgreq.ReqTypeID, closer, pkgbase, target, status)
pkgreq.ClosureComment = pkgreq.ClosureComment or ClosureFactory().get_closure(
pkgreq.ReqTypeID, closer, pkgbase, target, status
)
pkgreq.ClosedTS = now
def handle_request(request: Request, reqtype_id: int,
pkgbase: PackageBase,
target: PackageBase = None) -> list[notify.Notification]:
def handle_request(
request: Request, reqtype_id: int, pkgbase: PackageBase, target: PackageBase = None
) -> list[notify.Notification]:
"""
Handle package requests before performing an action.
@ -165,17 +179,20 @@ def handle_request(request: Request, reqtype_id: int,
if reqtype_id == ORPHAN_ID:
if not verify_orphan_request(request.user, pkgbase):
_ = l10n.get_translator_for_request(request)
raise InvariantError(_(
"No due existing orphan requests to accept for %s."
) % pkgbase.Name)
raise InvariantError(
_("No due existing orphan requests to accept for %s.") % pkgbase.Name
)
# Produce a base query for requests related to `pkgbase`, based
# on ReqTypeID matching `reqtype_id`, pending status and a correct
# PackagBaseName column.
query: orm.Query = pkgbase.requests.filter(
and_(PackageRequest.ReqTypeID == reqtype_id,
and_(
PackageRequest.ReqTypeID == reqtype_id,
PackageRequest.Status == PENDING_ID,
PackageRequest.PackageBaseName == pkgbase.Name))
PackageRequest.PackageBaseName == pkgbase.Name,
)
)
# Build a query for records we should accept. For merge requests,
# this is specific to a matching MergeBaseName. For others, this
@ -183,8 +200,7 @@ def handle_request(request: Request, reqtype_id: int,
accept_query: orm.Query = query
if target:
# If a `target` was supplied, filter by MergeBaseName
accept_query = query.filter(
PackageRequest.MergeBaseName == target.Name)
accept_query = query.filter(PackageRequest.MergeBaseName == target.Name)
# Build an accept list out of `accept_query`.
to_accept: list[PackageRequest] = accept_query.all()
@ -203,14 +219,16 @@ def handle_request(request: Request, reqtype_id: int,
if not to_accept:
utcnow = time.utcnow()
with db.begin():
pkgreq = db.create(PackageRequest,
pkgreq = db.create(
PackageRequest,
ReqTypeID=reqtype_id,
RequestTS=utcnow,
User=request.user,
PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
Comments="Autogenerated by aurweb.",
ClosureComment=str())
ClosureComment=str(),
)
# If it's a merge request, set MergeBaseName to `target`.Name.
if pkgreq.ReqTypeID == MERGE_ID:
@ -222,15 +240,20 @@ def handle_request(request: Request, reqtype_id: int,
# Update requests with their new status and closures.
with db.begin():
util.apply_all(to_accept, lambda p: close_pkgreq(
p, request.user, pkgbase, target, ACCEPTED_ID))
util.apply_all(to_reject, lambda p: close_pkgreq(
p, request.user, pkgbase, target, REJECTED_ID))
util.apply_all(
to_accept,
lambda p: close_pkgreq(p, request.user, pkgbase, target, ACCEPTED_ID),
)
util.apply_all(
to_reject,
lambda p: close_pkgreq(p, request.user, pkgbase, target, REJECTED_ID),
)
# Create RequestCloseNotifications for all requests involved.
for pkgreq in (to_accept + to_reject):
for pkgreq in to_accept + to_reject:
notif = notify.RequestCloseNotification(
request.user.ID, pkgreq.ID, pkgreq.status_display())
request.user.ID, pkgreq.ID, pkgreq.status_display()
)
notifs.append(notif)
# Return notifications to the caller for sending.

View file

@ -4,7 +4,12 @@ from sqlalchemy import and_, case, or_, orm
from aurweb import db, models
from aurweb.models import Package, PackageBase, User
from aurweb.models.dependency_type import CHECKDEPENDS_ID, DEPENDS_ID, MAKEDEPENDS_ID, OPTDEPENDS_ID
from aurweb.models.dependency_type import (
CHECKDEPENDS_ID,
DEPENDS_ID,
MAKEDEPENDS_ID,
OPTDEPENDS_ID,
)
from aurweb.models.package_comaintainer import PackageComaintainer
from aurweb.models.package_keyword import PackageKeyword
from aurweb.models.package_notification import PackageNotification

View file

@ -3,7 +3,6 @@ from http import HTTPStatus
from typing import Tuple, Union
import orjson
from fastapi import HTTPException
from sqlalchemy import orm
@ -61,13 +60,13 @@ def dep_extra_desc(dep: models.PackageDependency) -> str:
@register_filter("pkgname_link")
def pkgname_link(pkgname: str) -> str:
record = db.query(Package).filter(
Package.Name == pkgname).exists()
record = db.query(Package).filter(Package.Name == pkgname).exists()
if db.query(record).scalar():
return f"/packages/{pkgname}"
official = db.query(OfficialProvider).filter(
OfficialProvider.Name == pkgname).exists()
official = (
db.query(OfficialProvider).filter(OfficialProvider.Name == pkgname).exists()
)
if db.query(official).scalar():
base = "/".join([OFFICIAL_BASE, "packages"])
return f"{base}/?q={pkgname}"
@ -83,16 +82,14 @@ def package_link(package: Union[Package, OfficialProvider]) -> str:
@register_filter("provides_markup")
def provides_markup(provides: Providers) -> str:
return ", ".join([
f'<a href="{package_link(pkg)}">{pkg.Name}</a>'
for pkg in provides
])
return ", ".join(
[f'<a href="{package_link(pkg)}">{pkg.Name}</a>' for pkg in provides]
)
def get_pkg_or_base(
name: str,
cls: Union[models.Package, models.PackageBase] = models.PackageBase) \
-> Union[models.Package, models.PackageBase]:
name: str, cls: Union[models.Package, models.PackageBase] = models.PackageBase
) -> Union[models.Package, models.PackageBase]:
"""Get a PackageBase instance by its name or raise a 404 if
it can't be found in the database.
@ -109,8 +106,7 @@ def get_pkg_or_base(
return instance
def get_pkgbase_comment(pkgbase: models.PackageBase, id: int) \
-> models.PackageComment:
def get_pkgbase_comment(pkgbase: models.PackageBase, id: int) -> models.PackageComment:
comment = pkgbase.comments.filter(models.PackageComment.ID == id).first()
if not comment:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
@ -122,8 +118,7 @@ def out_of_date(packages: orm.Query) -> orm.Query:
return packages.filter(models.PackageBase.OutOfDateTS.isnot(None))
def updated_packages(limit: int = 0,
cache_ttl: int = 600) -> list[models.Package]:
def updated_packages(limit: int = 0, cache_ttl: int = 600) -> list[models.Package]:
"""Return a list of valid Package objects ordered by their
ModifiedTS column in descending order from cache, after setting
the cache when no key yet exists.
@ -139,10 +134,11 @@ def updated_packages(limit: int = 0,
return orjson.loads(packages)
with db.begin():
query = db.query(models.Package).join(models.PackageBase).filter(
models.PackageBase.PackagerUID.isnot(None)
).order_by(
models.PackageBase.ModifiedTS.desc()
query = (
db.query(models.Package)
.join(models.PackageBase)
.filter(models.PackageBase.PackagerUID.isnot(None))
.order_by(models.PackageBase.ModifiedTS.desc())
)
if limit:
@ -152,13 +148,13 @@ def updated_packages(limit: int = 0,
for pkg in query:
# For each Package returned by the query, append a dict
# containing Package columns we're interested in.
packages.append({
packages.append(
{
"Name": pkg.Name,
"Version": pkg.Version,
"PackageBase": {
"ModifiedTS": pkg.PackageBase.ModifiedTS
"PackageBase": {"ModifiedTS": pkg.PackageBase.ModifiedTS},
}
})
)
# Store the JSON serialization of the package_updates key into Redis.
redis.set("package_updates", orjson.dumps(packages))
@ -168,8 +164,7 @@ def updated_packages(limit: int = 0,
return packages
def query_voted(query: list[models.Package],
user: models.User) -> dict[int, bool]:
def query_voted(query: list[models.Package], user: models.User) -> dict[int, bool]:
"""Produce a dictionary of package base ID keys to boolean values,
which indicate whether or not the package base has a vote record
related to user.
@ -180,19 +175,17 @@ def query_voted(query: list[models.Package],
"""
output = defaultdict(bool)
query_set = {pkg.PackageBaseID for pkg in query}
voted = db.query(models.PackageVote).join(
models.PackageBase,
models.PackageBase.ID.in_(query_set)
).filter(
models.PackageVote.UsersID == user.ID
voted = (
db.query(models.PackageVote)
.join(models.PackageBase, models.PackageBase.ID.in_(query_set))
.filter(models.PackageVote.UsersID == user.ID)
)
for vote in voted:
output[vote.PackageBase.ID] = True
return output
def query_notified(query: list[models.Package],
user: models.User) -> dict[int, bool]:
def query_notified(query: list[models.Package], user: models.User) -> dict[int, bool]:
"""Produce a dictionary of package base ID keys to boolean values,
which indicate whether or not the package base has a notification
record related to user.
@ -203,19 +196,17 @@ def query_notified(query: list[models.Package],
"""
output = defaultdict(bool)
query_set = {pkg.PackageBaseID for pkg in query}
notified = db.query(models.PackageNotification).join(
models.PackageBase,
models.PackageBase.ID.in_(query_set)
).filter(
models.PackageNotification.UserID == user.ID
notified = (
db.query(models.PackageNotification)
.join(models.PackageBase, models.PackageBase.ID.in_(query_set))
.filter(models.PackageNotification.UserID == user.ID)
)
for notif in notified:
output[notif.PackageBase.ID] = True
return output
def pkg_required(pkgname: str, provides: list[str]) \
-> list[PackageDependency]:
def pkg_required(pkgname: str, provides: list[str]) -> list[PackageDependency]:
"""
Get dependencies that match a string in `[pkgname] + provides`.
@ -225,9 +216,12 @@ def pkg_required(pkgname: str, provides: list[str]) \
:return: List of PackageDependency instances
"""
targets = set([pkgname] + provides)
query = db.query(PackageDependency).join(Package).filter(
PackageDependency.DepName.in_(targets)
).order_by(Package.Name.asc())
query = (
db.query(PackageDependency)
.join(Package)
.filter(PackageDependency.DepName.in_(targets))
.order_by(Package.Name.asc())
)
return query

View file

@ -14,15 +14,15 @@ logger = logging.get_logger(__name__)
def pkgbase_notify_instance(request: Request, pkgbase: PackageBase) -> None:
notif = db.query(pkgbase.notifications.filter(
notif = db.query(
pkgbase.notifications.filter(
PackageNotification.UserID == request.user.ID
).exists()).scalar()
).exists()
).scalar()
has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY)
if has_cred and not notif:
with db.begin():
db.create(PackageNotification,
PackageBase=pkgbase,
User=request.user)
db.create(PackageNotification, PackageBase=pkgbase, User=request.user)
def pkgbase_unnotify_instance(request: Request, pkgbase: PackageBase) -> None:
@ -36,8 +36,11 @@ def pkgbase_unnotify_instance(request: Request, pkgbase: PackageBase) -> None:
def pkgbase_unflag_instance(request: Request, pkgbase: PackageBase) -> None:
has_cred = request.user.has_credential(creds.PKGBASE_UNFLAG, approved=[
pkgbase.Flagger, pkgbase.Maintainer] + [c.User for c in pkgbase.comaintainers])
has_cred = request.user.has_credential(
creds.PKGBASE_UNFLAG,
approved=[pkgbase.Flagger, pkgbase.Maintainer]
+ [c.User for c in pkgbase.comaintainers],
)
if has_cred:
with db.begin():
pkgbase.OutOfDateTS = None
@ -93,9 +96,9 @@ def pkgbase_adopt_instance(request: Request, pkgbase: PackageBase) -> None:
notif.send()
def pkgbase_delete_instance(request: Request, pkgbase: PackageBase,
comments: str = str()) \
-> list[notify.Notification]:
def pkgbase_delete_instance(
request: Request, pkgbase: PackageBase, comments: str = str()
) -> list[notify.Notification]:
notifs = handle_request(request, DELETION_ID, pkgbase) + [
notify.DeleteNotification(request.user.ID, pkgbase.ID)
]
@ -107,8 +110,9 @@ def pkgbase_delete_instance(request: Request, pkgbase: PackageBase,
return notifs
def pkgbase_merge_instance(request: Request, pkgbase: PackageBase,
target: PackageBase, comments: str = str()) -> None:
def pkgbase_merge_instance(
request: Request, pkgbase: PackageBase, target: PackageBase, comments: str = str()
) -> None:
pkgbasename = str(pkgbase.Name)
# Create notifications.
@ -144,8 +148,10 @@ def pkgbase_merge_instance(request: Request, pkgbase: PackageBase,
db.delete(pkgbase)
# Log this out for accountability purposes.
logger.info(f"Trusted User '{request.user.Username}' merged "
f"'{pkgbasename}' into '{target.Name}'.")
logger.info(
f"Trusted User '{request.user.Username}' merged "
f"'{pkgbasename}' into '{target.Name}'."
)
# Send notifications.
util.apply_all(notifs, lambda n: n.send())

View file

@ -10,18 +10,22 @@ from aurweb.models.package_comment import PackageComment
from aurweb.models.package_request import PENDING_ID, PackageRequest
from aurweb.models.package_vote import PackageVote
from aurweb.scripts import notify
from aurweb.templates import make_context as _make_context
from aurweb.templates import make_variable_context as _make_variable_context
from aurweb.templates import (
make_context as _make_context,
make_variable_context as _make_variable_context,
)
async def make_variable_context(request: Request, pkgbase: PackageBase) \
-> dict[str, Any]:
async def make_variable_context(
request: Request, pkgbase: PackageBase
) -> dict[str, Any]:
ctx = await _make_variable_context(request, pkgbase.Name)
return make_context(request, pkgbase, ctx)
def make_context(request: Request, pkgbase: PackageBase,
context: dict[str, Any] = None) -> dict[str, Any]:
def make_context(
request: Request, pkgbase: PackageBase, context: dict[str, Any] = None
) -> dict[str, Any]:
"""Make a basic context for package or pkgbase.
:param request: FastAPI request
@ -34,14 +38,16 @@ def make_context(request: Request, pkgbase: PackageBase,
# Per page and offset.
offset, per_page = util.sanitize_params(
request.query_params.get("O", defaults.O),
request.query_params.get("PP", defaults.COMMENTS_PER_PAGE))
request.query_params.get("PP", defaults.COMMENTS_PER_PAGE),
)
context["O"] = offset
context["PP"] = per_page
context["git_clone_uri_anon"] = config.get("options", "git_clone_uri_anon")
context["git_clone_uri_priv"] = config.get("options", "git_clone_uri_priv")
context["pkgbase"] = pkgbase
context["comaintainers"] = [
c.User for c in pkgbase.comaintainers.order_by(
c.User
for c in pkgbase.comaintainers.order_by(
PackageComaintainer.Priority.asc()
).all()
]
@ -53,9 +59,11 @@ def make_context(request: Request, pkgbase: PackageBase,
context["comments_total"] = pkgbase.comments.order_by(
PackageComment.CommentTS.desc()
).count()
context["comments"] = pkgbase.comments.order_by(
PackageComment.CommentTS.desc()
).limit(per_page).offset(offset)
context["comments"] = (
pkgbase.comments.order_by(PackageComment.CommentTS.desc())
.limit(per_page)
.offset(offset)
)
context["pinned_comments"] = pkgbase.comments.filter(
PackageComment.PinnedTS != 0
).order_by(PackageComment.CommentTS.desc())
@ -70,15 +78,15 @@ def make_context(request: Request, pkgbase: PackageBase,
).scalar()
context["requests"] = pkgbase.requests.filter(
and_(PackageRequest.Status == PENDING_ID,
PackageRequest.ClosedTS.is_(None))
and_(PackageRequest.Status == PENDING_ID, PackageRequest.ClosedTS.is_(None))
).count()
return context
def remove_comaintainer(comaint: PackageComaintainer) \
-> notify.ComaintainerRemoveNotification:
def remove_comaintainer(
comaint: PackageComaintainer,
) -> notify.ComaintainerRemoveNotification:
"""
Remove a PackageComaintainer.
@ -107,9 +115,9 @@ def remove_comaintainers(pkgbase: PackageBase, usernames: list[str]) -> None:
"""
notifications = []
with db.begin():
comaintainers = pkgbase.comaintainers.join(User).filter(
User.Username.in_(usernames)
).all()
comaintainers = (
pkgbase.comaintainers.join(User).filter(User.Username.in_(usernames)).all()
)
notifications = [
notify.ComaintainerRemoveNotification(co.User.ID, pkgbase.ID)
for co in comaintainers
@ -133,8 +141,7 @@ def latest_priority(pkgbase: PackageBase) -> int:
"""
# Order comaintainers related to pkgbase by Priority DESC.
record = pkgbase.comaintainers.order_by(
PackageComaintainer.Priority.desc()).first()
record = pkgbase.comaintainers.order_by(PackageComaintainer.Priority.desc()).first()
# Use Priority column if record exists, otherwise 0.
return record.Priority if record else 0
@ -148,8 +155,9 @@ class NoopComaintainerNotification:
return
def add_comaintainer(pkgbase: PackageBase, comaintainer: User) \
-> notify.ComaintainerAddNotification:
def add_comaintainer(
pkgbase: PackageBase, comaintainer: User
) -> notify.ComaintainerAddNotification:
"""
Add a new comaintainer to `pkgbase`.
@ -165,14 +173,19 @@ def add_comaintainer(pkgbase: PackageBase, comaintainer: User) \
new_prio = latest_priority(pkgbase) + 1
with db.begin():
db.create(PackageComaintainer, PackageBase=pkgbase,
User=comaintainer, Priority=new_prio)
db.create(
PackageComaintainer,
PackageBase=pkgbase,
User=comaintainer,
Priority=new_prio,
)
return notify.ComaintainerAddNotification(comaintainer.ID, pkgbase.ID)
def add_comaintainers(request: Request, pkgbase: PackageBase,
usernames: list[str]) -> None:
def add_comaintainers(
request: Request, pkgbase: PackageBase, usernames: list[str]
) -> None:
"""
Add comaintainers to `pkgbase`.
@ -216,7 +229,6 @@ def rotate_comaintainers(pkgbase: PackageBase) -> None:
:param pkgbase: PackageBase instance
"""
comaintainers = pkgbase.comaintainers.order_by(
PackageComaintainer.Priority.asc())
comaintainers = pkgbase.comaintainers.order_by(PackageComaintainer.Priority.asc())
for i, comaint in enumerate(comaintainers):
comaint.Priority = i + 1

View file

@ -5,9 +5,13 @@ from aurweb.exceptions import ValidationError
from aurweb.models import PackageBase
def request(pkgbase: PackageBase,
type: str, comments: str, merge_into: str,
context: dict[str, Any]) -> None:
def request(
pkgbase: PackageBase,
type: str,
comments: str,
merge_into: str,
context: dict[str, Any],
) -> None:
if not comments:
raise ValidationError(["The comment field must not be empty."])
@ -15,21 +19,16 @@ def request(pkgbase: PackageBase,
# Perform merge-related checks.
if not merge_into:
# TODO: This error needs to be translated.
raise ValidationError(
['The "Merge into" field must not be empty.'])
raise ValidationError(['The "Merge into" field must not be empty.'])
target = db.query(PackageBase).filter(
PackageBase.Name == merge_into
).first()
target = db.query(PackageBase).filter(PackageBase.Name == merge_into).first()
if not target:
# TODO: This error needs to be translated.
raise ValidationError([
"The package base you want to merge into does not exist."
])
raise ValidationError(
["The package base you want to merge into does not exist."]
)
db.refresh(target)
if target.ID == pkgbase.ID:
# TODO: This error needs to be translated.
raise ValidationError([
"You cannot merge a package base into itself."
])
raise ValidationError(["You cannot merge a package base into itself."])

View file

@ -19,8 +19,9 @@ def instrumentator():
# Their license is included in LICENSES/starlette_exporter.
# The code has been modified to remove child route checks
# (since we don't have any) and to stay within an 80-width limit.
def get_matching_route_path(scope: dict[Any, Any], routes: list[Route],
route_name: Optional[str] = None) -> str:
def get_matching_route_path(
scope: dict[Any, Any], routes: list[Route], route_name: Optional[str] = None
) -> str:
"""
Find a matching route and return its original path string
@ -34,7 +35,7 @@ def get_matching_route_path(scope: dict[Any, Any], routes: list[Route],
if match == Match.FULL:
route_name = route.path
'''
"""
# This path exists in the original function's code, but we
# don't need it (currently), so it's been removed to avoid
# useless test coverage.
@ -47,7 +48,7 @@ def get_matching_route_path(scope: dict[Any, Any], routes: list[Route],
route_name = None
else:
route_name += child_route_name
'''
"""
return route_name
elif match == Match.PARTIAL and route_name is None:
@ -55,9 +56,11 @@ def get_matching_route_path(scope: dict[Any, Any], routes: list[Route],
def http_requests_total() -> Callable[[Info], None]:
metric = Counter("http_requests_total",
metric = Counter(
"http_requests_total",
"Number of HTTP requests.",
labelnames=("method", "path", "status"))
labelnames=("method", "path", "status"),
)
def instrumentation(info: Info) -> None:
if info.request.method.lower() in ("head", "options"): # pragma: no cover
@ -85,7 +88,7 @@ def http_requests_total() -> Callable[[Info], None]:
"type": scope.get("type"),
"path": root_path + scope.get("path"),
"path_params": scope.get("path_params", {}),
"method": scope.get("method")
"method": scope.get("method"),
}
method = scope.get("method")
@ -102,7 +105,8 @@ def http_api_requests_total() -> Callable[[Info], None]:
metric = Counter(
"http_api_requests",
"Number of times an RPC API type has been requested.",
labelnames=("type", "status"))
labelnames=("type", "status"),
)
def instrumentation(info: Info) -> None:
if info.request.method.lower() in ("head", "options"): # pragma: no cover

View file

@ -38,8 +38,7 @@ def _update_ratelimit_db(request: Request):
now = time.utcnow()
time_to_delete = now - window_length
records = db.query(ApiRateLimit).filter(
ApiRateLimit.WindowStart < time_to_delete)
records = db.query(ApiRateLimit).filter(ApiRateLimit.WindowStart < time_to_delete)
with db.begin():
db.delete_all(records)
@ -47,9 +46,7 @@ def _update_ratelimit_db(request: Request):
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
with db.begin():
if not record:
record = db.create(ApiRateLimit,
WindowStart=now,
IP=host, Requests=1)
record = db.create(ApiRateLimit, WindowStart=now, IP=host, Requests=1)
else:
record.Requests += 1

View file

@ -1,9 +1,7 @@
import fakeredis
from redis import ConnectionPool, Redis
import aurweb.config
from aurweb import logging
logger = logging.get_logger(__name__)

View file

@ -3,7 +3,18 @@ API routers for FastAPI.
See https://fastapi.tiangolo.com/tutorial/bigger-applications/
"""
from . import accounts, auth, html, packages, pkgbase, requests, rpc, rss, sso, trusted_user
from . import (
accounts,
auth,
html,
packages,
pkgbase,
requests,
rpc,
rss,
sso,
trusted_user,
)
"""
aurweb application routes. This constant can be any iterable

View file

@ -1,6 +1,5 @@
import copy
import typing
from http import HTTPStatus
from typing import Any
@ -9,7 +8,6 @@ from fastapi.responses import HTMLResponse, RedirectResponse
from sqlalchemy import and_, or_
import aurweb.config
from aurweb import cookies, db, l10n, logging, models, util
from aurweb.auth import account_type_required, requires_auth, requires_guest
from aurweb.captcha import get_captcha_salts
@ -37,21 +35,23 @@ async def passreset(request: Request):
@router.post("/passreset", response_class=HTMLResponse)
@handle_form_exceptions
@requires_guest
async def passreset_post(request: Request,
async def passreset_post(
request: Request,
user: str = Form(...),
resetkey: str = Form(default=None),
password: str = Form(default=None),
confirm: str = Form(default=None)):
confirm: str = Form(default=None),
):
context = await make_variable_context(request, "Password Reset")
# The user parameter being required, we can match against
criteria = or_(models.User.Username == user, models.User.Email == user)
db_user = db.query(models.User,
and_(criteria, models.User.Suspended == 0)).first()
db_user = db.query(models.User, and_(criteria, models.User.Suspended == 0)).first()
if db_user is None:
context["errors"] = ["Invalid e-mail."]
return render_template(request, "passreset.html", context,
status_code=HTTPStatus.NOT_FOUND)
return render_template(
request, "passreset.html", context, status_code=HTTPStatus.NOT_FOUND
)
db.refresh(db_user)
if resetkey:
@ -59,29 +59,34 @@ async def passreset_post(request: Request,
if not db_user.ResetKey or resetkey != db_user.ResetKey:
context["errors"] = ["Invalid e-mail."]
return render_template(request, "passreset.html", context,
status_code=HTTPStatus.NOT_FOUND)
return render_template(
request, "passreset.html", context, status_code=HTTPStatus.NOT_FOUND
)
if not user or not password:
context["errors"] = ["Missing a required field."]
return render_template(request, "passreset.html", context,
status_code=HTTPStatus.BAD_REQUEST)
return render_template(
request, "passreset.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if password != confirm:
# If the provided password does not match the provided confirm.
context["errors"] = ["Password fields do not match."]
return render_template(request, "passreset.html", context,
status_code=HTTPStatus.BAD_REQUEST)
return render_template(
request, "passreset.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if len(password) < models.User.minimum_passwd_length():
# Translate the error here, which simplifies error output
# in the jinja2 template.
_ = get_translator_for_request(request)
context["errors"] = [_(
"Your password must be at least %s characters.") % (
str(models.User.minimum_passwd_length()))]
return render_template(request, "passreset.html", context,
status_code=HTTPStatus.BAD_REQUEST)
context["errors"] = [
_("Your password must be at least %s characters.")
% (str(models.User.minimum_passwd_length()))
]
return render_template(
request, "passreset.html", context, status_code=HTTPStatus.BAD_REQUEST
)
# We got to this point; everything matched up. Update the password
# and remove the ResetKey.
@ -92,8 +97,9 @@ async def passreset_post(request: Request,
db_user.update_password(password)
# Render ?step=complete.
return RedirectResponse(url="/passreset?step=complete",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(
url="/passreset?step=complete", status_code=HTTPStatus.SEE_OTHER
)
# If we got here, we continue with issuing a resetkey for the user.
resetkey = generate_resetkey()
@ -103,12 +109,12 @@ async def passreset_post(request: Request,
ResetKeyNotification(db_user.ID).send()
# Render ?step=confirm.
return RedirectResponse(url="/passreset?step=confirm",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(
url="/passreset?step=confirm", status_code=HTTPStatus.SEE_OTHER
)
def process_account_form(request: Request, user: models.User,
args: dict[str, Any]):
def process_account_form(request: Request, user: models.User, args: dict[str, Any]):
"""Process an account form. All fields are optional and only checks
requirements in the case they are present.
@ -146,7 +152,7 @@ def process_account_form(request: Request, user: models.User,
validate.username_in_use,
validate.email_in_use,
validate.invalid_account_type,
validate.invalid_captcha
validate.invalid_captcha,
]
try:
@ -158,10 +164,9 @@ def process_account_form(request: Request, user: models.User,
return (True, [])
def make_account_form_context(context: dict,
request: Request,
user: models.User,
args: dict):
def make_account_form_context(
context: dict, request: Request, user: models.User, args: dict
):
"""Modify a FastAPI context and add attributes for the account form.
:param context: FastAPI context
@ -173,15 +178,17 @@ def make_account_form_context(context: dict,
# Do not modify the original context.
context = copy.copy(context)
context["account_types"] = list(filter(
context["account_types"] = list(
filter(
lambda e: request.user.AccountTypeID >= e[0],
[
(at.USER_ID, f"Normal {at.USER}"),
(at.TRUSTED_USER_ID, at.TRUSTED_USER),
(at.DEVELOPER_ID, at.DEVELOPER),
(at.TRUSTED_USER_AND_DEV_ID, at.TRUSTED_USER_AND_DEV)
]
))
(at.TRUSTED_USER_AND_DEV_ID, at.TRUSTED_USER_AND_DEV),
],
)
)
if request.user.is_authenticated():
context["username"] = args.get("U", user.Username)
@ -229,7 +236,8 @@ def make_account_form_context(context: dict,
@router.get("/register", response_class=HTMLResponse)
@requires_guest
async def account_register(request: Request,
async def account_register(
request: Request,
U: str = Form(default=str()), # Username
E: str = Form(default=str()), # Email
H: str = Form(default=False), # Hide Email
@ -238,15 +246,14 @@ async def account_register(request: Request,
HP: str = Form(default=None), # Homepage
I: str = Form(default=None), # IRC Nick
K: str = Form(default=None), # PGP Key FP
L: str = Form(default=aurweb.config.get(
"options", "default_lang")),
TZ: str = Form(default=aurweb.config.get(
"options", "default_timezone")),
L: str = Form(default=aurweb.config.get("options", "default_lang")),
TZ: str = Form(default=aurweb.config.get("options", "default_timezone")),
PK: str = Form(default=None),
CN: bool = Form(default=False), # Comment Notify
CU: bool = Form(default=False), # Update Notify
CO: bool = Form(default=False), # Owner Notify
captcha: str = Form(default=str())):
captcha: str = Form(default=str()),
):
context = await make_variable_context(request, "Register")
context["captcha_salt"] = get_captcha_salts()[0]
context = make_account_form_context(context, request, None, dict())
@ -256,32 +263,32 @@ async def account_register(request: Request,
@router.post("/register", response_class=HTMLResponse)
@handle_form_exceptions
@requires_guest
async def account_register_post(request: Request,
async def account_register_post(
request: Request,
U: str = Form(default=str()), # Username
E: str = Form(default=str()), # Email
H: str = Form(default=False), # Hide Email
BE: str = Form(default=None), # Backup Email
R: str = Form(default=''), # Real Name
R: str = Form(default=""), # Real Name
HP: str = Form(default=None), # Homepage
I: str = Form(default=None), # IRC Nick
K: str = Form(default=None), # PGP Key
L: str = Form(default=aurweb.config.get(
"options", "default_lang")),
TZ: str = Form(default=aurweb.config.get(
"options", "default_timezone")),
L: str = Form(default=aurweb.config.get("options", "default_lang")),
TZ: str = Form(default=aurweb.config.get("options", "default_timezone")),
PK: str = Form(default=str()), # SSH PubKey
CN: bool = Form(default=False),
UN: bool = Form(default=False),
ON: bool = Form(default=False),
captcha: str = Form(default=None),
captcha_salt: str = Form(...)):
captcha_salt: str = Form(...),
):
context = await make_variable_context(request, "Register")
args = dict(await request.form())
args["K"] = args.get("K", str()).replace(" ", "")
K = args.get("K")
# Force "H" into a boolean.
args["H"] = H = (args.get("H", str()) == "on")
args["H"] = H = args.get("H", str()) == "on"
context = make_account_form_context(context, request, None, args)
ok, errors = process_account_form(request, request.user, args)
@ -289,30 +296,45 @@ async def account_register_post(request: Request,
# If the field values given do not meet the requirements,
# return HTTP 400 with an error.
context["errors"] = errors
return render_template(request, "register.html", context,
status_code=HTTPStatus.BAD_REQUEST)
return render_template(
request, "register.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if not captcha:
context["errors"] = ["The CAPTCHA is missing."]
return render_template(request, "register.html", context,
status_code=HTTPStatus.BAD_REQUEST)
return render_template(
request, "register.html", context, status_code=HTTPStatus.BAD_REQUEST
)
# Create a user with no password with a resetkey, then send
# an email off about it.
resetkey = generate_resetkey()
# By default, we grab the User account type to associate with.
atype = db.query(models.AccountType,
models.AccountType.AccountType == "User").first()
atype = db.query(
models.AccountType, models.AccountType.AccountType == "User"
).first()
# Create a user given all parameters available.
with db.begin():
user = db.create(models.User, Username=U,
Email=E, HideEmail=H, BackupEmail=BE,
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
LangPreference=L, Timezone=TZ, CommentNotify=CN,
UpdateNotify=UN, OwnershipNotify=ON,
ResetKey=resetkey, AccountType=atype)
user = db.create(
models.User,
Username=U,
Email=E,
HideEmail=H,
BackupEmail=BE,
RealName=R,
Homepage=HP,
IRCNick=I,
PGPKey=K,
LangPreference=L,
Timezone=TZ,
CommentNotify=CN,
UpdateNotify=UN,
OwnershipNotify=ON,
ResetKey=resetkey,
AccountType=atype,
)
# If a PK was given and either one does not exist or the given
# PK mismatches the existing user's SSHPubKey.PubKey.
@ -323,8 +345,9 @@ async def account_register_post(request: Request,
pk = " ".join(k)
fprint = get_fingerprint(pk)
with db.begin():
db.create(models.SSHPubKey, UserID=user.ID,
PubKey=pk, Fingerprint=fprint)
db.create(
models.SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fprint
)
# Send a reset key notification to the new user.
WelcomeNotification(user.ID).send()
@ -334,8 +357,9 @@ async def account_register_post(request: Request,
return render_template(request, "register.html", context)
def cannot_edit(request: Request, user: models.User) \
-> typing.Optional[RedirectResponse]:
def cannot_edit(
request: Request, user: models.User
) -> typing.Optional[RedirectResponse]:
"""
Decide if `request.user` cannot edit `user`.
@ -373,7 +397,8 @@ async def account_edit(request: Request, username: str):
@router.post("/account/{username}/edit", response_class=HTMLResponse)
@handle_form_exceptions
@requires_auth
async def account_edit_post(request: Request,
async def account_edit_post(
request: Request,
username: str,
U: str = Form(default=str()), # Username
J: bool = Form(default=False),
@ -384,10 +409,8 @@ async def account_edit_post(request: Request,
HP: str = Form(default=None), # Homepage
I: str = Form(default=None), # IRC Nick
K: str = Form(default=None), # PGP Key
L: str = Form(aurweb.config.get(
"options", "default_lang")),
TZ: str = Form(aurweb.config.get(
"options", "default_timezone")),
L: str = Form(aurweb.config.get("options", "default_lang")),
TZ: str = Form(aurweb.config.get("options", "default_timezone")),
P: str = Form(default=str()), # New Password
C: str = Form(default=None), # Password Confirm
PK: str = Form(default=None), # PubKey
@ -395,9 +418,9 @@ async def account_edit_post(request: Request,
UN: bool = Form(default=False), # Update Notify
ON: bool = Form(default=False), # Owner Notify
T: int = Form(default=None),
passwd: str = Form(default=str())):
user = db.query(models.User).filter(
models.User.Username == username).first()
passwd: str = Form(default=str()),
):
user = db.query(models.User).filter(models.User.Username == username).first()
response = cannot_edit(request, user)
if response:
return response
@ -416,13 +439,15 @@ async def account_edit_post(request: Request,
if not passwd:
context["errors"] = ["Invalid password."]
return render_template(request, "account/edit.html", context,
status_code=HTTPStatus.BAD_REQUEST)
return render_template(
request, "account/edit.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if not ok:
context["errors"] = errors
return render_template(request, "account/edit.html", context,
status_code=HTTPStatus.BAD_REQUEST)
return render_template(
request, "account/edit.html", context, status_code=HTTPStatus.BAD_REQUEST
)
updates = [
update.simple,
@ -430,7 +455,7 @@ async def account_edit_post(request: Request,
update.timezone,
update.ssh_pubkey,
update.account_type,
update.password
update.password,
]
for f in updates:
@ -441,18 +466,17 @@ async def account_edit_post(request: Request,
# Update cookies with requests, in case they were changed.
response = render_template(request, "account/edit.html", context)
return cookies.update_response_cookies(request, response,
aurtz=TZ, aurlang=L)
return cookies.update_response_cookies(request, response, aurtz=TZ, aurlang=L)
@router.get("/account/{username}")
async def account(request: Request, username: str):
_ = l10n.get_translator_for_request(request)
context = await make_variable_context(
request, _("Account") + " " + username)
context = await make_variable_context(request, _("Account") + " " + username)
if not request.user.is_authenticated():
return render_template(request, "account/show.html", context,
status_code=HTTPStatus.UNAUTHORIZED)
return render_template(
request, "account/show.html", context, status_code=HTTPStatus.UNAUTHORIZED
)
# Get related User record, if possible.
user = get_user_by_name(username)
@ -463,8 +487,7 @@ async def account(request: Request, username: str):
context["pgp_key"] = " ".join([k[i : i + 4] for i in range(0, len(k), 4)])
login_ts = None
session = db.query(models.Session).filter(
models.Session.UsersID == user.ID).first()
session = db.query(models.Session).filter(models.Session.UsersID == user.ID).first()
if session:
login_ts = user.session.LastUpdateTS
context["login_ts"] = login_ts
@ -480,15 +503,14 @@ async def account_comments(request: Request, username: str):
context = make_context(request, "Accounts")
context["username"] = username
context["comments"] = user.package_comments.order_by(
models.PackageComment.CommentTS.desc())
models.PackageComment.CommentTS.desc()
)
return render_template(request, "account/comments.html", context)
@router.get("/accounts")
@requires_auth
@account_type_required({at.TRUSTED_USER,
at.DEVELOPER,
at.TRUSTED_USER_AND_DEV})
@account_type_required({at.TRUSTED_USER, at.DEVELOPER, at.TRUSTED_USER_AND_DEV})
async def accounts(request: Request):
context = make_context(request, "Accounts")
return render_template(request, "account/search.html", context)
@ -497,10 +519,9 @@ async def accounts(request: Request):
@router.post("/accounts")
@handle_form_exceptions
@requires_auth
@account_type_required({at.TRUSTED_USER,
at.DEVELOPER,
at.TRUSTED_USER_AND_DEV})
async def accounts_post(request: Request,
@account_type_required({at.TRUSTED_USER, at.DEVELOPER, at.TRUSTED_USER_AND_DEV})
async def accounts_post(
request: Request,
O: int = Form(default=0), # Offset
SB: str = Form(default=str()), # Sort By
U: str = Form(default=str()), # Username
@ -509,7 +530,8 @@ async def accounts_post(request: Request,
E: str = Form(default=str()), # Email
R: str = Form(default=str()), # Real Name
I: str = Form(default=str()), # IRC Nick
K: str = Form(default=str())): # PGP Key
K: str = Form(default=str()),
): # PGP Key
context = await make_variable_context(request, "Accounts")
context["pp"] = pp = 50 # Hits per page.
@ -534,7 +556,7 @@ async def accounts_post(request: Request,
"u": at.USER_ID,
"t": at.TRUSTED_USER_ID,
"d": at.DEVELOPER_ID,
"td": at.TRUSTED_USER_AND_DEV_ID
"td": at.TRUSTED_USER_AND_DEV_ID,
}
account_type_id = account_types.get(T, None)
@ -545,7 +567,8 @@ async def accounts_post(request: Request,
# Populate this list with any additional statements to
# be ANDed together.
statements = [
v for k, v in [
v
for k, v in [
(account_type_id is not None, models.AccountType.ID == account_type_id),
(bool(U), models.User.Username.like(f"%{U}%")),
(bool(S), models.User.Suspended == S),
@ -553,7 +576,8 @@ async def accounts_post(request: Request,
(bool(R), models.User.RealName.like(f"%{R}%")),
(bool(I), models.User.IRCNick.like(f"%{I}%")),
(bool(K), models.User.PGPKey.like(f"%{K}%")),
] if k
]
if k
]
# Filter the query by coe-mbining all statements added above into
@ -571,9 +595,7 @@ async def accounts_post(request: Request,
return render_template(request, "account/index.html", context)
def render_terms_of_service(request: Request,
context: dict,
terms: typing.Iterable):
def render_terms_of_service(request: Request, context: dict, terms: typing.Iterable):
if not terms:
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
context["unaccepted_terms"] = terms
@ -585,14 +607,21 @@ def render_terms_of_service(request: Request,
async def terms_of_service(request: Request):
# Query the database for terms that were previously accepted,
# but now have a bumped Revision that needs to be accepted.
diffs = db.query(models.Term).join(models.AcceptedTerm).filter(
models.AcceptedTerm.Revision < models.Term.Revision).all()
diffs = (
db.query(models.Term)
.join(models.AcceptedTerm)
.filter(models.AcceptedTerm.Revision < models.Term.Revision)
.all()
)
# Query the database for any terms that have not yet been accepted.
unaccepted = db.query(models.Term).filter(
~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all()
unaccepted = (
db.query(models.Term)
.filter(~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID)))
.all()
)
for record in (diffs + unaccepted):
for record in diffs + unaccepted:
db.refresh(record)
# Translate the 'Terms of Service' part of our page title.
@ -607,16 +636,22 @@ async def terms_of_service(request: Request):
@router.post("/tos")
@handle_form_exceptions
@requires_auth
async def terms_of_service_post(request: Request,
accept: bool = Form(default=False)):
async def terms_of_service_post(request: Request, accept: bool = Form(default=False)):
# Query the database for terms that were previously accepted,
# but now have a bumped Revision that needs to be accepted.
diffs = db.query(models.Term).join(models.AcceptedTerm).filter(
models.AcceptedTerm.Revision < models.Term.Revision).all()
diffs = (
db.query(models.Term)
.join(models.AcceptedTerm)
.filter(models.AcceptedTerm.Revision < models.Term.Revision)
.all()
)
# Query the database for any terms that have not yet been accepted.
unaccepted = db.query(models.Term).filter(
~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all()
unaccepted = (
db.query(models.Term)
.filter(~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID)))
.all()
)
if not accept:
# Translate the 'Terms of Service' part of our page title.
@ -628,7 +663,8 @@ async def terms_of_service_post(request: Request,
# them instead of reiterating the process in terms_of_service.
accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(
request, context, util.apply_all(accept_needed, db.refresh))
request, context, util.apply_all(accept_needed, db.refresh)
)
with db.begin():
# For each term we found, query for the matching accepted term
@ -636,13 +672,18 @@ async def terms_of_service_post(request: Request,
for term in diffs:
db.refresh(term)
accepted_term = request.user.accepted_terms.filter(
models.AcceptedTerm.TermsID == term.ID).first()
models.AcceptedTerm.TermsID == term.ID
).first()
accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it!
for term in unaccepted:
db.refresh(term)
db.create(models.AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision)
db.create(
models.AcceptedTerm,
User=request.user,
Term=term,
Revision=term.Revision,
)
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)

View file

@ -5,7 +5,6 @@ from fastapi.responses import HTMLResponse, RedirectResponse
from sqlalchemy import or_
import aurweb.config
from aurweb import cookies, db
from aurweb.auth import requires_auth, requires_guest
from aurweb.exceptions import handle_form_exceptions
@ -32,55 +31,73 @@ async def login_get(request: Request, next: str = "/"):
@router.post("/login", response_class=HTMLResponse)
@handle_form_exceptions
@requires_guest
async def login_post(request: Request,
async def login_post(
request: Request,
next: str = Form(...),
user: str = Form(default=str()),
passwd: str = Form(default=str()),
remember_me: bool = Form(default=False)):
remember_me: bool = Form(default=False),
):
# TODO: Once the Origin header gets broader adoption, this code can be
# slightly simplified to use it.
login_path = aurweb.config.get("options", "aur_location") + "/login"
referer = request.headers.get("Referer")
if not referer or not referer.startswith(login_path):
_ = get_translator_for_request(request)
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST,
detail=_("Bad Referer header."))
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, detail=_("Bad Referer header.")
)
with db.begin():
user = db.query(User).filter(
or_(User.Username == user, User.Email == user)
).first()
user = (
db.query(User)
.filter(or_(User.Username == user, User.Email == user))
.first()
)
if not user:
return await login_template(request, next,
errors=["Bad username or password."])
return await login_template(request, next, errors=["Bad username or password."])
if user.Suspended:
return await login_template(request, next,
errors=["Account Suspended"])
return await login_template(request, next, errors=["Account Suspended"])
cookie_timeout = cookies.timeout(remember_me)
sid = user.login(request, passwd, cookie_timeout)
if not sid:
return await login_template(request, next,
errors=["Bad username or password."])
return await login_template(request, next, errors=["Bad username or password."])
response = RedirectResponse(url=next,
status_code=HTTPStatus.SEE_OTHER)
response = RedirectResponse(url=next, status_code=HTTPStatus.SEE_OTHER)
secure = aurweb.config.getboolean("options", "disable_http_login")
response.set_cookie("AURSID", sid, max_age=cookie_timeout,
secure=secure, httponly=secure,
samesite=cookies.samesite())
response.set_cookie("AURTZ", user.Timezone,
secure=secure, httponly=secure,
samesite=cookies.samesite())
response.set_cookie("AURLANG", user.LangPreference,
secure=secure, httponly=secure,
samesite=cookies.samesite())
response.set_cookie("AURREMEMBER", remember_me,
secure=secure, httponly=secure,
samesite=cookies.samesite())
response.set_cookie(
"AURSID",
sid,
max_age=cookie_timeout,
secure=secure,
httponly=secure,
samesite=cookies.samesite(),
)
response.set_cookie(
"AURTZ",
user.Timezone,
secure=secure,
httponly=secure,
samesite=cookies.samesite(),
)
response.set_cookie(
"AURLANG",
user.LangPreference,
secure=secure,
httponly=secure,
samesite=cookies.samesite(),
)
response.set_cookie(
"AURREMEMBER",
remember_me,
secure=secure,
httponly=secure,
samesite=cookies.samesite(),
)
return response
@ -93,8 +110,7 @@ async def logout(request: Request, next: str = Form(default="/")):
# Use 303 since we may be handling a post request, that'll get it
# to redirect to a get request.
response = RedirectResponse(url=next,
status_code=HTTPStatus.SEE_OTHER)
response = RedirectResponse(url=next, status_code=HTTPStatus.SEE_OTHER)
response.delete_cookie("AURSID")
response.delete_cookie("AURTZ")
return response

View file

@ -2,17 +2,20 @@
decorators in some way; more complex routes should be defined in their
own modules and imported here. """
import os
from http import HTTPStatus
from fastapi import APIRouter, Form, HTTPException, Request, Response
from fastapi.responses import HTMLResponse, RedirectResponse
from prometheus_client import CONTENT_TYPE_LATEST, CollectorRegistry, generate_latest, multiprocess
from prometheus_client import (
CONTENT_TYPE_LATEST,
CollectorRegistry,
generate_latest,
multiprocess,
)
from sqlalchemy import and_, case, or_
import aurweb.config
import aurweb.models.package_request
from aurweb import cookies, db, logging, models, time, util
from aurweb.cache import db_count_cache
from aurweb.exceptions import handle_form_exceptions
@ -34,10 +37,12 @@ async def favicon(request: Request):
@router.post("/language", response_class=RedirectResponse)
@handle_form_exceptions
async def language(request: Request,
async def language(
request: Request,
set_lang: str = Form(...),
next: str = Form(...),
q: str = Form(default=None)):
q: str = Form(default=None),
):
"""
A POST route used to set a session's language.
@ -45,7 +50,7 @@ async def language(request: Request,
setting the language on any page, we want to preserve query
parameters across the redirect.
"""
if next[0] != '/':
if next[0] != "/":
return HTMLResponse(b"Invalid 'next' parameter.", status_code=400)
query_string = "?" + q if q else str()
@ -56,12 +61,13 @@ async def language(request: Request,
request.user.LangPreference = set_lang
# In any case, set the response's AURLANG cookie that never expires.
response = RedirectResponse(url=f"{next}{query_string}",
status_code=HTTPStatus.SEE_OTHER)
response = RedirectResponse(
url=f"{next}{query_string}", status_code=HTTPStatus.SEE_OTHER
)
secure = aurweb.config.getboolean("options", "disable_http_login")
response.set_cookie("AURLANG", set_lang,
secure=secure, httponly=secure,
samesite=cookies.samesite())
response.set_cookie(
"AURLANG", set_lang, secure=secure, httponly=secure, samesite=cookies.samesite()
)
return response
@ -69,7 +75,7 @@ async def language(request: Request,
async def index(request: Request):
"""Homepage route."""
context = make_context(request, "Home")
context['ssh_fingerprints'] = util.get_ssh_fingerprints()
context["ssh_fingerprints"] = util.get_ssh_fingerprints()
bases = db.query(models.PackageBase)
@ -79,24 +85,33 @@ async def index(request: Request):
# Package statistics.
query = bases.filter(models.PackageBase.PackagerUID.isnot(None))
context["package_count"] = await db_count_cache(
redis, "package_count", query, expire=cache_expire)
redis, "package_count", query, expire=cache_expire
)
query = bases.filter(
and_(models.PackageBase.MaintainerUID.is_(None),
models.PackageBase.PackagerUID.isnot(None))
and_(
models.PackageBase.MaintainerUID.is_(None),
models.PackageBase.PackagerUID.isnot(None),
)
)
context["orphan_count"] = await db_count_cache(
redis, "orphan_count", query, expire=cache_expire)
redis, "orphan_count", query, expire=cache_expire
)
query = db.query(models.User)
context["user_count"] = await db_count_cache(
redis, "user_count", query, expire=cache_expire)
redis, "user_count", query, expire=cache_expire
)
query = query.filter(
or_(models.User.AccountTypeID == TRUSTED_USER_ID,
models.User.AccountTypeID == TRUSTED_USER_AND_DEV_ID))
or_(
models.User.AccountTypeID == TRUSTED_USER_ID,
models.User.AccountTypeID == TRUSTED_USER_AND_DEV_ID,
)
)
context["trusted_user_count"] = await db_count_cache(
redis, "trusted_user_count", query, expire=cache_expire)
redis, "trusted_user_count", query, expire=cache_expire
)
# Current timestamp.
now = time.utcnow()
@ -106,31 +121,40 @@ async def index(request: Request):
one_hour = 3600
updated = bases.filter(
and_(models.PackageBase.ModifiedTS - models.PackageBase.SubmittedTS >= one_hour,
models.PackageBase.PackagerUID.isnot(None))
and_(
models.PackageBase.ModifiedTS - models.PackageBase.SubmittedTS >= one_hour,
models.PackageBase.PackagerUID.isnot(None),
)
)
query = bases.filter(
and_(models.PackageBase.SubmittedTS >= seven_days_ago,
models.PackageBase.PackagerUID.isnot(None))
and_(
models.PackageBase.SubmittedTS >= seven_days_ago,
models.PackageBase.PackagerUID.isnot(None),
)
)
context["seven_days_old_added"] = await db_count_cache(
redis, "seven_days_old_added", query, expire=cache_expire)
redis, "seven_days_old_added", query, expire=cache_expire
)
query = updated.filter(models.PackageBase.ModifiedTS >= seven_days_ago)
context["seven_days_old_updated"] = await db_count_cache(
redis, "seven_days_old_updated", query, expire=cache_expire)
redis, "seven_days_old_updated", query, expire=cache_expire
)
year = seven_days * 52 # Fifty two weeks worth: one year.
year_ago = now - year
query = updated.filter(models.PackageBase.ModifiedTS >= year_ago)
context["year_old_updated"] = await db_count_cache(
redis, "year_old_updated", query, expire=cache_expire)
redis, "year_old_updated", query, expire=cache_expire
)
query = bases.filter(
models.PackageBase.ModifiedTS - models.PackageBase.SubmittedTS < 3600)
models.PackageBase.ModifiedTS - models.PackageBase.SubmittedTS < 3600
)
context["never_updated"] = await db_count_cache(
redis, "never_updated", query, expire=cache_expire)
redis, "never_updated", query, expire=cache_expire
)
# Get the 15 most recently updated packages.
context["package_updates"] = updated_packages(15, cache_expire)
@ -140,78 +164,92 @@ async def index(request: Request):
# the dashboard display.
packages = db.query(models.Package).join(models.PackageBase)
maintained = packages.join(
maintained = (
packages.join(
models.PackageComaintainer,
models.PackageComaintainer.PackageBaseID == models.PackageBase.ID,
isouter=True
).join(
isouter=True,
)
.join(
models.User,
or_(models.PackageBase.MaintainerUID == models.User.ID,
models.PackageComaintainer.UsersID == models.User.ID)
).filter(
models.User.ID == request.user.ID
or_(
models.PackageBase.MaintainerUID == models.User.ID,
models.PackageComaintainer.UsersID == models.User.ID,
),
)
.filter(models.User.ID == request.user.ID)
)
# Packages maintained by the user that have been flagged.
context["flagged_packages"] = maintained.filter(
models.PackageBase.OutOfDateTS.isnot(None)
).order_by(
models.PackageBase.ModifiedTS.desc(), models.Package.Name.asc()
).limit(50).all()
context["flagged_packages"] = (
maintained.filter(models.PackageBase.OutOfDateTS.isnot(None))
.order_by(models.PackageBase.ModifiedTS.desc(), models.Package.Name.asc())
.limit(50)
.all()
)
# Flagged packages that request.user has voted for.
context["flagged_packages_voted"] = query_voted(
context.get("flagged_packages"), request.user)
context.get("flagged_packages"), request.user
)
# Flagged packages that request.user is being notified about.
context["flagged_packages_notified"] = query_notified(
context.get("flagged_packages"), request.user)
context.get("flagged_packages"), request.user
)
archive_time = aurweb.config.getint('options', 'request_archive_time')
archive_time = aurweb.config.getint("options", "request_archive_time")
start = now - archive_time
# Package requests created by request.user.
context["package_requests"] = request.user.package_requests.filter(
context["package_requests"] = (
request.user.package_requests.filter(
models.PackageRequest.RequestTS >= start
).order_by(
)
.order_by(
# Order primarily by the Status column being PENDING_ID,
# and secondarily by RequestTS; both in descending order.
case([(models.PackageRequest.Status == PENDING_ID, 1)],
else_=0).desc(),
models.PackageRequest.RequestTS.desc()
).limit(50).all()
case([(models.PackageRequest.Status == PENDING_ID, 1)], else_=0).desc(),
models.PackageRequest.RequestTS.desc(),
)
.limit(50)
.all()
)
# Packages that the request user maintains or comaintains.
context["packages"] = maintained.filter(
models.User.ID == models.PackageBase.MaintainerUID
).order_by(
models.PackageBase.ModifiedTS.desc(), models.Package.Name.desc()
).limit(50).all()
context["packages"] = (
maintained.filter(models.User.ID == models.PackageBase.MaintainerUID)
.order_by(models.PackageBase.ModifiedTS.desc(), models.Package.Name.desc())
.limit(50)
.all()
)
# Packages that request.user has voted for.
context["packages_voted"] = query_voted(
context.get("packages"), request.user)
context["packages_voted"] = query_voted(context.get("packages"), request.user)
# Packages that request.user is being notified about.
context["packages_notified"] = query_notified(
context.get("packages"), request.user)
context.get("packages"), request.user
)
# Any packages that the request user comaintains.
context["comaintained"] = packages.join(
models.PackageComaintainer
).filter(
models.PackageComaintainer.UsersID == request.user.ID
).order_by(
models.PackageBase.ModifiedTS.desc(), models.Package.Name.desc()
).limit(50).all()
context["comaintained"] = (
packages.join(models.PackageComaintainer)
.filter(models.PackageComaintainer.UsersID == request.user.ID)
.order_by(models.PackageBase.ModifiedTS.desc(), models.Package.Name.desc())
.limit(50)
.all()
)
# Comaintained packages that request.user has voted for.
context["comaintained_voted"] = query_voted(
context.get("comaintained"), request.user)
context.get("comaintained"), request.user
)
# Comaintained packages that request.user is being notified about.
context["comaintained_notified"] = query_notified(
context.get("comaintained"), request.user)
context.get("comaintained"), request.user
)
return render_template(request, "index.html", context)
@ -232,16 +270,15 @@ async def archive_sha256(request: Request, archive: str):
@router.get("/metrics")
async def metrics(request: Request):
if not os.environ.get("PROMETHEUS_MULTIPROC_DIR", None):
return Response("Prometheus metrics are not enabled.",
status_code=HTTPStatus.SERVICE_UNAVAILABLE)
return Response(
"Prometheus metrics are not enabled.",
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
data = generate_latest(registry)
headers = {
"Content-Type": CONTENT_TYPE_LATEST,
"Content-Length": str(len(data))
}
headers = {"Content-Type": CONTENT_TYPE_LATEST, "Content-Length": str(len(data))}
return Response(data, headers=headers)

View file

@ -5,7 +5,6 @@ from typing import Any
from fastapi import APIRouter, Form, Query, Request, Response
import aurweb.filters # noqa: F401
from aurweb import config, db, defaults, logging, models, util
from aurweb.auth import creds, requires_auth
from aurweb.exceptions import InvariantError, handle_form_exceptions
@ -13,23 +12,24 @@ from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID
from aurweb.packages import util as pkgutil
from aurweb.packages.search import PackageSearch
from aurweb.packages.util import get_pkg_or_base
from aurweb.pkgbase import actions as pkgbase_actions
from aurweb.pkgbase import util as pkgbaseutil
from aurweb.pkgbase import actions as pkgbase_actions, util as pkgbaseutil
from aurweb.templates import make_context, make_variable_context, render_template
logger = logging.get_logger(__name__)
router = APIRouter()
async def packages_get(request: Request, context: dict[str, Any],
status_code: HTTPStatus = HTTPStatus.OK):
async def packages_get(
request: Request, context: dict[str, Any], status_code: HTTPStatus = HTTPStatus.OK
):
# Query parameters used in this request.
context["q"] = dict(request.query_params)
# Per page and offset.
offset, per_page = util.sanitize_params(
request.query_params.get("O", defaults.O),
request.query_params.get("PP", defaults.PP))
request.query_params.get("PP", defaults.PP),
)
context["O"] = offset
# Limit PP to options.max_search_results
@ -82,8 +82,7 @@ async def packages_get(request: Request, context: dict[str, Any],
if submit == "Orphans":
# If the user clicked the "Orphans" button, we only want
# orphaned packages.
search.query = search.query.filter(
models.PackageBase.MaintainerUID.is_(None))
search.query = search.query.filter(models.PackageBase.MaintainerUID.is_(None))
# Collect search result count here; we've applied our keywords.
# Including more query operations below, like ordering, will
@ -94,7 +93,9 @@ async def packages_get(request: Request, context: dict[str, Any],
search.sort_by(sort_by, sort_order)
# Insert search results into the context.
results = search.results().with_entities(
results = (
search.results()
.with_entities(
models.Package.ID,
models.Package.Name,
models.Package.PackageBaseID,
@ -105,15 +106,18 @@ async def packages_get(request: Request, context: dict[str, Any],
models.PackageBase.OutOfDateTS,
models.User.Username.label("Maintainer"),
models.PackageVote.PackageBaseID.label("Voted"),
models.PackageNotification.PackageBaseID.label("Notify")
).group_by(models.Package.Name)
models.PackageNotification.PackageBaseID.label("Notify"),
)
.group_by(models.Package.Name)
)
packages = results.limit(per_page).offset(offset)
context["packages"] = packages
context["packages_count"] = num_packages
return render_template(request, "packages/index.html", context,
status_code=status_code)
return render_template(
request, "packages/index.html", context, status_code=status_code
)
@router.get("/packages")
@ -123,9 +127,12 @@ async def packages(request: Request) -> Response:
@router.get("/packages/{name}")
async def package(request: Request, name: str,
async def package(
request: Request,
name: str,
all_deps: bool = Query(default=False),
all_reqs: bool = Query(default=False)) -> Response:
all_reqs: bool = Query(default=False),
) -> Response:
"""
Get a package by name.
@ -156,26 +163,21 @@ async def package(request: Request, name: str,
# Add our base information.
context = await pkgbaseutil.make_variable_context(request, pkgbase)
context.update(
{
"all_deps": all_deps,
"all_reqs": all_reqs
}
)
context.update({"all_deps": all_deps, "all_reqs": all_reqs})
context["package"] = pkg
# Package sources.
context["sources"] = pkg.package_sources.order_by(
models.PackageSource.Source.asc()).all()
models.PackageSource.Source.asc()
).all()
# Listing metadata.
context["max_listing"] = max_listing = 20
# Package dependencies.
deps = pkg.package_dependencies.order_by(
models.PackageDependency.DepTypeID.asc(),
models.PackageDependency.DepName.asc()
models.PackageDependency.DepTypeID.asc(), models.PackageDependency.DepName.asc()
)
context["depends_count"] = deps.count()
if not all_deps:
@ -183,8 +185,7 @@ async def package(request: Request, name: str,
context["dependencies"] = deps.all()
# Package requirements (other packages depend on this one).
reqs = pkgutil.pkg_required(
pkg.Name, [p.RelName for p in rels_data.get("p", [])])
reqs = pkgutil.pkg_required(pkg.Name, [p.RelName for p in rels_data.get("p", [])])
context["reqs_count"] = reqs.count()
if not all_reqs:
reqs = reqs.limit(max_listing)
@ -210,8 +211,7 @@ async def package(request: Request, name: str,
return render_template(request, "packages/show.html", context)
async def packages_unflag(request: Request, package_ids: list[int] = [],
**kwargs):
async def packages_unflag(request: Request, package_ids: list[int] = [], **kwargs):
if not package_ids:
return (False, ["You did not select any packages to unflag."])
@ -220,11 +220,11 @@ async def packages_unflag(request: Request, package_ids: list[int] = [],
bases = set()
package_ids = set(package_ids) # Convert this to a set for O(1).
packages = db.query(models.Package).filter(
models.Package.ID.in_(package_ids)).all()
packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
for pkg in packages:
has_cred = request.user.has_credential(
creds.PKGBASE_UNFLAG, approved=[pkg.PackageBase.Flagger])
creds.PKGBASE_UNFLAG, approved=[pkg.PackageBase.Flagger]
)
if not has_cred:
return (False, ["You did not select any packages to unflag."])
@ -236,20 +236,17 @@ async def packages_unflag(request: Request, package_ids: list[int] = [],
return (True, ["The selected packages have been unflagged."])
async def packages_notify(request: Request, package_ids: list[int] = [],
**kwargs):
async def packages_notify(request: Request, package_ids: list[int] = [], **kwargs):
# In cases where we encounter errors with the request, we'll
# use this error tuple as a return value.
# TODO: This error does not yet have a translation.
error_tuple = (False,
["You did not select any packages to be notified about."])
error_tuple = (False, ["You did not select any packages to be notified about."])
if not package_ids:
return error_tuple
bases = set()
package_ids = set(package_ids)
packages = db.query(models.Package).filter(
models.Package.ID.in_(package_ids)).all()
packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
for pkg in packages:
if pkg.PackageBase not in bases:
@ -257,9 +254,11 @@ async def packages_notify(request: Request, package_ids: list[int] = [],
# Perform some checks on what the user selected for notify.
for pkgbase in bases:
notif = db.query(pkgbase.notifications.filter(
notif = db.query(
pkgbase.notifications.filter(
models.PackageNotification.UserID == request.user.ID
).exists()).scalar()
).exists()
).scalar()
has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY)
# If the request user either does not have credentials
@ -275,23 +274,20 @@ async def packages_notify(request: Request, package_ids: list[int] = [],
return (True, ["The selected packages' notifications have been enabled."])
async def packages_unnotify(request: Request, package_ids: list[int] = [],
**kwargs):
async def packages_unnotify(request: Request, package_ids: list[int] = [], **kwargs):
if not package_ids:
# TODO: This error does not yet have a translation.
return (False,
["You did not select any packages for notification removal."])
return (False, ["You did not select any packages for notification removal."])
# TODO: This error does not yet have a translation.
error_tuple = (
False,
["A package you selected does not have notifications enabled."]
["A package you selected does not have notifications enabled."],
)
bases = set()
package_ids = set(package_ids)
packages = db.query(models.Package).filter(
models.Package.ID.in_(package_ids)).all()
packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
for pkg in packages:
if pkg.PackageBase not in bases:
@ -299,9 +295,11 @@ async def packages_unnotify(request: Request, package_ids: list[int] = [],
# Perform some checks on what the user selected for notify.
for pkgbase in bases:
notif = db.query(pkgbase.notifications.filter(
notif = db.query(
pkgbase.notifications.filter(
models.PackageNotification.UserID == request.user.ID
).exists()).scalar()
).exists()
).scalar()
if not notif:
return error_tuple
@ -312,19 +310,24 @@ async def packages_unnotify(request: Request, package_ids: list[int] = [],
return (True, ["The selected packages' notifications have been removed."])
async def packages_adopt(request: Request, package_ids: list[int] = [],
confirm: bool = False, **kwargs):
async def packages_adopt(
request: Request, package_ids: list[int] = [], confirm: bool = False, **kwargs
):
if not package_ids:
return (False, ["You did not select any packages to adopt."])
if not confirm:
return (False, ["The selected packages have not been adopted, "
"check the confirmation checkbox."])
return (
False,
[
"The selected packages have not been adopted, "
"check the confirmation checkbox."
],
)
bases = set()
package_ids = set(package_ids)
packages = db.query(models.Package).filter(
models.Package.ID.in_(package_ids)).all()
packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
for pkg in packages:
if pkg.PackageBase not in bases:
@ -335,8 +338,10 @@ async def packages_adopt(request: Request, package_ids: list[int] = [],
has_cred = request.user.has_credential(creds.PKGBASE_ADOPT)
if not (has_cred or not pkgbase.Maintainer):
# TODO: This error needs to be translated.
return (False, ["You are not allowed to adopt one of the "
"packages you selected."])
return (
False,
["You are not allowed to adopt one of the " "packages you selected."],
)
# Now, really adopt the bases.
for pkgbase in bases:
@ -345,8 +350,7 @@ async def packages_adopt(request: Request, package_ids: list[int] = [],
return (True, ["The selected packages have been adopted."])
def disown_all(request: Request, pkgbases: list[models.PackageBase]) \
-> list[str]:
def disown_all(request: Request, pkgbases: list[models.PackageBase]) -> list[str]:
errors = []
for pkgbase in pkgbases:
try:
@ -356,19 +360,24 @@ def disown_all(request: Request, pkgbases: list[models.PackageBase]) \
return errors
async def packages_disown(request: Request, package_ids: list[int] = [],
confirm: bool = False, **kwargs):
async def packages_disown(
request: Request, package_ids: list[int] = [], confirm: bool = False, **kwargs
):
if not package_ids:
return (False, ["You did not select any packages to disown."])
if not confirm:
return (False, ["The selected packages have not been disowned, "
"check the confirmation checkbox."])
return (
False,
[
"The selected packages have not been disowned, "
"check the confirmation checkbox."
],
)
bases = set()
package_ids = set(package_ids)
packages = db.query(models.Package).filter(
models.Package.ID.in_(package_ids)).all()
packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
for pkg in packages:
if pkg.PackageBase not in bases:
@ -376,12 +385,15 @@ async def packages_disown(request: Request, package_ids: list[int] = [],
# Check that the user has credentials for every package they selected.
for pkgbase in bases:
has_cred = request.user.has_credential(creds.PKGBASE_DISOWN,
approved=[pkgbase.Maintainer])
has_cred = request.user.has_credential(
creds.PKGBASE_DISOWN, approved=[pkgbase.Maintainer]
)
if not has_cred:
# TODO: This error needs to be translated.
return (False, ["You are not allowed to disown one "
"of the packages you selected."])
return (
False,
["You are not allowed to disown one " "of the packages you selected."],
)
# Now, disown all the bases if we can.
if errors := disown_all(request, bases):
@ -390,23 +402,31 @@ async def packages_disown(request: Request, package_ids: list[int] = [],
return (True, ["The selected packages have been disowned."])
async def packages_delete(request: Request, package_ids: list[int] = [],
confirm: bool = False, merge_into: str = str(),
**kwargs):
async def packages_delete(
request: Request,
package_ids: list[int] = [],
confirm: bool = False,
merge_into: str = str(),
**kwargs,
):
if not package_ids:
return (False, ["You did not select any packages to delete."])
if not confirm:
return (False, ["The selected packages have not been deleted, "
"check the confirmation checkbox."])
return (
False,
[
"The selected packages have not been deleted, "
"check the confirmation checkbox."
],
)
if not request.user.has_credential(creds.PKGBASE_DELETE):
return (False, ["You do not have permission to delete packages."])
# set-ify package_ids and query the database for related records.
package_ids = set(package_ids)
packages = db.query(models.Package).filter(
models.Package.ID.in_(package_ids)).all()
packages = db.query(models.Package).filter(models.Package.ID.in_(package_ids)).all()
if len(packages) != len(package_ids):
# Let the user know there was an issue with their input: they have
@ -422,12 +442,15 @@ async def packages_delete(request: Request, package_ids: list[int] = [],
notifs += pkgbase_actions.pkgbase_delete_instance(request, pkgbase)
# Log out the fact that this happened for accountability.
logger.info(f"Privileged user '{request.user.Username}' deleted the "
f"following package bases: {str(deleted_bases)}.")
logger.info(
f"Privileged user '{request.user.Username}' deleted the "
f"following package bases: {str(deleted_bases)}."
)
util.apply_all(notifs, lambda n: n.send())
return (True, ["The selected packages have been deleted."])
# A mapping of action string -> callback functions used within the
# `packages_post` route below. We expect any action callback to
# return a tuple in the format: (succeeded: bool, message: list[str]).
@ -444,10 +467,12 @@ PACKAGE_ACTIONS = {
@router.post("/packages")
@handle_form_exceptions
@requires_auth
async def packages_post(request: Request,
async def packages_post(
request: Request,
IDs: list[int] = Form(default=[]),
action: str = Form(default=str()),
confirm: bool = Form(default=False)):
confirm: bool = Form(default=False),
):
# If an invalid action is specified, just render GET /packages
# with an BAD_REQUEST status_code.

View file

@ -16,9 +16,7 @@ from aurweb.models.package_vote import PackageVote
from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID
from aurweb.packages.requests import update_closure_comment
from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment
from aurweb.pkgbase import actions
from aurweb.pkgbase import util as pkgbaseutil
from aurweb.pkgbase import validate
from aurweb.pkgbase import actions, util as pkgbaseutil, validate
from aurweb.scripts import notify, popupdate
from aurweb.scripts.rendercomment import update_comment_render_fastapi
from aurweb.templates import make_variable_context, render_template
@ -44,8 +42,9 @@ async def pkgbase(request: Request, name: str) -> Response:
packages = pkgbase.packages.all()
pkg = packages[0]
if len(packages) == 1 and pkg.Name == pkgbase.Name:
return RedirectResponse(f"/packages/{pkg.Name}",
status_code=int(HTTPStatus.SEE_OTHER))
return RedirectResponse(
f"/packages/{pkg.Name}", status_code=int(HTTPStatus.SEE_OTHER)
)
# Add our base information.
context = pkgbaseutil.make_context(request, pkgbase)
@ -69,8 +68,7 @@ async def pkgbase_voters(request: Request, name: str) -> Response:
pkgbase = get_pkg_or_base(name, PackageBase)
if not request.user.has_credential(creds.PKGBASE_LIST_VOTERS):
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
context = templates.make_context(request, "Voters")
context["pkgbase"] = pkgbase
@ -82,8 +80,7 @@ async def pkgbase_flag_comment(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
if pkgbase.OutOfDateTS is None:
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
context = templates.make_context(request, "Flag Comment")
context["pkgbase"] = pkgbase
@ -92,13 +89,15 @@ async def pkgbase_flag_comment(request: Request, name: str):
@router.post("/pkgbase/{name}/keywords")
@handle_form_exceptions
async def pkgbase_keywords(request: Request, name: str,
keywords: str = Form(default=str())):
async def pkgbase_keywords(
request: Request, name: str, keywords: str = Form(default=str())
):
pkgbase = get_pkg_or_base(name, PackageBase)
approved = [pkgbase.Maintainer] + [c.User for c in pkgbase.comaintainers]
has_cred = creds.has_credential(request.user, creds.PKGBASE_SET_KEYWORDS,
approved=approved)
has_cred = creds.has_credential(
request.user, creds.PKGBASE_SET_KEYWORDS, approved=approved
)
if not has_cred:
return Response(status_code=HTTPStatus.UNAUTHORIZED)
@ -108,15 +107,14 @@ async def pkgbase_keywords(request: Request, name: str,
# Delete all keywords which are not supplied by the user.
with db.begin():
other_keywords = pkgbase.keywords.filter(
~PackageKeyword.Keyword.in_(keywords))
other_keyword_strings = set(
kwd.Keyword.lower() for kwd in other_keywords)
other_keywords = pkgbase.keywords.filter(~PackageKeyword.Keyword.in_(keywords))
other_keyword_strings = set(kwd.Keyword.lower() for kwd in other_keywords)
existing_keywords = set(
kwd.Keyword.lower() for kwd in
pkgbase.keywords.filter(
~PackageKeyword.Keyword.in_(other_keyword_strings))
kwd.Keyword.lower()
for kwd in pkgbase.keywords.filter(
~PackageKeyword.Keyword.in_(other_keyword_strings)
)
)
db.delete_all(other_keywords)
@ -124,8 +122,7 @@ async def pkgbase_keywords(request: Request, name: str,
for keyword in new_keywords:
db.create(PackageKeyword, PackageBase=pkgbase, Keyword=keyword)
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@router.get("/pkgbase/{name}/flag")
@ -135,8 +132,7 @@ async def pkgbase_flag_get(request: Request, name: str):
has_cred = request.user.has_credential(creds.PKGBASE_FLAG)
if not has_cred or pkgbase.OutOfDateTS is not None:
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
context = templates.make_context(request, "Flag Package Out-Of-Date")
context["pkgbase"] = pkgbase
@ -146,17 +142,20 @@ async def pkgbase_flag_get(request: Request, name: str):
@router.post("/pkgbase/{name}/flag")
@handle_form_exceptions
@requires_auth
async def pkgbase_flag_post(request: Request, name: str,
comments: str = Form(default=str())):
async def pkgbase_flag_post(
request: Request, name: str, comments: str = Form(default=str())
):
pkgbase = get_pkg_or_base(name, PackageBase)
if not comments:
context = templates.make_context(request, "Flag Package Out-Of-Date")
context["pkgbase"] = pkgbase
context["errors"] = ["The selected packages have not been flagged, "
"please enter a comment."]
return render_template(request, "pkgbase/flag.html", context,
status_code=HTTPStatus.BAD_REQUEST)
context["errors"] = [
"The selected packages have not been flagged, " "please enter a comment."
]
return render_template(
request, "pkgbase/flag.html", context, status_code=HTTPStatus.BAD_REQUEST
)
has_cred = request.user.has_credential(creds.PKGBASE_FLAG)
if has_cred and not pkgbase.OutOfDateTS:
@ -168,17 +167,18 @@ async def pkgbase_flag_post(request: Request, name: str,
notify.FlagNotification(request.user.ID, pkgbase.ID).send()
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@router.post("/pkgbase/{name}/comments")
@handle_form_exceptions
@requires_auth
async def pkgbase_comments_post(
request: Request, name: str,
request: Request,
name: str,
comment: str = Form(default=str()),
enable_notifications: bool = Form(default=False)):
enable_notifications: bool = Form(default=False),
):
"""Add a new comment via POST request."""
pkgbase = get_pkg_or_base(name, PackageBase)
@ -189,29 +189,34 @@ async def pkgbase_comments_post(
# update the db record.
now = time.utcnow()
with db.begin():
comment = db.create(PackageComment, User=request.user,
comment = db.create(
PackageComment,
User=request.user,
PackageBase=pkgbase,
Comments=comment, RenderedComment=str(),
CommentTS=now)
Comments=comment,
RenderedComment=str(),
CommentTS=now,
)
if enable_notifications and not request.user.notified(pkgbase):
db.create(PackageNotification,
User=request.user,
PackageBase=pkgbase)
db.create(PackageNotification, User=request.user, PackageBase=pkgbase)
update_comment_render_fastapi(comment)
notif = notify.CommentNotification(request.user.ID, pkgbase.ID, comment.ID)
notif.send()
# Redirect to the pkgbase page.
return RedirectResponse(f"/pkgbase/{pkgbase.Name}#comment-{comment.ID}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(
f"/pkgbase/{pkgbase.Name}#comment-{comment.ID}",
status_code=HTTPStatus.SEE_OTHER,
)
@router.get("/pkgbase/{name}/comments/{id}/form")
@requires_auth
async def pkgbase_comment_form(request: Request, name: str, id: int,
next: str = Query(default=None)):
async def pkgbase_comment_form(
request: Request, name: str, id: int, next: str = Query(default=None)
):
"""
Produce a comment form for comment {id}.
@ -244,14 +249,16 @@ async def pkgbase_comment_form(request: Request, name: str, id: int,
context["next"] = next
form = templates.render_raw_template(
request, "partials/packages/comment_form.html", context)
request, "partials/packages/comment_form.html", context
)
return JSONResponse({"form": form})
@router.get("/pkgbase/{name}/comments/{id}/edit")
@requires_auth
async def pkgbase_comment_edit(request: Request, name: str, id: int,
next: str = Form(default=None)):
async def pkgbase_comment_edit(
request: Request, name: str, id: int, next: str = Form(default=None)
):
"""
Render the non-javascript edit form.
@ -276,10 +283,13 @@ async def pkgbase_comment_edit(request: Request, name: str, id: int,
@handle_form_exceptions
@requires_auth
async def pkgbase_comment_post(
request: Request, name: str, id: int,
request: Request,
name: str,
id: int,
comment: str = Form(default=str()),
enable_notifications: bool = Form(default=False),
next: str = Form(default=None)):
next: str = Form(default=None),
):
"""Edit an existing comment."""
pkgbase = get_pkg_or_base(name, PackageBase)
db_comment = get_pkgbase_comment(pkgbase, id)
@ -302,24 +312,24 @@ async def pkgbase_comment_post(
PackageNotification.PackageBaseID == pkgbase.ID
).first()
if enable_notifications and not db_notif:
db.create(PackageNotification,
User=request.user,
PackageBase=pkgbase)
db.create(PackageNotification, User=request.user, PackageBase=pkgbase)
update_comment_render_fastapi(db_comment)
if not next:
next = f"/pkgbase/{pkgbase.Name}"
# Redirect to the pkgbase page anchored to the updated comment.
return RedirectResponse(f"{next}#comment-{db_comment.ID}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(
f"{next}#comment-{db_comment.ID}", status_code=HTTPStatus.SEE_OTHER
)
@router.post("/pkgbase/{name}/comments/{id}/pin")
@handle_form_exceptions
@requires_auth
async def pkgbase_comment_pin(request: Request, name: str, id: int,
next: str = Form(default=None)):
async def pkgbase_comment_pin(
request: Request, name: str, id: int, next: str = Form(default=None)
):
"""
Pin a comment.
@ -332,13 +342,15 @@ async def pkgbase_comment_pin(request: Request, name: str, id: int,
pkgbase = get_pkg_or_base(name, PackageBase)
comment = get_pkgbase_comment(pkgbase, id)
has_cred = request.user.has_credential(creds.COMMENT_PIN,
approved=comment.maintainers())
has_cred = request.user.has_credential(
creds.COMMENT_PIN, approved=comment.maintainers()
)
if not has_cred:
_ = l10n.get_translator_for_request(request)
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail=_("You are not allowed to pin this comment."))
detail=_("You are not allowed to pin this comment."),
)
now = time.utcnow()
with db.begin():
@ -353,8 +365,9 @@ async def pkgbase_comment_pin(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/unpin")
@handle_form_exceptions
@requires_auth
async def pkgbase_comment_unpin(request: Request, name: str, id: int,
next: str = Form(default=None)):
async def pkgbase_comment_unpin(
request: Request, name: str, id: int, next: str = Form(default=None)
):
"""
Unpin a comment.
@ -367,13 +380,15 @@ async def pkgbase_comment_unpin(request: Request, name: str, id: int,
pkgbase = get_pkg_or_base(name, PackageBase)
comment = get_pkgbase_comment(pkgbase, id)
has_cred = request.user.has_credential(creds.COMMENT_PIN,
approved=comment.maintainers())
has_cred = request.user.has_credential(
creds.COMMENT_PIN, approved=comment.maintainers()
)
if not has_cred:
_ = l10n.get_translator_for_request(request)
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail=_("You are not allowed to unpin this comment."))
detail=_("You are not allowed to unpin this comment."),
)
with db.begin():
comment.PinnedTS = 0
@ -387,8 +402,9 @@ async def pkgbase_comment_unpin(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/delete")
@handle_form_exceptions
@requires_auth
async def pkgbase_comment_delete(request: Request, name: str, id: int,
next: str = Form(default=None)):
async def pkgbase_comment_delete(
request: Request, name: str, id: int, next: str = Form(default=None)
):
"""
Delete a comment.
@ -405,13 +421,13 @@ async def pkgbase_comment_delete(request: Request, name: str, id: int,
pkgbase = get_pkg_or_base(name, PackageBase)
comment = get_pkgbase_comment(pkgbase, id)
authorized = request.user.has_credential(creds.COMMENT_DELETE,
[comment.User])
authorized = request.user.has_credential(creds.COMMENT_DELETE, [comment.User])
if not authorized:
_ = l10n.get_translator_for_request(request)
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail=_("You are not allowed to delete this comment."))
detail=_("You are not allowed to delete this comment."),
)
now = time.utcnow()
with db.begin():
@ -427,8 +443,9 @@ async def pkgbase_comment_delete(request: Request, name: str, id: int,
@router.post("/pkgbase/{name}/comments/{id}/undelete")
@handle_form_exceptions
@requires_auth
async def pkgbase_comment_undelete(request: Request, name: str, id: int,
next: str = Form(default=None)):
async def pkgbase_comment_undelete(
request: Request, name: str, id: int, next: str = Form(default=None)
):
"""
Undelete a comment.
@ -445,13 +462,15 @@ async def pkgbase_comment_undelete(request: Request, name: str, id: int,
pkgbase = get_pkg_or_base(name, PackageBase)
comment = get_pkgbase_comment(pkgbase, id)
has_cred = request.user.has_credential(creds.COMMENT_UNDELETE,
approved=[comment.User])
has_cred = request.user.has_credential(
creds.COMMENT_UNDELETE, approved=[comment.User]
)
if not has_cred:
_ = l10n.get_translator_for_request(request)
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail=_("You are not allowed to undelete this comment."))
detail=_("You are not allowed to undelete this comment."),
)
with db.begin():
comment.Deleter = None
@ -469,23 +488,17 @@ async def pkgbase_comment_undelete(request: Request, name: str, id: int,
async def pkgbase_vote(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
vote = pkgbase.package_votes.filter(
PackageVote.UsersID == request.user.ID
).first()
vote = pkgbase.package_votes.filter(PackageVote.UsersID == request.user.ID).first()
has_cred = request.user.has_credential(creds.PKGBASE_VOTE)
if has_cred and not vote:
now = time.utcnow()
with db.begin():
db.create(PackageVote,
User=request.user,
PackageBase=pkgbase,
VoteTS=now)
db.create(PackageVote, User=request.user, PackageBase=pkgbase, VoteTS=now)
# Update NumVotes/Popularity.
popupdate.run_single(pkgbase)
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@router.post("/pkgbase/{name}/unvote")
@ -494,9 +507,7 @@ async def pkgbase_vote(request: Request, name: str):
async def pkgbase_unvote(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
vote = pkgbase.package_votes.filter(
PackageVote.UsersID == request.user.ID
).first()
vote = pkgbase.package_votes.filter(PackageVote.UsersID == request.user.ID).first()
has_cred = request.user.has_credential(creds.PKGBASE_VOTE)
if has_cred and vote:
with db.begin():
@ -505,8 +516,7 @@ async def pkgbase_unvote(request: Request, name: str):
# Update NumVotes/Popularity.
popupdate.run_single(pkgbase)
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@router.post("/pkgbase/{name}/notify")
@ -515,8 +525,7 @@ async def pkgbase_unvote(request: Request, name: str):
async def pkgbase_notify(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_notify_instance(request, pkgbase)
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@router.post("/pkgbase/{name}/unnotify")
@ -525,8 +534,7 @@ async def pkgbase_notify(request: Request, name: str):
async def pkgbase_unnotify(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_unnotify_instance(request, pkgbase)
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@router.post("/pkgbase/{name}/unflag")
@ -535,20 +543,19 @@ async def pkgbase_unnotify(request: Request, name: str):
async def pkgbase_unflag(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase)
actions.pkgbase_unflag_instance(request, pkgbase)
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@router.get("/pkgbase/{name}/disown")
@requires_auth
async def pkgbase_disown_get(request: Request, name: str,
next: str = Query(default=str())):
async def pkgbase_disown_get(
request: Request, name: str, next: str = Query(default=str())
):
pkgbase = get_pkg_or_base(name, PackageBase)
comaints = {c.User for c in pkgbase.comaintainers}
approved = [pkgbase.Maintainer] + list(comaints)
has_cred = request.user.has_credential(creds.PKGBASE_DISOWN,
approved=approved)
has_cred = request.user.has_credential(creds.PKGBASE_DISOWN, approved=approved)
if not has_cred:
return RedirectResponse(f"/pkgbase/{name}", HTTPStatus.SEE_OTHER)
@ -563,27 +570,33 @@ async def pkgbase_disown_get(request: Request, name: str,
@router.post("/pkgbase/{name}/disown")
@handle_form_exceptions
@requires_auth
async def pkgbase_disown_post(request: Request, name: str,
async def pkgbase_disown_post(
request: Request,
name: str,
comments: str = Form(default=str()),
confirm: bool = Form(default=False),
next: str = Form(default=str())):
next: str = Form(default=str()),
):
pkgbase = get_pkg_or_base(name, PackageBase)
comaints = {c.User for c in pkgbase.comaintainers}
approved = [pkgbase.Maintainer] + list(comaints)
has_cred = request.user.has_credential(creds.PKGBASE_DISOWN,
approved=approved)
has_cred = request.user.has_credential(creds.PKGBASE_DISOWN, approved=approved)
if not has_cred:
return RedirectResponse(f"/pkgbase/{name}",
HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", HTTPStatus.SEE_OTHER)
context = templates.make_context(request, "Disown Package")
context["pkgbase"] = pkgbase
if not confirm:
context["errors"] = [("The selected packages have not been disowned, "
"check the confirmation checkbox.")]
return render_template(request, "pkgbase/disown.html", context,
status_code=HTTPStatus.BAD_REQUEST)
context["errors"] = [
(
"The selected packages have not been disowned, "
"check the confirmation checkbox."
)
]
return render_template(
request, "pkgbase/disown.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if request.user != pkgbase.Maintainer and request.user not in comaints:
with db.begin():
@ -593,8 +606,9 @@ async def pkgbase_disown_post(request: Request, name: str,
actions.pkgbase_disown_instance(request, pkgbase)
except InvariantError as exc:
context["errors"] = [str(exc)]
return render_template(request, "pkgbase/disown.html", context,
status_code=HTTPStatus.BAD_REQUEST)
return render_template(
request, "pkgbase/disown.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if not next:
next = f"/pkgbase/{name}"
@ -615,8 +629,7 @@ async def pkgbase_adopt_post(request: Request, name: str):
# if no maintainer currently exists.
actions.pkgbase_adopt_instance(request, pkgbase)
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@router.get("/pkgbase/{name}/comaintainers")
@ -627,20 +640,20 @@ async def pkgbase_comaintainers(request: Request, name: str) -> Response:
# Unauthorized users (Non-TU/Dev and not the pkgbase maintainer)
# get redirected to the package base's page.
has_creds = request.user.has_credential(creds.PKGBASE_EDIT_COMAINTAINERS,
approved=[pkgbase.Maintainer])
has_creds = request.user.has_credential(
creds.PKGBASE_EDIT_COMAINTAINERS, approved=[pkgbase.Maintainer]
)
if not has_creds:
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
# Add our base information.
context = templates.make_context(request, "Manage Co-maintainers")
context.update({
context.update(
{
"pkgbase": pkgbase,
"comaintainers": [
c.User.Username for c in pkgbase.comaintainers
]
})
"comaintainers": [c.User.Username for c in pkgbase.comaintainers],
}
)
return render_template(request, "pkgbase/comaintainers.html", context)
@ -648,50 +661,52 @@ async def pkgbase_comaintainers(request: Request, name: str) -> Response:
@router.post("/pkgbase/{name}/comaintainers")
@handle_form_exceptions
@requires_auth
async def pkgbase_comaintainers_post(request: Request, name: str,
users: str = Form(default=str())) \
-> Response:
async def pkgbase_comaintainers_post(
request: Request, name: str, users: str = Form(default=str())
) -> Response:
# Get the PackageBase.
pkgbase = get_pkg_or_base(name, PackageBase)
# Unauthorized users (Non-TU/Dev and not the pkgbase maintainer)
# get redirected to the package base's page.
has_creds = request.user.has_credential(creds.PKGBASE_EDIT_COMAINTAINERS,
approved=[pkgbase.Maintainer])
has_creds = request.user.has_credential(
creds.PKGBASE_EDIT_COMAINTAINERS, approved=[pkgbase.Maintainer]
)
if not has_creds:
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
users = {e.strip() for e in users.split("\n") if bool(e.strip())}
records = {c.User.Username for c in pkgbase.comaintainers}
users_to_rm = records.difference(users)
pkgbaseutil.remove_comaintainers(pkgbase, users_to_rm)
logger.debug(f"{request.user} removed comaintainers from "
f"{pkgbase.Name}: {users_to_rm}")
logger.debug(
f"{request.user} removed comaintainers from " f"{pkgbase.Name}: {users_to_rm}"
)
users_to_add = users.difference(records)
error = pkgbaseutil.add_comaintainers(request, pkgbase, users_to_add)
if error:
context = templates.make_context(request, "Manage Co-maintainers")
context["pkgbase"] = pkgbase
context["comaintainers"] = [
c.User.Username for c in pkgbase.comaintainers
]
context["comaintainers"] = [c.User.Username for c in pkgbase.comaintainers]
context["errors"] = [error]
return render_template(request, "pkgbase/comaintainers.html", context)
logger.debug(f"{request.user} added comaintainers to "
f"{pkgbase.Name}: {users_to_add}")
logger.debug(
f"{request.user} added comaintainers to " f"{pkgbase.Name}: {users_to_add}"
)
return RedirectResponse(f"/pkgbase/{pkgbase.Name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(
f"/pkgbase/{pkgbase.Name}", status_code=HTTPStatus.SEE_OTHER
)
@router.get("/pkgbase/{name}/request")
@requires_auth
async def pkgbase_request(request: Request, name: str,
next: str = Query(default=str())):
async def pkgbase_request(
request: Request, name: str, next: str = Query(default=str())
):
pkgbase = get_pkg_or_base(name, PackageBase)
context = await make_variable_context(request, "Submit Request")
context["pkgbase"] = pkgbase
@ -702,28 +717,28 @@ async def pkgbase_request(request: Request, name: str,
@router.post("/pkgbase/{name}/request")
@handle_form_exceptions
@requires_auth
async def pkgbase_request_post(request: Request, name: str,
async def pkgbase_request_post(
request: Request,
name: str,
type: str = Form(...),
merge_into: str = Form(default=None),
comments: str = Form(default=str()),
next: str = Form(default=str())):
next: str = Form(default=str()),
):
pkgbase = get_pkg_or_base(name, PackageBase)
# Create our render context.
context = await make_variable_context(request, "Submit Request")
context["pkgbase"] = pkgbase
types = {
"deletion": DELETION_ID,
"merge": MERGE_ID,
"orphan": ORPHAN_ID
}
types = {"deletion": DELETION_ID, "merge": MERGE_ID, "orphan": ORPHAN_ID}
if type not in types:
# In the case that someone crafted a POST request with an invalid
# type, just return them to the request form with BAD_REQUEST status.
return render_template(request, "pkgbase/request.html", context,
status_code=HTTPStatus.BAD_REQUEST)
return render_template(
request, "pkgbase/request.html", context, status_code=HTTPStatus.BAD_REQUEST
)
try:
validate.request(pkgbase, type, comments, merge_into, context)
@ -735,7 +750,8 @@ async def pkgbase_request_post(request: Request, name: str,
# All good. Create a new PackageRequest based on the given type.
now = time.utcnow()
with db.begin():
pkgreq = db.create(PackageRequest,
pkgreq = db.create(
PackageRequest,
ReqTypeID=types.get(type),
User=request.user,
RequestTS=now,
@ -743,12 +759,17 @@ async def pkgbase_request_post(request: Request, name: str,
PackageBaseName=pkgbase.Name,
MergeBaseName=merge_into,
Comments=comments,
ClosureComment=str())
ClosureComment=str(),
)
# Prepare notification object.
notif = notify.RequestOpenNotification(
request.user.ID, pkgreq.ID, type,
pkgreq.PackageBase.ID, merge_into=merge_into or None)
request.user.ID,
pkgreq.ID,
type,
pkgreq.PackageBase.ID,
merge_into=merge_into or None,
)
# Send the notification now that we're out of the DB scope.
notif.send()
@ -767,13 +788,13 @@ async def pkgbase_request_post(request: Request, name: str,
pkgbase.Maintainer = None
pkgreq.Status = ACCEPTED_ID
notif = notify.RequestCloseNotification(
request.user.ID, pkgreq.ID, pkgreq.status_display())
request.user.ID, pkgreq.ID, pkgreq.status_display()
)
notif.send()
logger.debug(f"New request #{pkgreq.ID} is marked for auto-orphan.")
elif type == "deletion" and is_maintainer and outdated:
# This request should be auto-accepted.
notifs = actions.pkgbase_delete_instance(
request, pkgbase, comments=comments)
notifs = actions.pkgbase_delete_instance(request, pkgbase, comments=comments)
util.apply_all(notifs, lambda n: n.send())
logger.debug(f"New request #{pkgreq.ID} is marked for auto-deletion.")
@ -783,11 +804,11 @@ async def pkgbase_request_post(request: Request, name: str,
@router.get("/pkgbase/{name}/delete")
@requires_auth
async def pkgbase_delete_get(request: Request, name: str,
next: str = Query(default=str())):
async def pkgbase_delete_get(
request: Request, name: str, next: str = Query(default=str())
):
if not request.user.has_credential(creds.PKGBASE_DELETE):
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
context = templates.make_context(request, "Package Deletion")
context["pkgbase"] = get_pkg_or_base(name, PackageBase)
@ -798,53 +819,60 @@ async def pkgbase_delete_get(request: Request, name: str,
@router.post("/pkgbase/{name}/delete")
@handle_form_exceptions
@requires_auth
async def pkgbase_delete_post(request: Request, name: str,
async def pkgbase_delete_post(
request: Request,
name: str,
confirm: bool = Form(default=False),
comments: str = Form(default=str()),
next: str = Form(default="/packages")):
next: str = Form(default="/packages"),
):
pkgbase = get_pkg_or_base(name, PackageBase)
if not request.user.has_credential(creds.PKGBASE_DELETE):
return RedirectResponse(f"/pkgbase/{name}",
status_code=HTTPStatus.SEE_OTHER)
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
if not confirm:
context = templates.make_context(request, "Package Deletion")
context["pkgbase"] = pkgbase
context["errors"] = [("The selected packages have not been deleted, "
"check the confirmation checkbox.")]
return render_template(request, "pkgbase/delete.html", context,
status_code=HTTPStatus.BAD_REQUEST)
context["errors"] = [
(
"The selected packages have not been deleted, "
"check the confirmation checkbox."
)
]
return render_template(
request, "pkgbase/delete.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if comments:
# Update any existing deletion requests' ClosureComment.
with db.begin():
requests = pkgbase.requests.filter(
and_(PackageRequest.Status == PENDING_ID,
PackageRequest.ReqTypeID == DELETION_ID)
and_(
PackageRequest.Status == PENDING_ID,
PackageRequest.ReqTypeID == DELETION_ID,
)
)
for pkgreq in requests:
pkgreq.ClosureComment = comments
notifs = actions.pkgbase_delete_instance(
request, pkgbase, comments=comments)
notifs = actions.pkgbase_delete_instance(request, pkgbase, comments=comments)
util.apply_all(notifs, lambda n: n.send())
return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER)
@router.get("/pkgbase/{name}/merge")
@requires_auth
async def pkgbase_merge_get(request: Request, name: str,
async def pkgbase_merge_get(
request: Request,
name: str,
into: str = Query(default=str()),
next: str = Query(default=str())):
next: str = Query(default=str()),
):
pkgbase = get_pkg_or_base(name, PackageBase)
context = templates.make_context(request, "Package Merging")
context.update({
"pkgbase": pkgbase,
"into": into,
"next": next
})
context.update({"pkgbase": pkgbase, "into": into, "next": next})
status_code = HTTPStatus.OK
# TODO: Lookup errors from credential instead of hardcoding them.
@ -852,51 +880,58 @@ async def pkgbase_merge_get(request: Request, name: str,
# Perhaps additionally: bad_credential_status_code(creds.PKGBASE_MERGE).
# Don't take these examples verbatim. We should find good naming.
if not request.user.has_credential(creds.PKGBASE_MERGE):
context["errors"] = [
"Only Trusted Users and Developers can merge packages."]
context["errors"] = ["Only Trusted Users and Developers can merge packages."]
status_code = HTTPStatus.UNAUTHORIZED
return render_template(request, "pkgbase/merge.html", context,
status_code=status_code)
return render_template(
request, "pkgbase/merge.html", context, status_code=status_code
)
@router.post("/pkgbase/{name}/merge")
@handle_form_exceptions
@requires_auth
async def pkgbase_merge_post(request: Request, name: str,
async def pkgbase_merge_post(
request: Request,
name: str,
into: str = Form(default=str()),
comments: str = Form(default=str()),
confirm: bool = Form(default=False),
next: str = Form(default=str())):
next: str = Form(default=str()),
):
pkgbase = get_pkg_or_base(name, PackageBase)
context = await make_variable_context(request, "Package Merging")
context["pkgbase"] = pkgbase
# TODO: Lookup errors from credential instead of hardcoding them.
if not request.user.has_credential(creds.PKGBASE_MERGE):
context["errors"] = [
"Only Trusted Users and Developers can merge packages."]
return render_template(request, "pkgbase/merge.html", context,
status_code=HTTPStatus.UNAUTHORIZED)
context["errors"] = ["Only Trusted Users and Developers can merge packages."]
return render_template(
request, "pkgbase/merge.html", context, status_code=HTTPStatus.UNAUTHORIZED
)
if not confirm:
context["errors"] = ["The selected packages have not been deleted, "
"check the confirmation checkbox."]
return render_template(request, "pkgbase/merge.html", context,
status_code=HTTPStatus.BAD_REQUEST)
context["errors"] = [
"The selected packages have not been deleted, "
"check the confirmation checkbox."
]
return render_template(
request, "pkgbase/merge.html", context, status_code=HTTPStatus.BAD_REQUEST
)
try:
target = get_pkg_or_base(into, PackageBase)
except HTTPException:
context["errors"] = [
"Cannot find package to merge votes and comments into."]
return render_template(request, "pkgbase/merge.html", context,
status_code=HTTPStatus.BAD_REQUEST)
context["errors"] = ["Cannot find package to merge votes and comments into."]
return render_template(
request, "pkgbase/merge.html", context, status_code=HTTPStatus.BAD_REQUEST
)
if pkgbase == target:
context["errors"] = ["Cannot merge a package base with itself."]
return render_template(request, "pkgbase/merge.html", context,
status_code=HTTPStatus.BAD_REQUEST)
return render_template(
request, "pkgbase/merge.html", context, status_code=HTTPStatus.BAD_REQUEST
)
with db.begin():
update_closure_comment(pkgbase, MERGE_ID, comments, target=target)

View file

@ -18,9 +18,11 @@ router = APIRouter()
@router.get("/requests")
@requires_auth
async def requests(request: Request,
async def requests(
request: Request,
O: int = Query(default=defaults.O),
PP: int = Query(default=defaults.PP)):
PP: int = Query(default=defaults.PP),
):
context = make_context(request, "Requests")
context["q"] = dict(request.query_params)
@ -30,8 +32,7 @@ async def requests(request: Request,
context["PP"] = PP
# A PackageRequest query, with left inner joined User and RequestType.
query = db.query(PackageRequest).join(
User, User.ID == PackageRequest.UsersID)
query = db.query(PackageRequest).join(User, User.ID == PackageRequest.UsersID)
# If the request user is not elevated (TU or Dev), then
# filter PackageRequests which are owned by the request user.
@ -39,12 +40,17 @@ async def requests(request: Request,
query = query.filter(PackageRequest.UsersID == request.user.ID)
context["total"] = query.count()
context["results"] = query.order_by(
context["results"] = (
query.order_by(
# Order primarily by the Status column being PENDING_ID,
# and secondarily by RequestTS; both in descending order.
case([(PackageRequest.Status == PENDING_ID, 1)], else_=0).desc(),
PackageRequest.RequestTS.desc()
).limit(PP).offset(O).all()
PackageRequest.RequestTS.desc(),
)
.limit(PP)
.offset(O)
.all()
)
return render_template(request, "requests.html", context)
@ -66,8 +72,9 @@ async def request_close(request: Request, id: int):
@router.post("/requests/{id}/close")
@handle_form_exceptions
@requires_auth
async def request_close_post(request: Request, id: int,
comments: str = Form(default=str())):
async def request_close_post(
request: Request, id: int, comments: str = Form(default=str())
):
pkgreq = get_pkgreq_by_id(id)
# `pkgreq`.User can close their own request.
@ -87,7 +94,8 @@ async def request_close_post(request: Request, id: int,
pkgreq.Status = REJECTED_ID
notify_ = notify.RequestCloseNotification(
request.user.ID, pkgreq.ID, pkgreq.status_display())
request.user.ID, pkgreq.ID, pkgreq.status_display()
)
notify_.send()
return RedirectResponse("/requests", status_code=HTTPStatus.SEE_OTHER)

View file

@ -1,12 +1,10 @@
import hashlib
import re
from http import HTTPStatus
from typing import Optional
from urllib.parse import unquote
import orjson
from fastapi import APIRouter, Form, Query, Request, Response
from fastapi.responses import JSONResponse
@ -39,9 +37,7 @@ def parse_args(request: Request):
# Create a list of (key, value) pairs of the given 'arg' and 'arg[]'
# query parameters from last to first.
query = list(reversed(unquote(request.url.query).split("&")))
parts = [
e.split("=", 1) for e in query if e.startswith(("arg=", "arg[]="))
]
parts = [e.split("=", 1) for e in query if e.startswith(("arg=", "arg[]="))]
args = []
if parts:
@ -63,24 +59,28 @@ def parse_args(request: Request):
return args
JSONP_EXPR = re.compile(r'^[a-zA-Z0-9()_.]{1,128}$')
JSONP_EXPR = re.compile(r"^[a-zA-Z0-9()_.]{1,128}$")
async def rpc_request(request: Request,
async def rpc_request(
request: Request,
v: Optional[int] = None,
type: Optional[str] = None,
by: Optional[str] = defaults.RPC_SEARCH_BY,
arg: Optional[str] = None,
args: Optional[list[str]] = [],
callback: Optional[str] = None):
callback: Optional[str] = None,
):
# Create a handle to our RPC class.
rpc = RPC(version=v, type=type)
# If ratelimit was exceeded, return a 429 Too Many Requests.
if check_ratelimit(request):
return JSONResponse(rpc.error("Rate limit reached"),
status_code=int(HTTPStatus.TOO_MANY_REQUESTS))
return JSONResponse(
rpc.error("Rate limit reached"),
status_code=int(HTTPStatus.TOO_MANY_REQUESTS),
)
# If `callback` was provided, produce a text/javascript response
# valid for the jsonp callback. Otherwise, by default, return
@ -115,15 +115,11 @@ async def rpc_request(request: Request,
# The ETag header expects quotes to surround any identifier.
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag
headers = {
"Content-Type": content_type,
"ETag": f'"{etag}"'
}
headers = {"Content-Type": content_type, "ETag": f'"{etag}"'}
if_none_match = request.headers.get("If-None-Match", str())
if if_none_match and if_none_match.strip("\t\n\r\" ") == etag:
return Response(headers=headers,
status_code=int(HTTPStatus.NOT_MODIFIED))
if if_none_match and if_none_match.strip('\t\n\r" ') == etag:
return Response(headers=headers, status_code=int(HTTPStatus.NOT_MODIFIED))
if callback:
content = f"/**/{callback}({content.decode()})"
@ -135,13 +131,15 @@ async def rpc_request(request: Request,
@router.get("/rpc.php") # Temporary! Remove on 03/04
@router.get("/rpc/")
@router.get("/rpc")
async def rpc(request: Request,
async def rpc(
request: Request,
v: Optional[int] = Query(default=None),
type: Optional[str] = Query(default=None),
by: Optional[str] = Query(default=defaults.RPC_SEARCH_BY),
arg: Optional[str] = Query(default=None),
args: Optional[list[str]] = Query(default=[], alias="arg[]"),
callback: Optional[str] = Query(default=None)):
callback: Optional[str] = Query(default=None),
):
if not request.url.query:
return documentation()
return await rpc_request(request, v, type, by, arg, args, callback)
@ -152,11 +150,13 @@ async def rpc(request: Request,
@router.post("/rpc/")
@router.post("/rpc")
@handle_form_exceptions
async def rpc_post(request: Request,
async def rpc_post(
request: Request,
v: Optional[int] = Form(default=None),
type: Optional[str] = Form(default=None),
by: Optional[str] = Form(default=defaults.RPC_SEARCH_BY),
arg: Optional[str] = Form(default=None),
args: Optional[list[str]] = Form(default=[], alias="arg[]"),
callback: Optional[str] = Form(default=None)):
callback: Optional[str] = Form(default=None),
):
return await rpc_request(request, v, type, by, arg, args, callback)

View file

@ -10,8 +10,7 @@ from aurweb.models import Package, PackageBase
router = APIRouter()
def make_rss_feed(request: Request, packages: list,
date_attr: str):
def make_rss_feed(request: Request, packages: list, date_attr: str):
"""Create an RSS Feed string for some packages.
:param request: A FastAPI request
@ -26,10 +25,12 @@ def make_rss_feed(request: Request, packages: list,
base = f"{request.url.scheme}://{request.url.netloc}"
feed.link(href=base, rel="alternate")
feed.link(href=f"{base}/rss", rel="self")
feed.image(title="AUR Newest Packages",
feed.image(
title="AUR Newest Packages",
url=f"{base}/static/css/archnavbar/aurlogo.png",
link=base,
description="AUR Newest Packages Feed")
description="AUR Newest Packages Feed",
)
for pkg in packages:
entry = feed.add_entry(order="append")
@ -53,8 +54,12 @@ def make_rss_feed(request: Request, packages: list,
@router.get("/rss/")
async def rss(request: Request):
packages = db.query(Package).join(PackageBase).order_by(
PackageBase.SubmittedTS.desc()).limit(100)
packages = (
db.query(Package)
.join(PackageBase)
.order_by(PackageBase.SubmittedTS.desc())
.limit(100)
)
feed = make_rss_feed(request, packages, "SubmittedTS")
response = Response(feed, media_type="application/rss+xml")
@ -69,8 +74,12 @@ async def rss(request: Request):
@router.get("/rss/modified")
async def rss_modified(request: Request):
packages = db.query(Package).join(PackageBase).order_by(
PackageBase.ModifiedTS.desc()).limit(100)
packages = (
db.query(Package)
.join(PackageBase)
.order_by(PackageBase.ModifiedTS.desc())
.limit(100)
)
feed = make_rss_feed(request, packages, "ModifiedTS")
response = Response(feed, media_type="application/rss+xml")

View file

@ -1,11 +1,9 @@
import time
import uuid
from http import HTTPStatus
from urllib.parse import urlencode
import fastapi
from authlib.integrations.starlette_client import OAuth, OAuthError
from fastapi import Depends, HTTPException
from fastapi.responses import RedirectResponse
@ -14,7 +12,6 @@ from starlette.requests import Request
import aurweb.config
import aurweb.db
from aurweb import util
from aurweb.l10n import get_translator_for_request
from aurweb.schema import Bans, Sessions, Users
@ -43,14 +40,18 @@ async def login(request: Request, redirect: str = None):
The `redirect` argument is a query parameter specifying the post-login
redirect URL.
"""
authenticate_url = aurweb.config.get("options", "aur_location") + "/sso/authenticate"
authenticate_url = (
aurweb.config.get("options", "aur_location") + "/sso/authenticate"
)
if redirect:
authenticate_url = authenticate_url + "?" + urlencode([("redirect", redirect)])
return await oauth.sso.authorize_redirect(request, authenticate_url, prompt="login")
def is_account_suspended(conn, user_id):
row = conn.execute(select([Users.c.Suspended]).where(Users.c.ID == user_id)).fetchone()
row = conn.execute(
select([Users.c.Suspended]).where(Users.c.ID == user_id)
).fetchone()
return row is not None and bool(row[0])
@ -60,23 +61,27 @@ def open_session(request, conn, user_id):
"""
if is_account_suspended(conn, user_id):
_ = get_translator_for_request(request)
raise HTTPException(status_code=HTTPStatus.FORBIDDEN,
detail=_('Account suspended'))
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail=_("Account suspended")
)
# TODO This is a terrible message because it could imply the attempt at
# logging in just caused the suspension.
sid = uuid.uuid4().hex
conn.execute(Sessions.insert().values(
conn.execute(
Sessions.insert().values(
UsersID=user_id,
SessionID=sid,
LastUpdateTS=time.time(),
))
)
)
# Update users last login information.
conn.execute(Users.update()
conn.execute(
Users.update()
.where(Users.c.ID == user_id)
.values(LastLogin=int(time.time()),
LastLoginIPAddress=request.client.host))
.values(LastLogin=int(time.time()), LastLoginIPAddress=request.client.host)
)
return sid
@ -98,7 +103,9 @@ def is_aur_url(url):
@router.get("/sso/authenticate")
async def authenticate(request: Request, redirect: str = None, conn=Depends(aurweb.db.connect)):
async def authenticate(
request: Request, redirect: str = None, conn=Depends(aurweb.db.connect)
):
"""
Receive an OpenID Connect ID token, validate it, then process it to create
an new AUR session.
@ -107,9 +114,12 @@ async def authenticate(request: Request, redirect: str = None, conn=Depends(aurw
_ = get_translator_for_request(request)
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail=_('The login form is currently disabled for your IP address, '
'probably due to sustained spam attacks. Sorry for the '
'inconvenience.'))
detail=_(
"The login form is currently disabled for your IP address, "
"probably due to sustained spam attacks. Sorry for the "
"inconvenience."
),
)
try:
token = await oauth.sso.authorize_access_token(request)
@ -120,30 +130,41 @@ async def authenticate(request: Request, redirect: str = None, conn=Depends(aurw
_ = get_translator_for_request(request)
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=_('Bad OAuth token. Please retry logging in from the start.'))
detail=_("Bad OAuth token. Please retry logging in from the start."),
)
sub = user.get("sub") # this is the SSO account ID in JWT terminology
if not sub:
_ = get_translator_for_request(request)
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST,
detail=_("JWT is missing its `sub` field."))
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=_("JWT is missing its `sub` field."),
)
aur_accounts = conn.execute(select([Users.c.ID]).where(Users.c.SSOAccountID == sub)) \
.fetchall()
aur_accounts = conn.execute(
select([Users.c.ID]).where(Users.c.SSOAccountID == sub)
).fetchall()
if not aur_accounts:
return "Sorry, we dont seem to know you Sir " + sub
elif len(aur_accounts) == 1:
sid = open_session(request, conn, aur_accounts[0][Users.c.ID])
response = RedirectResponse(redirect if redirect and is_aur_url(redirect) else "/")
response = RedirectResponse(
redirect if redirect and is_aur_url(redirect) else "/"
)
secure_cookies = aurweb.config.getboolean("options", "disable_http_login")
response.set_cookie(key="AURSID", value=sid, httponly=True,
secure=secure_cookies)
response.set_cookie(
key="AURSID", value=sid, httponly=True, secure=secure_cookies
)
if "id_token" in token:
# We save the id_token for the SSO logout. Its not too important
# though, so if we cant find it, we can live without it.
response.set_cookie(key="SSO_ID_TOKEN", value=token["id_token"],
path="/sso/", httponly=True,
secure=secure_cookies)
response.set_cookie(
key="SSO_ID_TOKEN",
value=token["id_token"],
path="/sso/",
httponly=True,
secure=secure_cookies,
)
return util.add_samesite_fields(response, "strict")
else:
# Weve got a severe integrity violation.
@ -165,8 +186,12 @@ async def logout(request: Request):
return RedirectResponse("/")
metadata = await oauth.sso.load_server_metadata()
query = urlencode({'post_logout_redirect_uri': aurweb.config.get('options', 'aur_location'),
'id_token_hint': id_token})
response = RedirectResponse(metadata["end_session_endpoint"] + '?' + query)
query = urlencode(
{
"post_logout_redirect_uri": aurweb.config.get("options", "aur_location"),
"id_token_hint": id_token,
}
)
response = RedirectResponse(metadata["end_session_endpoint"] + "?" + query)
response.delete_cookie("SSO_ID_TOKEN", path="/sso/")
return response

View file

@ -1,6 +1,5 @@
import html
import typing
from http import HTTPStatus
from typing import Any
@ -30,32 +29,35 @@ ADDVOTE_SPECIFICS = {
"add_tu": (7 * 24 * 60 * 60, 0.66),
"remove_tu": (7 * 24 * 60 * 60, 0.75),
"remove_inactive_tu": (5 * 24 * 60 * 60, 0.66),
"bylaws": (7 * 24 * 60 * 60, 0.75)
"bylaws": (7 * 24 * 60 * 60, 0.75),
}
def populate_trusted_user_counts(context: dict[str, Any]) -> None:
tu_query = db.query(User).filter(
or_(User.AccountTypeID == TRUSTED_USER_ID,
User.AccountTypeID == TRUSTED_USER_AND_DEV_ID)
or_(
User.AccountTypeID == TRUSTED_USER_ID,
User.AccountTypeID == TRUSTED_USER_AND_DEV_ID,
)
)
context["trusted_user_count"] = tu_query.count()
# In case any records have a None InactivityTS.
active_tu_query = tu_query.filter(
or_(User.InactivityTS.is_(None),
User.InactivityTS == 0)
or_(User.InactivityTS.is_(None), User.InactivityTS == 0)
)
context["active_trusted_user_count"] = active_tu_query.count()
@router.get("/tu")
@requires_auth
async def trusted_user(request: Request,
async def trusted_user(
request: Request,
coff: int = 0, # current offset
cby: str = "desc", # current by
poff: int = 0, # past offset
pby: str = "desc"): # past by
pby: str = "desc",
): # past by
"""Proposal listings."""
if not request.user.has_credential(creds.TU_LIST_VOTES):
@ -81,40 +83,47 @@ async def trusted_user(request: Request,
past_by = "desc"
context["past_by"] = past_by
current_votes = db.query(models.TUVoteInfo).filter(
models.TUVoteInfo.End > ts).order_by(
models.TUVoteInfo.Submitted.desc())
current_votes = (
db.query(models.TUVoteInfo)
.filter(models.TUVoteInfo.End > ts)
.order_by(models.TUVoteInfo.Submitted.desc())
)
context["current_votes_count"] = current_votes.count()
current_votes = current_votes.limit(pp).offset(current_off)
context["current_votes"] = reversed(current_votes.all()) \
if current_by == "asc" else current_votes.all()
context["current_votes"] = (
reversed(current_votes.all()) if current_by == "asc" else current_votes.all()
)
context["current_off"] = current_off
past_votes = db.query(models.TUVoteInfo).filter(
models.TUVoteInfo.End <= ts).order_by(
models.TUVoteInfo.Submitted.desc())
past_votes = (
db.query(models.TUVoteInfo)
.filter(models.TUVoteInfo.End <= ts)
.order_by(models.TUVoteInfo.Submitted.desc())
)
context["past_votes_count"] = past_votes.count()
past_votes = past_votes.limit(pp).offset(past_off)
context["past_votes"] = reversed(past_votes.all()) \
if past_by == "asc" else past_votes.all()
context["past_votes"] = (
reversed(past_votes.all()) if past_by == "asc" else past_votes.all()
)
context["past_off"] = past_off
last_vote = func.max(models.TUVote.VoteID).label("LastVote")
last_votes_by_tu = db.query(models.TUVote).join(models.User).join(
models.TUVoteInfo,
models.TUVoteInfo.ID == models.TUVote.VoteID
).filter(
and_(models.TUVote.VoteID == models.TUVoteInfo.ID,
last_votes_by_tu = (
db.query(models.TUVote)
.join(models.User)
.join(models.TUVoteInfo, models.TUVoteInfo.ID == models.TUVote.VoteID)
.filter(
and_(
models.TUVote.VoteID == models.TUVoteInfo.ID,
models.User.ID == models.TUVote.UserID,
models.TUVoteInfo.End < ts,
or_(models.User.AccountTypeID == 2,
models.User.AccountTypeID == 4))
).with_entities(
models.TUVote.UserID,
last_vote,
models.User.Username
).group_by(models.TUVote.UserID).order_by(
last_vote.desc(), models.User.Username.asc())
or_(models.User.AccountTypeID == 2, models.User.AccountTypeID == 4),
)
)
.with_entities(models.TUVote.UserID, last_vote, models.User.Username)
.group_by(models.TUVote.UserID)
.order_by(last_vote.desc(), models.User.Username.asc())
)
context["last_votes_by_tu"] = last_votes_by_tu.all()
context["current_by_next"] = "asc" if current_by == "desc" else "desc"
@ -126,17 +135,21 @@ async def trusted_user(request: Request,
"coff": current_off,
"cby": current_by,
"poff": past_off,
"pby": past_by
"pby": past_by,
}
return render_template(request, "tu/index.html", context)
def render_proposal(request: Request, context: dict, proposal: int,
def render_proposal(
request: Request,
context: dict,
proposal: int,
voteinfo: models.TUVoteInfo,
voters: typing.Iterable[models.User],
vote: models.TUVote,
status_code: HTTPStatus = HTTPStatus.OK):
status_code: HTTPStatus = HTTPStatus.OK,
):
"""Render a single TU proposal."""
context["proposal"] = proposal
context["voteinfo"] = voteinfo
@ -146,8 +159,9 @@ def render_proposal(request: Request, context: dict, proposal: int,
participation = (total / voteinfo.ActiveTUs) if voteinfo.ActiveTUs else 0
context["participation"] = participation
accepted = (voteinfo.Yes > voteinfo.ActiveTUs / 2) or \
(participation > voteinfo.Quorum and voteinfo.Yes > voteinfo.No)
accepted = (voteinfo.Yes > voteinfo.ActiveTUs / 2) or (
participation > voteinfo.Quorum and voteinfo.Yes > voteinfo.No
)
context["accepted"] = accepted
can_vote = voters.filter(models.TUVote.User == request.user).first() is None
@ -159,8 +173,7 @@ def render_proposal(request: Request, context: dict, proposal: int,
context["vote"] = vote
context["has_voted"] = vote is not None
return render_template(request, "tu/show.html", context,
status_code=status_code)
return render_template(request, "tu/show.html", context, status_code=status_code)
@router.get("/tu/{proposal}")
@ -172,16 +185,27 @@ async def trusted_user_proposal(request: Request, proposal: int):
context = await make_variable_context(request, "Trusted User")
proposal = int(proposal)
voteinfo = db.query(models.TUVoteInfo).filter(
models.TUVoteInfo.ID == proposal).first()
voteinfo = (
db.query(models.TUVoteInfo).filter(models.TUVoteInfo.ID == proposal).first()
)
if not voteinfo:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
voters = db.query(models.User).join(models.TUVote).filter(
models.TUVote.VoteID == voteinfo.ID)
vote = db.query(models.TUVote).filter(
and_(models.TUVote.UserID == request.user.ID,
models.TUVote.VoteID == voteinfo.ID)).first()
voters = (
db.query(models.User)
.join(models.TUVote)
.filter(models.TUVote.VoteID == voteinfo.ID)
)
vote = (
db.query(models.TUVote)
.filter(
and_(
models.TUVote.UserID == request.user.ID,
models.TUVote.VoteID == voteinfo.ID,
)
)
.first()
)
if not request.user.has_credential(creds.TU_VOTE):
context["error"] = "Only Trusted Users are allowed to vote."
if voteinfo.User == request.user.Username:
@ -196,24 +220,36 @@ async def trusted_user_proposal(request: Request, proposal: int):
@router.post("/tu/{proposal}")
@handle_form_exceptions
@requires_auth
async def trusted_user_proposal_post(request: Request, proposal: int,
decision: str = Form(...)):
async def trusted_user_proposal_post(
request: Request, proposal: int, decision: str = Form(...)
):
if not request.user.has_credential(creds.TU_LIST_VOTES):
return RedirectResponse("/tu", status_code=HTTPStatus.SEE_OTHER)
context = await make_variable_context(request, "Trusted User")
proposal = int(proposal) # Make sure it's an int.
voteinfo = db.query(models.TUVoteInfo).filter(
models.TUVoteInfo.ID == proposal).first()
voteinfo = (
db.query(models.TUVoteInfo).filter(models.TUVoteInfo.ID == proposal).first()
)
if not voteinfo:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
voters = db.query(models.User).join(models.TUVote).filter(
models.TUVote.VoteID == voteinfo.ID)
vote = db.query(models.TUVote).filter(
and_(models.TUVote.UserID == request.user.ID,
models.TUVote.VoteID == voteinfo.ID)).first()
voters = (
db.query(models.User)
.join(models.TUVote)
.filter(models.TUVote.VoteID == voteinfo.ID)
)
vote = (
db.query(models.TUVote)
.filter(
and_(
models.TUVote.UserID == request.user.ID,
models.TUVote.VoteID == voteinfo.ID,
)
)
.first()
)
status_code = HTTPStatus.OK
if not request.user.has_credential(creds.TU_VOTE):
@ -227,16 +263,15 @@ async def trusted_user_proposal_post(request: Request, proposal: int,
status_code = HTTPStatus.BAD_REQUEST
if status_code != HTTPStatus.OK:
return render_proposal(request, context, proposal,
voteinfo, voters, vote,
status_code=status_code)
return render_proposal(
request, context, proposal, voteinfo, voters, vote, status_code=status_code
)
if decision in {"Yes", "No", "Abstain"}:
# Increment whichever decision was given to us.
setattr(voteinfo, decision, getattr(voteinfo, decision) + 1)
else:
return Response("Invalid 'decision' value.",
status_code=HTTPStatus.BAD_REQUEST)
return Response("Invalid 'decision' value.", status_code=HTTPStatus.BAD_REQUEST)
with db.begin():
vote = db.create(models.TUVote, User=request.user, VoteInfo=voteinfo)
@ -247,8 +282,9 @@ async def trusted_user_proposal_post(request: Request, proposal: int,
@router.get("/addvote")
@requires_auth
async def trusted_user_addvote(request: Request, user: str = str(),
type: str = "add_tu", agenda: str = str()):
async def trusted_user_addvote(
request: Request, user: str = str(), type: str = "add_tu", agenda: str = str()
):
if not request.user.has_credential(creds.TU_ADD_VOTE):
return RedirectResponse("/tu", status_code=HTTPStatus.SEE_OTHER)
@ -268,10 +304,12 @@ async def trusted_user_addvote(request: Request, user: str = str(),
@router.post("/addvote")
@handle_form_exceptions
@requires_auth
async def trusted_user_addvote_post(request: Request,
async def trusted_user_addvote_post(
request: Request,
user: str = Form(default=str()),
type: str = Form(default=str()),
agenda: str = Form(default=str())):
agenda: str = Form(default=str()),
):
if not request.user.has_credential(creds.TU_ADD_VOTE):
return RedirectResponse("/tu", status_code=HTTPStatus.SEE_OTHER)
@ -288,21 +326,24 @@ async def trusted_user_addvote_post(request: Request,
# Alright, get some database records, if we can.
if type != "bylaws":
user_record = db.query(models.User).filter(
models.User.Username == user).first()
user_record = db.query(models.User).filter(models.User.Username == user).first()
if user_record is None:
context["error"] = "Username does not exist."
return render_addvote(context, HTTPStatus.NOT_FOUND)
utcnow = time.utcnow()
voteinfo = db.query(models.TUVoteInfo).filter(
and_(models.TUVoteInfo.User == user,
models.TUVoteInfo.End > utcnow)).count()
voteinfo = (
db.query(models.TUVoteInfo)
.filter(
and_(models.TUVoteInfo.User == user, models.TUVoteInfo.End > utcnow)
)
.count()
)
if voteinfo:
_ = l10n.get_translator_for_request(request)
context["error"] = _(
"%s already has proposal running for them.") % (
html.escape(user),)
context["error"] = _("%s already has proposal running for them.") % (
html.escape(user),
)
return render_addvote(context, HTTPStatus.BAD_REQUEST)
if type not in ADDVOTE_SPECIFICS:
@ -323,16 +364,27 @@ async def trusted_user_addvote_post(request: Request,
# Create a new TUVoteInfo (proposal)!
with db.begin():
active_tus = db.query(User).filter(
and_(User.Suspended == 0,
active_tus = (
db.query(User)
.filter(
and_(
User.Suspended == 0,
User.InactivityTS.isnot(None),
User.AccountTypeID.in_(types))
).count()
voteinfo = db.create(models.TUVoteInfo, User=user,
User.AccountTypeID.in_(types),
)
)
.count()
)
voteinfo = db.create(
models.TUVoteInfo,
User=user,
Agenda=html.escape(agenda),
Submitted=timestamp, End=(timestamp + duration),
Quorum=quorum, ActiveTUs=active_tus,
Submitter=request.user)
Submitted=timestamp,
End=(timestamp + duration),
Quorum=quorum,
ActiveTUs=active_tus,
Submitter=request.user,
)
# Redirect to the new proposal.
endpoint = f"/tu/{voteinfo.ID}"

View file

@ -1,5 +1,4 @@
import os
from collections import defaultdict
from typing import Any, Callable, NewType, Union
@ -7,7 +6,6 @@ from fastapi.responses import HTMLResponse
from sqlalchemy import and_, literal, orm
import aurweb.config as config
from aurweb import db, defaults, models
from aurweb.exceptions import RPCError
from aurweb.filters import number_format
@ -23,8 +21,7 @@ TYPE_MAPPING = {
"replaces": "Replaces",
}
DataGenerator = NewType("DataGenerator",
Callable[[models.Package], dict[str, Any]])
DataGenerator = NewType("DataGenerator", Callable[[models.Package], dict[str, Any]])
def documentation():
@ -66,17 +63,25 @@ class RPC:
# A set of RPC types supported by this API.
EXPOSED_TYPES = {
"info", "multiinfo",
"search", "msearch",
"suggest", "suggest-pkgbase"
"info",
"multiinfo",
"search",
"msearch",
"suggest",
"suggest-pkgbase",
}
# A mapping of type aliases.
TYPE_ALIASES = {"info": "multiinfo"}
EXPOSED_BYS = {
"name-desc", "name", "maintainer",
"depends", "makedepends", "optdepends", "checkdepends"
"name-desc",
"name",
"maintainer",
"depends",
"makedepends",
"optdepends",
"checkdepends",
}
# A mapping of by aliases.
@ -92,7 +97,7 @@ class RPC:
"results": [],
"resultcount": 0,
"type": "error",
"error": message
"error": message,
}
def _verify_inputs(self, by: str = [], args: list[str] = []) -> None:
@ -143,7 +148,7 @@ class RPC:
"Popularity": pop,
"OutOfDate": package.OutOfDateTS,
"FirstSubmitted": package.SubmittedTS,
"LastModified": package.ModifiedTS
"LastModified": package.ModifiedTS,
}
def _get_info_json_data(self, package: models.Package) -> dict[str, Any]:
@ -151,10 +156,7 @@ class RPC:
# All info results have _at least_ an empty list of
# License and Keywords.
data.update({
"License": [],
"Keywords": []
})
data.update({"License": [], "Keywords": []})
# If we actually got extra_info records, update data with
# them for this particular package.
@ -163,9 +165,9 @@ class RPC:
return data
def _assemble_json_data(self, packages: list[models.Package],
data_generator: DataGenerator) \
-> list[dict[str, Any]]:
def _assemble_json_data(
self, packages: list[models.Package], data_generator: DataGenerator
) -> list[dict[str, Any]]:
"""
Assemble JSON data out of a list of packages.
@ -192,16 +194,22 @@ class RPC:
models.User.Username.label("Maintainer"),
).group_by(models.Package.ID)
def _handle_multiinfo_type(self, args: list[str] = [], **kwargs) \
-> list[dict[str, Any]]:
def _handle_multiinfo_type(
self, args: list[str] = [], **kwargs
) -> list[dict[str, Any]]:
self._enforce_args(args)
args = set(args)
packages = db.query(models.Package).join(models.PackageBase).join(
packages = (
db.query(models.Package)
.join(models.PackageBase)
.join(
models.User,
models.User.ID == models.PackageBase.MaintainerUID,
isouter=True
).filter(models.Package.Name.in_(args))
isouter=True,
)
.filter(models.Package.Name.in_(args))
)
max_results = config.getint("options", "max_rpc_results")
packages = self._entities(packages).limit(max_results + 1)
@ -217,65 +225,75 @@ class RPC:
subqueries = [
# PackageDependency
db.query(
models.PackageDependency
).join(models.DependencyType).filter(
models.PackageDependency.PackageID.in_(ids)
).with_entities(
db.query(models.PackageDependency)
.join(models.DependencyType)
.filter(models.PackageDependency.PackageID.in_(ids))
.with_entities(
models.PackageDependency.PackageID.label("ID"),
models.DependencyType.Name.label("Type"),
models.PackageDependency.DepName.label("Name"),
models.PackageDependency.DepCondition.label("Cond")
).distinct().order_by("Name"),
models.PackageDependency.DepCondition.label("Cond"),
)
.distinct()
.order_by("Name"),
# PackageRelation
db.query(
models.PackageRelation
).join(models.RelationType).filter(
models.PackageRelation.PackageID.in_(ids)
).with_entities(
db.query(models.PackageRelation)
.join(models.RelationType)
.filter(models.PackageRelation.PackageID.in_(ids))
.with_entities(
models.PackageRelation.PackageID.label("ID"),
models.RelationType.Name.label("Type"),
models.PackageRelation.RelName.label("Name"),
models.PackageRelation.RelCondition.label("Cond")
).distinct().order_by("Name"),
models.PackageRelation.RelCondition.label("Cond"),
)
.distinct()
.order_by("Name"),
# Groups
db.query(models.PackageGroup).join(
db.query(models.PackageGroup)
.join(
models.Group,
and_(models.PackageGroup.GroupID == models.Group.ID,
models.PackageGroup.PackageID.in_(ids))
).with_entities(
and_(
models.PackageGroup.GroupID == models.Group.ID,
models.PackageGroup.PackageID.in_(ids),
),
)
.with_entities(
models.PackageGroup.PackageID.label("ID"),
literal("Groups").label("Type"),
models.Group.Name.label("Name"),
literal(str()).label("Cond")
).distinct().order_by("Name"),
literal(str()).label("Cond"),
)
.distinct()
.order_by("Name"),
# Licenses
db.query(models.PackageLicense).join(
models.License,
models.PackageLicense.LicenseID == models.License.ID
).filter(
models.PackageLicense.PackageID.in_(ids)
).with_entities(
db.query(models.PackageLicense)
.join(models.License, models.PackageLicense.LicenseID == models.License.ID)
.filter(models.PackageLicense.PackageID.in_(ids))
.with_entities(
models.PackageLicense.PackageID.label("ID"),
literal("License").label("Type"),
models.License.Name.label("Name"),
literal(str()).label("Cond")
).distinct().order_by("Name"),
literal(str()).label("Cond"),
)
.distinct()
.order_by("Name"),
# Keywords
db.query(models.PackageKeyword).join(
db.query(models.PackageKeyword)
.join(
models.Package,
and_(Package.PackageBaseID == PackageKeyword.PackageBaseID,
Package.ID.in_(ids))
).with_entities(
and_(
Package.PackageBaseID == PackageKeyword.PackageBaseID,
Package.ID.in_(ids),
),
)
.with_entities(
models.Package.ID.label("ID"),
literal("Keywords").label("Type"),
models.PackageKeyword.Keyword.label("Name"),
literal(str()).label("Cond")
).distinct().order_by("Name")
literal(str()).label("Cond"),
)
.distinct()
.order_by("Name"),
]
# Union all subqueries together.
@ -295,8 +313,9 @@ class RPC:
return self._assemble_json_data(packages, self._get_info_json_data)
def _handle_search_type(self, by: str = defaults.RPC_SEARCH_BY,
args: list[str] = []) -> list[dict[str, Any]]:
def _handle_search_type(
self, by: str = defaults.RPC_SEARCH_BY, args: list[str] = []
) -> list[dict[str, Any]]:
# If `by` isn't maintainer and we don't have any args, raise an error.
# In maintainer's case, return all orphans if there are no args,
# so we need args to pass through to the handler without errors.
@ -318,49 +337,63 @@ class RPC:
return self._assemble_json_data(results, self._get_json_data)
def _handle_msearch_type(self, args: list[str] = [], **kwargs)\
-> list[dict[str, Any]]:
def _handle_msearch_type(
self, args: list[str] = [], **kwargs
) -> list[dict[str, Any]]:
return self._handle_search_type(by="m", args=args)
def _handle_suggest_type(self, args: list[str] = [], **kwargs)\
-> list[str]:
def _handle_suggest_type(self, args: list[str] = [], **kwargs) -> list[str]:
if not args:
return []
arg = args[0]
packages = db.query(models.Package.Name).join(
models.PackageBase
).filter(
and_(models.PackageBase.PackagerUID.isnot(None),
models.Package.Name.like(f"{arg}%"))
).order_by(models.Package.Name.asc()).limit(20)
packages = (
db.query(models.Package.Name)
.join(models.PackageBase)
.filter(
and_(
models.PackageBase.PackagerUID.isnot(None),
models.Package.Name.like(f"{arg}%"),
)
)
.order_by(models.Package.Name.asc())
.limit(20)
)
return [pkg.Name for pkg in packages]
def _handle_suggest_pkgbase_type(self, args: list[str] = [], **kwargs)\
-> list[str]:
def _handle_suggest_pkgbase_type(self, args: list[str] = [], **kwargs) -> list[str]:
if not args:
return []
arg = args[0]
packages = db.query(models.PackageBase.Name).filter(
and_(models.PackageBase.PackagerUID.isnot(None),
models.PackageBase.Name.like(f"{arg}%"))
).order_by(models.PackageBase.Name.asc()).limit(20)
packages = (
db.query(models.PackageBase.Name)
.filter(
and_(
models.PackageBase.PackagerUID.isnot(None),
models.PackageBase.Name.like(f"{arg}%"),
)
)
.order_by(models.PackageBase.Name.asc())
.limit(20)
)
return [pkg.Name for pkg in packages]
def _is_suggestion(self) -> bool:
return self.type.startswith("suggest")
def _handle_callback(self, by: str, args: list[str])\
-> Union[list[dict[str, Any]], list[str]]:
def _handle_callback(
self, by: str, args: list[str]
) -> Union[list[dict[str, Any]], list[str]]:
# Get a handle to our callback and trap an RPCError with
# an empty list of results based on callback's execution.
callback = getattr(self, f"_handle_{self.type.replace('-', '_')}_type")
results = callback(by=by, args=args)
return results
def handle(self, by: str = defaults.RPC_SEARCH_BY, args: list[str] = [])\
-> Union[list[dict[str, Any]], dict[str, Any]]:
def handle(
self, by: str = defaults.RPC_SEARCH_BY, args: list[str] = []
) -> Union[list[dict[str, Any]], dict[str, Any]]:
"""Request entrypoint. A router should pass v, type and args
to this function and expect an output dictionary to be returned.
@ -392,8 +425,5 @@ class RPC:
return results
# Return JSON output.
data.update({
"resultcount": len(results),
"results": results
})
data.update({"resultcount": len(results), "results": results})
return data

View file

@ -6,7 +6,18 @@ usually be automatically generated. See `migrations/README` for details.
"""
from sqlalchemy import CHAR, TIMESTAMP, Column, ForeignKey, Index, MetaData, String, Table, Text, text
from sqlalchemy import (
CHAR,
TIMESTAMP,
Column,
ForeignKey,
Index,
MetaData,
String,
Table,
Text,
text,
)
from sqlalchemy.dialects.mysql import BIGINT, DECIMAL, INTEGER, TINYINT
from sqlalchemy.ext.compiler import compiles
@ -15,13 +26,13 @@ import aurweb.config
db_backend = aurweb.config.get("database", "backend")
@compiles(TINYINT, 'sqlite')
@compiles(TINYINT, "sqlite")
def compile_tinyint_sqlite(type_, compiler, **kw): # pragma: no cover
"""TINYINT is not supported on SQLite. Substitute it with INTEGER."""
return 'INTEGER'
return "INTEGER"
@compiles(BIGINT, 'sqlite')
@compiles(BIGINT, "sqlite")
def compile_bigint_sqlite(type_, compiler, **kw): # pragma: no cover
"""
For SQLite's AUTOINCREMENT to work on BIGINT columns, we need to map BIGINT
@ -29,429 +40,567 @@ def compile_bigint_sqlite(type_, compiler, **kw): # pragma: no cover
See https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#allowing-autoincrement-behavior-sqlalchemy-types-other-than-integer-integer
""" # noqa: E501
return 'INTEGER'
return "INTEGER"
metadata = MetaData()
# Define the Account Types for the AUR.
AccountTypes = Table(
'AccountTypes', metadata,
Column('ID', TINYINT(unsigned=True), primary_key=True),
Column('AccountType', String(32), nullable=False, server_default=text("''")),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci'
"AccountTypes",
metadata,
Column("ID", TINYINT(unsigned=True), primary_key=True),
Column("AccountType", String(32), nullable=False, server_default=text("''")),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# User information for each user regardless of type.
Users = Table(
'Users', metadata,
Column('ID', INTEGER(unsigned=True), primary_key=True),
Column('AccountTypeID', ForeignKey('AccountTypes.ID', ondelete="NO ACTION"), nullable=False, server_default=text("1")),
Column('Suspended', TINYINT(unsigned=True), nullable=False, server_default=text("0")),
Column('Username', String(32), nullable=False, unique=True),
Column('Email', String(254), nullable=False, unique=True),
Column('BackupEmail', String(254)),
Column('HideEmail', TINYINT(unsigned=True), nullable=False, server_default=text("0")),
Column('Passwd', String(255), nullable=False),
Column('Salt', CHAR(32), nullable=False, server_default=text("''")),
Column('ResetKey', CHAR(32), nullable=False, server_default=text("''")),
Column('RealName', String(64), nullable=False, server_default=text("''")),
Column('LangPreference', String(6), nullable=False, server_default=text("'en'")),
Column('Timezone', String(32), nullable=False, server_default=text("'UTC'")),
Column('Homepage', Text),
Column('IRCNick', String(32), nullable=False, server_default=text("''")),
Column('PGPKey', String(40)),
Column('LastLogin', BIGINT(unsigned=True), nullable=False, server_default=text("0")),
Column('LastLoginIPAddress', String(45)),
Column('LastSSHLogin', BIGINT(unsigned=True), nullable=False, server_default=text("0")),
Column('LastSSHLoginIPAddress', String(45)),
Column('InactivityTS', BIGINT(unsigned=True), nullable=False, server_default=text("0")),
Column('RegistrationTS', TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")),
Column('CommentNotify', TINYINT(1), nullable=False, server_default=text("1")),
Column('UpdateNotify', TINYINT(1), nullable=False, server_default=text("0")),
Column('OwnershipNotify', TINYINT(1), nullable=False, server_default=text("1")),
Column('SSOAccountID', String(255), nullable=True, unique=True),
Index('UsersAccountTypeID', 'AccountTypeID'),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"Users",
metadata,
Column("ID", INTEGER(unsigned=True), primary_key=True),
Column(
"AccountTypeID",
ForeignKey("AccountTypes.ID", ondelete="NO ACTION"),
nullable=False,
server_default=text("1"),
),
Column(
"Suspended", TINYINT(unsigned=True), nullable=False, server_default=text("0")
),
Column("Username", String(32), nullable=False, unique=True),
Column("Email", String(254), nullable=False, unique=True),
Column("BackupEmail", String(254)),
Column(
"HideEmail", TINYINT(unsigned=True), nullable=False, server_default=text("0")
),
Column("Passwd", String(255), nullable=False),
Column("Salt", CHAR(32), nullable=False, server_default=text("''")),
Column("ResetKey", CHAR(32), nullable=False, server_default=text("''")),
Column("RealName", String(64), nullable=False, server_default=text("''")),
Column("LangPreference", String(6), nullable=False, server_default=text("'en'")),
Column("Timezone", String(32), nullable=False, server_default=text("'UTC'")),
Column("Homepage", Text),
Column("IRCNick", String(32), nullable=False, server_default=text("''")),
Column("PGPKey", String(40)),
Column(
"LastLogin", BIGINT(unsigned=True), nullable=False, server_default=text("0")
),
Column("LastLoginIPAddress", String(45)),
Column(
"LastSSHLogin", BIGINT(unsigned=True), nullable=False, server_default=text("0")
),
Column("LastSSHLoginIPAddress", String(45)),
Column(
"InactivityTS", BIGINT(unsigned=True), nullable=False, server_default=text("0")
),
Column(
"RegistrationTS",
TIMESTAMP,
nullable=False,
server_default=text("CURRENT_TIMESTAMP"),
),
Column("CommentNotify", TINYINT(1), nullable=False, server_default=text("1")),
Column("UpdateNotify", TINYINT(1), nullable=False, server_default=text("0")),
Column("OwnershipNotify", TINYINT(1), nullable=False, server_default=text("1")),
Column("SSOAccountID", String(255), nullable=True, unique=True),
Index("UsersAccountTypeID", "AccountTypeID"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# SSH public keys used for the aurweb SSH/Git interface.
SSHPubKeys = Table(
'SSHPubKeys', metadata,
Column('UserID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False),
Column('Fingerprint', String(44), primary_key=True),
Column('PubKey', String(4096), nullable=False),
mysql_engine='InnoDB', mysql_charset='utf8mb4', mysql_collate='utf8mb4_bin',
"SSHPubKeys",
metadata,
Column("UserID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column("Fingerprint", String(44), primary_key=True),
Column("PubKey", String(4096), nullable=False),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_bin",
)
# Track Users logging in/out of AUR web site.
Sessions = Table(
'Sessions', metadata,
Column('UsersID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False),
Column('SessionID', CHAR(32), nullable=False, unique=True),
Column('LastUpdateTS', BIGINT(unsigned=True), nullable=False),
mysql_engine='InnoDB', mysql_charset='utf8mb4', mysql_collate='utf8mb4_bin',
"Sessions",
metadata,
Column("UsersID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column("SessionID", CHAR(32), nullable=False, unique=True),
Column("LastUpdateTS", BIGINT(unsigned=True), nullable=False),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_bin",
)
# Information on package bases
PackageBases = Table(
'PackageBases', metadata,
Column('ID', INTEGER(unsigned=True), primary_key=True),
Column('Name', String(255), nullable=False, unique=True),
Column('NumVotes', INTEGER(unsigned=True), nullable=False, server_default=text("0")),
Column('Popularity',
DECIMAL(10, 6, unsigned=True)
if db_backend == "mysql" else String(17),
nullable=False, server_default=text("0")),
Column('OutOfDateTS', BIGINT(unsigned=True)),
Column('FlaggerComment', Text, nullable=False),
Column('SubmittedTS', BIGINT(unsigned=True), nullable=False),
Column('ModifiedTS', BIGINT(unsigned=True), nullable=False),
Column('FlaggerUID', ForeignKey('Users.ID', ondelete='SET NULL')), # who flagged the package out-of-date?
"PackageBases",
metadata,
Column("ID", INTEGER(unsigned=True), primary_key=True),
Column("Name", String(255), nullable=False, unique=True),
Column(
"NumVotes", INTEGER(unsigned=True), nullable=False, server_default=text("0")
),
Column(
"Popularity",
DECIMAL(10, 6, unsigned=True) if db_backend == "mysql" else String(17),
nullable=False,
server_default=text("0"),
),
Column("OutOfDateTS", BIGINT(unsigned=True)),
Column("FlaggerComment", Text, nullable=False),
Column("SubmittedTS", BIGINT(unsigned=True), nullable=False),
Column("ModifiedTS", BIGINT(unsigned=True), nullable=False),
Column(
"FlaggerUID", ForeignKey("Users.ID", ondelete="SET NULL")
), # who flagged the package out-of-date?
# deleting a user will cause packages to be orphaned, not deleted
Column('SubmitterUID', ForeignKey('Users.ID', ondelete='SET NULL')), # who submitted it?
Column('MaintainerUID', ForeignKey('Users.ID', ondelete='SET NULL')), # User
Column('PackagerUID', ForeignKey('Users.ID', ondelete='SET NULL')), # Last packager
Index('BasesMaintainerUID', 'MaintainerUID'),
Index('BasesNumVotes', 'NumVotes'),
Index('BasesPackagerUID', 'PackagerUID'),
Index('BasesSubmitterUID', 'SubmitterUID'),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
Column(
"SubmitterUID", ForeignKey("Users.ID", ondelete="SET NULL")
), # who submitted it?
Column("MaintainerUID", ForeignKey("Users.ID", ondelete="SET NULL")), # User
Column("PackagerUID", ForeignKey("Users.ID", ondelete="SET NULL")), # Last packager
Index("BasesMaintainerUID", "MaintainerUID"),
Index("BasesNumVotes", "NumVotes"),
Index("BasesPackagerUID", "PackagerUID"),
Index("BasesSubmitterUID", "SubmitterUID"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Keywords of package bases
PackageKeywords = Table(
'PackageKeywords', metadata,
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), primary_key=True, nullable=True),
Column('Keyword', String(255), primary_key=True, nullable=False, server_default=text("''")),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"PackageKeywords",
metadata,
Column(
"PackageBaseID",
ForeignKey("PackageBases.ID", ondelete="CASCADE"),
primary_key=True,
nullable=True,
),
Column(
"Keyword",
String(255),
primary_key=True,
nullable=False,
server_default=text("''"),
),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Information about the actual packages
Packages = Table(
'Packages', metadata,
Column('ID', INTEGER(unsigned=True), primary_key=True),
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), nullable=False),
Column('Name', String(255), nullable=False, unique=True),
Column('Version', String(255), nullable=False, server_default=text("''")),
Column('Description', String(255)),
Column('URL', String(8000)),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"Packages",
metadata,
Column("ID", INTEGER(unsigned=True), primary_key=True),
Column(
"PackageBaseID",
ForeignKey("PackageBases.ID", ondelete="CASCADE"),
nullable=False,
),
Column("Name", String(255), nullable=False, unique=True),
Column("Version", String(255), nullable=False, server_default=text("''")),
Column("Description", String(255)),
Column("URL", String(8000)),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Information about licenses
Licenses = Table(
'Licenses', metadata,
Column('ID', INTEGER(unsigned=True), primary_key=True),
Column('Name', String(255), nullable=False, unique=True),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"Licenses",
metadata,
Column("ID", INTEGER(unsigned=True), primary_key=True),
Column("Name", String(255), nullable=False, unique=True),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Information about package-license-relations
PackageLicenses = Table(
'PackageLicenses', metadata,
Column('PackageID', ForeignKey('Packages.ID', ondelete='CASCADE'), primary_key=True, nullable=True),
Column('LicenseID', ForeignKey('Licenses.ID', ondelete='CASCADE'), primary_key=True, nullable=True),
mysql_engine='InnoDB',
"PackageLicenses",
metadata,
Column(
"PackageID",
ForeignKey("Packages.ID", ondelete="CASCADE"),
primary_key=True,
nullable=True,
),
Column(
"LicenseID",
ForeignKey("Licenses.ID", ondelete="CASCADE"),
primary_key=True,
nullable=True,
),
mysql_engine="InnoDB",
)
# Information about groups
Groups = Table(
'Groups', metadata,
Column('ID', INTEGER(unsigned=True), primary_key=True),
Column('Name', String(255), nullable=False, unique=True),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"Groups",
metadata,
Column("ID", INTEGER(unsigned=True), primary_key=True),
Column("Name", String(255), nullable=False, unique=True),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Information about package-group-relations
PackageGroups = Table(
'PackageGroups', metadata,
Column('PackageID', ForeignKey('Packages.ID', ondelete='CASCADE'), primary_key=True, nullable=True),
Column('GroupID', ForeignKey('Groups.ID', ondelete='CASCADE'), primary_key=True, nullable=True),
mysql_engine='InnoDB',
"PackageGroups",
metadata,
Column(
"PackageID",
ForeignKey("Packages.ID", ondelete="CASCADE"),
primary_key=True,
nullable=True,
),
Column(
"GroupID",
ForeignKey("Groups.ID", ondelete="CASCADE"),
primary_key=True,
nullable=True,
),
mysql_engine="InnoDB",
)
# Define the package dependency types
DependencyTypes = Table(
'DependencyTypes', metadata,
Column('ID', TINYINT(unsigned=True), primary_key=True),
Column('Name', String(32), nullable=False, server_default=text("''")),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"DependencyTypes",
metadata,
Column("ID", TINYINT(unsigned=True), primary_key=True),
Column("Name", String(32), nullable=False, server_default=text("''")),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Track which dependencies a package has
PackageDepends = Table(
'PackageDepends', metadata,
Column('PackageID', ForeignKey('Packages.ID', ondelete='CASCADE'), nullable=False),
Column('DepTypeID', ForeignKey('DependencyTypes.ID', ondelete="NO ACTION"), nullable=False),
Column('DepName', String(255), nullable=False),
Column('DepDesc', String(255)),
Column('DepCondition', String(255)),
Column('DepArch', String(255)),
Index('DependsDepName', 'DepName'),
Index('DependsPackageID', 'PackageID'),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"PackageDepends",
metadata,
Column("PackageID", ForeignKey("Packages.ID", ondelete="CASCADE"), nullable=False),
Column(
"DepTypeID",
ForeignKey("DependencyTypes.ID", ondelete="NO ACTION"),
nullable=False,
),
Column("DepName", String(255), nullable=False),
Column("DepDesc", String(255)),
Column("DepCondition", String(255)),
Column("DepArch", String(255)),
Index("DependsDepName", "DepName"),
Index("DependsPackageID", "PackageID"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Define the package relation types
RelationTypes = Table(
'RelationTypes', metadata,
Column('ID', TINYINT(unsigned=True), primary_key=True),
Column('Name', String(32), nullable=False, server_default=text("''")),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"RelationTypes",
metadata,
Column("ID", TINYINT(unsigned=True), primary_key=True),
Column("Name", String(32), nullable=False, server_default=text("''")),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Track which conflicts, provides and replaces a package has
PackageRelations = Table(
'PackageRelations', metadata,
Column('PackageID', ForeignKey('Packages.ID', ondelete='CASCADE'), nullable=False),
Column('RelTypeID', ForeignKey('RelationTypes.ID', ondelete="NO ACTION"), nullable=False),
Column('RelName', String(255), nullable=False),
Column('RelCondition', String(255)),
Column('RelArch', String(255)),
Index('RelationsPackageID', 'PackageID'),
Index('RelationsRelName', 'RelName'),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"PackageRelations",
metadata,
Column("PackageID", ForeignKey("Packages.ID", ondelete="CASCADE"), nullable=False),
Column(
"RelTypeID",
ForeignKey("RelationTypes.ID", ondelete="NO ACTION"),
nullable=False,
),
Column("RelName", String(255), nullable=False),
Column("RelCondition", String(255)),
Column("RelArch", String(255)),
Index("RelationsPackageID", "PackageID"),
Index("RelationsRelName", "RelName"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Track which sources a package has
PackageSources = Table(
'PackageSources', metadata,
Column('PackageID', ForeignKey('Packages.ID', ondelete='CASCADE'), nullable=False),
Column('Source', String(8000), nullable=False, server_default=text("'/dev/null'")),
Column('SourceArch', String(255)),
Index('SourcesPackageID', 'PackageID'),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"PackageSources",
metadata,
Column("PackageID", ForeignKey("Packages.ID", ondelete="CASCADE"), nullable=False),
Column("Source", String(8000), nullable=False, server_default=text("'/dev/null'")),
Column("SourceArch", String(255)),
Index("SourcesPackageID", "PackageID"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Track votes for packages
PackageVotes = Table(
'PackageVotes', metadata,
Column('UsersID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False),
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), nullable=False),
Column('VoteTS', BIGINT(unsigned=True), nullable=False),
Index('VoteUsersIDPackageID', 'UsersID', 'PackageBaseID', unique=True),
Index('VotesPackageBaseID', 'PackageBaseID'),
Index('VotesUsersID', 'UsersID'),
mysql_engine='InnoDB',
"PackageVotes",
metadata,
Column("UsersID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column(
"PackageBaseID",
ForeignKey("PackageBases.ID", ondelete="CASCADE"),
nullable=False,
),
Column("VoteTS", BIGINT(unsigned=True), nullable=False),
Index("VoteUsersIDPackageID", "UsersID", "PackageBaseID", unique=True),
Index("VotesPackageBaseID", "PackageBaseID"),
Index("VotesUsersID", "UsersID"),
mysql_engine="InnoDB",
)
# Record comments for packages
PackageComments = Table(
'PackageComments', metadata,
Column('ID', BIGINT(unsigned=True), primary_key=True),
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), nullable=False),
Column('UsersID', ForeignKey('Users.ID', ondelete='SET NULL')),
Column('Comments', Text, nullable=False),
Column('RenderedComment', Text, nullable=False),
Column('CommentTS', BIGINT(unsigned=True), nullable=False, server_default=text("0")),
Column('EditedTS', BIGINT(unsigned=True)),
Column('EditedUsersID', ForeignKey('Users.ID', ondelete='SET NULL')),
Column('DelTS', BIGINT(unsigned=True)),
Column('DelUsersID', ForeignKey('Users.ID', ondelete='CASCADE')),
Column('PinnedTS', BIGINT(unsigned=True), nullable=False, server_default=text("0")),
Index('CommentsPackageBaseID', 'PackageBaseID'),
Index('CommentsUsersID', 'UsersID'),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"PackageComments",
metadata,
Column("ID", BIGINT(unsigned=True), primary_key=True),
Column(
"PackageBaseID",
ForeignKey("PackageBases.ID", ondelete="CASCADE"),
nullable=False,
),
Column("UsersID", ForeignKey("Users.ID", ondelete="SET NULL")),
Column("Comments", Text, nullable=False),
Column("RenderedComment", Text, nullable=False),
Column(
"CommentTS", BIGINT(unsigned=True), nullable=False, server_default=text("0")
),
Column("EditedTS", BIGINT(unsigned=True)),
Column("EditedUsersID", ForeignKey("Users.ID", ondelete="SET NULL")),
Column("DelTS", BIGINT(unsigned=True)),
Column("DelUsersID", ForeignKey("Users.ID", ondelete="CASCADE")),
Column("PinnedTS", BIGINT(unsigned=True), nullable=False, server_default=text("0")),
Index("CommentsPackageBaseID", "PackageBaseID"),
Index("CommentsUsersID", "UsersID"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Package base co-maintainers
PackageComaintainers = Table(
'PackageComaintainers', metadata,
Column('UsersID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False),
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), nullable=False),
Column('Priority', INTEGER(unsigned=True), nullable=False),
Index('ComaintainersPackageBaseID', 'PackageBaseID'),
Index('ComaintainersUsersID', 'UsersID'),
mysql_engine='InnoDB',
"PackageComaintainers",
metadata,
Column("UsersID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column(
"PackageBaseID",
ForeignKey("PackageBases.ID", ondelete="CASCADE"),
nullable=False,
),
Column("Priority", INTEGER(unsigned=True), nullable=False),
Index("ComaintainersPackageBaseID", "PackageBaseID"),
Index("ComaintainersUsersID", "UsersID"),
mysql_engine="InnoDB",
)
# Package base notifications
PackageNotifications = Table(
'PackageNotifications', metadata,
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='CASCADE'), nullable=False),
Column('UserID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False),
Index('NotifyUserIDPkgID', 'UserID', 'PackageBaseID', unique=True),
mysql_engine='InnoDB',
"PackageNotifications",
metadata,
Column(
"PackageBaseID",
ForeignKey("PackageBases.ID", ondelete="CASCADE"),
nullable=False,
),
Column("UserID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Index("NotifyUserIDPkgID", "UserID", "PackageBaseID", unique=True),
mysql_engine="InnoDB",
)
# Package name blacklist
PackageBlacklist = Table(
'PackageBlacklist', metadata,
Column('ID', INTEGER(unsigned=True), primary_key=True),
Column('Name', String(64), nullable=False, unique=True),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"PackageBlacklist",
metadata,
Column("ID", INTEGER(unsigned=True), primary_key=True),
Column("Name", String(64), nullable=False, unique=True),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Providers in the official repositories
OfficialProviders = Table(
'OfficialProviders', metadata,
Column('ID', INTEGER(unsigned=True), primary_key=True),
Column('Name', String(64), nullable=False),
Column('Repo', String(64), nullable=False),
Column('Provides', String(64), nullable=False),
Index('ProviderNameProvides', 'Name', 'Provides', unique=True),
mysql_engine='InnoDB', mysql_charset='utf8mb4', mysql_collate='utf8mb4_bin',
"OfficialProviders",
metadata,
Column("ID", INTEGER(unsigned=True), primary_key=True),
Column("Name", String(64), nullable=False),
Column("Repo", String(64), nullable=False),
Column("Provides", String(64), nullable=False),
Index("ProviderNameProvides", "Name", "Provides", unique=True),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_bin",
)
# Define package request types
RequestTypes = Table(
'RequestTypes', metadata,
Column('ID', TINYINT(unsigned=True), primary_key=True),
Column('Name', String(32), nullable=False, server_default=text("''")),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"RequestTypes",
metadata,
Column("ID", TINYINT(unsigned=True), primary_key=True),
Column("Name", String(32), nullable=False, server_default=text("''")),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Package requests
PackageRequests = Table(
'PackageRequests', metadata,
Column('ID', BIGINT(unsigned=True), primary_key=True),
Column('ReqTypeID', ForeignKey('RequestTypes.ID', ondelete="NO ACTION"), nullable=False),
Column('PackageBaseID', ForeignKey('PackageBases.ID', ondelete='SET NULL')),
Column('PackageBaseName', String(255), nullable=False),
Column('MergeBaseName', String(255)),
Column('UsersID', ForeignKey('Users.ID', ondelete='SET NULL')),
Column('Comments', Text, nullable=False),
Column('ClosureComment', Text, nullable=False),
Column('RequestTS', BIGINT(unsigned=True), nullable=False, server_default=text("0")),
Column('ClosedTS', BIGINT(unsigned=True)),
Column('ClosedUID', ForeignKey('Users.ID', ondelete='SET NULL')),
Column('Status', TINYINT(unsigned=True), nullable=False, server_default=text("0")),
Index('RequestsPackageBaseID', 'PackageBaseID'),
Index('RequestsUsersID', 'UsersID'),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"PackageRequests",
metadata,
Column("ID", BIGINT(unsigned=True), primary_key=True),
Column(
"ReqTypeID", ForeignKey("RequestTypes.ID", ondelete="NO ACTION"), nullable=False
),
Column("PackageBaseID", ForeignKey("PackageBases.ID", ondelete="SET NULL")),
Column("PackageBaseName", String(255), nullable=False),
Column("MergeBaseName", String(255)),
Column("UsersID", ForeignKey("Users.ID", ondelete="SET NULL")),
Column("Comments", Text, nullable=False),
Column("ClosureComment", Text, nullable=False),
Column(
"RequestTS", BIGINT(unsigned=True), nullable=False, server_default=text("0")
),
Column("ClosedTS", BIGINT(unsigned=True)),
Column("ClosedUID", ForeignKey("Users.ID", ondelete="SET NULL")),
Column("Status", TINYINT(unsigned=True), nullable=False, server_default=text("0")),
Index("RequestsPackageBaseID", "PackageBaseID"),
Index("RequestsUsersID", "UsersID"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Vote information
TU_VoteInfo = Table(
'TU_VoteInfo', metadata,
Column('ID', INTEGER(unsigned=True), primary_key=True),
Column('Agenda', Text, nullable=False),
Column('User', String(32), nullable=False),
Column('Submitted', BIGINT(unsigned=True), nullable=False),
Column('End', BIGINT(unsigned=True), nullable=False),
Column('Quorum',
DECIMAL(2, 2, unsigned=True)
if db_backend == "mysql" else String(5),
nullable=False),
Column('SubmitterID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False),
Column('Yes', INTEGER(unsigned=True), nullable=False, server_default=text("'0'")),
Column('No', INTEGER(unsigned=True), nullable=False, server_default=text("'0'")),
Column('Abstain', INTEGER(unsigned=True), nullable=False, server_default=text("'0'")),
Column('ActiveTUs', INTEGER(unsigned=True), nullable=False, server_default=text("'0'")),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"TU_VoteInfo",
metadata,
Column("ID", INTEGER(unsigned=True), primary_key=True),
Column("Agenda", Text, nullable=False),
Column("User", String(32), nullable=False),
Column("Submitted", BIGINT(unsigned=True), nullable=False),
Column("End", BIGINT(unsigned=True), nullable=False),
Column(
"Quorum",
DECIMAL(2, 2, unsigned=True) if db_backend == "mysql" else String(5),
nullable=False,
),
Column("SubmitterID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column("Yes", INTEGER(unsigned=True), nullable=False, server_default=text("'0'")),
Column("No", INTEGER(unsigned=True), nullable=False, server_default=text("'0'")),
Column(
"Abstain", INTEGER(unsigned=True), nullable=False, server_default=text("'0'")
),
Column(
"ActiveTUs", INTEGER(unsigned=True), nullable=False, server_default=text("'0'")
),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Individual vote records
TU_Votes = Table(
'TU_Votes', metadata,
Column('VoteID', ForeignKey('TU_VoteInfo.ID', ondelete='CASCADE'), nullable=False),
Column('UserID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False),
mysql_engine='InnoDB',
"TU_Votes",
metadata,
Column("VoteID", ForeignKey("TU_VoteInfo.ID", ondelete="CASCADE"), nullable=False),
Column("UserID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
mysql_engine="InnoDB",
)
# Malicious user banning
Bans = Table(
'Bans', metadata,
Column('IPAddress', String(45), primary_key=True),
Column('BanTS', TIMESTAMP, nullable=False),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"Bans",
metadata,
Column("IPAddress", String(45), primary_key=True),
Column("BanTS", TIMESTAMP, nullable=False),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Terms and Conditions
Terms = Table(
'Terms', metadata,
Column('ID', INTEGER(unsigned=True), primary_key=True),
Column('Description', String(255), nullable=False),
Column('URL', String(8000), nullable=False),
Column('Revision', INTEGER(unsigned=True), nullable=False, server_default=text("1")),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"Terms",
metadata,
Column("ID", INTEGER(unsigned=True), primary_key=True),
Column("Description", String(255), nullable=False),
Column("URL", String(8000), nullable=False),
Column(
"Revision", INTEGER(unsigned=True), nullable=False, server_default=text("1")
),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)
# Terms and Conditions accepted by users
AcceptedTerms = Table(
'AcceptedTerms', metadata,
Column('UsersID', ForeignKey('Users.ID', ondelete='CASCADE'), nullable=False),
Column('TermsID', ForeignKey('Terms.ID', ondelete='CASCADE'), nullable=False),
Column('Revision', INTEGER(unsigned=True), nullable=False, server_default=text("0")),
mysql_engine='InnoDB',
"AcceptedTerms",
metadata,
Column("UsersID", ForeignKey("Users.ID", ondelete="CASCADE"), nullable=False),
Column("TermsID", ForeignKey("Terms.ID", ondelete="CASCADE"), nullable=False),
Column(
"Revision", INTEGER(unsigned=True), nullable=False, server_default=text("0")
),
mysql_engine="InnoDB",
)
# Rate limits for API
ApiRateLimit = Table(
'ApiRateLimit', metadata,
Column('IP', String(45), primary_key=True, unique=True, default=str()),
Column('Requests', INTEGER(11), nullable=False),
Column('WindowStart', BIGINT(20), nullable=False),
Index('ApiRateLimitWindowStart', 'WindowStart'),
mysql_engine='InnoDB',
mysql_charset='utf8mb4',
mysql_collate='utf8mb4_general_ci',
"ApiRateLimit",
metadata,
Column("IP", String(45), primary_key=True, unique=True, default=str()),
Column("Requests", INTEGER(11), nullable=False),
Column("WindowStart", BIGINT(20), nullable=False),
Index("ApiRateLimitWindowStart", "WindowStart"),
mysql_engine="InnoDB",
mysql_charset="utf8mb4",
mysql_collate="utf8mb4_general_ci",
)

View file

@ -11,7 +11,6 @@ import sys
import traceback
import aurweb.models.account_type as at
from aurweb import db
from aurweb.models.account_type import AccountType
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
@ -30,8 +29,9 @@ def parse_args():
parser.add_argument("--ssh-pubkey", help="SSH PubKey")
choices = at.ACCOUNT_TYPE_NAME.values()
parser.add_argument("-t", "--type", help="Account Type",
choices=choices, default=at.USER)
parser.add_argument(
"-t", "--type", help="Account Type", choices=choices, default=at.USER
)
return parser.parse_args()
@ -40,25 +40,29 @@ def main():
args = parse_args()
db.get_engine()
type = db.query(AccountType,
AccountType.AccountType == args.type).first()
type = db.query(AccountType, AccountType.AccountType == args.type).first()
with db.begin():
user = db.create(User, Username=args.username,
Email=args.email, Passwd=args.password,
RealName=args.realname, IRCNick=args.ircnick,
PGPKey=args.pgp_key, AccountType=type)
user = db.create(
User,
Username=args.username,
Email=args.email,
Passwd=args.password,
RealName=args.realname,
IRCNick=args.ircnick,
PGPKey=args.pgp_key,
AccountType=type,
)
if args.ssh_pubkey:
pubkey = args.ssh_pubkey.strip()
# Remove host from the pubkey if it's there.
pubkey = ' '.join(pubkey.split(' ')[:2])
pubkey = " ".join(pubkey.split(" ")[:2])
with db.begin():
db.create(SSHPubKey,
User=user,
PubKey=pubkey,
Fingerprint=get_fingerprint(pubkey))
db.create(
SSHPubKey, User=user, PubKey=pubkey, Fingerprint=get_fingerprint(pubkey)
)
print(user.json())
return 0

View file

@ -3,11 +3,9 @@
import re
import pyalpm
from sqlalchemy import and_
import aurweb.config
from aurweb import db, util
from aurweb.models import OfficialProvider
@ -18,8 +16,8 @@ def _main(force: bool = False):
repomap = dict()
db_path = aurweb.config.get("aurblup", "db-path")
sync_dbs = aurweb.config.get('aurblup', 'sync-dbs').split(' ')
server = aurweb.config.get('aurblup', 'server')
sync_dbs = aurweb.config.get("aurblup", "sync-dbs").split(" ")
server = aurweb.config.get("aurblup", "server")
h = pyalpm.Handle("/", db_path)
for sync_db in sync_dbs:
@ -35,28 +33,35 @@ def _main(force: bool = False):
providers.add((pkg.name, pkg.name))
repomap[(pkg.name, pkg.name)] = repo.name
for provision in pkg.provides:
provisionname = re.sub(r'(<|=|>).*', '', provision)
provisionname = re.sub(r"(<|=|>).*", "", provision)
providers.add((pkg.name, provisionname))
repomap[(pkg.name, provisionname)] = repo.name
with db.begin():
old_providers = set(
db.query(OfficialProvider).with_entities(
db.query(OfficialProvider)
.with_entities(
OfficialProvider.Name.label("Name"),
OfficialProvider.Provides.label("Provides")
).distinct().order_by("Name").all()
OfficialProvider.Provides.label("Provides"),
)
.distinct()
.order_by("Name")
.all()
)
for name, provides in old_providers.difference(providers):
db.delete_all(db.query(OfficialProvider).filter(
and_(OfficialProvider.Name == name,
OfficialProvider.Provides == provides)
))
db.delete_all(
db.query(OfficialProvider).filter(
and_(
OfficialProvider.Name == name,
OfficialProvider.Provides == provides,
)
)
)
for name, provides in providers.difference(old_providers):
repo = repomap.get((name, provides))
db.create(OfficialProvider, Name=name,
Repo=repo, Provides=provides)
db.create(OfficialProvider, Name=name, Repo=repo, Provides=provides)
def main(force: bool = False):
@ -64,5 +69,5 @@ def main(force: bool = False):
_main(force)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -50,12 +50,12 @@ def parse_args():
actions = ["get", "set", "unset"]
parser = argparse.ArgumentParser(
description="aurweb configuration tool",
formatter_class=lambda prog: fmt_cls(prog=prog, max_help_position=80))
formatter_class=lambda prog: fmt_cls(prog=prog, max_help_position=80),
)
parser.add_argument("action", choices=actions, help="script action")
parser.add_argument("section", help="config section")
parser.add_argument("option", help="config option")
parser.add_argument("value", nargs="?", default=0,
help="config option value")
parser.add_argument("value", nargs="?", default=0, help="config option value")
return parser.parse_args()

View file

@ -25,16 +25,13 @@ import os
import shutil
import sys
import tempfile
from collections import defaultdict
from typing import Any
import orjson
from sqlalchemy import literal, orm
import aurweb.config
from aurweb import db, filters, logging, models, util
from aurweb.benchmark import Benchmark
from aurweb.models import Package, PackageBase, User
@ -90,65 +87,68 @@ def get_extended_dict(query: orm.Query):
def get_extended_fields():
subqueries = [
# PackageDependency
db.query(
models.PackageDependency
).join(models.DependencyType).with_entities(
db.query(models.PackageDependency)
.join(models.DependencyType)
.with_entities(
models.PackageDependency.PackageID.label("ID"),
models.DependencyType.Name.label("Type"),
models.PackageDependency.DepName.label("Name"),
models.PackageDependency.DepCondition.label("Cond")
).distinct().order_by("Name"),
models.PackageDependency.DepCondition.label("Cond"),
)
.distinct()
.order_by("Name"),
# PackageRelation
db.query(
models.PackageRelation
).join(models.RelationType).with_entities(
db.query(models.PackageRelation)
.join(models.RelationType)
.with_entities(
models.PackageRelation.PackageID.label("ID"),
models.RelationType.Name.label("Type"),
models.PackageRelation.RelName.label("Name"),
models.PackageRelation.RelCondition.label("Cond")
).distinct().order_by("Name"),
models.PackageRelation.RelCondition.label("Cond"),
)
.distinct()
.order_by("Name"),
# Groups
db.query(models.PackageGroup).join(
models.Group,
models.PackageGroup.GroupID == models.Group.ID
).with_entities(
db.query(models.PackageGroup)
.join(models.Group, models.PackageGroup.GroupID == models.Group.ID)
.with_entities(
models.PackageGroup.PackageID.label("ID"),
literal("Groups").label("Type"),
models.Group.Name.label("Name"),
literal(str()).label("Cond")
).distinct().order_by("Name"),
literal(str()).label("Cond"),
)
.distinct()
.order_by("Name"),
# Licenses
db.query(models.PackageLicense).join(
models.License,
models.PackageLicense.LicenseID == models.License.ID
).with_entities(
db.query(models.PackageLicense)
.join(models.License, models.PackageLicense.LicenseID == models.License.ID)
.with_entities(
models.PackageLicense.PackageID.label("ID"),
literal("License").label("Type"),
models.License.Name.label("Name"),
literal(str()).label("Cond")
).distinct().order_by("Name"),
literal(str()).label("Cond"),
)
.distinct()
.order_by("Name"),
# Keywords
db.query(models.PackageKeyword).join(
models.Package,
Package.PackageBaseID == models.PackageKeyword.PackageBaseID
).with_entities(
db.query(models.PackageKeyword)
.join(
models.Package, Package.PackageBaseID == models.PackageKeyword.PackageBaseID
)
.with_entities(
models.Package.ID.label("ID"),
literal("Keywords").label("Type"),
models.PackageKeyword.Keyword.label("Name"),
literal(str()).label("Cond")
).distinct().order_by("Name")
literal(str()).label("Cond"),
)
.distinct()
.order_by("Name"),
]
query = subqueries[0].union_all(*subqueries[1:])
return get_extended_dict(query)
EXTENDED_FIELD_HANDLERS = {
"--extended": get_extended_fields
}
EXTENDED_FIELD_HANDLERS = {"--extended": get_extended_fields}
def as_dict(package: Package) -> dict[str, Any]:
@ -181,23 +181,21 @@ def _main():
archivedir = aurweb.config.get("mkpkglists", "archivedir")
os.makedirs(archivedir, exist_ok=True)
PACKAGES = aurweb.config.get('mkpkglists', 'packagesfile')
META = aurweb.config.get('mkpkglists', 'packagesmetafile')
META_EXT = aurweb.config.get('mkpkglists', 'packagesmetaextfile')
PKGBASE = aurweb.config.get('mkpkglists', 'pkgbasefile')
USERS = aurweb.config.get('mkpkglists', 'userfile')
PACKAGES = aurweb.config.get("mkpkglists", "packagesfile")
META = aurweb.config.get("mkpkglists", "packagesmetafile")
META_EXT = aurweb.config.get("mkpkglists", "packagesmetaextfile")
PKGBASE = aurweb.config.get("mkpkglists", "pkgbasefile")
USERS = aurweb.config.get("mkpkglists", "userfile")
bench = Benchmark()
logger.info("Started re-creating archives, wait a while...")
query = db.query(Package).join(
PackageBase,
PackageBase.ID == Package.PackageBaseID
).join(
User,
PackageBase.MaintainerUID == User.ID,
isouter=True
).filter(PackageBase.PackagerUID.isnot(None)).with_entities(
query = (
db.query(Package)
.join(PackageBase, PackageBase.ID == Package.PackageBaseID)
.join(User, PackageBase.MaintainerUID == User.ID, isouter=True)
.filter(PackageBase.PackagerUID.isnot(None))
.with_entities(
Package.ID,
Package.Name,
PackageBase.ID.label("PackageBaseID"),
@ -210,8 +208,11 @@ def _main():
PackageBase.OutOfDateTS.label("OutOfDate"),
User.Username.label("Maintainer"),
PackageBase.SubmittedTS.label("FirstSubmitted"),
PackageBase.ModifiedTS.label("LastModified")
).distinct().order_by("Name")
PackageBase.ModifiedTS.label("LastModified"),
)
.distinct()
.order_by("Name")
)
# Produce packages-meta-v1.json.gz
output = list()
@ -252,7 +253,7 @@ def _main():
# We stream out package json objects line per line, so
# we also need to include the ',' character at the end
# of package lines (excluding the last package).
suffix = b",\n" if i < n else b'\n'
suffix = b",\n" if i < n else b"\n"
# Write out to packagesmetafile
output.append(item)
@ -273,8 +274,7 @@ def _main():
util.apply_all(gzips.values(), lambda gz: gz.close())
# Produce pkgbase.gz
query = db.query(PackageBase.Name).filter(
PackageBase.PackagerUID.isnot(None)).all()
query = db.query(PackageBase.Name).filter(PackageBase.PackagerUID.isnot(None)).all()
tmp_pkgbase = os.path.join(tmpdir, os.path.basename(PKGBASE))
with gzip.open(tmp_pkgbase, "wt") as f:
f.writelines([f"{base.Name}\n" for i, base in enumerate(query)])
@ -317,5 +317,5 @@ def main():
_main()
if __name__ == '__main__':
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load diff

View file

@ -11,8 +11,8 @@ def _main():
limit_to = time.utcnow() - 86400
query = db.query(PackageBase).filter(
and_(PackageBase.SubmittedTS < limit_to,
PackageBase.PackagerUID.is_(None)))
and_(PackageBase.SubmittedTS < limit_to, PackageBase.PackagerUID.is_(None))
)
db.delete_all(query)
@ -22,5 +22,5 @@ def main():
_main()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -1,8 +1,7 @@
#!/usr/bin/env python3
from sqlalchemy import and_, func
from sqlalchemy.sql.functions import coalesce
from sqlalchemy.sql.functions import sum as _sum
from sqlalchemy.sql.functions import coalesce, sum as _sum
from aurweb import db, time
from aurweb.models import PackageBase, PackageVote
@ -20,18 +19,26 @@ def run_variable(pkgbases: list[PackageBase] = []) -> None:
now = time.utcnow()
# NumVotes subquery.
votes_subq = db.get_session().query(
func.count("*")
).select_from(PackageVote).filter(
PackageVote.PackageBaseID == PackageBase.ID
votes_subq = (
db.get_session()
.query(func.count("*"))
.select_from(PackageVote)
.filter(PackageVote.PackageBaseID == PackageBase.ID)
)
# Popularity subquery.
pop_subq = db.get_session().query(
pop_subq = (
db.get_session()
.query(
coalesce(_sum(func.pow(0.98, (now - PackageVote.VoteTS) / 86400)), 0.0),
).select_from(PackageVote).filter(
and_(PackageVote.PackageBaseID == PackageBase.ID,
PackageVote.VoteTS.isnot(None))
)
.select_from(PackageVote)
.filter(
and_(
PackageVote.PackageBaseID == PackageBase.ID,
PackageVote.VoteTS.isnot(None),
)
)
)
with db.begin():
@ -42,10 +49,12 @@ def run_variable(pkgbases: list[PackageBase] = []) -> None:
ids = {pkgbase.ID for pkgbase in pkgbases}
query = query.filter(PackageBase.ID.in_(ids))
query.update({
query.update(
{
"NumVotes": votes_subq.scalar_subquery(),
"Popularity": pop_subq.scalar_subquery()
})
"Popularity": pop_subq.scalar_subquery(),
}
)
def run_single(pkgbase: PackageBase) -> None:
@ -65,5 +74,5 @@ def main():
run_variable()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -1,7 +1,6 @@
#!/usr/bin/env python3
import sys
from urllib.parse import quote_plus
from xml.etree.ElementTree import Element
@ -10,7 +9,6 @@ import markdown
import pygit2
import aurweb.config
from aurweb import db, logging, util
from aurweb.models import PackageComment
@ -25,13 +23,15 @@ class LinkifyExtension(markdown.extensions.Extension):
# Captures http(s) and ftp URLs until the first non URL-ish character.
# Excludes trailing punctuation.
_urlre = (r'(\b(?:https?|ftp):\/\/[\w\/\#~:.?+=&%@!\-;,]+?'
r'(?=[.:?\-;,]*(?:[^\w\/\#~:.?+=&%@!\-;,]|$)))')
_urlre = (
r"(\b(?:https?|ftp):\/\/[\w\/\#~:.?+=&%@!\-;,]+?"
r"(?=[.:?\-;,]*(?:[^\w\/\#~:.?+=&%@!\-;,]|$)))"
)
def extendMarkdown(self, md):
processor = markdown.inlinepatterns.AutolinkInlineProcessor(self._urlre, md)
# Register it right after the default <>-link processor (priority 120).
md.inlinePatterns.register(processor, 'linkify', 119)
md.inlinePatterns.register(processor, "linkify", 119)
class FlysprayLinksInlineProcessor(markdown.inlinepatterns.InlineProcessor):
@ -43,16 +43,16 @@ class FlysprayLinksInlineProcessor(markdown.inlinepatterns.InlineProcessor):
"""
def handleMatch(self, m, data):
el = Element('a')
el.set('href', f'https://bugs.archlinux.org/task/{m.group(1)}')
el = Element("a")
el.set("href", f"https://bugs.archlinux.org/task/{m.group(1)}")
el.text = markdown.util.AtomicString(m.group(0))
return (el, m.start(0), m.end(0))
class FlysprayLinksExtension(markdown.extensions.Extension):
def extendMarkdown(self, md):
processor = FlysprayLinksInlineProcessor(r'\bFS#(\d+)\b', md)
md.inlinePatterns.register(processor, 'flyspray-links', 118)
processor = FlysprayLinksInlineProcessor(r"\bFS#(\d+)\b", md)
md.inlinePatterns.register(processor, "flyspray-links", 118)
class GitCommitsInlineProcessor(markdown.inlinepatterns.InlineProcessor):
@ -65,10 +65,10 @@ class GitCommitsInlineProcessor(markdown.inlinepatterns.InlineProcessor):
"""
def __init__(self, md, head):
repo_path = aurweb.config.get('serve', 'repo-path')
repo_path = aurweb.config.get("serve", "repo-path")
self._repo = pygit2.Repository(repo_path)
self._head = head
super().__init__(r'\b([0-9a-f]{7,40})\b', md)
super().__init__(r"\b([0-9a-f]{7,40})\b", md)
def handleMatch(self, m, data):
oid = m.group(1)
@ -76,13 +76,12 @@ class GitCommitsInlineProcessor(markdown.inlinepatterns.InlineProcessor):
# Unknown OID; preserve the orginal text.
return (None, None, None)
el = Element('a')
el = Element("a")
commit_uri = aurweb.config.get("options", "commit_uri")
prefixlen = util.git_search(self._repo, oid)
el.set('href', commit_uri % (
quote_plus(self._head),
quote_plus(oid[:prefixlen])
))
el.set(
"href", commit_uri % (quote_plus(self._head), quote_plus(oid[:prefixlen]))
)
el.text = markdown.util.AtomicString(oid[:prefixlen])
return (el, m.start(0), m.end(0))
@ -97,7 +96,7 @@ class GitCommitsExtension(markdown.extensions.Extension):
def extendMarkdown(self, md):
try:
processor = GitCommitsInlineProcessor(md, self._head)
md.inlinePatterns.register(processor, 'git-commits', 117)
md.inlinePatterns.register(processor, "git-commits", 117)
except pygit2.GitError:
logger.error(f"No git repository found for '{self._head}'.")
@ -105,16 +104,16 @@ class GitCommitsExtension(markdown.extensions.Extension):
class HeadingTreeprocessor(markdown.treeprocessors.Treeprocessor):
def run(self, doc):
for elem in doc:
if elem.tag == 'h1':
elem.tag = 'h5'
elif elem.tag in ['h2', 'h3', 'h4', 'h5']:
elem.tag = 'h6'
if elem.tag == "h1":
elem.tag = "h5"
elif elem.tag in ["h2", "h3", "h4", "h5"]:
elem.tag = "h6"
class HeadingExtension(markdown.extensions.Extension):
def extendMarkdown(self, md):
# Priority doesn't matter since we don't conflict with other processors.
md.treeprocessors.register(HeadingTreeprocessor(md), 'heading', 30)
md.treeprocessors.register(HeadingTreeprocessor(md), "heading", 30)
def save_rendered_comment(comment: PackageComment, html: str):
@ -130,16 +129,26 @@ def update_comment_render(comment: PackageComment) -> None:
text = comment.Comments
pkgbasename = comment.PackageBase.Name
html = markdown.markdown(text, extensions=[
'fenced_code',
html = markdown.markdown(
text,
extensions=[
"fenced_code",
LinkifyExtension(),
FlysprayLinksExtension(),
GitCommitsExtension(pkgbasename),
HeadingExtension()
])
HeadingExtension(),
],
)
allowed_tags = (bleach.sanitizer.ALLOWED_TAGS
+ ['p', 'pre', 'h4', 'h5', 'h6', 'br', 'hr'])
allowed_tags = bleach.sanitizer.ALLOWED_TAGS + [
"p",
"pre",
"h4",
"h5",
"h6",
"br",
"hr",
]
html = bleach.clean(html, tags=allowed_tags)
save_rendered_comment(comment, html)
db.refresh(comment)
@ -148,11 +157,9 @@ def update_comment_render(comment: PackageComment) -> None:
def main():
db.get_engine()
comment_id = int(sys.argv[1])
comment = db.query(PackageComment).filter(
PackageComment.ID == comment_id
).first()
comment = db.query(PackageComment).filter(PackageComment.ID == comment_id).first()
update_comment_render(comment)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -3,12 +3,11 @@
from sqlalchemy import and_
import aurweb.config
from aurweb import db, time
from aurweb.models import TUVoteInfo
from aurweb.scripts import notify
notify_cmd = aurweb.config.get('notifications', 'notify-cmd')
notify_cmd = aurweb.config.get("notifications", "notify-cmd")
def main():
@ -23,13 +22,12 @@ def main():
filter_to = now + end
query = db.query(TUVoteInfo.ID).filter(
and_(TUVoteInfo.End >= filter_from,
TUVoteInfo.End <= filter_to)
and_(TUVoteInfo.End >= filter_from, TUVoteInfo.End <= filter_to)
)
for voteinfo in query:
notif = notify.TUVoteReminderNotification(voteinfo.ID)
notif.send()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -9,14 +9,16 @@ from aurweb.models import User
def _main():
limit_to = time.utcnow() - 86400 * 7
update_ = update(User).where(
User.LastLogin < limit_to
).values(LastLoginIPAddress=None)
update_ = (
update(User).where(User.LastLogin < limit_to).values(LastLoginIPAddress=None)
)
db.get_session().execute(update_)
update_ = update(User).where(
User.LastSSHLogin < limit_to
).values(LastSSHLoginIPAddress=None)
update_ = (
update(User)
.where(User.LastSSHLogin < limit_to)
.values(LastSSHLoginIPAddress=None)
)
db.get_session().execute(update_)
@ -26,5 +28,5 @@ def main():
_main()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -16,18 +16,16 @@ import subprocess
import sys
import tempfile
import time
from typing import Iterable
import aurweb.config
import aurweb.schema
from aurweb.exceptions import AurwebException
children = []
temporary_dir = None
verbosity = 0
asgi_backend = ''
asgi_backend = ""
workers = 1
PHP_BINARY = os.environ.get("PHP_BINARY", "php")
@ -60,22 +58,21 @@ def validate_php_config() -> None:
:return: None
"""
try:
proc = subprocess.Popen([PHP_BINARY, "-m"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
proc = subprocess.Popen(
[PHP_BINARY, "-m"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
out, _ = proc.communicate()
except FileNotFoundError:
raise AurwebException(f"Unable to locate the '{PHP_BINARY}' "
"executable.")
raise AurwebException(f"Unable to locate the '{PHP_BINARY}' " "executable.")
assert proc.returncode == 0, ("Received non-zero error code "
f"{proc.returncode} from '{PHP_BINARY}'.")
assert proc.returncode == 0, (
"Received non-zero error code " f"{proc.returncode} from '{PHP_BINARY}'."
)
modules = out.decode().splitlines()
for module in PHP_MODULES:
if module not in modules:
raise AurwebException(
f"PHP does not have the '{module}' module enabled.")
raise AurwebException(f"PHP does not have the '{module}' module enabled.")
def generate_nginx_config():
@ -91,7 +88,8 @@ def generate_nginx_config():
config_path = os.path.join(temporary_dir, "nginx.conf")
config = open(config_path, "w")
# We double nginx's braces because they conflict with Python's f-strings.
config.write(f"""
config.write(
f"""
events {{}}
daemon off;
error_log /dev/stderr info;
@ -124,7 +122,8 @@ def generate_nginx_config():
}}
}}
}}
""")
"""
)
return config_path
@ -146,20 +145,23 @@ def start():
return
atexit.register(stop)
if 'AUR_CONFIG' in os.environ:
os.environ['AUR_CONFIG'] = os.path.realpath(os.environ['AUR_CONFIG'])
if "AUR_CONFIG" in os.environ:
os.environ["AUR_CONFIG"] = os.path.realpath(os.environ["AUR_CONFIG"])
try:
terminal_width = os.get_terminal_size().columns
except OSError:
terminal_width = 80
print("{ruler}\n"
print(
"{ruler}\n"
"Spawing PHP and FastAPI, then nginx as a reverse proxy.\n"
"Check out {aur_location}\n"
"Hit ^C to terminate everything.\n"
"{ruler}"
.format(ruler=("-" * terminal_width),
aur_location=aurweb.config.get('options', 'aur_location')))
"{ruler}".format(
ruler=("-" * terminal_width),
aur_location=aurweb.config.get("options", "aur_location"),
)
)
# PHP
php_address = aurweb.config.get("php", "bind_address")
@ -168,8 +170,9 @@ def start():
spawn_child(["php", "-S", php_address, "-t", htmldir])
# FastAPI
fastapi_host, fastapi_port = aurweb.config.get(
"fastapi", "bind_address").rsplit(":", 1)
fastapi_host, fastapi_port = aurweb.config.get("fastapi", "bind_address").rsplit(
":", 1
)
# Logging config.
aurwebdir = aurweb.config.get("options", "aurwebdir")
@ -178,20 +181,33 @@ def start():
backend_args = {
"hypercorn": ["-b", f"{fastapi_host}:{fastapi_port}"],
"uvicorn": ["--host", fastapi_host, "--port", fastapi_port],
"gunicorn": ["--bind", f"{fastapi_host}:{fastapi_port}",
"-k", "uvicorn.workers.UvicornWorker",
"-w", str(workers)]
"gunicorn": [
"--bind",
f"{fastapi_host}:{fastapi_port}",
"-k",
"uvicorn.workers.UvicornWorker",
"-w",
str(workers),
],
}
backend_args = backend_args.get(asgi_backend)
spawn_child([
"python", "-m", asgi_backend,
"--log-config", fastapi_log_config,
] + backend_args + ["aurweb.asgi:app"])
spawn_child(
[
"python",
"-m",
asgi_backend,
"--log-config",
fastapi_log_config,
]
+ backend_args
+ ["aurweb.asgi:app"]
)
# nginx
spawn_child(["nginx", "-p", temporary_dir, "-c", generate_nginx_config()])
print(f"""
print(
f"""
> Started nginx.
>
> PHP backend: http://{php_address}
@ -201,11 +217,13 @@ def start():
> FastAPI frontend: http://{fastapi_host}:{FASTAPI_NGINX_PORT}
>
> Frontends are hosted via nginx and should be preferred.
""")
"""
)
def _kill_children(children: Iterable, exceptions: list[Exception] = []) \
-> list[Exception]:
def _kill_children(
children: Iterable, exceptions: list[Exception] = []
) -> list[Exception]:
"""
Kill each process found in `children`.
@ -223,8 +241,9 @@ def _kill_children(children: Iterable, exceptions: list[Exception] = []) \
return exceptions
def _wait_for_children(children: Iterable, exceptions: list[Exception] = []) \
-> list[Exception]:
def _wait_for_children(
children: Iterable, exceptions: list[Exception] = []
) -> list[Exception]:
"""
Wait for each process to end found in `children`.
@ -261,21 +280,31 @@ def stop() -> None:
exceptions = _wait_for_children(children, exceptions)
children = []
if exceptions:
raise ProcessExceptions("Errors terminating the child processes:",
exceptions)
raise ProcessExceptions("Errors terminating the child processes:", exceptions)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='python -m aurweb.spawn',
description='Start aurweb\'s test server.')
parser.add_argument('-v', '--verbose', action='count', default=0,
help='increase verbosity')
choices = ['hypercorn', 'gunicorn', 'uvicorn']
parser.add_argument('-b', '--backend', choices=choices, default='uvicorn',
help='asgi backend used to launch the python server')
parser.add_argument("-w", "--workers", default=1, type=int,
help="number of workers to use in gunicorn")
prog="python -m aurweb.spawn", description="Start aurweb's test server."
)
parser.add_argument(
"-v", "--verbose", action="count", default=0, help="increase verbosity"
)
choices = ["hypercorn", "gunicorn", "uvicorn"]
parser.add_argument(
"-b",
"--backend",
choices=choices,
default="uvicorn",
help="asgi backend used to launch the python server",
)
parser.add_argument(
"-w",
"--workers",
default=1,
type=int,
help="number of workers to use in gunicorn",
)
args = parser.parse_args()
try:

View file

@ -1,24 +1,23 @@
import copy
import functools
import os
from http import HTTPStatus
from typing import Callable
import jinja2
from fastapi import Request
from fastapi.responses import HTMLResponse
import aurweb.config
from aurweb import cookies, l10n, time
# Prepare jinja2 objects.
_loader = jinja2.FileSystemLoader(os.path.join(
aurweb.config.get("options", "aurwebdir"), "templates"))
_env = jinja2.Environment(loader=_loader, autoescape=True,
extensions=["jinja2.ext.i18n"])
_loader = jinja2.FileSystemLoader(
os.path.join(aurweb.config.get("options", "aurwebdir"), "templates")
)
_env = jinja2.Environment(
loader=_loader, autoescape=True, extensions=["jinja2.ext.i18n"]
)
def register_filter(name: str) -> Callable:
@ -35,26 +34,31 @@ def register_filter(name: str) -> Callable:
:param name: Filter name
:return: Callable used for filter
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
_env.filters[name] = wrapper
return wrapper
return decorator
def register_function(name: str) -> Callable:
""" A decorator that can be used to register a function.
"""
"""A decorator that can be used to register a function."""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
if name in _env.globals:
raise KeyError(f"Jinja already has a function named '{name}'")
_env.globals[name] = wrapper
return wrapper
return decorator
@ -85,7 +89,7 @@ def make_context(request: Request, title: str, next: str = None):
"config": aurweb.config,
"creds": aurweb.auth.creds,
"next": next if next else request.url.path,
"version": os.environ.get("COMMIT_HASH", aurweb.config.AURWEB_VERSION)
"version": os.environ.get("COMMIT_HASH", aurweb.config.AURWEB_VERSION),
}
@ -93,9 +97,11 @@ async def make_variable_context(request: Request, title: str, next: str = None):
"""Make a context with variables provided by the user
(query params via GET or form data via POST)."""
context = make_context(request, title, next)
to_copy = dict(request.query_params) \
if request.method.lower() == "get" \
to_copy = (
dict(request.query_params)
if request.method.lower() == "get"
else dict(await request.form())
)
for k, v in to_copy.items():
context[k] = v
@ -126,10 +132,9 @@ def render_raw_template(request: Request, path: str, context: dict):
return template.render(context)
def render_template(request: Request,
path: str,
context: dict,
status_code: HTTPStatus = HTTPStatus.OK):
def render_template(
request: Request, path: str, context: dict, status_code: HTTPStatus = HTTPStatus.OK
):
"""Render a template as an HTMLResponse."""
rendered = render_raw_template(request, path, context)
response = HTMLResponse(rendered, status_code=int(status_code))

View file

@ -1,5 +1,4 @@
import aurweb.db
from aurweb import models

View file

@ -17,6 +17,7 @@ class AlpmDatabase:
This class can be used to add or remove packages from a
test repository.
"""
repo = "test"
def __init__(self, database_root: str):
@ -35,13 +36,14 @@ class AlpmDatabase:
os.makedirs(pkgdir)
return pkgdir
def add(self, pkgname: str, pkgver: str, arch: str,
provides: list[str] = []) -> None:
def add(
self, pkgname: str, pkgver: str, arch: str, provides: list[str] = []
) -> None:
context = {
"pkgname": pkgname,
"pkgver": pkgver,
"arch": arch,
"provides": provides
"provides": provides,
}
template = base_template("testing/alpm_package.j2")
pkgdir = self._get_pkgdir(pkgname, pkgver, self.repo)
@ -76,8 +78,9 @@ class AlpmDatabase:
self.clean()
cmdline = ["bash", "-c", "bsdtar -czvf ../test.db *"]
proc = subprocess.run(cmdline, cwd=self.repopath)
assert proc.returncode == 0, \
f"Bad return code while creating alpm database: {proc.returncode}"
assert (
proc.returncode == 0
), f"Bad return code while creating alpm database: {proc.returncode}"
# Print out the md5 hash value of the new test.db.
test_db = os.path.join(self.remote, "test.db")

View file

@ -5,7 +5,6 @@ import email
import os
import re
import sys
from typing import TextIO
@ -28,6 +27,7 @@ class Email:
print(email.headers)
"""
TEST_DIR = "test-emails"
def __init__(self, serial: int = 1, autoparse: bool = True):
@ -61,7 +61,7 @@ class Email:
value = os.environ.get("PYTEST_CURRENT_TEST", "email").split(" ")[0]
if suite:
value = value.split(":")[0]
return re.sub(r'(\/|\.|,|:)', "_", value)
return re.sub(r"(\/|\.|,|:)", "_", value)
@staticmethod
def count() -> int:
@ -159,6 +159,6 @@ class Email:
lines += [
f"== Email #{i + 1} ==",
email.glue(),
f"== End of Email #{i + 1}"
f"== End of Email #{i + 1}",
]
print("\n".join(lines), file=file)

View file

@ -1,6 +1,5 @@
import hashlib
import os
from typing import Callable
from posix_ipc import O_CREAT, Semaphore

View file

@ -1,6 +1,5 @@
import os
import shlex
from subprocess import PIPE, Popen
from typing import Tuple

View file

@ -3,6 +3,7 @@ import aurweb.config
class User:
"""A fake User model."""
# Fake columns.
LangPreference = aurweb.config.get("options", "default_lang")
Timezone = aurweb.config.get("options", "default_timezone")
@ -16,6 +17,7 @@ class User:
class Client:
"""A fake FastAPI Request.client object."""
# A fake host.
host = "127.0.0.1"
@ -26,15 +28,18 @@ class URL:
class Request:
"""A fake Request object which mimics a FastAPI Request for tests."""
client = Client()
url = URL()
def __init__(self,
def __init__(
self,
user: User = User(),
authenticated: bool = False,
method: str = "GET",
headers: dict[str, str] = dict(),
cookies: dict[str, str] = dict()) -> "Request":
cookies: dict[str, str] = dict(),
) -> "Request":
self.user = user
self.user.authenticated = authenticated

View file

@ -42,4 +42,5 @@ class FakeSMTP:
class FakeSMTP_SSL(FakeSMTP):
"""A fake version of smtplib.SMTP_SSL used for testing."""
use_ssl = True

View file

@ -1,5 +1,4 @@
import zoneinfo
from collections import OrderedDict
from datetime import datetime
from urllib.parse import unquote
@ -24,7 +23,7 @@ def tz_offset(name: str):
offset = dt.utcoffset().total_seconds() / 60 / 60
# Prefix the offset string with a - or +.
offset_string = '-' if offset < 0 else '+'
offset_string = "-" if offset < 0 else "+"
# Remove any negativity from the offset. We want a good offset. :)
offset = abs(offset)
@ -42,15 +41,21 @@ def tz_offset(name: str):
return offset_string
SUPPORTED_TIMEZONES = OrderedDict({
SUPPORTED_TIMEZONES = OrderedDict(
{
# Flatten out the list of tuples into an OrderedDict.
timezone: offset for timezone, offset in sorted([
timezone: offset
for timezone, offset in sorted(
[
# Comprehend a list of tuples (timezone, offset display string)
# and sort them by (offset, timezone).
(tz, "(UTC%s) %s" % (tz_offset(tz), tz))
for tz in zoneinfo.available_timezones()
], key=lambda element: (tz_offset(element[0]), element[0]))
})
],
key=lambda element: (tz_offset(element[0]), element[0]),
)
}
)
def get_request_timezone(request: Request):

View file

@ -8,12 +8,23 @@ from aurweb.models.ssh_pub_key import get_fingerprint
from aurweb.util import strtobool
def simple(U: str = str(), E: str = str(), H: bool = False,
BE: str = str(), R: str = str(), HP: str = str(),
I: str = str(), K: str = str(), J: bool = False,
CN: bool = False, UN: bool = False, ON: bool = False,
S: bool = False, user: models.User = None,
**kwargs) -> None:
def simple(
U: str = str(),
E: str = str(),
H: bool = False,
BE: str = str(),
R: str = str(),
HP: str = str(),
I: str = str(),
K: str = str(),
J: bool = False,
CN: bool = False,
UN: bool = False,
ON: bool = False,
S: bool = False,
user: models.User = None,
**kwargs,
) -> None:
now = time.utcnow()
with db.begin():
user.Username = U or user.Username
@ -31,22 +42,26 @@ def simple(U: str = str(), E: str = str(), H: bool = False,
user.OwnershipNotify = strtobool(ON)
def language(L: str = str(),
def language(
L: str = str(),
request: Request = None,
user: models.User = None,
context: dict[str, Any] = {},
**kwargs) -> None:
**kwargs,
) -> None:
if L and L != user.LangPreference:
with db.begin():
user.LangPreference = L
context["language"] = L
def timezone(TZ: str = str(),
def timezone(
TZ: str = str(),
request: Request = None,
user: models.User = None,
context: dict[str, Any] = {},
**kwargs) -> None:
**kwargs,
) -> None:
if TZ and TZ != user.Timezone:
with db.begin():
user.Timezone = TZ
@ -67,8 +82,7 @@ def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None:
with db.begin():
# Delete any existing keys we can't find.
to_remove = user.ssh_pub_keys.filter(
~SSHPubKey.Fingerprint.in_(fprints))
to_remove = user.ssh_pub_keys.filter(~SSHPubKey.Fingerprint.in_(fprints))
db.delete_all(to_remove)
# For each key, if it does not yet exist, create it.
@ -79,24 +93,27 @@ def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None:
).exists()
if not db.query(exists).scalar():
# No public key exists, create one.
db.create(models.SSHPubKey, UserID=user.ID,
db.create(
models.SSHPubKey,
UserID=user.ID,
PubKey=" ".join([prefix, key]),
Fingerprint=fprints[i])
Fingerprint=fprints[i],
)
def account_type(T: int = None,
user: models.User = None,
**kwargs) -> None:
def account_type(T: int = None, user: models.User = None, **kwargs) -> None:
if T is not None and (T := int(T)) != user.AccountTypeID:
with db.begin():
user.AccountTypeID = T
def password(P: str = str(),
def password(
P: str = str(),
request: Request = None,
user: models.User = None,
context: dict[str, Any] = {},
**kwargs) -> None:
**kwargs,
) -> None:
if P and not user.valid_password(P):
# Remove the fields we consumed for passwords.
context["P"] = context["C"] = str()

View file

@ -25,42 +25,44 @@ def invalid_fields(E: str = str(), U: str = str(), **kwargs) -> None:
raise ValidationError(["Missing a required field."])
def invalid_suspend_permission(request: Request = None,
user: models.User = None,
S: str = "False",
**kwargs) -> None:
def invalid_suspend_permission(
request: Request = None, user: models.User = None, S: str = "False", **kwargs
) -> None:
if not request.user.is_elevated() and strtobool(S) != bool(user.Suspended):
raise ValidationError([
"You do not have permission to suspend accounts."])
raise ValidationError(["You do not have permission to suspend accounts."])
def invalid_username(request: Request = None, U: str = str(),
_: l10n.Translator = None,
**kwargs) -> None:
def invalid_username(
request: Request = None, U: str = str(), _: l10n.Translator = None, **kwargs
) -> None:
if not util.valid_username(U):
username_min_len = config.getint("options", "username_min_len")
username_max_len = config.getint("options", "username_max_len")
raise ValidationError([
raise ValidationError(
[
"The username is invalid.",
[
_("It must be between %s and %s characters long") % (
username_min_len, username_max_len),
_("It must be between %s and %s characters long")
% (username_min_len, username_max_len),
"Start and end with a letter or number",
"Can contain only one period, underscore or hyphen.",
],
]
])
)
def invalid_password(P: str = str(), C: str = str(),
_: l10n.Translator = None, **kwargs) -> None:
def invalid_password(
P: str = str(), C: str = str(), _: l10n.Translator = None, **kwargs
) -> None:
if P:
if not util.valid_password(P):
username_min_len = config.getint(
"options", "username_min_len")
raise ValidationError([
_("Your password must be at least %s characters.") % (
username_min_len)
])
username_min_len = config.getint("options", "username_min_len")
raise ValidationError(
[
_("Your password must be at least %s characters.")
% (username_min_len)
]
)
elif not C:
raise ValidationError(["Please confirm your new password."])
elif P != C:
@ -71,15 +73,18 @@ def is_banned(request: Request = None, **kwargs) -> None:
host = request.client.host
exists = db.query(models.Ban, models.Ban.IPAddress == host).exists()
if db.query(exists).scalar():
raise ValidationError([
raise ValidationError(
[
"Account registration has been disabled for your "
"IP address, probably due to sustained spam attacks. "
"Sorry for the inconvenience."
])
]
)
def invalid_user_password(request: Request = None, passwd: str = str(),
**kwargs) -> None:
def invalid_user_password(
request: Request = None, passwd: str = str(), **kwargs
) -> None:
if request.user.is_authenticated():
if not request.user.valid_password(passwd):
raise ValidationError(["Invalid password."])
@ -97,8 +102,9 @@ def invalid_backup_email(BE: str = str(), **kwargs) -> None:
def invalid_homepage(HP: str = str(), **kwargs) -> None:
if HP and not util.valid_homepage(HP):
raise ValidationError([
"The home page is invalid, please specify the full HTTP(s) URL."])
raise ValidationError(
["The home page is invalid, please specify the full HTTP(s) URL."]
)
def invalid_pgp_key(K: str = str(), **kwargs) -> None:
@ -106,8 +112,9 @@ def invalid_pgp_key(K: str = str(), **kwargs) -> None:
raise ValidationError(["The PGP key fingerprint is invalid."])
def invalid_ssh_pubkey(PK: str = str(), user: models.User = None,
_: l10n.Translator = None, **kwargs) -> None:
def invalid_ssh_pubkey(
PK: str = str(), user: models.User = None, _: l10n.Translator = None, **kwargs
) -> None:
if not PK:
return
@ -119,15 +126,23 @@ def invalid_ssh_pubkey(PK: str = str(), user: models.User = None,
for prefix, key in keys:
fingerprint = get_fingerprint(f"{prefix} {key}")
exists = db.query(models.SSHPubKey).filter(
and_(models.SSHPubKey.UserID != user.ID,
models.SSHPubKey.Fingerprint == fingerprint)
).exists()
exists = (
db.query(models.SSHPubKey)
.filter(
and_(
models.SSHPubKey.UserID != user.ID,
models.SSHPubKey.Fingerprint == fingerprint,
)
)
.exists()
)
if db.query(exists).scalar():
raise ValidationError([
_("The SSH public key, %s%s%s, is already in use.") % (
"<strong>", fingerprint, "</strong>")
])
raise ValidationError(
[
_("The SSH public key, %s%s%s, is already in use.")
% ("<strong>", fingerprint, "</strong>")
]
)
def invalid_language(L: str = str(), **kwargs) -> None:
@ -140,60 +155,78 @@ def invalid_timezone(TZ: str = str(), **kwargs) -> None:
raise ValidationError(["Timezone is not currently supported."])
def username_in_use(U: str = str(), user: models.User = None,
_: l10n.Translator = None, **kwargs) -> None:
exists = db.query(models.User).filter(
and_(models.User.ID != user.ID,
models.User.Username == U)
).exists()
def username_in_use(
U: str = str(), user: models.User = None, _: l10n.Translator = None, **kwargs
) -> None:
exists = (
db.query(models.User)
.filter(and_(models.User.ID != user.ID, models.User.Username == U))
.exists()
)
if db.query(exists).scalar():
# If the username already exists...
raise ValidationError([
_("The username, %s%s%s, is already in use.") % (
"<strong>", U, "</strong>")
])
raise ValidationError(
[
_("The username, %s%s%s, is already in use.")
% ("<strong>", U, "</strong>")
]
)
def email_in_use(E: str = str(), user: models.User = None,
_: l10n.Translator = None, **kwargs) -> None:
exists = db.query(models.User).filter(
and_(models.User.ID != user.ID,
models.User.Email == E)
).exists()
def email_in_use(
E: str = str(), user: models.User = None, _: l10n.Translator = None, **kwargs
) -> None:
exists = (
db.query(models.User)
.filter(and_(models.User.ID != user.ID, models.User.Email == E))
.exists()
)
if db.query(exists).scalar():
# If the email already exists...
raise ValidationError([
_("The address, %s%s%s, is already in use.") % (
"<strong>", E, "</strong>")
])
raise ValidationError(
[
_("The address, %s%s%s, is already in use.")
% ("<strong>", E, "</strong>")
]
)
def invalid_account_type(T: int = None, request: Request = None,
def invalid_account_type(
T: int = None,
request: Request = None,
user: models.User = None,
_: l10n.Translator = None,
**kwargs) -> None:
**kwargs,
) -> None:
if T is not None and (T := int(T)) != user.AccountTypeID:
name = ACCOUNT_TYPE_NAME.get(T, None)
has_cred = request.user.has_credential(creds.ACCOUNT_CHANGE_TYPE)
if name is None:
raise ValidationError(["Invalid account type provided."])
elif not has_cred:
raise ValidationError([
"You do not have permission to change account types."])
raise ValidationError(
["You do not have permission to change account types."]
)
elif T > request.user.AccountTypeID:
# If the chosen account type is higher than the editor's account
# type, the editor doesn't have permission to set the new type.
error = _("You do not have permission to change "
"this user's account type to %s.") % name
error = (
_(
"You do not have permission to change "
"this user's account type to %s."
)
% name
)
raise ValidationError([error])
logger.debug(f"Trusted User '{request.user.Username}' has "
logger.debug(
f"Trusted User '{request.user.Username}' has "
f"modified '{user.Username}' account's type to"
f" {name}.")
f" {name}."
)
def invalid_captcha(captcha_salt: str = None, captcha: str = None,
**kwargs) -> None:
def invalid_captcha(captcha_salt: str = None, captcha: str = None, **kwargs) -> None:
if captcha_salt and captcha_salt not in get_captcha_salts():
raise ValidationError(["This CAPTCHA has expired. Please try again."])

View file

@ -2,7 +2,6 @@ import math
import re
import secrets
import string
from datetime import datetime
from http import HTTPStatus
from subprocess import PIPE, Popen
@ -11,12 +10,10 @@ from urllib.parse import urlparse
import fastapi
import pygit2
from email_validator import EmailSyntaxError, validate_email
from fastapi.responses import JSONResponse
import aurweb.config
from aurweb import defaults, logging
logger = logging.get_logger(__name__)
@ -24,7 +21,7 @@ logger = logging.get_logger(__name__)
def make_random_string(length: int) -> str:
alphanumerics = string.ascii_lowercase + string.digits
return ''.join([secrets.choice(alphanumerics) for i in range(length)])
return "".join([secrets.choice(alphanumerics) for i in range(length)])
def make_nonce(length: int = 8):
@ -45,7 +42,7 @@ def valid_username(username):
# Check that username contains: one or more alphanumeric
# characters, an optional separator of '.', '-' or '_', followed
# by alphanumeric characters.
return re.match(r'^[a-zA-Z0-9]+[.\-_]?[a-zA-Z0-9]+$', username)
return re.match(r"^[a-zA-Z0-9]+[.\-_]?[a-zA-Z0-9]+$", username)
def valid_email(email):
@ -151,8 +148,7 @@ def git_search(repo: pygit2.Repository, commit_hash: str) -> int:
return prefixlen
async def error_or_result(next: Callable, *args, **kwargs) \
-> fastapi.Response:
async def error_or_result(next: Callable, *args, **kwargs) -> fastapi.Response:
"""
Try to return a response from `next`.
@ -176,7 +172,7 @@ async def error_or_result(next: Callable, *args, **kwargs) \
def parse_ssh_key(string: str) -> Tuple[str, str]:
"""Parse an SSH public key."""
invalid_exc = ValueError("The SSH public key is invalid.")
parts = re.sub(r'\s\s+', ' ', string.strip()).split()
parts = re.sub(r"\s\s+", " ", string.strip()).split()
if len(parts) < 2:
raise invalid_exc
@ -185,8 +181,7 @@ def parse_ssh_key(string: str) -> Tuple[str, str]:
if prefix not in prefixes:
raise invalid_exc
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE,
stderr=PIPE)
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE, stderr=PIPE)
out, _ = proc.communicate(f"{prefix} {key}".encode())
if proc.returncode:
raise invalid_exc

View file

@ -108,4 +108,3 @@ The following list of steps describes exactly how this verification works:
- `options.disable_http_login: 1`
- `options.login_timeout: <default_provided_in_config.defaults>`
- `options.persistent_cookie_timeout: <default_provided_in_config.defaults>`

View file

@ -147,4 +147,3 @@ http {
'' close;
}
}

View file

@ -2,7 +2,6 @@ import logging
import logging.config
import sqlalchemy
from alembic import context
import aurweb.db
@ -69,9 +68,7 @@ def run_migrations_online():
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()

View file

@ -15,7 +15,6 @@ import os
import random
import sys
import time
from datetime import datetime
import bcrypt
@ -26,8 +25,8 @@ USER_ID = 5 # Users.ID of first bogus user
PKG_ID = 1 # Packages.ID of first package
# how many users to 'register'
MAX_USERS = int(os.environ.get("MAX_USERS", 38000))
MAX_DEVS = .1 # what percentage of MAX_USERS are Developers
MAX_TUS = .2 # what percentage of MAX_USERS are Trusted Users
MAX_DEVS = 0.1 # what percentage of MAX_USERS are Developers
MAX_TUS = 0.2 # what percentage of MAX_USERS are Trusted Users
# how many packages to load
MAX_PKGS = int(os.environ.get("MAX_PKGS", 32000))
PKG_DEPS = (1, 15) # min/max depends a package has
@ -35,7 +34,7 @@ PKG_RELS = (1, 5) # min/max relations a package has
PKG_SRC = (1, 3) # min/max sources a package has
PKG_CMNTS = (1, 5) # min/max number of comments a package has
CATEGORIES_COUNT = 17 # the number of categories from aur-schema
VOTING = (0, .001) # percentage range for package voting
VOTING = (0, 0.001) # percentage range for package voting
# number of open trusted user proposals
OPEN_PROPOSALS = int(os.environ.get("OPEN_PROPOSALS", 15))
# number of closed trusted user proposals
@ -116,7 +115,7 @@ def normalize(unicode_data):
"""We only accept ascii for usernames. Also use this to normalize
package names; our database utf8mb4 collations compare with Unicode
Equivalence."""
return unicode_data.encode('ascii', 'ignore').decode('ascii')
return unicode_data.encode("ascii", "ignore").decode("ascii")
# select random usernames
@ -196,10 +195,12 @@ for u in user_keys:
# "{salt}{username}"
to_hash = f"{salt}{u}"
h = hashlib.new('md5')
h = hashlib.new("md5")
h.update(to_hash.encode())
s = ("INSERT INTO Users (ID, AccountTypeID, Username, Email, Passwd, Salt)"
" VALUES (%d, %d, '%s', '%s@example.com', '%s', '%s');\n")
s = (
"INSERT INTO Users (ID, AccountTypeID, Username, Email, Passwd, Salt)"
" VALUES (%d, %d, '%s', '%s@example.com', '%s', '%s');\n"
)
s = s % (seen_users[u], account_type, u, u, h.hexdigest(), salt)
out.write(s)
@ -230,13 +231,17 @@ for p in list(seen_pkgs.keys()):
uuid = genUID() # the submitter/user
s = ("INSERT INTO PackageBases (ID, Name, FlaggerComment, SubmittedTS, ModifiedTS, "
"SubmitterUID, MaintainerUID, PackagerUID) VALUES (%d, '%s', '', %d, %d, %d, %s, %s);\n")
s = (
"INSERT INTO PackageBases (ID, Name, FlaggerComment, SubmittedTS, ModifiedTS, "
"SubmitterUID, MaintainerUID, PackagerUID) VALUES (%d, '%s', '', %d, %d, %d, %s, %s);\n"
)
s = s % (seen_pkgs[p], p, NOW, NOW, uuid, muid, puid)
out.write(s)
s = ("INSERT INTO Packages (ID, PackageBaseID, Name, Version) VALUES "
"(%d, %d, '%s', '%s');\n")
s = (
"INSERT INTO Packages (ID, PackageBaseID, Name, Version) VALUES "
"(%d, %d, '%s', '%s');\n"
)
s = s % (seen_pkgs[p], seen_pkgs[p], p, genVersion())
out.write(s)
@ -247,8 +252,10 @@ for p in list(seen_pkgs.keys()):
num_comments = random.randrange(PKG_CMNTS[0], PKG_CMNTS[1])
for i in range(0, num_comments):
now = NOW + random.randrange(400, 86400 * 3)
s = ("INSERT INTO PackageComments (PackageBaseID, UsersID,"
" Comments, RenderedComment, CommentTS) VALUES (%d, %d, '%s', '', %d);\n")
s = (
"INSERT INTO PackageComments (PackageBaseID, UsersID,"
" Comments, RenderedComment, CommentTS) VALUES (%d, %d, '%s', '', %d);\n"
)
s = s % (seen_pkgs[p], genUID(), genFortune(), now)
out.write(s)
@ -258,14 +265,17 @@ utcnow = int(datetime.utcnow().timestamp())
track_votes = {}
log.debug("Casting votes for packages.")
for u in user_keys:
num_votes = random.randrange(int(len(seen_pkgs) * VOTING[0]),
int(len(seen_pkgs) * VOTING[1]))
num_votes = random.randrange(
int(len(seen_pkgs) * VOTING[0]), int(len(seen_pkgs) * VOTING[1])
)
pkgvote = {}
for v in range(num_votes):
pkg = random.randrange(1, len(seen_pkgs) + 1)
if pkg not in pkgvote:
s = ("INSERT INTO PackageVotes (UsersID, PackageBaseID, VoteTS)"
" VALUES (%d, %d, %d);\n")
s = (
"INSERT INTO PackageVotes (UsersID, PackageBaseID, VoteTS)"
" VALUES (%d, %d, %d);\n"
)
s = s % (seen_users[u], pkg, utcnow)
pkgvote[pkg] = 1
if pkg not in track_votes:
@ -310,9 +320,12 @@ for p in seen_pkgs_keys:
src_file = user_keys[random.randrange(0, len(user_keys))]
src = "%s%s.%s/%s/%s-%s.tar.gz" % (
RANDOM_URL[random.randrange(0, len(RANDOM_URL))],
p, RANDOM_TLDS[random.randrange(0, len(RANDOM_TLDS))],
p,
RANDOM_TLDS[random.randrange(0, len(RANDOM_TLDS))],
RANDOM_LOCS[random.randrange(0, len(RANDOM_LOCS))],
src_file, genVersion())
src_file,
genVersion(),
)
s = "INSERT INTO PackageSources(PackageID, Source) VALUES (%d, '%s');\n"
s = s % (seen_pkgs[p], src)
out.write(s)
@ -334,8 +347,10 @@ for t in range(0, OPEN_PROPOSALS + CLOSE_PROPOSALS):
else:
user = user_keys[random.randrange(0, len(user_keys))]
suid = trustedusers[random.randrange(0, len(trustedusers))]
s = ("INSERT INTO TU_VoteInfo (Agenda, User, Submitted, End,"
" Quorum, SubmitterID) VALUES ('%s', '%s', %d, %d, 0.0, %d);\n")
s = (
"INSERT INTO TU_VoteInfo (Agenda, User, Submitted, End,"
" Quorum, SubmitterID) VALUES ('%s', '%s', %d, %d, 0.0, %d);\n"
)
s = s % (genFortune(), user, start, end, suid)
out.write(s)
count += 1

View file

@ -65,4 +65,3 @@
</form>
</div>
{% endblock %}

View file

@ -79,4 +79,3 @@
</td>
</tr>
</table>

View file

@ -39,12 +39,10 @@ ahead of each function takes too long when compared to this method.
"""
import os
import pathlib
from multiprocessing import Lock
import py
import pytest
from posix_ipc import O_CREAT, Semaphore
from sqlalchemy import create_engine
from sqlalchemy.engine import URL
@ -54,7 +52,6 @@ from sqlalchemy.orm import scoped_session
import aurweb.config
import aurweb.db
from aurweb import initdb, logging, testing
from aurweb.testing.email import Email
from aurweb.testing.filelock import FileLock
@ -78,13 +75,10 @@ def test_engine() -> Engine:
unix_socket = aurweb.config.get_with_fallback("database", "socket", None)
kwargs = {
"username": aurweb.config.get("database", "user"),
"password": aurweb.config.get_with_fallback(
"database", "password", None),
"password": aurweb.config.get_with_fallback("database", "password", None),
"host": aurweb.config.get("database", "host"),
"port": aurweb.config.get_with_fallback("database", "port", None),
"query": {
"unix_socket": unix_socket
}
"query": {"unix_socket": unix_socket},
}
backend = aurweb.config.get("database", "backend")
@ -99,6 +93,7 @@ class AlembicArgs:
This structure is needed to pass conftest-specific arguments
to initdb.run duration database creation.
"""
verbose = False
use_alembic = True

View file

@ -1,5 +1,4 @@
import pytest
from sqlalchemy.exc import IntegrityError
from aurweb import db
@ -17,17 +16,21 @@ def setup(db_test):
@pytest.fixture
def user() -> User:
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountTypeID=USER_ID)
user = db.create(
User,
Username="test",
Email="test@example.org",
RealName="Test User",
Passwd="testPassword",
AccountTypeID=USER_ID,
)
yield user
@pytest.fixture
def term() -> Term:
with db.begin():
term = db.create(Term, Description="Test term",
URL="https://test.term")
term = db.create(Term, Description="Test term", URL="https://test.term")
yield term

View file

@ -28,20 +28,24 @@ def test_account_type(account_type):
# Next, test our string functions.
assert str(account_type) == "TestUser"
assert repr(account_type) == \
"<AccountType(ID='%s', AccountType='TestUser')>" % (
account_type.ID)
assert repr(account_type) == "<AccountType(ID='%s', AccountType='TestUser')>" % (
account_type.ID
)
record = db.query(AccountType,
AccountType.AccountType == "TestUser").first()
record = db.query(AccountType, AccountType.AccountType == "TestUser").first()
assert account_type == record
def test_user_account_type_relationship(account_type):
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
user = db.create(
User,
Username="test",
Email="test@example.org",
RealName="Test User",
Passwd="testPassword",
AccountType=account_type,
)
assert user.AccountType == account_type

View file

@ -1,6 +1,5 @@
import re
import tempfile
from datetime import datetime
from http import HTTPStatus
from logging import DEBUG
@ -8,17 +7,21 @@ from subprocess import Popen
import lxml.html
import pytest
from fastapi.testclient import TestClient
import aurweb.models.account_type as at
from aurweb import captcha, db, logging, time
from aurweb.asgi import app
from aurweb.db import create, query
from aurweb.models.accepted_term import AcceptedTerm
from aurweb.models.account_type import (DEVELOPER_ID, TRUSTED_USER, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, USER_ID,
AccountType)
from aurweb.models.account_type import (
DEVELOPER_ID,
TRUSTED_USER,
TRUSTED_USER_AND_DEV_ID,
TRUSTED_USER_ID,
USER_ID,
AccountType,
)
from aurweb.models.ban import Ban
from aurweb.models.session import Session
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
@ -39,8 +42,11 @@ def make_ssh_pubkey():
# dependency to passing this test).
with tempfile.TemporaryDirectory() as tmpdir:
with open("/dev/null", "w") as null:
proc = Popen(["ssh-keygen", "-f", f"{tmpdir}/test.ssh", "-N", ""],
stdout=null, stderr=null)
proc = Popen(
["ssh-keygen", "-f", f"{tmpdir}/test.ssh", "-N", ""],
stdout=null,
stderr=null,
)
proc.wait()
assert proc.returncode == 0
@ -60,9 +66,13 @@ def client() -> TestClient:
def create_user(username: str) -> User:
email = f"{username}@example.org"
user = create(User, Username=username, Email=email,
user = create(
User,
Username=username,
Email=email,
Passwd="testPassword",
AccountTypeID=USER_ID)
AccountTypeID=USER_ID,
)
return user
@ -85,8 +95,9 @@ def test_get_passreset_authed_redirects(client: TestClient, user: User):
assert sid is not None
with client as request:
response = request.get("/passreset", cookies={"AURSID": sid},
allow_redirects=False)
response = request.get(
"/passreset", cookies={"AURSID": sid}, allow_redirects=False
)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
assert response.headers.get("location") == "/"
@ -129,10 +140,12 @@ def test_post_passreset_authed_redirects(client: TestClient, user: User):
assert sid is not None
with client as request:
response = request.post("/passreset",
response = request.post(
"/passreset",
cookies={"AURSID": sid},
data={"user": "blah"},
allow_redirects=False)
allow_redirects=False,
)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
assert response.headers.get("location") == "/"
@ -166,8 +179,9 @@ def test_post_passreset_user_suspended(client: TestClient, user: User):
def test_post_passreset_resetkey(client: TestClient, user: User):
with db.begin():
user.session = Session(UsersID=user.ID, SessionID="blah",
LastUpdateTS=time.utcnow())
user.session = Session(
UsersID=user.ID, SessionID="blah", LastUpdateTS=time.utcnow()
)
# Prepare a password reset.
with client as request:
@ -182,7 +196,7 @@ def test_post_passreset_resetkey(client: TestClient, user: User):
"user": TEST_USERNAME,
"resetkey": resetkey,
"password": "abcd1234",
"confirm": "abcd1234"
"confirm": "abcd1234",
}
with client as request:
@ -200,10 +214,7 @@ def make_resetkey(client: TestClient, user: User):
def make_passreset_data(user: User, resetkey: str):
return {
"user": user.Username,
"resetkey": resetkey
}
return {"user": user.Username, "resetkey": resetkey}
def test_post_passreset_error_invalid_email(client: TestClient, user: User):
@ -240,8 +251,7 @@ def test_post_passreset_error_missing_field(client: TestClient, user: User):
assert error in response.content.decode("utf-8")
def test_post_passreset_error_password_mismatch(client: TestClient,
user: User):
def test_post_passreset_error_password_mismatch(client: TestClient, user: User):
resetkey = make_resetkey(client, user)
post_data = make_passreset_data(user, resetkey)
@ -257,8 +267,7 @@ def test_post_passreset_error_password_mismatch(client: TestClient,
assert error in response.content.decode("utf-8")
def test_post_passreset_error_password_requirements(client: TestClient,
user: User):
def test_post_passreset_error_password_requirements(client: TestClient, user: User):
resetkey = make_resetkey(client, user)
post_data = make_passreset_data(user, resetkey)
@ -297,7 +306,7 @@ def post_register(request, **kwargs):
"L": "en",
"TZ": "UTC",
"captcha": answer,
"captcha_salt": salt
"captcha_salt": salt,
}
# For any kwargs given, override their k:v pairs in data.
@ -380,9 +389,11 @@ def test_post_register_error_ip_banned(client: TestClient):
assert response.status_code == int(HTTPStatus.BAD_REQUEST)
content = response.content.decode()
assert ("Account registration has been disabled for your IP address, " +
"probably due to sustained spam attacks. Sorry for the " +
"inconvenience.") in content
assert (
"Account registration has been disabled for your IP address, "
+ "probably due to sustained spam attacks. Sorry for the "
+ "inconvenience."
) in content
def test_post_register_error_missing_username(client: TestClient):
@ -489,7 +500,7 @@ def test_post_register_error_invalid_pgp_fingerprints(client: TestClient):
expected = "The PGP key fingerprint is invalid."
assert expected in content
pk = 'z' + ('a' * 39)
pk = "z" + ("a" * 39)
with client as request:
response = post_register(request, K=pk)
@ -569,8 +580,11 @@ def test_post_register_error_ssh_pubkey_taken(client: TestClient, user: User):
# dependency to passing this test).
with tempfile.TemporaryDirectory() as tmpdir:
with open("/dev/null", "w") as null:
proc = Popen(["ssh-keygen", "-f", f"{tmpdir}/test.ssh", "-N", ""],
stdout=null, stderr=null)
proc = Popen(
["ssh-keygen", "-f", f"{tmpdir}/test.ssh", "-N", ""],
stdout=null,
stderr=null,
)
proc.wait()
assert proc.returncode == 0
@ -602,8 +616,11 @@ def test_post_register_with_ssh_pubkey(client: TestClient):
# dependency to passing this test).
with tempfile.TemporaryDirectory() as tmpdir:
with open("/dev/null", "w") as null:
proc = Popen(["ssh-keygen", "-f", f"{tmpdir}/test.ssh", "-N", ""],
stdout=null, stderr=null)
proc = Popen(
["ssh-keygen", "-f", f"{tmpdir}/test.ssh", "-N", ""],
stdout=null,
stderr=null,
)
proc.wait()
assert proc.returncode == 0
@ -700,14 +717,18 @@ def test_get_account_edit_unauthorized(client: TestClient, user: User):
sid = user.login(request, "testPassword")
with db.begin():
user2 = create(User, Username="test2", Email="test2@example.org",
Passwd="testPassword", AccountTypeID=USER_ID)
user2 = create(
User,
Username="test2",
Email="test2@example.org",
Passwd="testPassword",
AccountTypeID=USER_ID,
)
endpoint = f"/account/{user2.Username}/edit"
with client as request:
# Try to edit `test2` while authenticated as `test`.
response = request.get(endpoint, cookies={"AURSID": sid},
allow_redirects=False)
response = request.get(endpoint, cookies={"AURSID": sid}, allow_redirects=False)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
expected = f"/account/{user2.Username}"
@ -718,16 +739,15 @@ def test_post_account_edit(client: TestClient, user: User):
request = Request()
sid = user.login(request, "testPassword")
post_data = {
"U": "test",
"E": "test666@example.org",
"passwd": "testPassword"
}
post_data = {"U": "test", "E": "test666@example.org", "passwd": "testPassword"}
with client as request:
response = request.post("/account/test/edit", cookies={
"AURSID": sid
}, data=post_data, allow_redirects=False)
response = request.post(
"/account/test/edit",
cookies={"AURSID": sid},
data=post_data,
allow_redirects=False,
)
assert response.status_code == int(HTTPStatus.OK)
@ -772,8 +792,7 @@ def test_post_account_edit_type_as_dev(client: TestClient, tu_user: User):
assert user2.AccountTypeID == at.DEVELOPER_ID
def test_post_account_edit_invalid_type_as_tu(client: TestClient,
tu_user: User):
def test_post_account_edit_invalid_type_as_tu(client: TestClient, tu_user: User):
with db.begin():
user2 = create_user("test_tu")
tu_user.AccountTypeID = at.TRUSTED_USER_ID
@ -792,8 +811,10 @@ def test_post_account_edit_invalid_type_as_tu(client: TestClient,
assert user2.AccountTypeID == at.USER_ID
errors = get_errors(resp.text)
expected = ("You do not have permission to change this user's "
f"account type to {at.DEVELOPER}.")
expected = (
"You do not have permission to change this user's "
f"account type to {at.DEVELOPER}."
)
assert errors[0].text.strip() == expected
@ -807,16 +828,13 @@ def test_post_account_edit_dev(client: TestClient, tu_user: User):
request = Request()
sid = tu_user.login(request, "testPassword")
post_data = {
"U": "test",
"E": "test666@example.org",
"passwd": "testPassword"
}
post_data = {"U": "test", "E": "test666@example.org", "passwd": "testPassword"}
endpoint = f"/account/{tu_user.Username}/edit"
with client as request:
response = request.post(endpoint, cookies={"AURSID": sid},
data=post_data, allow_redirects=False)
response = request.post(
endpoint, cookies={"AURSID": sid}, data=post_data, allow_redirects=False
)
assert response.status_code == int(HTTPStatus.OK)
expected = "The account, <strong>test</strong>, "
@ -832,13 +850,16 @@ def test_post_account_edit_language(client: TestClient, user: User):
"U": "test",
"E": "test@example.org",
"L": "de", # German
"passwd": "testPassword"
"passwd": "testPassword",
}
with client as request:
response = request.post("/account/test/edit", cookies={
"AURSID": sid
}, data=post_data, allow_redirects=False)
response = request.post(
"/account/test/edit",
cookies={"AURSID": sid},
data=post_data,
allow_redirects=False,
)
assert response.status_code == int(HTTPStatus.OK)
@ -859,33 +880,33 @@ def test_post_account_edit_timezone(client: TestClient, user: User):
"U": "test",
"E": "test@example.org",
"TZ": "CET",
"passwd": "testPassword"
"passwd": "testPassword",
}
with client as request:
response = request.post("/account/test/edit", cookies={
"AURSID": sid
}, data=post_data, allow_redirects=False)
response = request.post(
"/account/test/edit",
cookies={"AURSID": sid},
data=post_data,
allow_redirects=False,
)
assert response.status_code == int(HTTPStatus.OK)
def test_post_account_edit_error_missing_password(client: TestClient,
user: User):
def test_post_account_edit_error_missing_password(client: TestClient, user: User):
request = Request()
sid = user.login(request, "testPassword")
post_data = {
"U": "test",
"E": "test@example.org",
"TZ": "CET",
"passwd": ""
}
post_data = {"U": "test", "E": "test@example.org", "TZ": "CET", "passwd": ""}
with client as request:
response = request.post("/account/test/edit", cookies={
"AURSID": sid
}, data=post_data, allow_redirects=False)
response = request.post(
"/account/test/edit",
cookies={"AURSID": sid},
data=post_data,
allow_redirects=False,
)
assert response.status_code == int(HTTPStatus.BAD_REQUEST)
@ -893,22 +914,19 @@ def test_post_account_edit_error_missing_password(client: TestClient,
assert "Invalid password." in content
def test_post_account_edit_error_invalid_password(client: TestClient,
user: User):
def test_post_account_edit_error_invalid_password(client: TestClient, user: User):
request = Request()
sid = user.login(request, "testPassword")
post_data = {
"U": "test",
"E": "test@example.org",
"TZ": "CET",
"passwd": "invalid"
}
post_data = {"U": "test", "E": "test@example.org", "TZ": "CET", "passwd": "invalid"}
with client as request:
response = request.post("/account/test/edit", cookies={
"AURSID": sid
}, data=post_data, allow_redirects=False)
response = request.post(
"/account/test/edit",
cookies={"AURSID": sid},
data=post_data,
allow_redirects=False,
)
assert response.status_code == int(HTTPStatus.BAD_REQUEST)
@ -916,18 +934,18 @@ def test_post_account_edit_error_invalid_password(client: TestClient,
assert "Invalid password." in content
def test_post_account_edit_suspend_unauthorized(client: TestClient,
user: User):
def test_post_account_edit_suspend_unauthorized(client: TestClient, user: User):
cookies = {"AURSID": user.login(Request(), "testPassword")}
post_data = {
"U": "test",
"E": "test@example.org",
"S": True,
"passwd": "testPassword"
"passwd": "testPassword",
}
with client as request:
resp = request.post(f"/account/{user.Username}/edit", data=post_data,
cookies=cookies)
resp = request.post(
f"/account/{user.Username}/edit", data=post_data, cookies=cookies
)
assert resp.status_code == int(HTTPStatus.BAD_REQUEST)
errors = get_errors(resp.text)
@ -945,11 +963,12 @@ def test_post_account_edit_inactivity(client: TestClient, user: User):
"U": "test",
"E": "test@example.org",
"J": True,
"passwd": "testPassword"
"passwd": "testPassword",
}
with client as request:
resp = request.post(f"/account/{user.Username}/edit", data=post_data,
cookies=cookies)
resp = request.post(
f"/account/{user.Username}/edit", data=post_data, cookies=cookies
)
assert resp.status_code == int(HTTPStatus.OK)
# Make sure the user record got updated correctly.
@ -957,8 +976,9 @@ def test_post_account_edit_inactivity(client: TestClient, user: User):
post_data.update({"J": False})
with client as request:
resp = request.post(f"/account/{user.Username}/edit", data=post_data,
cookies=cookies)
resp = request.post(
f"/account/{user.Username}/edit", data=post_data, cookies=cookies
)
assert resp.status_code == int(HTTPStatus.OK)
assert user.InactivityTS == 0
@ -974,7 +994,7 @@ def test_post_account_edit_suspended(client: TestClient, user: User):
"U": "test",
"E": "test@example.org",
"S": True,
"passwd": "testPassword"
"passwd": "testPassword",
}
endpoint = f"/account/{user.Username}/edit"
with client as request:
@ -997,21 +1017,27 @@ def test_post_account_edit_error_unauthorized(client: TestClient, user: User):
sid = user.login(request, "testPassword")
with db.begin():
user2 = create(User, Username="test2", Email="test2@example.org",
Passwd="testPassword", AccountTypeID=USER_ID)
user2 = create(
User,
Username="test2",
Email="test2@example.org",
Passwd="testPassword",
AccountTypeID=USER_ID,
)
post_data = {
"U": "test",
"E": "test@example.org",
"TZ": "CET",
"passwd": "testPassword"
"passwd": "testPassword",
}
endpoint = f"/account/{user2.Username}/edit"
with client as request:
# Attempt to edit 'test2' while logged in as 'test'.
response = request.post(endpoint, cookies={"AURSID": sid},
data=post_data, allow_redirects=False)
response = request.post(
endpoint, cookies={"AURSID": sid}, data=post_data, allow_redirects=False
)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
expected = f"/account/{user2.Username}"
@ -1026,13 +1052,16 @@ def test_post_account_edit_ssh_pub_key(client: TestClient, user: User):
"U": "test",
"E": "test@example.org",
"PK": make_ssh_pubkey(),
"passwd": "testPassword"
"passwd": "testPassword",
}
with client as request:
response = request.post("/account/test/edit", cookies={
"AURSID": sid
}, data=post_data, allow_redirects=False)
response = request.post(
"/account/test/edit",
cookies={"AURSID": sid},
data=post_data,
allow_redirects=False,
)
assert response.status_code == int(HTTPStatus.OK)
@ -1040,9 +1069,12 @@ def test_post_account_edit_ssh_pub_key(client: TestClient, user: User):
post_data["PK"] = make_ssh_pubkey()
with client as request:
response = request.post("/account/test/edit", cookies={
"AURSID": sid
}, data=post_data, allow_redirects=False)
response = request.post(
"/account/test/edit",
cookies={"AURSID": sid},
data=post_data,
allow_redirects=False,
)
assert response.status_code == int(HTTPStatus.OK)
@ -1055,13 +1087,16 @@ def test_post_account_edit_missing_ssh_pubkey(client: TestClient, user: User):
"U": user.Username,
"E": user.Email,
"PK": make_ssh_pubkey(),
"passwd": "testPassword"
"passwd": "testPassword",
}
with client as request:
response = request.post("/account/test/edit", cookies={
"AURSID": sid
}, data=post_data, allow_redirects=False)
response = request.post(
"/account/test/edit",
cookies={"AURSID": sid},
data=post_data,
allow_redirects=False,
)
assert response.status_code == int(HTTPStatus.OK)
@ -1069,13 +1104,16 @@ def test_post_account_edit_missing_ssh_pubkey(client: TestClient, user: User):
"U": user.Username,
"E": user.Email,
"PK": str(), # Pass an empty string now to walk the delete path.
"passwd": "testPassword"
"passwd": "testPassword",
}
with client as request:
response = request.post("/account/test/edit", cookies={
"AURSID": sid
}, data=post_data, allow_redirects=False)
response = request.post(
"/account/test/edit",
cookies={"AURSID": sid},
data=post_data,
allow_redirects=False,
)
assert response.status_code == int(HTTPStatus.OK)
@ -1087,12 +1125,13 @@ def test_post_account_edit_invalid_ssh_pubkey(client: TestClient, user: User):
"U": "test",
"E": "test@example.org",
"PK": pubkey,
"passwd": "testPassword"
"passwd": "testPassword",
}
cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request:
response = request.post("/account/test/edit", data=data,
cookies=cookies, allow_redirects=False)
response = request.post(
"/account/test/edit", data=data, cookies=cookies, allow_redirects=False
)
assert response.status_code == int(HTTPStatus.BAD_REQUEST)
@ -1106,13 +1145,16 @@ def test_post_account_edit_password(client: TestClient, user: User):
"E": "test@example.org",
"P": "newPassword",
"C": "newPassword",
"passwd": "testPassword"
"passwd": "testPassword",
}
with client as request:
response = request.post("/account/test/edit", cookies={
"AURSID": sid
}, data=post_data, allow_redirects=False)
response = request.post(
"/account/test/edit",
cookies={"AURSID": sid},
data=post_data,
allow_redirects=False,
)
assert response.status_code == int(HTTPStatus.OK)
@ -1132,7 +1174,7 @@ def test_post_account_edit_self_type_as_user(client: TestClient, user: User):
"U": user.Username,
"E": user.Email,
"T": TRUSTED_USER_ID,
"passwd": "testPassword"
"passwd": "testPassword",
}
with client as request:
resp = request.post(endpoint, data=data, cookies=cookies)
@ -1151,8 +1193,7 @@ def test_post_account_edit_other_user_as_user(client: TestClient, user: User):
endpoint = f"/account/{user2.Username}/edit"
with client as request:
resp = request.get(endpoint, cookies=cookies,
allow_redirects=False)
resp = request.get(endpoint, cookies=cookies, allow_redirects=False)
assert resp.status_code == int(HTTPStatus.SEE_OTHER)
assert resp.headers.get("location") == f"/account/{user2.Username}"
@ -1172,7 +1213,7 @@ def test_post_account_edit_self_type_as_tu(client: TestClient, tu_user: User):
"U": tu_user.Username,
"E": tu_user.Email,
"T": USER_ID,
"passwd": "testPassword"
"passwd": "testPassword",
}
with client as request:
resp = request.post(endpoint, data=data, cookies=cookies)
@ -1182,7 +1223,8 @@ def test_post_account_edit_self_type_as_tu(client: TestClient, tu_user: User):
def test_post_account_edit_other_user_type_as_tu(
client: TestClient, tu_user: User, caplog: pytest.LogCaptureFixture):
client: TestClient, tu_user: User, caplog: pytest.LogCaptureFixture
):
caplog.set_level(DEBUG)
with db.begin():
@ -1202,7 +1244,7 @@ def test_post_account_edit_other_user_type_as_tu(
"U": user2.Username,
"E": user2.Email,
"T": TRUSTED_USER_ID,
"passwd": "testPassword"
"passwd": "testPassword",
}
with client as request:
resp = request.post(endpoint, data=data, cookies=cookies)
@ -1212,14 +1254,17 @@ def test_post_account_edit_other_user_type_as_tu(
assert user2.AccountTypeID == TRUSTED_USER_ID
# and also that this got logged out at DEBUG level.
expected = (f"Trusted User '{tu_user.Username}' has "
expected = (
f"Trusted User '{tu_user.Username}' has "
f"modified '{user2.Username}' account's type to"
f" {TRUSTED_USER}.")
f" {TRUSTED_USER}."
)
assert expected in caplog.text
def test_post_account_edit_other_user_type_as_tu_invalid_type(
client: TestClient, tu_user: User, caplog: pytest.LogCaptureFixture):
client: TestClient, tu_user: User, caplog: pytest.LogCaptureFixture
):
with db.begin():
user2 = create_user("test2")
@ -1227,12 +1272,7 @@ def test_post_account_edit_other_user_type_as_tu_invalid_type(
endpoint = f"/account/{user2.Username}/edit"
# As a TU, we can modify other user's account types.
data = {
"U": user2.Username,
"E": user2.Email,
"T": 0,
"passwd": "testPassword"
}
data = {"U": user2.Username, "E": user2.Email, "T": 0, "passwd": "testPassword"}
with client as request:
resp = request.post(endpoint, data=data, cookies=cookies)
assert resp.status_code == int(HTTPStatus.BAD_REQUEST)
@ -1247,8 +1287,9 @@ def test_get_account(client: TestClient, user: User):
sid = user.login(request, "testPassword")
with client as request:
response = request.get("/account/test", cookies={"AURSID": sid},
allow_redirects=False)
response = request.get(
"/account/test", cookies={"AURSID": sid}, allow_redirects=False
)
assert response.status_code == int(HTTPStatus.OK)
@ -1258,8 +1299,9 @@ def test_get_account_not_found(client: TestClient, user: User):
sid = user.login(request, "testPassword")
with client as request:
response = request.get("/account/not_found", cookies={"AURSID": sid},
allow_redirects=False)
response = request.get(
"/account/not_found", cookies={"AURSID": sid}, allow_redirects=False
)
assert response.status_code == int(HTTPStatus.NOT_FOUND)
@ -1360,8 +1402,7 @@ def test_post_accounts(client: TestClient, user: User, tu_user: User):
columns = rows[i].xpath("./td")
assert len(columns) == 7
username, atype, suspended, real_name, \
irc_nick, pgp_key, edit = columns
username, atype, suspended, real_name, irc_nick, pgp_key, edit = columns
username = next(iter(username.xpath("./a")))
assert username.text.strip() == _user.Username
@ -1379,8 +1420,10 @@ def test_post_accounts(client: TestClient, user: User, tu_user: User):
else:
assert not edit
logger.debug('Checked user row {"id": %s, "username": "%s"}.'
% (_user.ID, _user.Username))
logger.debug(
'Checked user row {"id": %s, "username": "%s"}.'
% (_user.ID, _user.Username)
)
def test_post_accounts_username(client: TestClient, user: User, tu_user: User):
@ -1389,8 +1432,7 @@ def test_post_accounts_username(client: TestClient, user: User, tu_user: User):
cookies = {"AURSID": sid}
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"U": user.Username})
response = request.post("/accounts", cookies=cookies, data={"U": user.Username})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1403,34 +1445,33 @@ def test_post_accounts_username(client: TestClient, user: User, tu_user: User):
assert username.text.strip() == user.Username
def test_post_accounts_account_type(client: TestClient, user: User,
tu_user: User):
def test_post_accounts_account_type(client: TestClient, user: User, tu_user: User):
# Check the different account type options.
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
# Make a user with the "User" role here so we can
# test the `u` parameter.
account_type = query(AccountType,
AccountType.AccountType == "User").first()
account_type = query(AccountType, AccountType.AccountType == "User").first()
with db.begin():
create(User, Username="test_2",
create(
User,
Username="test_2",
Email="test_2@example.org",
RealName="Test User 2",
Passwd="testPassword",
AccountType=account_type)
AccountType=account_type,
)
# Expect no entries; we marked our only user as a User type.
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"T": "t"})
response = request.post("/accounts", cookies=cookies, data={"T": "t"})
assert response.status_code == int(HTTPStatus.OK)
assert len(get_rows(response.text)) == 0
# So, let's also ensure that specifying "u" returns our user.
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"T": "u"})
response = request.post("/accounts", cookies=cookies, data={"T": "u"})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1443,13 +1484,12 @@ def test_post_accounts_account_type(client: TestClient, user: User,
# Set our only user to a Trusted User.
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_ID
).first()
user.AccountType = (
query(AccountType).filter(AccountType.ID == TRUSTED_USER_ID).first()
)
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"T": "t"})
response = request.post("/accounts", cookies=cookies, data={"T": "t"})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1461,13 +1501,12 @@ def test_post_accounts_account_type(client: TestClient, user: User,
assert type.text.strip() == "Trusted User"
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == DEVELOPER_ID
).first()
user.AccountType = (
query(AccountType).filter(AccountType.ID == DEVELOPER_ID).first()
)
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"T": "d"})
response = request.post("/accounts", cookies=cookies, data={"T": "d"})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1479,13 +1518,12 @@ def test_post_accounts_account_type(client: TestClient, user: User,
assert type.text.strip() == "Developer"
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
user.AccountType = (
query(AccountType).filter(AccountType.ID == TRUSTED_USER_AND_DEV_ID).first()
)
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"T": "td"})
response = request.post("/accounts", cookies=cookies, data={"T": "td"})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1517,8 +1555,7 @@ def test_post_accounts_status(client: TestClient, user: User, tu_user: User):
user.Suspended = True
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"S": True})
response = request.post("/accounts", cookies=cookies, data={"S": True})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1535,8 +1572,7 @@ def test_post_accounts_email(client: TestClient, user: User, tu_user: User):
# Search via email.
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"E": user.Email})
response = request.post("/accounts", cookies=cookies, data={"E": user.Email})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1549,8 +1585,7 @@ def test_post_accounts_realname(client: TestClient, user: User, tu_user: User):
cookies = {"AURSID": sid}
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"R": user.RealName})
response = request.post("/accounts", cookies=cookies, data={"R": user.RealName})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1563,8 +1598,7 @@ def test_post_accounts_irc(client: TestClient, user: User, tu_user: User):
cookies = {"AURSID": sid}
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"I": user.IRCNick})
response = request.post("/accounts", cookies=cookies, data={"I": user.IRCNick})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1589,22 +1623,19 @@ def test_post_accounts_sortby(client: TestClient, user: User, tu_user: User):
first_rows = rows
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"SB": "u"})
response = request.post("/accounts", cookies=cookies, data={"SB": "u"})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
assert len(rows) == 2
def compare_text_values(column, lhs, rhs):
return [row[column].text for row in lhs] \
== [row[column].text for row in rhs]
return [row[column].text for row in lhs] == [row[column].text for row in rhs]
# Test the username rows are ordered the same.
assert compare_text_values(0, first_rows, rows) is True
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"SB": "i"})
response = request.post("/accounts", cookies=cookies, data={"SB": "i"})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
assert len(rows) == 2
@ -1614,8 +1645,7 @@ def test_post_accounts_sortby(client: TestClient, user: User, tu_user: User):
# Sort by "i" -> RealName.
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"SB": "r"})
response = request.post("/accounts", cookies=cookies, data={"SB": "r"})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
assert len(rows) == 2
@ -1624,9 +1654,9 @@ def test_post_accounts_sortby(client: TestClient, user: User, tu_user: User):
assert compare_text_values(4, first_rows, reversed(rows)) is True
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
user.AccountType = (
query(AccountType).filter(AccountType.ID == TRUSTED_USER_AND_DEV_ID).first()
)
# Fetch first_rows again with our new AccountType ordering.
with client as request:
@ -1638,8 +1668,7 @@ def test_post_accounts_sortby(client: TestClient, user: User, tu_user: User):
# Sort by "t" -> AccountType.
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"SB": "t"})
response = request.post("/accounts", cookies=cookies, data={"SB": "t"})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
assert len(rows) == 2
@ -1657,8 +1686,7 @@ def test_post_accounts_pgp_key(client: TestClient, user: User, tu_user: User):
# Search via PGPKey.
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"K": user.PGPKey})
response = request.post("/accounts", cookies=cookies, data={"K": user.PGPKey})
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1668,15 +1696,17 @@ def test_post_accounts_pgp_key(client: TestClient, user: User, tu_user: User):
def test_post_accounts_paged(client: TestClient, user: User, tu_user: User):
# Create 150 users.
users = [user]
account_type = query(AccountType,
AccountType.AccountType == "User").first()
account_type = query(AccountType, AccountType.AccountType == "User").first()
with db.begin():
for i in range(150):
_user = create(User, Username=f"test_#{i}",
_user = create(
User,
Username=f"test_#{i}",
Email=f"test_#{i}@example.org",
RealName=f"Test User #{i}",
Passwd="testPassword",
AccountType=account_type)
AccountType=account_type,
)
users.append(_user)
sid = user.login(Request(), "testPassword")
@ -1709,8 +1739,9 @@ def test_post_accounts_paged(client: TestClient, user: User, tu_user: User):
assert "disabled" not in page_next.attrib
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"O": 50}) # +50 offset.
response = request.post(
"/accounts", cookies=cookies, data={"O": 50}
) # +50 offset.
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1724,8 +1755,9 @@ def test_post_accounts_paged(client: TestClient, user: User, tu_user: User):
assert username.text.strip() == _user.Username
with client as request:
response = request.post("/accounts", cookies=cookies,
data={"O": 101}) # Last page.
response = request.post(
"/accounts", cookies=cookies, data={"O": 101}
) # Last page.
assert response.status_code == int(HTTPStatus.OK)
rows = get_rows(response.text)
@ -1741,8 +1773,9 @@ def test_post_accounts_paged(client: TestClient, user: User, tu_user: User):
def test_get_terms_of_service(client: TestClient, user: User):
with db.begin():
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
term = create(
Term, Description="Test term.", URL="http://localhost", Revision=1
)
with client as request:
response = request.get("/tos", allow_redirects=False)
@ -1764,8 +1797,9 @@ def test_get_terms_of_service(client: TestClient, user: User):
assert response.status_code == int(HTTPStatus.OK)
with db.begin():
accepted_term = create(AcceptedTerm, User=user,
Term=term, Revision=term.Revision)
accepted_term = create(
AcceptedTerm, User=user, Term=term, Revision=term.Revision
)
with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1800,8 +1834,9 @@ def test_post_terms_of_service(client: TestClient, user: User):
# Create a fresh Term.
with db.begin():
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
term = create(
Term, Description="Test term.", URL="http://localhost", Revision=1
)
# Test that the term we just created is listed.
with client as request:
@ -1810,8 +1845,7 @@ def test_post_terms_of_service(client: TestClient, user: User):
# Make a POST request to /tos with the agree checkbox disabled (False).
with client as request:
response = request.post("/tos", data={"accept": False},
cookies=cookies)
response = request.post("/tos", data={"accept": False}, cookies=cookies)
assert response.status_code == int(HTTPStatus.OK)
# Make a POST request to /tos with the agree checkbox enabled (True).
@ -1820,8 +1854,7 @@ def test_post_terms_of_service(client: TestClient, user: User):
assert response.status_code == int(HTTPStatus.SEE_OTHER)
# Query the db for the record created by the post request.
accepted_term = query(AcceptedTerm,
AcceptedTerm.TermsID == term.ID).first()
accepted_term = query(AcceptedTerm, AcceptedTerm.TermsID == term.ID).first()
assert accepted_term.User == user
assert accepted_term.Term == term

View file

@ -3,16 +3,17 @@ from unittest import mock
import pytest
import aurweb.models.account_type as at
from aurweb import db
from aurweb.models import User
from aurweb.scripts import adduser
from aurweb.testing.requests import Request
TEST_SSH_PUBKEY = ("ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAI"
TEST_SSH_PUBKEY = (
"ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAI"
"bmlzdHAyNTYAAABBBEURnkiY6JoLyqDE8Li1XuAW+LHmkmLDMW/GL5wY"
"7k4/A+Ta7bjA3MOKrF9j4EuUTvCuNXULxvpfSqheTFWZc+g= "
"kevr@volcano")
"kevr@volcano"
)
@pytest.fixture(autouse=True)
@ -38,18 +39,36 @@ def test_adduser():
def test_adduser_tu():
run_main([
"-u", "test", "-e", "test@example.org", "-p", "abcd1234",
"-t", at.TRUSTED_USER
])
run_main(
[
"-u",
"test",
"-e",
"test@example.org",
"-p",
"abcd1234",
"-t",
at.TRUSTED_USER,
]
)
test = db.query(User).filter(User.Username == "test").first()
assert test is not None
assert test.AccountTypeID == at.TRUSTED_USER_ID
def test_adduser_ssh_pk():
run_main(["-u", "test", "-e", "test@example.org", "-p", "abcd1234",
"--ssh-pubkey", TEST_SSH_PUBKEY])
run_main(
[
"-u",
"test",
"-e",
"test@example.org",
"-p",
"abcd1234",
"--ssh-pubkey",
TEST_SSH_PUBKEY,
]
)
test = db.query(User).filter(User.Username == "test").first()
assert test is not None
assert TEST_SSH_PUBKEY.startswith(test.ssh_pub_keys.first().PubKey)

Some files were not shown because too many files have changed in this diff Show more