From 4103ab49c9c6922e89f89a25fc3b2b5b461c1bcb Mon Sep 17 00:00:00 2001 From: Kevin Morris Date: Sun, 14 Nov 2021 15:36:06 -0800 Subject: [PATCH] housekeep(fastapi): rework aurweb.db session API Changes: ------- - Add aurweb.db.get_session() - Returns aurweb.db's global `session` instance - Provides us a way to change the implementation of the session instance without interrupting user code. - Use aurweb.db.get_session() in session API methods - Add docstrings to session API methods - Refactor aurweb.db.delete - Normalize aurweb.db.delete to an alias of session.delete - Refresh instances in places we depend on their non-PK columns being up to date. Signed-off-by: Kevin Morris --- aurweb/auth.py | 8 ++- aurweb/db.py | 89 ++++++++++++++++++++------------- aurweb/models/ban.py | 9 ++-- aurweb/models/user.py | 2 +- aurweb/packages/util.py | 18 +++++-- aurweb/ratelimit.py | 4 +- aurweb/routers/accounts.py | 49 ++++++++---------- aurweb/routers/packages.py | 68 +++++++++++++------------ aurweb/rpc.py | 31 ++++++++++-- aurweb/scripts/popupdate.py | 2 +- aurweb/scripts/rendercomment.py | 12 +++-- aurweb/testing/__init__.py | 10 ++-- aurweb/users/__init__.py | 0 aurweb/users/util.py | 19 +++++++ aurweb/util.py | 1 + test/test_account_type.py | 4 +- test/test_db.py | 6 +-- test/test_dependency_type.py | 4 +- test/test_packages_util.py | 6 +++ test/test_ratelimit.py | 2 +- test/test_relation_type.py | 2 +- test/test_request_type.py | 4 +- 22 files changed, 212 insertions(+), 138 deletions(-) create mode 100644 aurweb/users/__init__.py create mode 100644 aurweb/users/util.py diff --git a/aurweb/auth.py b/aurweb/auth.py index 38754db0..98a43fd5 100644 --- a/aurweb/auth.py +++ b/aurweb/auth.py @@ -13,7 +13,7 @@ from starlette.requests import HTTPConnection import aurweb.config -from aurweb import l10n, util +from aurweb import db, l10n, util from aurweb.models import Session, User from aurweb.models.account_type import ACCOUNT_TYPE_ID from aurweb.templates import make_variable_context, render_template @@ -98,14 +98,12 @@ class AnonymousUser: class BasicAuthBackend(AuthenticationBackend): async def authenticate(self, conn: HTTPConnection): - from aurweb.db import session - sid = conn.cookies.get("AURSID") if not sid: return (None, AnonymousUser()) now_ts = datetime.utcnow().timestamp() - record = session.query(Session).filter( + record = db.query(Session).filter( and_(Session.SessionID == sid, Session.LastUpdateTS >= now_ts)).first() @@ -116,7 +114,7 @@ class BasicAuthBackend(AuthenticationBackend): # At this point, we cannot have an invalid user if the record # exists, due to ForeignKey constraints in the schema upheld # by mysqlclient. - user = session.query(User).filter(User.ID == record.UsersID).first() + user = db.query(User).filter(User.ID == record.UsersID).first() user.nonce = util.make_nonce() user.authenticated = True diff --git a/aurweb/db.py b/aurweb/db.py index c1e80751..39232d5a 100644 --- a/aurweb/db.py +++ b/aurweb/db.py @@ -2,10 +2,10 @@ import functools import math import re -from typing import Iterable +from typing import Iterable, NewType from sqlalchemy import event -from sqlalchemy.orm import scoped_session +from sqlalchemy.orm import Query, scoped_session import aurweb.config import aurweb.util @@ -22,6 +22,9 @@ session = None # Global introspected object memo. introspected = dict() +# A mocked up type. +Base = NewType("aurweb.models.declarative_base.Base", "Base") + def make_random_value(table: str, column: str): """ Generate a unique, random value for a string column in a table. @@ -58,55 +61,69 @@ def make_random_value(table: str, column: str): return string -def query(model, *args, **kwargs): - return session.query(model).filter(*args, **kwargs) +def get_session(): + """ Return aurweb.db's global session. """ + return session -def create(model, *args, **kwargs): - instance = model(*args, **kwargs) +def refresh(model: Base) -> Base: + """ Refresh the session's knowledge of `model`. """ + get_session().refresh(model) + return model + + +def query(Model: Base, *args, **kwargs) -> Query: + """ + Perform an ORM query against the database session. + + This method also runs Query.filter on the resulting model + query with *args and **kwargs. + + :param Model: Declarative ORM class + """ + return get_session().query(Model).filter(*args, **kwargs) + + +def create(Model: Base, *args, **kwargs) -> Base: + """ + Create a record and add() it to the database session. + + :param Model: Declarative ORM class + :return: Model instance + """ + instance = Model(*args, **kwargs) return add(instance) -def delete(model, *args, **kwargs): - instance = session.query(model).filter(*args, **kwargs) - for record in instance: - session.delete(record) +def delete(model: Base) -> None: + """ + Delete a set of records found by Query.filter(*args, **kwargs). + + :param Model: Declarative ORM class + """ + get_session().delete(model) -def delete_all(iterable: Iterable): - with begin(): - for obj in iterable: - session.delete(obj) +def delete_all(iterable: Iterable) -> None: + """ Delete each instance found in `iterable`. """ + session_ = get_session() + aurweb.util.apply_all(iterable, session_.delete) -def rollback(): - session.rollback() +def rollback() -> None: + """ Rollback the database session. """ + get_session().rollback() -def add(model): - session.add(model) +def add(model: Base) -> Base: + """ Add `model` to the database session. """ + get_session().add(model) return model def begin(): - """ Begin an SQLAlchemy SessionTransaction. - - This context is **required** to perform an modifications to the - database. - - Example: - - with db.begin(): - object = db.create(...) - # On __exit__, db.commit() is run. - - with db.begin(): - object = db.delete(...) - # On __exit__, db.commit() is run. - - :return: A new SessionTransaction based on session - """ - return session.begin() + """ Begin an SQLAlchemy SessionTransaction. """ + return get_session().begin() def get_sqlalchemy_url(): diff --git a/aurweb/models/ban.py b/aurweb/models/ban.py index a70be7b9..0fcb6d2e 100644 --- a/aurweb/models/ban.py +++ b/aurweb/models/ban.py @@ -1,6 +1,6 @@ from fastapi import Request -from aurweb import schema +from aurweb import db, schema from aurweb.models.declarative import Base @@ -10,11 +10,10 @@ class Ban(Base): __mapper_args__ = {"primary_key": [__table__.c.IPAddress]} def __init__(self, **kwargs): - self.IPAddress = kwargs.get("IPAddress") - self.BanTS = kwargs.get("BanTS") + super().__init__(**kwargs) def is_banned(request: Request): - from aurweb.db import session ip = request.client.host - return session.query(Ban).filter(Ban.IPAddress == ip).first() is not None + exists = db.query(Ban).filter(Ban.IPAddress == ip).exists() + return db.query(exists).scalar() diff --git a/aurweb/models/user.py b/aurweb/models/user.py index 8db34c38..43910db9 100644 --- a/aurweb/models/user.py +++ b/aurweb/models/user.py @@ -146,7 +146,7 @@ class User(Base): self.authenticated = False if self.session: with db.begin(): - db.session.delete(self.session) + db.delete(self.session) def is_trusted_user(self): return self.AccountType.ID in { diff --git a/aurweb/packages/util.py b/aurweb/packages/util.py index cdec26f3..78f5bf18 100644 --- a/aurweb/packages/util.py +++ b/aurweb/packages/util.py @@ -110,18 +110,26 @@ def get_pkg_or_base( raise HTTPException(status_code=HTTPStatus.NOT_FOUND) instance = db.query(cls).filter(cls.Name == name).first() - if cls == models.PackageBase and not instance: + if not instance: raise HTTPException(status_code=HTTPStatus.NOT_FOUND) - return instance + return db.refresh(instance) -def get_pkgbase_comment( - pkgbase: models.PackageBase, id: int) -> models.PackageComment: +def get_pkgbase_comment(pkgbase: models.PackageBase, id: int) \ + -> models.PackageComment: comment = pkgbase.comments.filter(models.PackageComment.ID == id).first() if not comment: raise HTTPException(status_code=HTTPStatus.NOT_FOUND) - return comment + return db.refresh(comment) + + +def get_pkgreq_by_id(id: int): + pkgreq = db.query(models.PackageRequest).filter( + models.PackageRequest.ID == id).first() + if not pkgreq: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND) + return db.refresh(pkgreq) @register_filter("out_of_date") diff --git a/aurweb/ratelimit.py b/aurweb/ratelimit.py index e306f7a7..a71cb1cc 100644 --- a/aurweb/ratelimit.py +++ b/aurweb/ratelimit.py @@ -40,8 +40,10 @@ def _update_ratelimit_db(request: Request): now = int(datetime.utcnow().timestamp()) time_to_delete = now - window_length + records = db.query(ApiRateLimit).filter( + ApiRateLimit.WindowStart < time_to_delete) with db.begin(): - db.delete(ApiRateLimit, ApiRateLimit.WindowStart < time_to_delete) + db.delete_all(records) host = request.client.host record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first() diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py index 83c16ed0..aca322b5 100644 --- a/aurweb/routers/accounts.py +++ b/aurweb/routers/accounts.py @@ -4,7 +4,7 @@ import typing from datetime import datetime from http import HTTPStatus -from fastapi import APIRouter, Form, HTTPException, Request +from fastapi import APIRouter, Form, Request from fastapi.responses import HTMLResponse, RedirectResponse from sqlalchemy import and_, func, or_ @@ -20,6 +20,7 @@ from aurweb.models.account_type import (DEVELOPER, DEVELOPER_ID, TRUSTED_USER, T from aurweb.models.ssh_pub_key import get_fingerprint from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification from aurweb.templates import make_context, make_variable_context, render_template +from aurweb.users.util import get_user_by_name router = APIRouter() logger = logging.get_logger(__name__) @@ -49,6 +50,7 @@ async def passreset_post(request: Request, return render_template(request, "passreset.html", context, status_code=HTTPStatus.NOT_FOUND) + db.refresh(user) if resetkey: context["resetkey"] = resetkey @@ -83,7 +85,7 @@ async def passreset_post(request: Request, with db.begin(): user.ResetKey = str() if user.session: - db.session.delete(user.session) + db.delete(user.session) user.update_password(password) # Render ?step=complete. @@ -458,15 +460,15 @@ def cannot_edit(request, user): @router.get("/account/{username}/edit", response_class=HTMLResponse) @auth_required(True, redirect="/account/{username}") -async def account_edit(request: Request, - username: str): +async def account_edit(request: Request, username: str): user = db.query(models.User, models.User.Username == username).first() + response = cannot_edit(request, user) if response: return response context = await make_variable_context(request, "Accounts") - context["user"] = user + context["user"] = db.refresh(user) context = make_account_form_context(context, request, user, dict()) return render_template(request, "account/edit.html", context) @@ -497,16 +499,14 @@ async def account_edit_post(request: Request, ON: bool = Form(default=False), # Owner Notify T: int = Form(default=None), passwd: str = Form(default=str())): - from aurweb.db import session - - user = session.query(models.User).filter( + user = db.query(models.User).filter( models.User.Username == username).first() response = cannot_edit(request, user) if response: return response context = await make_variable_context(request, "Accounts") - context["user"] = user + context["user"] = db.refresh(user) args = dict(await request.form()) context = make_account_form_context(context, request, user, args) @@ -575,7 +575,7 @@ async def account_edit_post(request: Request, user.ssh_pub_key.Fingerprint = fingerprint elif user.ssh_pub_key: # Else, if the user has a public key already, delete it. - session.delete(user.ssh_pub_key) + db.delete(user.ssh_pub_key) if T and T != user.AccountTypeID: with db.begin(): @@ -617,27 +617,16 @@ account_template = ( status_code=HTTPStatus.UNAUTHORIZED) async def account(request: Request, username: str): _ = l10n.get_translator_for_request(request) - context = await make_variable_context(request, - _("Account") + " " + username) - - user = db.query(models.User, models.User.Username == username).first() - if not user: - raise HTTPException(status_code=HTTPStatus.NOT_FOUND) - - context["user"] = user - + context = await make_variable_context( + request, _("Account") + " " + username) + context["user"] = get_user_by_name(username) return render_template(request, "account/show.html", context) @router.get("/account/{username}/comments") @auth_required(redirect="/account/{username}/comments") async def account_comments(request: Request, username: str): - user = db.query(models.User).filter( - models.User.Username == username - ).first() - if not user: - raise HTTPException(status_code=HTTPStatus.NOT_FOUND) - + user = get_user_by_name(username) context = make_context(request, "Accounts") context["username"] = username context["comments"] = user.package_comments.order_by( @@ -725,7 +714,7 @@ async def accounts_post(request: Request, # Finally, order and truncate our users for the current page. users = query.order_by(*order_by).limit(pp).offset(offset) - context["users"] = users + context["users"] = util.apply_all(users, db.refresh) return render_template(request, "account/index.html", context) @@ -751,6 +740,9 @@ async def terms_of_service(request: Request): unaccepted = db.query(models.Term).filter( ~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all() + for record in (diffs + unaccepted): + db.refresh(record) + # Translate the 'Terms of Service' part of our page title. _ = l10n.get_translator_for_request(request) title = f"AUR {_('Terms of Service')}" @@ -782,18 +774,21 @@ async def terms_of_service_post(request: Request, # We already did the database filters here, so let's just use # them instead of reiterating the process in terms_of_service. accept_needed = sorted(unaccepted + diffs) - return render_terms_of_service(request, context, accept_needed) + return render_terms_of_service( + request, context, util.apply_all(accept_needed, db.refresh)) with db.begin(): # For each term we found, query for the matching accepted term # and update its Revision to the term's current Revision. for term in diffs: + db.refresh(term) accepted_term = request.user.accepted_terms.filter( models.AcceptedTerm.TermsID == term.ID).first() accepted_term.Revision = term.Revision # For each term that was never accepted, accept it! for term in unaccepted: + db.refresh(term) db.create(models.AcceptedTerm, User=request.user, Term=term, Revision=term.Revision) diff --git a/aurweb/routers/packages.py b/aurweb/routers/packages.py index 0949909e..07e8af72 100644 --- a/aurweb/routers/packages.py +++ b/aurweb/routers/packages.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List from fastapi import APIRouter, Form, HTTPException, Query, Request, Response from fastapi.responses import JSONResponse, RedirectResponse -from sqlalchemy import and_, case +from sqlalchemy import case import aurweb.filters import aurweb.packages.util @@ -15,9 +15,9 @@ from aurweb.models.package_request import ACCEPTED_ID, PENDING_ID, REJECTED_ID from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID from aurweb.models.request_type import DELETION_ID, MERGE, MERGE_ID from aurweb.packages.search import PackageSearch -from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, query_notified, query_voted +from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, get_pkgreq_by_id, query_notified, query_voted from aurweb.scripts import notify, popupdate -from aurweb.scripts.rendercomment import update_comment_render +from aurweb.scripts.rendercomment import update_comment_render_fastapi from aurweb.templates import make_context, make_variable_context, render_raw_template, render_template logger = logging.get_logger(__name__) @@ -92,7 +92,10 @@ async def packages_get(request: Request, context: Dict[str, Any], # Insert search results into the context. results = search.results() - context["packages"] = results.limit(per_page).offset(offset) + + packages = results.limit(per_page).offset(offset) + util.apply_all(packages, db.refresh) + context["packages"] = packages context["packages_voted"] = query_voted( context.get("packages"), request.user) context["packages_notified"] = query_notified( @@ -132,6 +135,7 @@ def create_request_if_missing(requests: List[models.PackageRequest], ClosedTS=now, Closer=user) requests.append(pkgreq) + return pkgreq def delete_package(deleter: models.User, package: models.Package): @@ -147,8 +151,9 @@ def delete_package(deleter: models.User, package: models.Package): ).first() with db.begin(): - create_request_if_missing( + pkgreq = create_request_if_missing( requests, reqtype, deleter, package) + db.refresh(pkgreq) bases_to_delete.append(package.PackageBase) @@ -171,8 +176,9 @@ def delete_package(deleter: models.User, package: models.Package): ) # Perform all the deletions. - db.delete_all([package]) - db.delete_all(bases_to_delete) + with db.begin(): + db.delete(package) + db.delete_all(bases_to_delete) # Send out all the notifications. util.apply_all(notifications, lambda n: n.send()) @@ -221,8 +227,7 @@ async def make_single_context(request: Request, async def package(request: Request, name: str) -> Response: # Get the Package. pkg = get_pkg_or_base(name, models.Package) - pkgbase = (get_pkg_or_base(name, models.PackageBase) - if not pkg else pkg.PackageBase) + pkgbase = pkg.PackageBase # Add our base information. context = await make_single_context(request, pkgbase) @@ -312,7 +317,7 @@ async def pkgbase_comments_post( db.create(models.PackageNotification, User=request.user, PackageBase=pkgbase) - update_comment_render(comment.ID) + update_comment_render_fastapi(comment) # Redirect to the pkgbase page. return RedirectResponse(f"/pkgbase/{pkgbase.Name}#comment-{comment.ID}", @@ -374,7 +379,7 @@ async def pkgbase_comment_post( db.create(models.PackageNotification, User=request.user, PackageBase=pkgbase) - update_comment_render(db_comment.ID) + update_comment_render_fastapi(db_comment) if not next: next = f"/pkgbase/{pkgbase.Name}" @@ -539,7 +544,7 @@ def remove_users(pkgbase, usernames): conn, comaintainer.User.ID, pkgbase.ID ) ) - db.session.delete(comaintainer) + db.delete(comaintainer) # Send out notifications if need be. for notify_ in notifications: @@ -679,14 +684,8 @@ async def requests(request: Request, @router.get("/pkgbase/{name}/request") @auth_required(True, redirect="/pkgbase/{name}/request") async def package_request(request: Request, name: str): + pkgbase = get_pkg_or_base(name, models.PackageBase) context = await make_variable_context(request, "Submit Request") - - pkgbase = db.query(models.PackageBase).filter( - models.PackageBase.Name == name).first() - - if not pkgbase: - raise HTTPException(status_code=HTTPStatus.NOT_FOUND) - context["pkgbase"] = pkgbase return render_template(request, "pkgbase/request.html", context) @@ -729,6 +728,7 @@ async def pkgbase_request_post(request: Request, name: str, ] return render_template(request, "pkgbase/request.html", context) + db.refresh(target) if target.ID == pkgbase.ID: # TODO: This error needs to be translated. context["errors"] = [ @@ -767,8 +767,7 @@ async def pkgbase_request_post(request: Request, name: str, @router.get("/requests/{id}/close") @auth_required(True, redirect="/requests/{id}/close") async def requests_close(request: Request, id: int): - pkgreq = db.query(models.PackageRequest).filter( - models.PackageRequest.ID == id).first() + pkgreq = get_pkgreq_by_id(id) if not request.user.is_elevated() and request.user != pkgreq.User: # Request user doesn't have permission here: redirect to '/'. return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER) @@ -783,8 +782,7 @@ async def requests_close(request: Request, id: int): async def requests_close_post(request: Request, id: int, reason: int = Form(default=0), comments: str = Form(default=str())): - pkgreq = db.query(models.PackageRequest).filter( - models.PackageRequest.ID == id).first() + pkgreq = get_pkgreq_by_id(id) if not request.user.is_elevated() and request.user != pkgreq.User: # Request user doesn't have permission here: redirect to '/'. return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER) @@ -823,13 +821,17 @@ async def pkgbase_keywords(request: Request, name: str, keywords = set(keywords.split(" ")) # Delete all keywords which are not supplied by the user. - with db.begin(): - db.delete(models.PackageKeyword, - and_(models.PackageKeyword.PackageBaseID == pkgbase.ID, - ~models.PackageKeyword.Keyword.in_(keywords))) + other_keywords = pkgbase.keywords.filter( + ~models.PackageKeyword.Keyword.in_(keywords)) + other_keyword_strings = [kwd.Keyword for kwd in other_keywords] - existing_keywords = set(kwd.Keyword for kwd in pkgbase.keywords.all()) + existing_keywords = set( + kwd.Keyword for kwd in + pkgbase.keywords.filter( + ~models.PackageKeyword.Keyword.in_(other_keyword_strings)) + ) with db.begin(): + db.delete_all(other_keywords) for keyword in keywords.difference(existing_keywords): db.create(models.PackageKeyword, PackageBase=pkgbase, @@ -940,7 +942,7 @@ def pkgbase_unnotify_instance(request: Request, pkgbase: models.PackageBase): has_cred = request.user.has_credential("CRED_PKGBASE_NOTIFY") if has_cred and notif: with db.begin(): - db.session.delete(notif) + db.delete(notif) @router.post("/pkgbase/{name}/unnotify") @@ -988,7 +990,7 @@ async def pkgbase_unvote(request: Request, name: str): has_cred = request.user.has_credential("CRED_PKGBASE_VOTE") if has_cred and vote: with db.begin(): - db.session.delete(vote) + db.delete(vote) # Update NumVotes/Popularity. conn = db.ConnectionExecutor(db.get_engine().raw_connection()) @@ -1015,7 +1017,7 @@ def pkgbase_disown_instance(request: Request, pkgbase: models.PackageBase): if co: with db.begin(): pkgbase.Maintainer = co.User - db.session.delete(co) + db.delete(co) else: pkgbase.Maintainer = None @@ -1463,8 +1465,8 @@ def pkgbase_merge_instance(request: Request, pkgbase: models.PackageBase, with db.begin(): # Delete pkgbase and its packages now that everything's merged. for pkg in pkgbase.packages: - db.session.delete(pkg) - db.session.delete(pkgbase) + db.delete(pkg) + db.delete(pkgbase) # Accept merge requests related to this pkgbase and target. for pkgreq in requests: diff --git a/aurweb/rpc.py b/aurweb/rpc.py index 4ab005af..03662790 100644 --- a/aurweb/rpc.py +++ b/aurweb/rpc.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List, NewType from sqlalchemy import and_ @@ -25,6 +25,10 @@ REL_TYPES = { } +DataGenerator = NewType("DataGenerator", + Callable[[models.Package], Dict[str, Any]]) + + class RPCError(Exception): pass @@ -188,15 +192,32 @@ class RPC: self._update_json_relations(package, data) return data - def _handle_multiinfo_type(self, args: List[str] = [], **kwargs): + def _assemble_json_data(self, packages: List[models.Package], + data_generator: DataGenerator) \ + -> List[Dict[str, Any]]: + """ + Assemble JSON data out of a list of packages. + + :param packages: A list of Package instances or a Package ORM query + :param data_generator: Generator callable of single-Package JSON data + """ + output = [] + for pkg in packages: + db.refresh(pkg) + output.append(data_generator(pkg)) + return output + + def _handle_multiinfo_type(self, args: List[str] = [], **kwargs) \ + -> List[Dict[str, Any]]: self._enforce_args(args) args = set(args) packages = db.query(models.Package).filter( models.Package.Name.in_(args)) - return [self._get_info_json_data(pkg) for pkg in packages] + return self._assemble_json_data(packages, self._get_info_json_data) def _handle_search_type(self, by: str = defaults.RPC_SEARCH_BY, - args: List[str] = []): + args: List[str] = []) \ + -> List[Dict[str, Any]]: # If `by` isn't maintainer and we don't have any args, raise an error. # In maintainer's case, return all orphans if there are no args, # so we need args to pass through to the handler without errors. @@ -212,7 +233,7 @@ class RPC: max_results = config.getint("options", "max_rpc_results") results = search.results().limit(max_results) - return [self._get_json_data(pkg) for pkg in results] + return self._assemble_json_data(results, self._get_json_data) def _handle_msearch_type(self, args: List[str] = [], **kwargs): return self._handle_search_type(by="m", args=args) diff --git a/aurweb/scripts/popupdate.py b/aurweb/scripts/popupdate.py index fa82208d..db4ba170 100755 --- a/aurweb/scripts/popupdate.py +++ b/aurweb/scripts/popupdate.py @@ -29,7 +29,7 @@ def run_single(conn, pkgbase): conn.commit() conn.close() - aurweb.db.session.refresh(pkgbase) + aurweb.db.refresh(pkgbase) def main(): diff --git a/aurweb/scripts/rendercomment.py b/aurweb/scripts/rendercomment.py index a00448d8..efa5357f 100755 --- a/aurweb/scripts/rendercomment.py +++ b/aurweb/scripts/rendercomment.py @@ -129,9 +129,14 @@ def save_rendered_comment(conn, commentid, html): [html, commentid]) -def update_comment_render(commentid): - conn = aurweb.db.Connection() +def update_comment_render_fastapi(comment): + conn = aurweb.db.ConnectionExecutor( + aurweb.db.get_engine().raw_connection()) + update_comment_render(conn, comment.ID) + aurweb.db.refresh(comment) + +def update_comment_render(conn, commentid): text, pkgbase = get_comment(conn, commentid) html = markdown.markdown(text, extensions=[ 'fenced_code', @@ -152,7 +157,8 @@ def update_comment_render(commentid): def main(): commentid = int(sys.argv[1]) - update_comment_render(commentid) + conn = aurweb.db.Connection() + update_comment_render(conn, commentid) if __name__ == '__main__': diff --git a/aurweb/testing/__init__.py b/aurweb/testing/__init__.py index 65d34253..2dd377e1 100644 --- a/aurweb/testing/__init__.py +++ b/aurweb/testing/__init__.py @@ -19,7 +19,7 @@ def references_graph(table): "regexp_1": r'(?i)\s+references\s+("|\')?', "regexp_2": r'("|\')?\s*\(', } - cursor = aurweb.db.session.execute(query, params=params) + cursor = aurweb.db.get_session().execute(query, params=params) return [row[0] for row in cursor.fetchall()] @@ -51,7 +51,7 @@ def setup_test_db(*args): db_backend = aurweb.config.get("database", "backend") if db_backend != "sqlite": # pragma: no cover - aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 0") + aurweb.db.get_session().execute("SET FOREIGN_KEY_CHECKS = 0") else: # We're using sqlite, setup tables to be deleted without violating # foreign key constraints by graphing references. @@ -59,10 +59,10 @@ def setup_test_db(*args): references_graph(table) for table in tables)) for table in tables: - aurweb.db.session.execute(f"DELETE FROM {table}") + aurweb.db.get_session().execute(f"DELETE FROM {table}") if db_backend != "sqlite": # pragma: no cover - aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 1") + aurweb.db.get_session().execute("SET FOREIGN_KEY_CHECKS = 1") # Expunge all objects from SQLAlchemy's IdentityMap. - aurweb.db.session.expunge_all() + aurweb.db.get_session().expunge_all() diff --git a/aurweb/users/__init__.py b/aurweb/users/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aurweb/users/util.py b/aurweb/users/util.py new file mode 100644 index 00000000..e9635f08 --- /dev/null +++ b/aurweb/users/util.py @@ -0,0 +1,19 @@ +from http import HTTPStatus + +from fastapi import HTTPException + +from aurweb import db +from aurweb.models import User + + +def get_user_by_name(username: str) -> User: + """ + Query a user by its username. + + :param username: User.Username + :return: User instance + """ + user = db.query(User).filter(User.Username == username).first() + if not user: + raise HTTPException(status_code=int(HTTPStatus.NOT_FOUND)) + return db.refresh(user) diff --git a/aurweb/util.py b/aurweb/util.py index 88142cbc..1c2042fa 100644 --- a/aurweb/util.py +++ b/aurweb/util.py @@ -155,6 +155,7 @@ def get_ssh_fingerprints(): def apply_all(iterable: Iterable, fn: Callable): for item in iterable: fn(item) + return iterable def sanitize_params(offset: str, per_page: str) -> Tuple[int, int]: diff --git a/test/test_account_type.py b/test/test_account_type.py index 86e68253..12472348 100644 --- a/test/test_account_type.py +++ b/test/test_account_type.py @@ -20,7 +20,7 @@ def setup(): yield account_type with begin(): - delete(AccountType, AccountType.ID == account_type.ID) + delete(account_type) def test_account_type(): @@ -50,4 +50,4 @@ def test_user_account_type_relationship(): # This must be deleted here to avoid foreign key issues when # deleting the temporary AccountType in the fixture. with begin(): - delete(User, User.ID == user.ID) + delete(user) diff --git a/test/test_db.py b/test/test_db.py index 7798d2f6..8283a957 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -279,13 +279,13 @@ def test_connection_execute_paramstyle_unsupported(): def test_create_delete(): with db.begin(): - db.create(AccountType, AccountType="test") + account_type = db.create(AccountType, AccountType="test") record = db.query(AccountType, AccountType.AccountType == "test").first() assert record is not None with db.begin(): - db.delete(AccountType, AccountType.AccountType == "test") + db.delete(account_type) record = db.query(AccountType, AccountType.AccountType == "test").first() assert record is None @@ -306,7 +306,7 @@ def test_add_commit(): # Remove the record. with db.begin(): - db.delete(AccountType, AccountType.ID == account_type.ID) + db.delete(account_type) def test_connection_executor_mysql_paramstyle(): diff --git a/test/test_dependency_type.py b/test/test_dependency_type.py index 4d555123..cb8dece4 100644 --- a/test/test_dependency_type.py +++ b/test/test_dependency_type.py @@ -24,7 +24,7 @@ def test_dependency_type_creation(): assert bool(dependency_type.ID) assert dependency_type.Name == "Test Type" with begin(): - delete(DependencyType, DependencyType.ID == dependency_type.ID) + delete(dependency_type) def test_dependency_type_null_name_uses_default(): @@ -32,4 +32,4 @@ def test_dependency_type_null_name_uses_default(): dependency_type = create(DependencyType) assert dependency_type.Name == str() with begin(): - delete(DependencyType, DependencyType.ID == dependency_type.ID) + delete(dependency_type) diff --git a/test/test_packages_util.py b/test/test_packages_util.py index 1396734b..622c08c2 100644 --- a/test/test_packages_util.py +++ b/test/test_packages_util.py @@ -2,6 +2,7 @@ from datetime import datetime import pytest +from fastapi import HTTPException from fastapi.testclient import TestClient from aurweb import asgi, db @@ -98,3 +99,8 @@ def test_query_notified(maintainer: User, package: Package): query = db.query(Package).filter(Package.ID == package.ID).all() query_notified = util.query_notified(query, maintainer) assert query_notified[package.PackageBase.ID] + + +def test_pkgreq_by_id_not_found(): + with pytest.raises(HTTPException): + util.get_pkgreq_by_id(0) diff --git a/test/test_ratelimit.py b/test/test_ratelimit.py index 2634b714..0a72a7e4 100644 --- a/test/test_ratelimit.py +++ b/test/test_ratelimit.py @@ -103,7 +103,7 @@ def test_ratelimit_db(get: mock.MagicMock, getboolean: mock.MagicMock, # Delete the ApiRateLimit record. with db.begin(): - db.delete(ApiRateLimit) + db.delete(db.query(ApiRateLimit).first()) # Should be good to go again! assert not check_ratelimit(request) diff --git a/test/test_relation_type.py b/test/test_relation_type.py index fbc22c71..d2dabceb 100644 --- a/test/test_relation_type.py +++ b/test/test_relation_type.py @@ -18,7 +18,7 @@ def test_relation_type_creation(): assert relation_type.Name == "test-relation" with db.begin(): - db.delete(RelationType, RelationType.ID == relation_type.ID) + db.delete(relation_type) def test_relation_types(): diff --git a/test/test_request_type.py b/test/test_request_type.py index 8d21c2d9..0db24921 100644 --- a/test/test_request_type.py +++ b/test/test_request_type.py @@ -18,7 +18,7 @@ def test_request_type_creation(): assert request_type.Name == "Test Request" with db.begin(): - db.delete(RequestType, RequestType.ID == request_type.ID) + db.delete(request_type) def test_request_type_null_name_returns_empty_string(): @@ -29,7 +29,7 @@ def test_request_type_null_name_returns_empty_string(): assert request_type.Name == str() with db.begin(): - db.delete(RequestType, RequestType.ID == request_type.ID) + db.delete(request_type) def test_request_type_name_display():