diff --git a/aurweb/auth/__init__.py b/aurweb/auth/__init__.py index 0c8bba69..b8056f91 100644 --- a/aurweb/auth/__init__.py +++ b/aurweb/auth/__init__.py @@ -96,6 +96,7 @@ class AnonymousUser: class BasicAuthBackend(AuthenticationBackend): + @db.async_retry_deadlock async def authenticate(self, conn: HTTPConnection): unauthenticated = (None, AnonymousUser()) sid = conn.cookies.get("AURSID") @@ -122,8 +123,7 @@ class BasicAuthBackend(AuthenticationBackend): # At this point, we cannot have an invalid user if the record # exists, due to ForeignKey constraints in the schema upheld # by mysqlclient. - with db.begin(): - user = db.query(User).filter(User.ID == record.UsersID).first() + user = db.query(User).filter(User.ID == record.UsersID).first() user.nonce = util.make_nonce() user.authenticated = True diff --git a/aurweb/db.py b/aurweb/db.py index 7425d928..ab0f80b8 100644 --- a/aurweb/db.py +++ b/aurweb/db.py @@ -161,6 +161,46 @@ def begin(): return get_session().begin() +def retry_deadlock(func): + from sqlalchemy.exc import OperationalError + + def wrapper(*args, _i: int = 0, **kwargs): + # Retry 10 times, then raise the exception + # If we fail before the 10th, recurse into `wrapper` + # If we fail on the 10th, continue to throw the exception + limit = 10 + try: + return func(*args, **kwargs) + except OperationalError as exc: + if _i < limit and "Deadlock found" in str(exc): + # Retry on deadlock by recursing into `wrapper` + return wrapper(*args, _i=_i + 1, **kwargs) + # Otherwise, just raise the exception + raise exc + + return wrapper + + +def async_retry_deadlock(func): + from sqlalchemy.exc import OperationalError + + async def wrapper(*args, _i: int = 0, **kwargs): + # Retry 10 times, then raise the exception + # If we fail before the 10th, recurse into `wrapper` + # If we fail on the 10th, continue to throw the exception + limit = 10 + try: + return await func(*args, **kwargs) + except OperationalError as exc: + if _i < limit and "Deadlock found" in str(exc): + # Retry on deadlock by recursing into `wrapper` + return await wrapper(*args, _i=_i + 1, **kwargs) + # Otherwise, just raise the exception + raise exc + + return wrapper + + def get_sqlalchemy_url(): """ Build an SQLAlchemy URL for use with create_engine. diff --git a/aurweb/models/user.py b/aurweb/models/user.py index 0404c77a..0d638677 100644 --- a/aurweb/models/user.py +++ b/aurweb/models/user.py @@ -151,7 +151,7 @@ class User(Base): return has_credential(self, credential, approved) - def logout(self, request: Request): + def logout(self, request: Request) -> None: self.authenticated = False if self.session: with db.begin(): diff --git a/aurweb/packages/requests.py b/aurweb/packages/requests.py index 7309a880..c09082f5 100644 --- a/aurweb/packages/requests.py +++ b/aurweb/packages/requests.py @@ -151,6 +151,7 @@ def close_pkgreq( pkgreq.ClosedTS = now +@db.retry_deadlock def handle_request( request: Request, reqtype_id: int, pkgbase: PackageBase, target: PackageBase = None ) -> list[notify.Notification]: @@ -239,15 +240,19 @@ def handle_request( to_accept.append(pkgreq) # Update requests with their new status and closures. - with db.begin(): - util.apply_all( - to_accept, - lambda p: close_pkgreq(p, request.user, pkgbase, target, ACCEPTED_ID), - ) - util.apply_all( - to_reject, - lambda p: close_pkgreq(p, request.user, pkgbase, target, REJECTED_ID), - ) + @db.retry_deadlock + def retry_closures(): + with db.begin(): + util.apply_all( + to_accept, + lambda p: close_pkgreq(p, request.user, pkgbase, target, ACCEPTED_ID), + ) + util.apply_all( + to_reject, + lambda p: close_pkgreq(p, request.user, pkgbase, target, REJECTED_ID), + ) + + retry_closures() # Create RequestCloseNotifications for all requests involved. for pkgreq in to_accept + to_reject: diff --git a/aurweb/packages/util.py b/aurweb/packages/util.py index 1ae7f9fe..b6ba7e20 100644 --- a/aurweb/packages/util.py +++ b/aurweb/packages/util.py @@ -99,8 +99,7 @@ def get_pkg_or_base( :raises HTTPException: With status code 404 if record doesn't exist :return: {Package,PackageBase} instance """ - with db.begin(): - instance = db.query(cls).filter(cls.Name == name).first() + instance = db.query(cls).filter(cls.Name == name).first() if not instance: raise HTTPException(status_code=HTTPStatus.NOT_FOUND) return instance @@ -133,16 +132,15 @@ def updated_packages(limit: int = 0, cache_ttl: int = 600) -> list[models.Packag # If we already have a cache, deserialize it and return. return orjson.loads(packages) - with db.begin(): - query = ( - db.query(models.Package) - .join(models.PackageBase) - .filter(models.PackageBase.PackagerUID.isnot(None)) - .order_by(models.PackageBase.ModifiedTS.desc()) - ) + query = ( + db.query(models.Package) + .join(models.PackageBase) + .filter(models.PackageBase.PackagerUID.isnot(None)) + .order_by(models.PackageBase.ModifiedTS.desc()) + ) - if limit: - query = query.limit(limit) + if limit: + query = query.limit(limit) packages = [] for pkg in query: diff --git a/aurweb/pkgbase/actions.py b/aurweb/pkgbase/actions.py index 9e7b0df5..a453cb36 100644 --- a/aurweb/pkgbase/actions.py +++ b/aurweb/pkgbase/actions.py @@ -2,7 +2,7 @@ from fastapi import Request from aurweb import db, logging, util from aurweb.auth import creds -from aurweb.models import PackageBase +from aurweb.models import PackageBase, User from aurweb.models.package_comaintainer import PackageComaintainer from aurweb.models.package_notification import PackageNotification from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID @@ -13,6 +13,12 @@ from aurweb.scripts import notify, popupdate logger = logging.get_logger(__name__) +@db.retry_deadlock +def _retry_notify(user: User, pkgbase: PackageBase) -> None: + with db.begin(): + db.create(PackageNotification, PackageBase=pkgbase, User=user) + + def pkgbase_notify_instance(request: Request, pkgbase: PackageBase) -> None: notif = db.query( pkgbase.notifications.filter( @@ -21,8 +27,13 @@ def pkgbase_notify_instance(request: Request, pkgbase: PackageBase) -> None: ).scalar() has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY) if has_cred and not notif: - with db.begin(): - db.create(PackageNotification, PackageBase=pkgbase, User=request.user) + _retry_notify(request.user, pkgbase) + + +@db.retry_deadlock +def _retry_unnotify(notif: PackageNotification, pkgbase: PackageBase) -> None: + with db.begin(): + db.delete(notif) def pkgbase_unnotify_instance(request: Request, pkgbase: PackageBase) -> None: @@ -31,8 +42,15 @@ def pkgbase_unnotify_instance(request: Request, pkgbase: PackageBase) -> None: ).first() has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY) if has_cred and notif: - with db.begin(): - db.delete(notif) + _retry_unnotify(notif, pkgbase) + + +@db.retry_deadlock +def _retry_unflag(pkgbase: PackageBase) -> None: + with db.begin(): + pkgbase.OutOfDateTS = None + pkgbase.Flagger = None + pkgbase.FlaggerComment = str() def pkgbase_unflag_instance(request: Request, pkgbase: PackageBase) -> None: @@ -42,20 +60,17 @@ def pkgbase_unflag_instance(request: Request, pkgbase: PackageBase) -> None: + [c.User for c in pkgbase.comaintainers], ) if has_cred: - with db.begin(): - pkgbase.OutOfDateTS = None - pkgbase.Flagger = None - pkgbase.FlaggerComment = str() + _retry_unflag(pkgbase) -def pkgbase_disown_instance(request: Request, pkgbase: PackageBase) -> None: - disowner = request.user - notifs = [notify.DisownNotification(disowner.ID, pkgbase.ID)] +@db.retry_deadlock +def _retry_disown(request: Request, pkgbase: PackageBase): + notifs: list[notify.Notification] = [] - is_maint = disowner == pkgbase.Maintainer + is_maint = request.user == pkgbase.Maintainer comaint = pkgbase.comaintainers.filter( - PackageComaintainer.User == disowner + PackageComaintainer.User == request.user ).one_or_none() is_comaint = comaint is not None @@ -85,38 +100,48 @@ def pkgbase_disown_instance(request: Request, pkgbase: PackageBase) -> None: pkgbase.Maintainer = None db.delete_all(pkgbase.comaintainers) + return notifs + + +def pkgbase_disown_instance(request: Request, pkgbase: PackageBase) -> None: + disowner = request.user + notifs = [notify.DisownNotification(disowner.ID, pkgbase.ID)] + notifs += _retry_disown(request, pkgbase) util.apply_all(notifs, lambda n: n.send()) -def pkgbase_adopt_instance(request: Request, pkgbase: PackageBase) -> None: +@db.retry_deadlock +def _retry_adopt(request: Request, pkgbase: PackageBase) -> None: with db.begin(): pkgbase.Maintainer = request.user + +def pkgbase_adopt_instance(request: Request, pkgbase: PackageBase) -> None: + _retry_adopt(request, pkgbase) notif = notify.AdoptNotification(request.user.ID, pkgbase.ID) notif.send() +@db.retry_deadlock +def _retry_delete(pkgbase: PackageBase, comments: str) -> None: + with db.begin(): + update_closure_comment(pkgbase, DELETION_ID, comments) + db.delete(pkgbase) + + def pkgbase_delete_instance( request: Request, pkgbase: PackageBase, comments: str = str() ) -> list[notify.Notification]: notif = notify.DeleteNotification(request.user.ID, pkgbase.ID) notifs = handle_request(request, DELETION_ID, pkgbase) + [notif] - with db.begin(): - update_closure_comment(pkgbase, DELETION_ID, comments) - db.delete(pkgbase) + _retry_delete(pkgbase, comments) return notifs -def pkgbase_merge_instance( - request: Request, pkgbase: PackageBase, target: PackageBase, comments: str = str() -) -> None: - pkgbasename = str(pkgbase.Name) - - # Create notifications. - notifs = handle_request(request, MERGE_ID, pkgbase, target) - +@db.retry_deadlock +def _retry_merge(pkgbase: PackageBase, target: PackageBase) -> None: # Target votes and notifications sets of user IDs that are # looking to be migrated. target_votes = set(v.UsersID for v in target.package_votes) @@ -146,6 +171,20 @@ def pkgbase_merge_instance( db.delete(pkg) db.delete(pkgbase) + +def pkgbase_merge_instance( + request: Request, + pkgbase: PackageBase, + target: PackageBase, + comments: str = str(), +) -> None: + pkgbasename = str(pkgbase.Name) + + # Create notifications. + notifs = handle_request(request, MERGE_ID, pkgbase, target) + + _retry_merge(pkgbase, target) + # Log this out for accountability purposes. logger.info( f"Trusted User '{request.user.Username}' merged " diff --git a/aurweb/pkgbase/util.py b/aurweb/pkgbase/util.py index 223c3013..968135d1 100644 --- a/aurweb/pkgbase/util.py +++ b/aurweb/pkgbase/util.py @@ -106,6 +106,7 @@ def remove_comaintainer( return notif +@db.retry_deadlock def remove_comaintainers(pkgbase: PackageBase, usernames: list[str]) -> None: """ Remove comaintainers from `pkgbase`. @@ -155,6 +156,7 @@ class NoopComaintainerNotification: return +@db.retry_deadlock def add_comaintainer( pkgbase: PackageBase, comaintainer: User ) -> notify.ComaintainerAddNotification: diff --git a/aurweb/ratelimit.py b/aurweb/ratelimit.py index cb08cdf5..97923a52 100644 --- a/aurweb/ratelimit.py +++ b/aurweb/ratelimit.py @@ -38,17 +38,26 @@ def _update_ratelimit_db(request: Request): now = time.utcnow() time_to_delete = now - window_length + @db.retry_deadlock + def retry_delete(records: list[ApiRateLimit]) -> None: + with db.begin(): + db.delete_all(records) + records = db.query(ApiRateLimit).filter(ApiRateLimit.WindowStart < time_to_delete) - with db.begin(): - db.delete_all(records) + retry_delete(records) + + @db.retry_deadlock + def retry_create(record: ApiRateLimit, now: int, host: str) -> ApiRateLimit: + with db.begin(): + if not record: + record = db.create(ApiRateLimit, WindowStart=now, IP=host, Requests=1) + else: + record.Requests += 1 + return record host = request.client.host record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first() - with db.begin(): - if not record: - record = db.create(ApiRateLimit, WindowStart=now, IP=host, Requests=1) - else: - record.Requests += 1 + record = retry_create(record, now, host) logger.debug(record.Requests) return record diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py index db05955a..3937757a 100644 --- a/aurweb/routers/accounts.py +++ b/aurweb/routers/accounts.py @@ -32,6 +32,7 @@ async def passreset(request: Request): return render_template(request, "passreset.html", context) +@db.async_retry_deadlock @router.post("/passreset", response_class=HTMLResponse) @handle_form_exceptions @requires_guest @@ -260,6 +261,7 @@ async def account_register( return render_template(request, "register.html", context) +@db.async_retry_deadlock @router.post("/register", response_class=HTMLResponse) @handle_form_exceptions @requires_guest @@ -336,18 +338,15 @@ async def account_register_post( AccountType=atype, ) - # If a PK was given and either one does not exist or the given - # PK mismatches the existing user's SSHPubKey.PubKey. - if PK: - # Get the second element in the PK, which is the actual key. - keys = util.parse_ssh_keys(PK.strip()) - for k in keys: - pk = " ".join(k) - fprint = get_fingerprint(pk) - with db.begin(): - db.create( - models.SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fprint - ) + # If a PK was given and either one does not exist or the given + # PK mismatches the existing user's SSHPubKey.PubKey. + if PK: + # Get the second element in the PK, which is the actual key. + keys = util.parse_ssh_keys(PK.strip()) + for k in keys: + pk = " ".join(k) + fprint = get_fingerprint(pk) + db.create(models.SSHPubKey, User=user, PubKey=pk, Fingerprint=fprint) # Send a reset key notification to the new user. WelcomeNotification(user.ID).send() @@ -458,6 +457,8 @@ async def account_edit_post( update.password, ] + # These update functions are all guarded by retry_deadlock; + # there's no need to guard this route itself. for f in updates: f(**args, request=request, user=user, context=context) @@ -633,6 +634,7 @@ async def terms_of_service(request: Request): return render_terms_of_service(request, context, accept_needed) +@db.async_retry_deadlock @router.post("/tos") @handle_form_exceptions @requires_auth diff --git a/aurweb/routers/auth.py b/aurweb/routers/auth.py index 3f94952e..0e675559 100644 --- a/aurweb/routers/auth.py +++ b/aurweb/routers/auth.py @@ -28,6 +28,11 @@ async def login_get(request: Request, next: str = "/"): return await login_template(request, next) +@db.retry_deadlock +def _retry_login(request: Request, user: User, passwd: str, cookie_timeout: int) -> str: + return user.login(request, passwd, cookie_timeout) + + @router.post("/login", response_class=HTMLResponse) @handle_form_exceptions @requires_guest @@ -48,13 +53,16 @@ async def login_post( status_code=HTTPStatus.BAD_REQUEST, detail=_("Bad Referer header.") ) - with db.begin(): - user = ( - db.query(User) - .filter(or_(User.Username == user, User.Email == user)) - .first() + user = ( + db.query(User) + .filter( + or_( + User.Username == user, + User.Email == user, + ) ) - + .first() + ) if not user: return await login_template(request, next, errors=["Bad username or password."]) @@ -62,7 +70,7 @@ async def login_post( return await login_template(request, next, errors=["Account Suspended"]) cookie_timeout = cookies.timeout(remember_me) - sid = user.login(request, passwd, cookie_timeout) + sid = _retry_login(request, user, passwd, cookie_timeout) if not sid: return await login_template(request, next, errors=["Bad username or password."]) @@ -101,12 +109,17 @@ async def login_post( return response +@db.retry_deadlock +def _retry_logout(request: Request) -> None: + request.user.logout(request) + + @router.post("/logout") @handle_form_exceptions @requires_auth async def logout(request: Request, next: str = Form(default="/")): if request.user.is_authenticated(): - request.user.logout(request) + _retry_logout(request) # Use 303 since we may be handling a post request, that'll get it # to redirect to a get request. diff --git a/aurweb/routers/html.py b/aurweb/routers/html.py index 2148d535..da1ffd55 100644 --- a/aurweb/routers/html.py +++ b/aurweb/routers/html.py @@ -35,6 +35,7 @@ async def favicon(request: Request): return RedirectResponse("/static/images/favicon.ico") +@db.async_retry_deadlock @router.post("/language", response_class=RedirectResponse) @handle_form_exceptions async def language( diff --git a/aurweb/routers/pkgbase.py b/aurweb/routers/pkgbase.py index 076aec1e..3b1ab688 100644 --- a/aurweb/routers/pkgbase.py +++ b/aurweb/routers/pkgbase.py @@ -87,6 +87,7 @@ async def pkgbase_flag_comment(request: Request, name: str): return render_template(request, "pkgbase/flag-comment.html", context) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/keywords") @handle_form_exceptions async def pkgbase_keywords( @@ -139,6 +140,7 @@ async def pkgbase_flag_get(request: Request, name: str): return render_template(request, "pkgbase/flag.html", context) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/flag") @handle_form_exceptions @requires_auth @@ -170,6 +172,7 @@ async def pkgbase_flag_post( return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/comments") @handle_form_exceptions @requires_auth @@ -279,6 +282,7 @@ async def pkgbase_comment_edit( return render_template(request, "pkgbase/comments/edit.html", context) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/comments/{id}") @handle_form_exceptions @requires_auth @@ -324,6 +328,7 @@ async def pkgbase_comment_post( ) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/comments/{id}/pin") @handle_form_exceptions @requires_auth @@ -362,6 +367,7 @@ async def pkgbase_comment_pin( return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/comments/{id}/unpin") @handle_form_exceptions @requires_auth @@ -399,6 +405,7 @@ async def pkgbase_comment_unpin( return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/comments/{id}/delete") @handle_form_exceptions @requires_auth @@ -440,6 +447,7 @@ async def pkgbase_comment_delete( return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/comments/{id}/undelete") @handle_form_exceptions @requires_auth @@ -482,6 +490,7 @@ async def pkgbase_comment_undelete( return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/vote") @handle_form_exceptions @requires_auth @@ -501,6 +510,7 @@ async def pkgbase_vote(request: Request, name: str): return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/unvote") @handle_form_exceptions @requires_auth @@ -519,6 +529,7 @@ async def pkgbase_unvote(request: Request, name: str): return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/notify") @handle_form_exceptions @requires_auth @@ -528,6 +539,7 @@ async def pkgbase_notify(request: Request, name: str): return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/unnotify") @handle_form_exceptions @requires_auth @@ -537,6 +549,7 @@ async def pkgbase_unnotify(request: Request, name: str): return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/unflag") @handle_form_exceptions @requires_auth @@ -567,6 +580,7 @@ async def pkgbase_disown_get( return render_template(request, "pkgbase/disown.html", context) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/disown") @handle_form_exceptions @requires_auth @@ -617,6 +631,7 @@ async def pkgbase_disown_post( return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/adopt") @handle_form_exceptions @requires_auth @@ -659,6 +674,7 @@ async def pkgbase_comaintainers(request: Request, name: str) -> Response: return render_template(request, "pkgbase/comaintainers.html", context) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/comaintainers") @handle_form_exceptions @requires_auth @@ -715,6 +731,7 @@ async def pkgbase_request( return render_template(request, "pkgbase/request.html", context) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/request") @handle_form_exceptions @requires_auth @@ -817,6 +834,7 @@ async def pkgbase_delete_get( return render_template(request, "pkgbase/delete.html", context) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/delete") @handle_form_exceptions @requires_auth @@ -889,6 +907,7 @@ async def pkgbase_merge_get( ) +@db.async_retry_deadlock @router.post("/pkgbase/{name}/merge") @handle_form_exceptions @requires_auth diff --git a/aurweb/routers/requests.py b/aurweb/routers/requests.py index 51be6d2c..bf86bdcc 100644 --- a/aurweb/routers/requests.py +++ b/aurweb/routers/requests.py @@ -69,6 +69,7 @@ async def request_close(request: Request, id: int): return render_template(request, "requests/close.html", context) +@db.async_retry_deadlock @router.post("/requests/{id}/close") @handle_form_exceptions @requires_auth diff --git a/aurweb/routers/trusted_user.py b/aurweb/routers/trusted_user.py index a84bb6bd..37edb072 100644 --- a/aurweb/routers/trusted_user.py +++ b/aurweb/routers/trusted_user.py @@ -217,6 +217,7 @@ async def trusted_user_proposal(request: Request, proposal: int): return render_proposal(request, context, proposal, voteinfo, voters, vote) +@db.async_retry_deadlock @router.post("/tu/{proposal}") @handle_form_exceptions @requires_auth @@ -267,13 +268,15 @@ async def trusted_user_proposal_post( request, context, proposal, voteinfo, voters, vote, status_code=status_code ) - if decision in {"Yes", "No", "Abstain"}: - # Increment whichever decision was given to us. - setattr(voteinfo, decision, getattr(voteinfo, decision) + 1) - else: - return Response("Invalid 'decision' value.", status_code=HTTPStatus.BAD_REQUEST) - with db.begin(): + if decision in {"Yes", "No", "Abstain"}: + # Increment whichever decision was given to us. + setattr(voteinfo, decision, getattr(voteinfo, decision) + 1) + else: + return Response( + "Invalid 'decision' value.", status_code=HTTPStatus.BAD_REQUEST + ) + vote = db.create(models.TUVote, User=request.user, VoteInfo=voteinfo) context["error"] = "You've already voted for this proposal." @@ -301,6 +304,7 @@ async def trusted_user_addvote( return render_template(request, "addvote.html", context) +@db.async_retry_deadlock @router.post("/addvote") @handle_form_exceptions @requires_auth diff --git a/aurweb/users/update.py b/aurweb/users/update.py index 51f2d2e0..6bd4a295 100644 --- a/aurweb/users/update.py +++ b/aurweb/users/update.py @@ -8,6 +8,7 @@ from aurweb.models.ssh_pub_key import get_fingerprint from aurweb.util import strtobool +@db.retry_deadlock def simple( U: str = str(), E: str = str(), @@ -42,6 +43,7 @@ def simple( user.OwnershipNotify = strtobool(ON) +@db.retry_deadlock def language( L: str = str(), request: Request = None, @@ -55,6 +57,7 @@ def language( context["language"] = L +@db.retry_deadlock def timezone( TZ: str = str(), request: Request = None, @@ -68,6 +71,7 @@ def timezone( context["language"] = TZ +@db.retry_deadlock def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None: if not PK: # If no pubkey is provided, wipe out any pubkeys the user @@ -101,12 +105,14 @@ def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None: ) +@db.retry_deadlock def account_type(T: int = None, user: models.User = None, **kwargs) -> None: if T is not None and (T := int(T)) != user.AccountTypeID: with db.begin(): user.AccountTypeID = T +@db.retry_deadlock def password( P: str = str(), request: Request = None, diff --git a/test/test_db.py b/test/test_db.py index 8ac5607d..22dbdd36 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -5,6 +5,7 @@ import tempfile from unittest import mock import pytest +from sqlalchemy.exc import OperationalError import aurweb.config import aurweb.initdb @@ -226,3 +227,22 @@ def test_name_without_pytest_current_test(): with mock.patch.dict("os.environ", {}, clear=True): dbname = aurweb.db.name() assert dbname == aurweb.config.get("database", "name") + + +def test_retry_deadlock(): + @db.retry_deadlock + def func(): + raise OperationalError("Deadlock found", tuple(), "") + + with pytest.raises(OperationalError): + func() + + +@pytest.mark.asyncio +async def test_async_retry_deadlock(): + @db.async_retry_deadlock + async def func(): + raise OperationalError("Deadlock found", tuple(), "") + + with pytest.raises(OperationalError): + await func()