Merge branch 'db-rework' into pu

This commit is contained in:
Kevin Morris 2021-11-15 00:02:56 -08:00
commit 91b570ff0d
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
24 changed files with 233 additions and 166 deletions

View file

@ -13,7 +13,7 @@ from starlette.requests import HTTPConnection
import aurweb.config import aurweb.config
from aurweb import l10n, util from aurweb import db, l10n, util
from aurweb.models import Session, User from aurweb.models import Session, User
from aurweb.models.account_type import ACCOUNT_TYPE_ID from aurweb.models.account_type import ACCOUNT_TYPE_ID
from aurweb.templates import make_variable_context, render_template from aurweb.templates import make_variable_context, render_template
@ -98,14 +98,12 @@ class AnonymousUser:
class BasicAuthBackend(AuthenticationBackend): class BasicAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection): async def authenticate(self, conn: HTTPConnection):
from aurweb.db import session
sid = conn.cookies.get("AURSID") sid = conn.cookies.get("AURSID")
if not sid: if not sid:
return (None, AnonymousUser()) return (None, AnonymousUser())
now_ts = datetime.utcnow().timestamp() now_ts = datetime.utcnow().timestamp()
record = session.query(Session).filter( record = db.query(Session).filter(
and_(Session.SessionID == sid, and_(Session.SessionID == sid,
Session.LastUpdateTS >= now_ts)).first() Session.LastUpdateTS >= now_ts)).first()
@ -116,7 +114,7 @@ class BasicAuthBackend(AuthenticationBackend):
# At this point, we cannot have an invalid user if the record # At this point, we cannot have an invalid user if the record
# exists, due to ForeignKey constraints in the schema upheld # exists, due to ForeignKey constraints in the schema upheld
# by mysqlclient. # by mysqlclient.
user = session.query(User).filter(User.ID == record.UsersID).first() user = db.query(User).filter(User.ID == record.UsersID).first()
user.nonce = util.make_nonce() user.nonce = util.make_nonce()
user.authenticated = True user.authenticated = True

View file

@ -2,10 +2,10 @@ import functools
import math import math
import re import re
from typing import Iterable from typing import Iterable, NewType
from sqlalchemy import event from sqlalchemy import event
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import Query, scoped_session
import aurweb.config import aurweb.config
import aurweb.util import aurweb.util
@ -22,6 +22,9 @@ session = None
# Global introspected object memo. # Global introspected object memo.
introspected = dict() introspected = dict()
# A mocked up type.
Base = NewType("aurweb.models.declarative_base.Base", "Base")
def make_random_value(table: str, column: str): def make_random_value(table: str, column: str):
""" Generate a unique, random value for a string column in a table. """ Generate a unique, random value for a string column in a table.
@ -58,55 +61,69 @@ def make_random_value(table: str, column: str):
return string return string
def query(model, *args, **kwargs): def get_session():
return session.query(model).filter(*args, **kwargs) """ Return aurweb.db's global session. """
return session
def create(model, *args, **kwargs): def refresh(model: Base) -> Base:
instance = model(*args, **kwargs) """ Refresh the session's knowledge of `model`. """
get_session().refresh(model)
return model
def query(Model: Base, *args, **kwargs) -> Query:
"""
Perform an ORM query against the database session.
This method also runs Query.filter on the resulting model
query with *args and **kwargs.
:param Model: Declarative ORM class
"""
return get_session().query(Model).filter(*args, **kwargs)
def create(Model: Base, *args, **kwargs) -> Base:
"""
Create a record and add() it to the database session.
:param Model: Declarative ORM class
:return: Model instance
"""
instance = Model(*args, **kwargs)
return add(instance) return add(instance)
def delete(model, *args, **kwargs): def delete(model: Base) -> None:
instance = session.query(model).filter(*args, **kwargs) """
for record in instance: Delete a set of records found by Query.filter(*args, **kwargs).
session.delete(record)
:param Model: Declarative ORM class
"""
get_session().delete(model)
def delete_all(iterable: Iterable): def delete_all(iterable: Iterable) -> None:
with begin(): """ Delete each instance found in `iterable`. """
for obj in iterable: session_ = get_session()
session.delete(obj) aurweb.util.apply_all(iterable, session_.delete)
def rollback(): def rollback() -> None:
session.rollback() """ Rollback the database session. """
get_session().rollback()
def add(model): def add(model: Base) -> Base:
session.add(model) """ Add `model` to the database session. """
get_session().add(model)
return model return model
def begin(): def begin():
""" Begin an SQLAlchemy SessionTransaction. """ Begin an SQLAlchemy SessionTransaction. """
return get_session().begin()
This context is **required** to perform an modifications to the
database.
Example:
with db.begin():
object = db.create(...)
# On __exit__, db.commit() is run.
with db.begin():
object = db.delete(...)
# On __exit__, db.commit() is run.
:return: A new SessionTransaction based on session
"""
return session.begin()
def get_sqlalchemy_url(): def get_sqlalchemy_url():

View file

@ -1,6 +1,6 @@
from fastapi import Request from fastapi import Request
from aurweb import schema from aurweb import db, schema
from aurweb.models.declarative import Base from aurweb.models.declarative import Base
@ -10,11 +10,10 @@ class Ban(Base):
__mapper_args__ = {"primary_key": [__table__.c.IPAddress]} __mapper_args__ = {"primary_key": [__table__.c.IPAddress]}
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.IPAddress = kwargs.get("IPAddress") super().__init__(**kwargs)
self.BanTS = kwargs.get("BanTS")
def is_banned(request: Request): def is_banned(request: Request):
from aurweb.db import session
ip = request.client.host ip = request.client.host
return session.query(Ban).filter(Ban.IPAddress == ip).first() is not None exists = db.query(Ban).filter(Ban.IPAddress == ip).exists()
return db.query(exists).scalar()

View file

@ -1,9 +1,10 @@
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import backref, relationship from sqlalchemy.orm import backref, relationship
from aurweb import schema from aurweb import db, schema
from aurweb.models.declarative import Base from aurweb.models.declarative import Base
from aurweb.models.dependency_type import DependencyType as _DependencyType from aurweb.models.dependency_type import DependencyType as _DependencyType
from aurweb.models.official_provider import OfficialProvider as _OfficialProvider
from aurweb.models.package import Package as _Package from aurweb.models.package import Package as _Package
@ -46,11 +47,7 @@ class PackageDependency(Base):
params=("NULL")) params=("NULL"))
def is_package(self) -> bool: def is_package(self) -> bool:
# TODO: Improve the speed of this query if possible. pkg = db.query(_Package).filter(_Package.Name == self.DepName).exists()
from aurweb import db official = db.query(_OfficialProvider).filter(
from aurweb.models.official_provider import OfficialProvider _OfficialProvider.Name == self.DepName).exists()
from aurweb.models.package import Package return db.query(pkg).scalar() or db.query(official).scalar()
pkg = db.query(Package, Package.Name == self.DepName)
official = db.query(OfficialProvider,
OfficialProvider.Name == self.DepName)
return pkg.scalar() or official.scalar()

View file

@ -146,7 +146,7 @@ class User(Base):
self.authenticated = False self.authenticated = False
if self.session: if self.session:
with db.begin(): with db.begin():
db.session.delete(self.session) db.delete(self.session)
def is_trusted_user(self): def is_trusted_user(self):
return self.AccountType.ID in { return self.AccountType.ID in {

View file

@ -110,18 +110,26 @@ def get_pkg_or_base(
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
instance = db.query(cls).filter(cls.Name == name).first() instance = db.query(cls).filter(cls.Name == name).first()
if cls == models.PackageBase and not instance: if not instance:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return instance return db.refresh(instance)
def get_pkgbase_comment( def get_pkgbase_comment(pkgbase: models.PackageBase, id: int) \
pkgbase: models.PackageBase, id: int) -> models.PackageComment: -> models.PackageComment:
comment = pkgbase.comments.filter(models.PackageComment.ID == id).first() comment = pkgbase.comments.filter(models.PackageComment.ID == id).first()
if not comment: if not comment:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return comment return db.refresh(comment)
def get_pkgreq_by_id(id: int):
pkgreq = db.query(models.PackageRequest).filter(
models.PackageRequest.ID == id).first()
if not pkgreq:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return db.refresh(pkgreq)
@register_filter("out_of_date") @register_filter("out_of_date")

View file

@ -40,8 +40,10 @@ def _update_ratelimit_db(request: Request):
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
time_to_delete = now - window_length time_to_delete = now - window_length
records = db.query(ApiRateLimit).filter(
ApiRateLimit.WindowStart < time_to_delete)
with db.begin(): with db.begin():
db.delete(ApiRateLimit, ApiRateLimit.WindowStart < time_to_delete) db.delete_all(records)
host = request.client.host host = request.client.host
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first() record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()

View file

@ -4,7 +4,7 @@ import typing
from datetime import datetime from datetime import datetime
from http import HTTPStatus from http import HTTPStatus
from fastapi import APIRouter, Form, HTTPException, Request from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.responses import HTMLResponse, RedirectResponse
from sqlalchemy import and_, func, or_ from sqlalchemy import and_, func, or_
@ -20,6 +20,7 @@ from aurweb.models.account_type import (DEVELOPER, DEVELOPER_ID, TRUSTED_USER, T
from aurweb.models.ssh_pub_key import get_fingerprint from aurweb.models.ssh_pub_key import get_fingerprint
from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification
from aurweb.templates import make_context, make_variable_context, render_template from aurweb.templates import make_context, make_variable_context, render_template
from aurweb.users.util import get_user_by_name
router = APIRouter() router = APIRouter()
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -49,6 +50,7 @@ async def passreset_post(request: Request,
return render_template(request, "passreset.html", context, return render_template(request, "passreset.html", context,
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
db.refresh(user)
if resetkey: if resetkey:
context["resetkey"] = resetkey context["resetkey"] = resetkey
@ -83,7 +85,7 @@ async def passreset_post(request: Request,
with db.begin(): with db.begin():
user.ResetKey = str() user.ResetKey = str()
if user.session: if user.session:
db.session.delete(user.session) db.delete(user.session)
user.update_password(password) user.update_password(password)
# Render ?step=complete. # Render ?step=complete.
@ -458,15 +460,15 @@ def cannot_edit(request, user):
@router.get("/account/{username}/edit", response_class=HTMLResponse) @router.get("/account/{username}/edit", response_class=HTMLResponse)
@auth_required(True, redirect="/account/{username}") @auth_required(True, redirect="/account/{username}")
async def account_edit(request: Request, async def account_edit(request: Request, username: str):
username: str):
user = db.query(models.User, models.User.Username == username).first() user = db.query(models.User, models.User.Username == username).first()
response = cannot_edit(request, user) response = cannot_edit(request, user)
if response: if response:
return response return response
context = await make_variable_context(request, "Accounts") context = await make_variable_context(request, "Accounts")
context["user"] = user context["user"] = db.refresh(user)
context = make_account_form_context(context, request, user, dict()) context = make_account_form_context(context, request, user, dict())
return render_template(request, "account/edit.html", context) return render_template(request, "account/edit.html", context)
@ -497,16 +499,14 @@ async def account_edit_post(request: Request,
ON: bool = Form(default=False), # Owner Notify ON: bool = Form(default=False), # Owner Notify
T: int = Form(default=None), T: int = Form(default=None),
passwd: str = Form(default=str())): passwd: str = Form(default=str())):
from aurweb.db import session user = db.query(models.User).filter(
user = session.query(models.User).filter(
models.User.Username == username).first() models.User.Username == username).first()
response = cannot_edit(request, user) response = cannot_edit(request, user)
if response: if response:
return response return response
context = await make_variable_context(request, "Accounts") context = await make_variable_context(request, "Accounts")
context["user"] = user context["user"] = db.refresh(user)
args = dict(await request.form()) args = dict(await request.form())
context = make_account_form_context(context, request, user, args) context = make_account_form_context(context, request, user, args)
@ -575,7 +575,7 @@ async def account_edit_post(request: Request,
user.ssh_pub_key.Fingerprint = fingerprint user.ssh_pub_key.Fingerprint = fingerprint
elif user.ssh_pub_key: elif user.ssh_pub_key:
# Else, if the user has a public key already, delete it. # Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key) db.delete(user.ssh_pub_key)
if T and T != user.AccountTypeID: if T and T != user.AccountTypeID:
with db.begin(): with db.begin():
@ -617,27 +617,16 @@ account_template = (
status_code=HTTPStatus.UNAUTHORIZED) status_code=HTTPStatus.UNAUTHORIZED)
async def account(request: Request, username: str): async def account(request: Request, username: str):
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
context = await make_variable_context(request, context = await make_variable_context(
_("Account") + " " + username) request, _("Account") + " " + username)
context["user"] = get_user_by_name(username)
user = db.query(models.User, models.User.Username == username).first()
if not user:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
context["user"] = user
return render_template(request, "account/show.html", context) return render_template(request, "account/show.html", context)
@router.get("/account/{username}/comments") @router.get("/account/{username}/comments")
@auth_required(redirect="/account/{username}/comments") @auth_required(redirect="/account/{username}/comments")
async def account_comments(request: Request, username: str): async def account_comments(request: Request, username: str):
user = db.query(models.User).filter( user = get_user_by_name(username)
models.User.Username == username
).first()
if not user:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
context = make_context(request, "Accounts") context = make_context(request, "Accounts")
context["username"] = username context["username"] = username
context["comments"] = user.package_comments.order_by( context["comments"] = user.package_comments.order_by(
@ -662,7 +651,7 @@ async def accounts(request: Request):
account_type.TRUSTED_USER_AND_DEV}) account_type.TRUSTED_USER_AND_DEV})
async def accounts_post(request: Request, async def accounts_post(request: Request,
O: int = Form(default=0), # Offset O: int = Form(default=0), # Offset
SB: str = Form(default=str()), # Search By SB: str = Form(default=str()), # Sort By
U: str = Form(default=str()), # Username U: str = Form(default=str()), # Username
T: str = Form(default=str()), # Account Type T: str = Form(default=str()), # Account Type
S: bool = Form(default=False), # Suspended S: bool = Form(default=False), # Suspended
@ -705,23 +694,19 @@ async def accounts_post(request: Request,
# Populate this list with any additional statements to # Populate this list with any additional statements to
# be ANDed together. # be ANDed together.
statements = [] statements = [
if account_type_id is not None: v for k, v in [
statements.append(models.AccountType.ID == account_type_id) (account_type_id is not None, models.AccountType.ID == account_type_id),
if U: (bool(U), models.User.Username.like(f"%{U}%")),
statements.append(models.User.Username.like(f"%{U}%")) (bool(S), models.User.Suspended == S),
if S: (bool(E), models.User.Email.like(f"%{E}%")),
statements.append(models.User.Suspended == S) (bool(R), models.User.RealName.like(f"%{R}%")),
if E: (bool(I), models.User.IRCNick.like(f"%{I}%")),
statements.append(models.User.Email.like(f"%{E}%")) (bool(K), models.User.PGPKey.like(f"%{K}%")),
if R: ] if k
statements.append(models.User.RealName.like(f"%{R}%")) ]
if I:
statements.append(models.User.IRCNick.like(f"%{I}%"))
if K:
statements.append(models.User.PGPKey.like(f"%{K}%"))
# Filter the query by combining all statements added above into # Filter the query by coe-mbining all statements added above into
# an AND statement, unless there's just one statement, which # an AND statement, unless there's just one statement, which
# we pass on to filter() as args. # we pass on to filter() as args.
if statements: if statements:
@ -729,7 +714,7 @@ async def accounts_post(request: Request,
# Finally, order and truncate our users for the current page. # Finally, order and truncate our users for the current page.
users = query.order_by(*order_by).limit(pp).offset(offset) users = query.order_by(*order_by).limit(pp).offset(offset)
context["users"] = users context["users"] = util.apply_all(users, db.refresh)
return render_template(request, "account/index.html", context) return render_template(request, "account/index.html", context)
@ -755,6 +740,9 @@ async def terms_of_service(request: Request):
unaccepted = db.query(models.Term).filter( unaccepted = db.query(models.Term).filter(
~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all() ~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all()
for record in (diffs + unaccepted):
db.refresh(record)
# Translate the 'Terms of Service' part of our page title. # Translate the 'Terms of Service' part of our page title.
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
title = f"AUR {_('Terms of Service')}" title = f"AUR {_('Terms of Service')}"
@ -786,18 +774,21 @@ async def terms_of_service_post(request: Request,
# We already did the database filters here, so let's just use # We already did the database filters here, so let's just use
# them instead of reiterating the process in terms_of_service. # them instead of reiterating the process in terms_of_service.
accept_needed = sorted(unaccepted + diffs) accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(request, context, accept_needed) return render_terms_of_service(
request, context, util.apply_all(accept_needed, db.refresh))
with db.begin(): with db.begin():
# For each term we found, query for the matching accepted term # For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision. # and update its Revision to the term's current Revision.
for term in diffs: for term in diffs:
db.refresh(term)
accepted_term = request.user.accepted_terms.filter( accepted_term = request.user.accepted_terms.filter(
models.AcceptedTerm.TermsID == term.ID).first() models.AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it! # For each term that was never accepted, accept it!
for term in unaccepted: for term in unaccepted:
db.refresh(term)
db.create(models.AcceptedTerm, User=request.user, db.create(models.AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision) Term=term, Revision=term.Revision)

View file

@ -4,7 +4,7 @@ from typing import Any, Dict, List
from fastapi import APIRouter, Form, HTTPException, Query, Request, Response from fastapi import APIRouter, Form, HTTPException, Query, Request, Response
from fastapi.responses import JSONResponse, RedirectResponse from fastapi.responses import JSONResponse, RedirectResponse
from sqlalchemy import and_, case from sqlalchemy import case
import aurweb.filters import aurweb.filters
import aurweb.packages.util import aurweb.packages.util
@ -15,9 +15,9 @@ from aurweb.models.package_request import ACCEPTED_ID, PENDING_ID, REJECTED_ID
from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID
from aurweb.models.request_type import DELETION_ID, MERGE, MERGE_ID from aurweb.models.request_type import DELETION_ID, MERGE, MERGE_ID
from aurweb.packages.search import PackageSearch from aurweb.packages.search import PackageSearch
from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, query_notified, query_voted from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, get_pkgreq_by_id, query_notified, query_voted
from aurweb.scripts import notify, popupdate from aurweb.scripts import notify, popupdate
from aurweb.scripts.rendercomment import update_comment_render from aurweb.scripts.rendercomment import update_comment_render_fastapi
from aurweb.templates import make_context, make_variable_context, render_raw_template, render_template from aurweb.templates import make_context, make_variable_context, render_raw_template, render_template
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -92,7 +92,10 @@ async def packages_get(request: Request, context: Dict[str, Any],
# Insert search results into the context. # Insert search results into the context.
results = search.results() results = search.results()
context["packages"] = results.limit(per_page).offset(offset)
packages = results.limit(per_page).offset(offset)
util.apply_all(packages, db.refresh)
context["packages"] = packages
context["packages_voted"] = query_voted( context["packages_voted"] = query_voted(
context.get("packages"), request.user) context.get("packages"), request.user)
context["packages_notified"] = query_notified( context["packages_notified"] = query_notified(
@ -132,6 +135,7 @@ def create_request_if_missing(requests: List[models.PackageRequest],
ClosedTS=now, ClosedTS=now,
Closer=user) Closer=user)
requests.append(pkgreq) requests.append(pkgreq)
return pkgreq
def delete_package(deleter: models.User, package: models.Package): def delete_package(deleter: models.User, package: models.Package):
@ -147,8 +151,9 @@ def delete_package(deleter: models.User, package: models.Package):
).first() ).first()
with db.begin(): with db.begin():
create_request_if_missing( pkgreq = create_request_if_missing(
requests, reqtype, deleter, package) requests, reqtype, deleter, package)
db.refresh(pkgreq)
bases_to_delete.append(package.PackageBase) bases_to_delete.append(package.PackageBase)
@ -171,7 +176,8 @@ def delete_package(deleter: models.User, package: models.Package):
) )
# Perform all the deletions. # Perform all the deletions.
db.delete_all([package]) with db.begin():
db.delete(package)
db.delete_all(bases_to_delete) db.delete_all(bases_to_delete)
# Send out all the notifications. # Send out all the notifications.
@ -221,8 +227,7 @@ async def make_single_context(request: Request,
async def package(request: Request, name: str) -> Response: async def package(request: Request, name: str) -> Response:
# Get the Package. # Get the Package.
pkg = get_pkg_or_base(name, models.Package) pkg = get_pkg_or_base(name, models.Package)
pkgbase = (get_pkg_or_base(name, models.PackageBase) pkgbase = pkg.PackageBase
if not pkg else pkg.PackageBase)
# Add our base information. # Add our base information.
context = await make_single_context(request, pkgbase) context = await make_single_context(request, pkgbase)
@ -312,7 +317,7 @@ async def pkgbase_comments_post(
db.create(models.PackageNotification, db.create(models.PackageNotification,
User=request.user, User=request.user,
PackageBase=pkgbase) PackageBase=pkgbase)
update_comment_render(comment.ID) update_comment_render_fastapi(comment)
# Redirect to the pkgbase page. # Redirect to the pkgbase page.
return RedirectResponse(f"/pkgbase/{pkgbase.Name}#comment-{comment.ID}", return RedirectResponse(f"/pkgbase/{pkgbase.Name}#comment-{comment.ID}",
@ -374,7 +379,7 @@ async def pkgbase_comment_post(
db.create(models.PackageNotification, db.create(models.PackageNotification,
User=request.user, User=request.user,
PackageBase=pkgbase) PackageBase=pkgbase)
update_comment_render(db_comment.ID) update_comment_render_fastapi(db_comment)
if not next: if not next:
next = f"/pkgbase/{pkgbase.Name}" next = f"/pkgbase/{pkgbase.Name}"
@ -539,7 +544,7 @@ def remove_users(pkgbase, usernames):
conn, comaintainer.User.ID, pkgbase.ID conn, comaintainer.User.ID, pkgbase.ID
) )
) )
db.session.delete(comaintainer) db.delete(comaintainer)
# Send out notifications if need be. # Send out notifications if need be.
for notify_ in notifications: for notify_ in notifications:
@ -679,14 +684,8 @@ async def requests(request: Request,
@router.get("/pkgbase/{name}/request") @router.get("/pkgbase/{name}/request")
@auth_required(True, redirect="/pkgbase/{name}/request") @auth_required(True, redirect="/pkgbase/{name}/request")
async def package_request(request: Request, name: str): async def package_request(request: Request, name: str):
pkgbase = get_pkg_or_base(name, models.PackageBase)
context = await make_variable_context(request, "Submit Request") context = await make_variable_context(request, "Submit Request")
pkgbase = db.query(models.PackageBase).filter(
models.PackageBase.Name == name).first()
if not pkgbase:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
return render_template(request, "pkgbase/request.html", context) return render_template(request, "pkgbase/request.html", context)
@ -729,6 +728,7 @@ async def pkgbase_request_post(request: Request, name: str,
] ]
return render_template(request, "pkgbase/request.html", context) return render_template(request, "pkgbase/request.html", context)
db.refresh(target)
if target.ID == pkgbase.ID: if target.ID == pkgbase.ID:
# TODO: This error needs to be translated. # TODO: This error needs to be translated.
context["errors"] = [ context["errors"] = [
@ -767,8 +767,7 @@ async def pkgbase_request_post(request: Request, name: str,
@router.get("/requests/{id}/close") @router.get("/requests/{id}/close")
@auth_required(True, redirect="/requests/{id}/close") @auth_required(True, redirect="/requests/{id}/close")
async def requests_close(request: Request, id: int): async def requests_close(request: Request, id: int):
pkgreq = db.query(models.PackageRequest).filter( pkgreq = get_pkgreq_by_id(id)
models.PackageRequest.ID == id).first()
if not request.user.is_elevated() and request.user != pkgreq.User: if not request.user.is_elevated() and request.user != pkgreq.User:
# Request user doesn't have permission here: redirect to '/'. # Request user doesn't have permission here: redirect to '/'.
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
@ -783,8 +782,7 @@ async def requests_close(request: Request, id: int):
async def requests_close_post(request: Request, id: int, async def requests_close_post(request: Request, id: int,
reason: int = Form(default=0), reason: int = Form(default=0),
comments: str = Form(default=str())): comments: str = Form(default=str())):
pkgreq = db.query(models.PackageRequest).filter( pkgreq = get_pkgreq_by_id(id)
models.PackageRequest.ID == id).first()
if not request.user.is_elevated() and request.user != pkgreq.User: if not request.user.is_elevated() and request.user != pkgreq.User:
# Request user doesn't have permission here: redirect to '/'. # Request user doesn't have permission here: redirect to '/'.
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
@ -823,13 +821,17 @@ async def pkgbase_keywords(request: Request, name: str,
keywords = set(keywords.split(" ")) keywords = set(keywords.split(" "))
# Delete all keywords which are not supplied by the user. # Delete all keywords which are not supplied by the user.
with db.begin(): other_keywords = pkgbase.keywords.filter(
db.delete(models.PackageKeyword, ~models.PackageKeyword.Keyword.in_(keywords))
and_(models.PackageKeyword.PackageBaseID == pkgbase.ID, other_keyword_strings = [kwd.Keyword for kwd in other_keywords]
~models.PackageKeyword.Keyword.in_(keywords)))
existing_keywords = set(kwd.Keyword for kwd in pkgbase.keywords.all()) existing_keywords = set(
kwd.Keyword for kwd in
pkgbase.keywords.filter(
~models.PackageKeyword.Keyword.in_(other_keyword_strings))
)
with db.begin(): with db.begin():
db.delete_all(other_keywords)
for keyword in keywords.difference(existing_keywords): for keyword in keywords.difference(existing_keywords):
db.create(models.PackageKeyword, db.create(models.PackageKeyword,
PackageBase=pkgbase, PackageBase=pkgbase,
@ -940,7 +942,7 @@ def pkgbase_unnotify_instance(request: Request, pkgbase: models.PackageBase):
has_cred = request.user.has_credential("CRED_PKGBASE_NOTIFY") has_cred = request.user.has_credential("CRED_PKGBASE_NOTIFY")
if has_cred and notif: if has_cred and notif:
with db.begin(): with db.begin():
db.session.delete(notif) db.delete(notif)
@router.post("/pkgbase/{name}/unnotify") @router.post("/pkgbase/{name}/unnotify")
@ -988,7 +990,7 @@ async def pkgbase_unvote(request: Request, name: str):
has_cred = request.user.has_credential("CRED_PKGBASE_VOTE") has_cred = request.user.has_credential("CRED_PKGBASE_VOTE")
if has_cred and vote: if has_cred and vote:
with db.begin(): with db.begin():
db.session.delete(vote) db.delete(vote)
# Update NumVotes/Popularity. # Update NumVotes/Popularity.
conn = db.ConnectionExecutor(db.get_engine().raw_connection()) conn = db.ConnectionExecutor(db.get_engine().raw_connection())
@ -1015,7 +1017,7 @@ def pkgbase_disown_instance(request: Request, pkgbase: models.PackageBase):
if co: if co:
with db.begin(): with db.begin():
pkgbase.Maintainer = co.User pkgbase.Maintainer = co.User
db.session.delete(co) db.delete(co)
else: else:
pkgbase.Maintainer = None pkgbase.Maintainer = None
@ -1463,8 +1465,8 @@ def pkgbase_merge_instance(request: Request, pkgbase: models.PackageBase,
with db.begin(): with db.begin():
# Delete pkgbase and its packages now that everything's merged. # Delete pkgbase and its packages now that everything's merged.
for pkg in pkgbase.packages: for pkg in pkgbase.packages:
db.session.delete(pkg) db.delete(pkg)
db.session.delete(pkgbase) db.delete(pkgbase)
# Accept merge requests related to this pkgbase and target. # Accept merge requests related to this pkgbase and target.
for pkgreq in requests: for pkgreq in requests:

View file

@ -1,5 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List from typing import Any, Callable, Dict, List, NewType
from sqlalchemy import and_ from sqlalchemy import and_
@ -25,6 +25,10 @@ REL_TYPES = {
} }
DataGenerator = NewType("DataGenerator",
Callable[[models.Package], Dict[str, Any]])
class RPCError(Exception): class RPCError(Exception):
pass pass
@ -188,15 +192,32 @@ class RPC:
self._update_json_relations(package, data) self._update_json_relations(package, data)
return data return data
def _handle_multiinfo_type(self, args: List[str] = [], **kwargs): def _assemble_json_data(self, packages: List[models.Package],
data_generator: DataGenerator) \
-> List[Dict[str, Any]]:
"""
Assemble JSON data out of a list of packages.
:param packages: A list of Package instances or a Package ORM query
:param data_generator: Generator callable of single-Package JSON data
"""
output = []
for pkg in packages:
db.refresh(pkg)
output.append(data_generator(pkg))
return output
def _handle_multiinfo_type(self, args: List[str] = [], **kwargs) \
-> List[Dict[str, Any]]:
self._enforce_args(args) self._enforce_args(args)
args = set(args) args = set(args)
packages = db.query(models.Package).filter( packages = db.query(models.Package).filter(
models.Package.Name.in_(args)) models.Package.Name.in_(args))
return [self._get_info_json_data(pkg) for pkg in packages] return self._assemble_json_data(packages, self._get_info_json_data)
def _handle_search_type(self, by: str = defaults.RPC_SEARCH_BY, def _handle_search_type(self, by: str = defaults.RPC_SEARCH_BY,
args: List[str] = []): args: List[str] = []) \
-> List[Dict[str, Any]]:
# If `by` isn't maintainer and we don't have any args, raise an error. # If `by` isn't maintainer and we don't have any args, raise an error.
# In maintainer's case, return all orphans if there are no args, # In maintainer's case, return all orphans if there are no args,
# so we need args to pass through to the handler without errors. # so we need args to pass through to the handler without errors.
@ -212,7 +233,7 @@ class RPC:
max_results = config.getint("options", "max_rpc_results") max_results = config.getint("options", "max_rpc_results")
results = search.results().limit(max_results) results = search.results().limit(max_results)
return [self._get_json_data(pkg) for pkg in results] return self._assemble_json_data(results, self._get_json_data)
def _handle_msearch_type(self, args: List[str] = [], **kwargs): def _handle_msearch_type(self, args: List[str] = [], **kwargs):
return self._handle_search_type(by="m", args=args) return self._handle_search_type(by="m", args=args)

View file

@ -29,7 +29,7 @@ def run_single(conn, pkgbase):
conn.commit() conn.commit()
conn.close() conn.close()
aurweb.db.session.refresh(pkgbase) aurweb.db.refresh(pkgbase)
def main(): def main():

View file

@ -129,9 +129,14 @@ def save_rendered_comment(conn, commentid, html):
[html, commentid]) [html, commentid])
def update_comment_render(commentid): def update_comment_render_fastapi(comment):
conn = aurweb.db.Connection() conn = aurweb.db.ConnectionExecutor(
aurweb.db.get_engine().raw_connection())
update_comment_render(conn, comment.ID)
aurweb.db.refresh(comment)
def update_comment_render(conn, commentid):
text, pkgbase = get_comment(conn, commentid) text, pkgbase = get_comment(conn, commentid)
html = markdown.markdown(text, extensions=[ html = markdown.markdown(text, extensions=[
'fenced_code', 'fenced_code',
@ -152,7 +157,8 @@ def update_comment_render(commentid):
def main(): def main():
commentid = int(sys.argv[1]) commentid = int(sys.argv[1])
update_comment_render(commentid) conn = aurweb.db.Connection()
update_comment_render(conn, commentid)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -19,7 +19,7 @@ def references_graph(table):
"regexp_1": r'(?i)\s+references\s+("|\')?', "regexp_1": r'(?i)\s+references\s+("|\')?',
"regexp_2": r'("|\')?\s*\(', "regexp_2": r'("|\')?\s*\(',
} }
cursor = aurweb.db.session.execute(query, params=params) cursor = aurweb.db.get_session().execute(query, params=params)
return [row[0] for row in cursor.fetchall()] return [row[0] for row in cursor.fetchall()]
@ -51,7 +51,7 @@ def setup_test_db(*args):
db_backend = aurweb.config.get("database", "backend") db_backend = aurweb.config.get("database", "backend")
if db_backend != "sqlite": # pragma: no cover if db_backend != "sqlite": # pragma: no cover
aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 0") aurweb.db.get_session().execute("SET FOREIGN_KEY_CHECKS = 0")
else: else:
# We're using sqlite, setup tables to be deleted without violating # We're using sqlite, setup tables to be deleted without violating
# foreign key constraints by graphing references. # foreign key constraints by graphing references.
@ -59,10 +59,10 @@ def setup_test_db(*args):
references_graph(table) for table in tables)) references_graph(table) for table in tables))
for table in tables: for table in tables:
aurweb.db.session.execute(f"DELETE FROM {table}") aurweb.db.get_session().execute(f"DELETE FROM {table}")
if db_backend != "sqlite": # pragma: no cover if db_backend != "sqlite": # pragma: no cover
aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 1") aurweb.db.get_session().execute("SET FOREIGN_KEY_CHECKS = 1")
# Expunge all objects from SQLAlchemy's IdentityMap. # Expunge all objects from SQLAlchemy's IdentityMap.
aurweb.db.session.expunge_all() aurweb.db.get_session().expunge_all()

0
aurweb/users/__init__.py Normal file
View file

19
aurweb/users/util.py Normal file
View file

@ -0,0 +1,19 @@
from http import HTTPStatus
from fastapi import HTTPException
from aurweb import db
from aurweb.models import User
def get_user_by_name(username: str) -> User:
"""
Query a user by its username.
:param username: User.Username
:return: User instance
"""
user = db.query(User).filter(User.Username == username).first()
if not user:
raise HTTPException(status_code=int(HTTPStatus.NOT_FOUND))
return db.refresh(user)

View file

@ -155,6 +155,7 @@ def get_ssh_fingerprints():
def apply_all(iterable: Iterable, fn: Callable): def apply_all(iterable: Iterable, fn: Callable):
for item in iterable: for item in iterable:
fn(item) fn(item)
return iterable
def sanitize_params(offset: str, per_page: str) -> Tuple[int, int]: def sanitize_params(offset: str, per_page: str) -> Tuple[int, int]:

View file

@ -132,11 +132,11 @@
<tr> <tr>
<th>{{ "Votes" | tr }}:</th> <th>{{ "Votes" | tr }}:</th>
{% if not is_maintainer %} {% if not is_maintainer %}
<td>{{ pkgbase.package_votes.count() }}</td> <td>{{ pkgbase.NumVotes }}</td>
{% else %} {% else %}
<td> <td>
<a href="/pkgbase/{{ pkgbase.Name }}/voters"> <a href="/pkgbase/{{ pkgbase.Name }}/voters">
{{ pkgbase.package_votes.count() }} {{ pkgbase.NumVotes }}
</a> </a>
</td> </td>
{% endif %} {% endif %}

View file

@ -20,7 +20,7 @@ def setup():
yield account_type yield account_type
with begin(): with begin():
delete(AccountType, AccountType.ID == account_type.ID) delete(account_type)
def test_account_type(): def test_account_type():
@ -50,4 +50,4 @@ def test_user_account_type_relationship():
# This must be deleted here to avoid foreign key issues when # This must be deleted here to avoid foreign key issues when
# deleting the temporary AccountType in the fixture. # deleting the temporary AccountType in the fixture.
with begin(): with begin():
delete(User, User.ID == user.ID) delete(user)

View file

@ -279,13 +279,13 @@ def test_connection_execute_paramstyle_unsupported():
def test_create_delete(): def test_create_delete():
with db.begin(): with db.begin():
db.create(AccountType, AccountType="test") account_type = db.create(AccountType, AccountType="test")
record = db.query(AccountType, AccountType.AccountType == "test").first() record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is not None assert record is not None
with db.begin(): with db.begin():
db.delete(AccountType, AccountType.AccountType == "test") db.delete(account_type)
record = db.query(AccountType, AccountType.AccountType == "test").first() record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is None assert record is None
@ -306,7 +306,7 @@ def test_add_commit():
# Remove the record. # Remove the record.
with db.begin(): with db.begin():
db.delete(AccountType, AccountType.ID == account_type.ID) db.delete(account_type)
def test_connection_executor_mysql_paramstyle(): def test_connection_executor_mysql_paramstyle():

View file

@ -24,7 +24,7 @@ def test_dependency_type_creation():
assert bool(dependency_type.ID) assert bool(dependency_type.ID)
assert dependency_type.Name == "Test Type" assert dependency_type.Name == "Test Type"
with begin(): with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID) delete(dependency_type)
def test_dependency_type_null_name_uses_default(): def test_dependency_type_null_name_uses_default():
@ -32,4 +32,4 @@ def test_dependency_type_null_name_uses_default():
dependency_type = create(DependencyType) dependency_type = create(DependencyType)
assert dependency_type.Name == str() assert dependency_type.Name == str()
with begin(): with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID) delete(dependency_type)

View file

@ -2,6 +2,7 @@ from datetime import datetime
import pytest import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from aurweb import asgi, db from aurweb import asgi, db
@ -98,3 +99,8 @@ def test_query_notified(maintainer: User, package: Package):
query = db.query(Package).filter(Package.ID == package.ID).all() query = db.query(Package).filter(Package.ID == package.ID).all()
query_notified = util.query_notified(query, maintainer) query_notified = util.query_notified(query, maintainer)
assert query_notified[package.PackageBase.ID] assert query_notified[package.PackageBase.ID]
def test_pkgreq_by_id_not_found():
with pytest.raises(HTTPException):
util.get_pkgreq_by_id(0)

View file

@ -103,7 +103,7 @@ def test_ratelimit_db(get: mock.MagicMock, getboolean: mock.MagicMock,
# Delete the ApiRateLimit record. # Delete the ApiRateLimit record.
with db.begin(): with db.begin():
db.delete(ApiRateLimit) db.delete(db.query(ApiRateLimit).first())
# Should be good to go again! # Should be good to go again!
assert not check_ratelimit(request) assert not check_ratelimit(request)

View file

@ -18,7 +18,7 @@ def test_relation_type_creation():
assert relation_type.Name == "test-relation" assert relation_type.Name == "test-relation"
with db.begin(): with db.begin():
db.delete(RelationType, RelationType.ID == relation_type.ID) db.delete(relation_type)
def test_relation_types(): def test_relation_types():

View file

@ -18,7 +18,7 @@ def test_request_type_creation():
assert request_type.Name == "Test Request" assert request_type.Name == "Test Request"
with db.begin(): with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID) db.delete(request_type)
def test_request_type_null_name_returns_empty_string(): def test_request_type_null_name_returns_empty_string():
@ -29,7 +29,7 @@ def test_request_type_null_name_returns_empty_string():
assert request_type.Name == str() assert request_type.Name == str()
with db.begin(): with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID) db.delete(request_type)
def test_request_type_name_display(): def test_request_type_name_display():