Merge branch 'db-rework' into pu

This commit is contained in:
Kevin Morris 2021-11-15 00:02:56 -08:00
commit 91b570ff0d
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
24 changed files with 233 additions and 166 deletions

View file

@ -13,7 +13,7 @@ from starlette.requests import HTTPConnection
import aurweb.config
from aurweb import l10n, util
from aurweb import db, l10n, util
from aurweb.models import Session, User
from aurweb.models.account_type import ACCOUNT_TYPE_ID
from aurweb.templates import make_variable_context, render_template
@ -98,14 +98,12 @@ class AnonymousUser:
class BasicAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection):
from aurweb.db import session
sid = conn.cookies.get("AURSID")
if not sid:
return (None, AnonymousUser())
now_ts = datetime.utcnow().timestamp()
record = session.query(Session).filter(
record = db.query(Session).filter(
and_(Session.SessionID == sid,
Session.LastUpdateTS >= now_ts)).first()
@ -116,7 +114,7 @@ class BasicAuthBackend(AuthenticationBackend):
# At this point, we cannot have an invalid user if the record
# exists, due to ForeignKey constraints in the schema upheld
# by mysqlclient.
user = session.query(User).filter(User.ID == record.UsersID).first()
user = db.query(User).filter(User.ID == record.UsersID).first()
user.nonce = util.make_nonce()
user.authenticated = True

View file

@ -2,10 +2,10 @@ import functools
import math
import re
from typing import Iterable
from typing import Iterable, NewType
from sqlalchemy import event
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import Query, scoped_session
import aurweb.config
import aurweb.util
@ -22,6 +22,9 @@ session = None
# Global introspected object memo.
introspected = dict()
# A mocked up type.
Base = NewType("aurweb.models.declarative_base.Base", "Base")
def make_random_value(table: str, column: str):
""" Generate a unique, random value for a string column in a table.
@ -58,55 +61,69 @@ def make_random_value(table: str, column: str):
return string
def query(model, *args, **kwargs):
return session.query(model).filter(*args, **kwargs)
def get_session():
""" Return aurweb.db's global session. """
return session
def create(model, *args, **kwargs):
instance = model(*args, **kwargs)
def refresh(model: Base) -> Base:
""" Refresh the session's knowledge of `model`. """
get_session().refresh(model)
return model
def query(Model: Base, *args, **kwargs) -> Query:
"""
Perform an ORM query against the database session.
This method also runs Query.filter on the resulting model
query with *args and **kwargs.
:param Model: Declarative ORM class
"""
return get_session().query(Model).filter(*args, **kwargs)
def create(Model: Base, *args, **kwargs) -> Base:
"""
Create a record and add() it to the database session.
:param Model: Declarative ORM class
:return: Model instance
"""
instance = Model(*args, **kwargs)
return add(instance)
def delete(model, *args, **kwargs):
instance = session.query(model).filter(*args, **kwargs)
for record in instance:
session.delete(record)
def delete(model: Base) -> None:
"""
Delete a set of records found by Query.filter(*args, **kwargs).
:param Model: Declarative ORM class
"""
get_session().delete(model)
def delete_all(iterable: Iterable):
with begin():
for obj in iterable:
session.delete(obj)
def delete_all(iterable: Iterable) -> None:
""" Delete each instance found in `iterable`. """
session_ = get_session()
aurweb.util.apply_all(iterable, session_.delete)
def rollback():
session.rollback()
def rollback() -> None:
""" Rollback the database session. """
get_session().rollback()
def add(model):
session.add(model)
def add(model: Base) -> Base:
""" Add `model` to the database session. """
get_session().add(model)
return model
def begin():
""" Begin an SQLAlchemy SessionTransaction.
This context is **required** to perform an modifications to the
database.
Example:
with db.begin():
object = db.create(...)
# On __exit__, db.commit() is run.
with db.begin():
object = db.delete(...)
# On __exit__, db.commit() is run.
:return: A new SessionTransaction based on session
"""
return session.begin()
""" Begin an SQLAlchemy SessionTransaction. """
return get_session().begin()
def get_sqlalchemy_url():

View file

@ -1,6 +1,6 @@
from fastapi import Request
from aurweb import schema
from aurweb import db, schema
from aurweb.models.declarative import Base
@ -10,11 +10,10 @@ class Ban(Base):
__mapper_args__ = {"primary_key": [__table__.c.IPAddress]}
def __init__(self, **kwargs):
self.IPAddress = kwargs.get("IPAddress")
self.BanTS = kwargs.get("BanTS")
super().__init__(**kwargs)
def is_banned(request: Request):
from aurweb.db import session
ip = request.client.host
return session.query(Ban).filter(Ban.IPAddress == ip).first() is not None
exists = db.query(Ban).filter(Ban.IPAddress == ip).exists()
return db.query(exists).scalar()

View file

@ -1,9 +1,10 @@
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import backref, relationship
from aurweb import schema
from aurweb import db, schema
from aurweb.models.declarative import Base
from aurweb.models.dependency_type import DependencyType as _DependencyType
from aurweb.models.official_provider import OfficialProvider as _OfficialProvider
from aurweb.models.package import Package as _Package
@ -46,11 +47,7 @@ class PackageDependency(Base):
params=("NULL"))
def is_package(self) -> bool:
# TODO: Improve the speed of this query if possible.
from aurweb import db
from aurweb.models.official_provider import OfficialProvider
from aurweb.models.package import Package
pkg = db.query(Package, Package.Name == self.DepName)
official = db.query(OfficialProvider,
OfficialProvider.Name == self.DepName)
return pkg.scalar() or official.scalar()
pkg = db.query(_Package).filter(_Package.Name == self.DepName).exists()
official = db.query(_OfficialProvider).filter(
_OfficialProvider.Name == self.DepName).exists()
return db.query(pkg).scalar() or db.query(official).scalar()

View file

@ -146,7 +146,7 @@ class User(Base):
self.authenticated = False
if self.session:
with db.begin():
db.session.delete(self.session)
db.delete(self.session)
def is_trusted_user(self):
return self.AccountType.ID in {

View file

@ -110,18 +110,26 @@ def get_pkg_or_base(
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
instance = db.query(cls).filter(cls.Name == name).first()
if cls == models.PackageBase and not instance:
if not instance:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return instance
return db.refresh(instance)
def get_pkgbase_comment(
pkgbase: models.PackageBase, id: int) -> models.PackageComment:
def get_pkgbase_comment(pkgbase: models.PackageBase, id: int) \
-> models.PackageComment:
comment = pkgbase.comments.filter(models.PackageComment.ID == id).first()
if not comment:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return comment
return db.refresh(comment)
def get_pkgreq_by_id(id: int):
pkgreq = db.query(models.PackageRequest).filter(
models.PackageRequest.ID == id).first()
if not pkgreq:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return db.refresh(pkgreq)
@register_filter("out_of_date")

View file

@ -40,8 +40,10 @@ def _update_ratelimit_db(request: Request):
now = int(datetime.utcnow().timestamp())
time_to_delete = now - window_length
records = db.query(ApiRateLimit).filter(
ApiRateLimit.WindowStart < time_to_delete)
with db.begin():
db.delete(ApiRateLimit, ApiRateLimit.WindowStart < time_to_delete)
db.delete_all(records)
host = request.client.host
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()

View file

@ -4,7 +4,7 @@ import typing
from datetime import datetime
from http import HTTPStatus
from fastapi import APIRouter, Form, HTTPException, Request
from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse, RedirectResponse
from sqlalchemy import and_, func, or_
@ -20,6 +20,7 @@ from aurweb.models.account_type import (DEVELOPER, DEVELOPER_ID, TRUSTED_USER, T
from aurweb.models.ssh_pub_key import get_fingerprint
from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification
from aurweb.templates import make_context, make_variable_context, render_template
from aurweb.users.util import get_user_by_name
router = APIRouter()
logger = logging.get_logger(__name__)
@ -49,6 +50,7 @@ async def passreset_post(request: Request,
return render_template(request, "passreset.html", context,
status_code=HTTPStatus.NOT_FOUND)
db.refresh(user)
if resetkey:
context["resetkey"] = resetkey
@ -83,7 +85,7 @@ async def passreset_post(request: Request,
with db.begin():
user.ResetKey = str()
if user.session:
db.session.delete(user.session)
db.delete(user.session)
user.update_password(password)
# Render ?step=complete.
@ -458,15 +460,15 @@ def cannot_edit(request, user):
@router.get("/account/{username}/edit", response_class=HTMLResponse)
@auth_required(True, redirect="/account/{username}")
async def account_edit(request: Request,
username: str):
async def account_edit(request: Request, username: str):
user = db.query(models.User, models.User.Username == username).first()
response = cannot_edit(request, user)
if response:
return response
context = await make_variable_context(request, "Accounts")
context["user"] = user
context["user"] = db.refresh(user)
context = make_account_form_context(context, request, user, dict())
return render_template(request, "account/edit.html", context)
@ -497,16 +499,14 @@ async def account_edit_post(request: Request,
ON: bool = Form(default=False), # Owner Notify
T: int = Form(default=None),
passwd: str = Form(default=str())):
from aurweb.db import session
user = session.query(models.User).filter(
user = db.query(models.User).filter(
models.User.Username == username).first()
response = cannot_edit(request, user)
if response:
return response
context = await make_variable_context(request, "Accounts")
context["user"] = user
context["user"] = db.refresh(user)
args = dict(await request.form())
context = make_account_form_context(context, request, user, args)
@ -575,7 +575,7 @@ async def account_edit_post(request: Request,
user.ssh_pub_key.Fingerprint = fingerprint
elif user.ssh_pub_key:
# Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key)
db.delete(user.ssh_pub_key)
if T and T != user.AccountTypeID:
with db.begin():
@ -617,27 +617,16 @@ account_template = (
status_code=HTTPStatus.UNAUTHORIZED)
async def account(request: Request, username: str):
_ = l10n.get_translator_for_request(request)
context = await make_variable_context(request,
_("Account") + " " + username)
user = db.query(models.User, models.User.Username == username).first()
if not user:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
context["user"] = user
context = await make_variable_context(
request, _("Account") + " " + username)
context["user"] = get_user_by_name(username)
return render_template(request, "account/show.html", context)
@router.get("/account/{username}/comments")
@auth_required(redirect="/account/{username}/comments")
async def account_comments(request: Request, username: str):
user = db.query(models.User).filter(
models.User.Username == username
).first()
if not user:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
user = get_user_by_name(username)
context = make_context(request, "Accounts")
context["username"] = username
context["comments"] = user.package_comments.order_by(
@ -662,7 +651,7 @@ async def accounts(request: Request):
account_type.TRUSTED_USER_AND_DEV})
async def accounts_post(request: Request,
O: int = Form(default=0), # Offset
SB: str = Form(default=str()), # Search By
SB: str = Form(default=str()), # Sort By
U: str = Form(default=str()), # Username
T: str = Form(default=str()), # Account Type
S: bool = Form(default=False), # Suspended
@ -705,23 +694,19 @@ async def accounts_post(request: Request,
# Populate this list with any additional statements to
# be ANDed together.
statements = []
if account_type_id is not None:
statements.append(models.AccountType.ID == account_type_id)
if U:
statements.append(models.User.Username.like(f"%{U}%"))
if S:
statements.append(models.User.Suspended == S)
if E:
statements.append(models.User.Email.like(f"%{E}%"))
if R:
statements.append(models.User.RealName.like(f"%{R}%"))
if I:
statements.append(models.User.IRCNick.like(f"%{I}%"))
if K:
statements.append(models.User.PGPKey.like(f"%{K}%"))
statements = [
v for k, v in [
(account_type_id is not None, models.AccountType.ID == account_type_id),
(bool(U), models.User.Username.like(f"%{U}%")),
(bool(S), models.User.Suspended == S),
(bool(E), models.User.Email.like(f"%{E}%")),
(bool(R), models.User.RealName.like(f"%{R}%")),
(bool(I), models.User.IRCNick.like(f"%{I}%")),
(bool(K), models.User.PGPKey.like(f"%{K}%")),
] if k
]
# Filter the query by combining all statements added above into
# Filter the query by coe-mbining all statements added above into
# an AND statement, unless there's just one statement, which
# we pass on to filter() as args.
if statements:
@ -729,7 +714,7 @@ async def accounts_post(request: Request,
# Finally, order and truncate our users for the current page.
users = query.order_by(*order_by).limit(pp).offset(offset)
context["users"] = users
context["users"] = util.apply_all(users, db.refresh)
return render_template(request, "account/index.html", context)
@ -755,6 +740,9 @@ async def terms_of_service(request: Request):
unaccepted = db.query(models.Term).filter(
~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all()
for record in (diffs + unaccepted):
db.refresh(record)
# Translate the 'Terms of Service' part of our page title.
_ = l10n.get_translator_for_request(request)
title = f"AUR {_('Terms of Service')}"
@ -786,18 +774,21 @@ async def terms_of_service_post(request: Request,
# We already did the database filters here, so let's just use
# them instead of reiterating the process in terms_of_service.
accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(request, context, accept_needed)
return render_terms_of_service(
request, context, util.apply_all(accept_needed, db.refresh))
with db.begin():
# For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision.
for term in diffs:
db.refresh(term)
accepted_term = request.user.accepted_terms.filter(
models.AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it!
for term in unaccepted:
db.refresh(term)
db.create(models.AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision)

View file

@ -4,7 +4,7 @@ from typing import Any, Dict, List
from fastapi import APIRouter, Form, HTTPException, Query, Request, Response
from fastapi.responses import JSONResponse, RedirectResponse
from sqlalchemy import and_, case
from sqlalchemy import case
import aurweb.filters
import aurweb.packages.util
@ -15,9 +15,9 @@ from aurweb.models.package_request import ACCEPTED_ID, PENDING_ID, REJECTED_ID
from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID
from aurweb.models.request_type import DELETION_ID, MERGE, MERGE_ID
from aurweb.packages.search import PackageSearch
from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, query_notified, query_voted
from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, get_pkgreq_by_id, query_notified, query_voted
from aurweb.scripts import notify, popupdate
from aurweb.scripts.rendercomment import update_comment_render
from aurweb.scripts.rendercomment import update_comment_render_fastapi
from aurweb.templates import make_context, make_variable_context, render_raw_template, render_template
logger = logging.get_logger(__name__)
@ -92,7 +92,10 @@ async def packages_get(request: Request, context: Dict[str, Any],
# Insert search results into the context.
results = search.results()
context["packages"] = results.limit(per_page).offset(offset)
packages = results.limit(per_page).offset(offset)
util.apply_all(packages, db.refresh)
context["packages"] = packages
context["packages_voted"] = query_voted(
context.get("packages"), request.user)
context["packages_notified"] = query_notified(
@ -132,6 +135,7 @@ def create_request_if_missing(requests: List[models.PackageRequest],
ClosedTS=now,
Closer=user)
requests.append(pkgreq)
return pkgreq
def delete_package(deleter: models.User, package: models.Package):
@ -147,8 +151,9 @@ def delete_package(deleter: models.User, package: models.Package):
).first()
with db.begin():
create_request_if_missing(
pkgreq = create_request_if_missing(
requests, reqtype, deleter, package)
db.refresh(pkgreq)
bases_to_delete.append(package.PackageBase)
@ -171,8 +176,9 @@ def delete_package(deleter: models.User, package: models.Package):
)
# Perform all the deletions.
db.delete_all([package])
db.delete_all(bases_to_delete)
with db.begin():
db.delete(package)
db.delete_all(bases_to_delete)
# Send out all the notifications.
util.apply_all(notifications, lambda n: n.send())
@ -221,8 +227,7 @@ async def make_single_context(request: Request,
async def package(request: Request, name: str) -> Response:
# Get the Package.
pkg = get_pkg_or_base(name, models.Package)
pkgbase = (get_pkg_or_base(name, models.PackageBase)
if not pkg else pkg.PackageBase)
pkgbase = pkg.PackageBase
# Add our base information.
context = await make_single_context(request, pkgbase)
@ -312,7 +317,7 @@ async def pkgbase_comments_post(
db.create(models.PackageNotification,
User=request.user,
PackageBase=pkgbase)
update_comment_render(comment.ID)
update_comment_render_fastapi(comment)
# Redirect to the pkgbase page.
return RedirectResponse(f"/pkgbase/{pkgbase.Name}#comment-{comment.ID}",
@ -374,7 +379,7 @@ async def pkgbase_comment_post(
db.create(models.PackageNotification,
User=request.user,
PackageBase=pkgbase)
update_comment_render(db_comment.ID)
update_comment_render_fastapi(db_comment)
if not next:
next = f"/pkgbase/{pkgbase.Name}"
@ -539,7 +544,7 @@ def remove_users(pkgbase, usernames):
conn, comaintainer.User.ID, pkgbase.ID
)
)
db.session.delete(comaintainer)
db.delete(comaintainer)
# Send out notifications if need be.
for notify_ in notifications:
@ -679,14 +684,8 @@ async def requests(request: Request,
@router.get("/pkgbase/{name}/request")
@auth_required(True, redirect="/pkgbase/{name}/request")
async def package_request(request: Request, name: str):
pkgbase = get_pkg_or_base(name, models.PackageBase)
context = await make_variable_context(request, "Submit Request")
pkgbase = db.query(models.PackageBase).filter(
models.PackageBase.Name == name).first()
if not pkgbase:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
context["pkgbase"] = pkgbase
return render_template(request, "pkgbase/request.html", context)
@ -729,6 +728,7 @@ async def pkgbase_request_post(request: Request, name: str,
]
return render_template(request, "pkgbase/request.html", context)
db.refresh(target)
if target.ID == pkgbase.ID:
# TODO: This error needs to be translated.
context["errors"] = [
@ -767,8 +767,7 @@ async def pkgbase_request_post(request: Request, name: str,
@router.get("/requests/{id}/close")
@auth_required(True, redirect="/requests/{id}/close")
async def requests_close(request: Request, id: int):
pkgreq = db.query(models.PackageRequest).filter(
models.PackageRequest.ID == id).first()
pkgreq = get_pkgreq_by_id(id)
if not request.user.is_elevated() and request.user != pkgreq.User:
# Request user doesn't have permission here: redirect to '/'.
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
@ -783,8 +782,7 @@ async def requests_close(request: Request, id: int):
async def requests_close_post(request: Request, id: int,
reason: int = Form(default=0),
comments: str = Form(default=str())):
pkgreq = db.query(models.PackageRequest).filter(
models.PackageRequest.ID == id).first()
pkgreq = get_pkgreq_by_id(id)
if not request.user.is_elevated() and request.user != pkgreq.User:
# Request user doesn't have permission here: redirect to '/'.
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
@ -823,13 +821,17 @@ async def pkgbase_keywords(request: Request, name: str,
keywords = set(keywords.split(" "))
# Delete all keywords which are not supplied by the user.
with db.begin():
db.delete(models.PackageKeyword,
and_(models.PackageKeyword.PackageBaseID == pkgbase.ID,
~models.PackageKeyword.Keyword.in_(keywords)))
other_keywords = pkgbase.keywords.filter(
~models.PackageKeyword.Keyword.in_(keywords))
other_keyword_strings = [kwd.Keyword for kwd in other_keywords]
existing_keywords = set(kwd.Keyword for kwd in pkgbase.keywords.all())
existing_keywords = set(
kwd.Keyword for kwd in
pkgbase.keywords.filter(
~models.PackageKeyword.Keyword.in_(other_keyword_strings))
)
with db.begin():
db.delete_all(other_keywords)
for keyword in keywords.difference(existing_keywords):
db.create(models.PackageKeyword,
PackageBase=pkgbase,
@ -940,7 +942,7 @@ def pkgbase_unnotify_instance(request: Request, pkgbase: models.PackageBase):
has_cred = request.user.has_credential("CRED_PKGBASE_NOTIFY")
if has_cred and notif:
with db.begin():
db.session.delete(notif)
db.delete(notif)
@router.post("/pkgbase/{name}/unnotify")
@ -988,7 +990,7 @@ async def pkgbase_unvote(request: Request, name: str):
has_cred = request.user.has_credential("CRED_PKGBASE_VOTE")
if has_cred and vote:
with db.begin():
db.session.delete(vote)
db.delete(vote)
# Update NumVotes/Popularity.
conn = db.ConnectionExecutor(db.get_engine().raw_connection())
@ -1015,7 +1017,7 @@ def pkgbase_disown_instance(request: Request, pkgbase: models.PackageBase):
if co:
with db.begin():
pkgbase.Maintainer = co.User
db.session.delete(co)
db.delete(co)
else:
pkgbase.Maintainer = None
@ -1463,8 +1465,8 @@ def pkgbase_merge_instance(request: Request, pkgbase: models.PackageBase,
with db.begin():
# Delete pkgbase and its packages now that everything's merged.
for pkg in pkgbase.packages:
db.session.delete(pkg)
db.session.delete(pkgbase)
db.delete(pkg)
db.delete(pkgbase)
# Accept merge requests related to this pkgbase and target.
for pkgreq in requests:

View file

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List, NewType
from sqlalchemy import and_
@ -25,6 +25,10 @@ REL_TYPES = {
}
DataGenerator = NewType("DataGenerator",
Callable[[models.Package], Dict[str, Any]])
class RPCError(Exception):
pass
@ -188,15 +192,32 @@ class RPC:
self._update_json_relations(package, data)
return data
def _handle_multiinfo_type(self, args: List[str] = [], **kwargs):
def _assemble_json_data(self, packages: List[models.Package],
data_generator: DataGenerator) \
-> List[Dict[str, Any]]:
"""
Assemble JSON data out of a list of packages.
:param packages: A list of Package instances or a Package ORM query
:param data_generator: Generator callable of single-Package JSON data
"""
output = []
for pkg in packages:
db.refresh(pkg)
output.append(data_generator(pkg))
return output
def _handle_multiinfo_type(self, args: List[str] = [], **kwargs) \
-> List[Dict[str, Any]]:
self._enforce_args(args)
args = set(args)
packages = db.query(models.Package).filter(
models.Package.Name.in_(args))
return [self._get_info_json_data(pkg) for pkg in packages]
return self._assemble_json_data(packages, self._get_info_json_data)
def _handle_search_type(self, by: str = defaults.RPC_SEARCH_BY,
args: List[str] = []):
args: List[str] = []) \
-> List[Dict[str, Any]]:
# If `by` isn't maintainer and we don't have any args, raise an error.
# In maintainer's case, return all orphans if there are no args,
# so we need args to pass through to the handler without errors.
@ -212,7 +233,7 @@ class RPC:
max_results = config.getint("options", "max_rpc_results")
results = search.results().limit(max_results)
return [self._get_json_data(pkg) for pkg in results]
return self._assemble_json_data(results, self._get_json_data)
def _handle_msearch_type(self, args: List[str] = [], **kwargs):
return self._handle_search_type(by="m", args=args)

View file

@ -29,7 +29,7 @@ def run_single(conn, pkgbase):
conn.commit()
conn.close()
aurweb.db.session.refresh(pkgbase)
aurweb.db.refresh(pkgbase)
def main():

View file

@ -129,9 +129,14 @@ def save_rendered_comment(conn, commentid, html):
[html, commentid])
def update_comment_render(commentid):
conn = aurweb.db.Connection()
def update_comment_render_fastapi(comment):
conn = aurweb.db.ConnectionExecutor(
aurweb.db.get_engine().raw_connection())
update_comment_render(conn, comment.ID)
aurweb.db.refresh(comment)
def update_comment_render(conn, commentid):
text, pkgbase = get_comment(conn, commentid)
html = markdown.markdown(text, extensions=[
'fenced_code',
@ -152,7 +157,8 @@ def update_comment_render(commentid):
def main():
commentid = int(sys.argv[1])
update_comment_render(commentid)
conn = aurweb.db.Connection()
update_comment_render(conn, commentid)
if __name__ == '__main__':

View file

@ -19,7 +19,7 @@ def references_graph(table):
"regexp_1": r'(?i)\s+references\s+("|\')?',
"regexp_2": r'("|\')?\s*\(',
}
cursor = aurweb.db.session.execute(query, params=params)
cursor = aurweb.db.get_session().execute(query, params=params)
return [row[0] for row in cursor.fetchall()]
@ -51,7 +51,7 @@ def setup_test_db(*args):
db_backend = aurweb.config.get("database", "backend")
if db_backend != "sqlite": # pragma: no cover
aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 0")
aurweb.db.get_session().execute("SET FOREIGN_KEY_CHECKS = 0")
else:
# We're using sqlite, setup tables to be deleted without violating
# foreign key constraints by graphing references.
@ -59,10 +59,10 @@ def setup_test_db(*args):
references_graph(table) for table in tables))
for table in tables:
aurweb.db.session.execute(f"DELETE FROM {table}")
aurweb.db.get_session().execute(f"DELETE FROM {table}")
if db_backend != "sqlite": # pragma: no cover
aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 1")
aurweb.db.get_session().execute("SET FOREIGN_KEY_CHECKS = 1")
# Expunge all objects from SQLAlchemy's IdentityMap.
aurweb.db.session.expunge_all()
aurweb.db.get_session().expunge_all()

0
aurweb/users/__init__.py Normal file
View file

19
aurweb/users/util.py Normal file
View file

@ -0,0 +1,19 @@
from http import HTTPStatus
from fastapi import HTTPException
from aurweb import db
from aurweb.models import User
def get_user_by_name(username: str) -> User:
"""
Query a user by its username.
:param username: User.Username
:return: User instance
"""
user = db.query(User).filter(User.Username == username).first()
if not user:
raise HTTPException(status_code=int(HTTPStatus.NOT_FOUND))
return db.refresh(user)

View file

@ -155,6 +155,7 @@ def get_ssh_fingerprints():
def apply_all(iterable: Iterable, fn: Callable):
for item in iterable:
fn(item)
return iterable
def sanitize_params(offset: str, per_page: str) -> Tuple[int, int]:

View file

@ -132,11 +132,11 @@
<tr>
<th>{{ "Votes" | tr }}:</th>
{% if not is_maintainer %}
<td>{{ pkgbase.package_votes.count() }}</td>
<td>{{ pkgbase.NumVotes }}</td>
{% else %}
<td>
<a href="/pkgbase/{{ pkgbase.Name }}/voters">
{{ pkgbase.package_votes.count() }}
{{ pkgbase.NumVotes }}
</a>
</td>
{% endif %}

View file

@ -20,7 +20,7 @@ def setup():
yield account_type
with begin():
delete(AccountType, AccountType.ID == account_type.ID)
delete(account_type)
def test_account_type():
@ -50,4 +50,4 @@ def test_user_account_type_relationship():
# This must be deleted here to avoid foreign key issues when
# deleting the temporary AccountType in the fixture.
with begin():
delete(User, User.ID == user.ID)
delete(user)

View file

@ -279,13 +279,13 @@ def test_connection_execute_paramstyle_unsupported():
def test_create_delete():
with db.begin():
db.create(AccountType, AccountType="test")
account_type = db.create(AccountType, AccountType="test")
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is not None
with db.begin():
db.delete(AccountType, AccountType.AccountType == "test")
db.delete(account_type)
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is None
@ -306,7 +306,7 @@ def test_add_commit():
# Remove the record.
with db.begin():
db.delete(AccountType, AccountType.ID == account_type.ID)
db.delete(account_type)
def test_connection_executor_mysql_paramstyle():

View file

@ -24,7 +24,7 @@ def test_dependency_type_creation():
assert bool(dependency_type.ID)
assert dependency_type.Name == "Test Type"
with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID)
delete(dependency_type)
def test_dependency_type_null_name_uses_default():
@ -32,4 +32,4 @@ def test_dependency_type_null_name_uses_default():
dependency_type = create(DependencyType)
assert dependency_type.Name == str()
with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID)
delete(dependency_type)

View file

@ -2,6 +2,7 @@ from datetime import datetime
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
from aurweb import asgi, db
@ -98,3 +99,8 @@ def test_query_notified(maintainer: User, package: Package):
query = db.query(Package).filter(Package.ID == package.ID).all()
query_notified = util.query_notified(query, maintainer)
assert query_notified[package.PackageBase.ID]
def test_pkgreq_by_id_not_found():
with pytest.raises(HTTPException):
util.get_pkgreq_by_id(0)

View file

@ -103,7 +103,7 @@ def test_ratelimit_db(get: mock.MagicMock, getboolean: mock.MagicMock,
# Delete the ApiRateLimit record.
with db.begin():
db.delete(ApiRateLimit)
db.delete(db.query(ApiRateLimit).first())
# Should be good to go again!
assert not check_ratelimit(request)

View file

@ -18,7 +18,7 @@ def test_relation_type_creation():
assert relation_type.Name == "test-relation"
with db.begin():
db.delete(RelationType, RelationType.ID == relation_type.ID)
db.delete(relation_type)
def test_relation_types():

View file

@ -18,7 +18,7 @@ def test_request_type_creation():
assert request_type.Name == "Test Request"
with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID)
db.delete(request_type)
def test_request_type_null_name_returns_empty_string():
@ -29,7 +29,7 @@ def test_request_type_null_name_returns_empty_string():
assert request_type.Name == str()
with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID)
db.delete(request_type)
def test_request_type_name_display():