fix: retry transactions who fail due to deadlocks

In my opinion, this kind of handling of transactions is pretty ugly.
The being said, we have issues with running into deadlocks on aur.al,
so this commit works against that immediate bug.

An ideal solution would be to deal with retrying transactions through
the `db.begin()` scope, so we wouldn't have to explicitly annotate
functions as "retry functions," which is what this commit does.

Closes #376

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2022-09-13 12:47:52 -07:00
parent f450b5dfc7
commit ec3152014b
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
16 changed files with 241 additions and 82 deletions

View file

@ -96,6 +96,7 @@ class AnonymousUser:
class BasicAuthBackend(AuthenticationBackend): class BasicAuthBackend(AuthenticationBackend):
@db.async_retry_deadlock
async def authenticate(self, conn: HTTPConnection): async def authenticate(self, conn: HTTPConnection):
unauthenticated = (None, AnonymousUser()) unauthenticated = (None, AnonymousUser())
sid = conn.cookies.get("AURSID") sid = conn.cookies.get("AURSID")
@ -122,7 +123,6 @@ class BasicAuthBackend(AuthenticationBackend):
# At this point, we cannot have an invalid user if the record # At this point, we cannot have an invalid user if the record
# exists, due to ForeignKey constraints in the schema upheld # exists, due to ForeignKey constraints in the schema upheld
# by mysqlclient. # by mysqlclient.
with db.begin():
user = db.query(User).filter(User.ID == record.UsersID).first() user = db.query(User).filter(User.ID == record.UsersID).first()
user.nonce = util.make_nonce() user.nonce = util.make_nonce()
user.authenticated = True user.authenticated = True

View file

@ -161,6 +161,46 @@ def begin():
return get_session().begin() return get_session().begin()
def retry_deadlock(func):
from sqlalchemy.exc import OperationalError
def wrapper(*args, _i: int = 0, **kwargs):
# Retry 10 times, then raise the exception
# If we fail before the 10th, recurse into `wrapper`
# If we fail on the 10th, continue to throw the exception
limit = 10
try:
return func(*args, **kwargs)
except OperationalError as exc:
if _i < limit and "Deadlock found" in str(exc):
# Retry on deadlock by recursing into `wrapper`
return wrapper(*args, _i=_i + 1, **kwargs)
# Otherwise, just raise the exception
raise exc
return wrapper
def async_retry_deadlock(func):
from sqlalchemy.exc import OperationalError
async def wrapper(*args, _i: int = 0, **kwargs):
# Retry 10 times, then raise the exception
# If we fail before the 10th, recurse into `wrapper`
# If we fail on the 10th, continue to throw the exception
limit = 10
try:
return await func(*args, **kwargs)
except OperationalError as exc:
if _i < limit and "Deadlock found" in str(exc):
# Retry on deadlock by recursing into `wrapper`
return await wrapper(*args, _i=_i + 1, **kwargs)
# Otherwise, just raise the exception
raise exc
return wrapper
def get_sqlalchemy_url(): def get_sqlalchemy_url():
""" """
Build an SQLAlchemy URL for use with create_engine. Build an SQLAlchemy URL for use with create_engine.

View file

@ -151,7 +151,7 @@ class User(Base):
return has_credential(self, credential, approved) return has_credential(self, credential, approved)
def logout(self, request: Request): def logout(self, request: Request) -> None:
self.authenticated = False self.authenticated = False
if self.session: if self.session:
with db.begin(): with db.begin():

View file

@ -151,6 +151,7 @@ def close_pkgreq(
pkgreq.ClosedTS = now pkgreq.ClosedTS = now
@db.retry_deadlock
def handle_request( def handle_request(
request: Request, reqtype_id: int, pkgbase: PackageBase, target: PackageBase = None request: Request, reqtype_id: int, pkgbase: PackageBase, target: PackageBase = None
) -> list[notify.Notification]: ) -> list[notify.Notification]:
@ -239,6 +240,8 @@ def handle_request(
to_accept.append(pkgreq) to_accept.append(pkgreq)
# Update requests with their new status and closures. # Update requests with their new status and closures.
@db.retry_deadlock
def retry_closures():
with db.begin(): with db.begin():
util.apply_all( util.apply_all(
to_accept, to_accept,
@ -249,6 +252,8 @@ def handle_request(
lambda p: close_pkgreq(p, request.user, pkgbase, target, REJECTED_ID), lambda p: close_pkgreq(p, request.user, pkgbase, target, REJECTED_ID),
) )
retry_closures()
# Create RequestCloseNotifications for all requests involved. # Create RequestCloseNotifications for all requests involved.
for pkgreq in to_accept + to_reject: for pkgreq in to_accept + to_reject:
notif = notify.RequestCloseNotification( notif = notify.RequestCloseNotification(

View file

@ -99,7 +99,6 @@ def get_pkg_or_base(
:raises HTTPException: With status code 404 if record doesn't exist :raises HTTPException: With status code 404 if record doesn't exist
:return: {Package,PackageBase} instance :return: {Package,PackageBase} instance
""" """
with db.begin():
instance = db.query(cls).filter(cls.Name == name).first() instance = db.query(cls).filter(cls.Name == name).first()
if not instance: if not instance:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
@ -133,7 +132,6 @@ def updated_packages(limit: int = 0, cache_ttl: int = 600) -> list[models.Packag
# If we already have a cache, deserialize it and return. # If we already have a cache, deserialize it and return.
return orjson.loads(packages) return orjson.loads(packages)
with db.begin():
query = ( query = (
db.query(models.Package) db.query(models.Package)
.join(models.PackageBase) .join(models.PackageBase)

View file

@ -2,7 +2,7 @@ from fastapi import Request
from aurweb import db, logging, util from aurweb import db, logging, util
from aurweb.auth import creds from aurweb.auth import creds
from aurweb.models import PackageBase from aurweb.models import PackageBase, User
from aurweb.models.package_comaintainer import PackageComaintainer from aurweb.models.package_comaintainer import PackageComaintainer
from aurweb.models.package_notification import PackageNotification from aurweb.models.package_notification import PackageNotification
from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID
@ -13,6 +13,12 @@ from aurweb.scripts import notify, popupdate
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@db.retry_deadlock
def _retry_notify(user: User, pkgbase: PackageBase) -> None:
with db.begin():
db.create(PackageNotification, PackageBase=pkgbase, User=user)
def pkgbase_notify_instance(request: Request, pkgbase: PackageBase) -> None: def pkgbase_notify_instance(request: Request, pkgbase: PackageBase) -> None:
notif = db.query( notif = db.query(
pkgbase.notifications.filter( pkgbase.notifications.filter(
@ -21,8 +27,13 @@ def pkgbase_notify_instance(request: Request, pkgbase: PackageBase) -> None:
).scalar() ).scalar()
has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY) has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY)
if has_cred and not notif: if has_cred and not notif:
_retry_notify(request.user, pkgbase)
@db.retry_deadlock
def _retry_unnotify(notif: PackageNotification, pkgbase: PackageBase) -> None:
with db.begin(): with db.begin():
db.create(PackageNotification, PackageBase=pkgbase, User=request.user) db.delete(notif)
def pkgbase_unnotify_instance(request: Request, pkgbase: PackageBase) -> None: def pkgbase_unnotify_instance(request: Request, pkgbase: PackageBase) -> None:
@ -31,8 +42,15 @@ def pkgbase_unnotify_instance(request: Request, pkgbase: PackageBase) -> None:
).first() ).first()
has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY) has_cred = request.user.has_credential(creds.PKGBASE_NOTIFY)
if has_cred and notif: if has_cred and notif:
_retry_unnotify(notif, pkgbase)
@db.retry_deadlock
def _retry_unflag(pkgbase: PackageBase) -> None:
with db.begin(): with db.begin():
db.delete(notif) pkgbase.OutOfDateTS = None
pkgbase.Flagger = None
pkgbase.FlaggerComment = str()
def pkgbase_unflag_instance(request: Request, pkgbase: PackageBase) -> None: def pkgbase_unflag_instance(request: Request, pkgbase: PackageBase) -> None:
@ -42,20 +60,17 @@ def pkgbase_unflag_instance(request: Request, pkgbase: PackageBase) -> None:
+ [c.User for c in pkgbase.comaintainers], + [c.User for c in pkgbase.comaintainers],
) )
if has_cred: if has_cred:
with db.begin(): _retry_unflag(pkgbase)
pkgbase.OutOfDateTS = None
pkgbase.Flagger = None
pkgbase.FlaggerComment = str()
def pkgbase_disown_instance(request: Request, pkgbase: PackageBase) -> None: @db.retry_deadlock
disowner = request.user def _retry_disown(request: Request, pkgbase: PackageBase):
notifs = [notify.DisownNotification(disowner.ID, pkgbase.ID)] notifs: list[notify.Notification] = []
is_maint = disowner == pkgbase.Maintainer is_maint = request.user == pkgbase.Maintainer
comaint = pkgbase.comaintainers.filter( comaint = pkgbase.comaintainers.filter(
PackageComaintainer.User == disowner PackageComaintainer.User == request.user
).one_or_none() ).one_or_none()
is_comaint = comaint is not None is_comaint = comaint is not None
@ -85,38 +100,48 @@ def pkgbase_disown_instance(request: Request, pkgbase: PackageBase) -> None:
pkgbase.Maintainer = None pkgbase.Maintainer = None
db.delete_all(pkgbase.comaintainers) db.delete_all(pkgbase.comaintainers)
return notifs
def pkgbase_disown_instance(request: Request, pkgbase: PackageBase) -> None:
disowner = request.user
notifs = [notify.DisownNotification(disowner.ID, pkgbase.ID)]
notifs += _retry_disown(request, pkgbase)
util.apply_all(notifs, lambda n: n.send()) util.apply_all(notifs, lambda n: n.send())
def pkgbase_adopt_instance(request: Request, pkgbase: PackageBase) -> None: @db.retry_deadlock
def _retry_adopt(request: Request, pkgbase: PackageBase) -> None:
with db.begin(): with db.begin():
pkgbase.Maintainer = request.user pkgbase.Maintainer = request.user
def pkgbase_adopt_instance(request: Request, pkgbase: PackageBase) -> None:
_retry_adopt(request, pkgbase)
notif = notify.AdoptNotification(request.user.ID, pkgbase.ID) notif = notify.AdoptNotification(request.user.ID, pkgbase.ID)
notif.send() notif.send()
@db.retry_deadlock
def _retry_delete(pkgbase: PackageBase, comments: str) -> None:
with db.begin():
update_closure_comment(pkgbase, DELETION_ID, comments)
db.delete(pkgbase)
def pkgbase_delete_instance( def pkgbase_delete_instance(
request: Request, pkgbase: PackageBase, comments: str = str() request: Request, pkgbase: PackageBase, comments: str = str()
) -> list[notify.Notification]: ) -> list[notify.Notification]:
notif = notify.DeleteNotification(request.user.ID, pkgbase.ID) notif = notify.DeleteNotification(request.user.ID, pkgbase.ID)
notifs = handle_request(request, DELETION_ID, pkgbase) + [notif] notifs = handle_request(request, DELETION_ID, pkgbase) + [notif]
with db.begin(): _retry_delete(pkgbase, comments)
update_closure_comment(pkgbase, DELETION_ID, comments)
db.delete(pkgbase)
return notifs return notifs
def pkgbase_merge_instance( @db.retry_deadlock
request: Request, pkgbase: PackageBase, target: PackageBase, comments: str = str() def _retry_merge(pkgbase: PackageBase, target: PackageBase) -> None:
) -> None:
pkgbasename = str(pkgbase.Name)
# Create notifications.
notifs = handle_request(request, MERGE_ID, pkgbase, target)
# Target votes and notifications sets of user IDs that are # Target votes and notifications sets of user IDs that are
# looking to be migrated. # looking to be migrated.
target_votes = set(v.UsersID for v in target.package_votes) target_votes = set(v.UsersID for v in target.package_votes)
@ -146,6 +171,20 @@ def pkgbase_merge_instance(
db.delete(pkg) db.delete(pkg)
db.delete(pkgbase) db.delete(pkgbase)
def pkgbase_merge_instance(
request: Request,
pkgbase: PackageBase,
target: PackageBase,
comments: str = str(),
) -> None:
pkgbasename = str(pkgbase.Name)
# Create notifications.
notifs = handle_request(request, MERGE_ID, pkgbase, target)
_retry_merge(pkgbase, target)
# Log this out for accountability purposes. # Log this out for accountability purposes.
logger.info( logger.info(
f"Trusted User '{request.user.Username}' merged " f"Trusted User '{request.user.Username}' merged "

View file

@ -106,6 +106,7 @@ def remove_comaintainer(
return notif return notif
@db.retry_deadlock
def remove_comaintainers(pkgbase: PackageBase, usernames: list[str]) -> None: def remove_comaintainers(pkgbase: PackageBase, usernames: list[str]) -> None:
""" """
Remove comaintainers from `pkgbase`. Remove comaintainers from `pkgbase`.
@ -155,6 +156,7 @@ class NoopComaintainerNotification:
return return
@db.retry_deadlock
def add_comaintainer( def add_comaintainer(
pkgbase: PackageBase, comaintainer: User pkgbase: PackageBase, comaintainer: User
) -> notify.ComaintainerAddNotification: ) -> notify.ComaintainerAddNotification:

View file

@ -38,17 +38,26 @@ def _update_ratelimit_db(request: Request):
now = time.utcnow() now = time.utcnow()
time_to_delete = now - window_length time_to_delete = now - window_length
records = db.query(ApiRateLimit).filter(ApiRateLimit.WindowStart < time_to_delete) @db.retry_deadlock
def retry_delete(records: list[ApiRateLimit]) -> None:
with db.begin(): with db.begin():
db.delete_all(records) db.delete_all(records)
host = request.client.host records = db.query(ApiRateLimit).filter(ApiRateLimit.WindowStart < time_to_delete)
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first() retry_delete(records)
@db.retry_deadlock
def retry_create(record: ApiRateLimit, now: int, host: str) -> ApiRateLimit:
with db.begin(): with db.begin():
if not record: if not record:
record = db.create(ApiRateLimit, WindowStart=now, IP=host, Requests=1) record = db.create(ApiRateLimit, WindowStart=now, IP=host, Requests=1)
else: else:
record.Requests += 1 record.Requests += 1
return record
host = request.client.host
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
record = retry_create(record, now, host)
logger.debug(record.Requests) logger.debug(record.Requests)
return record return record

View file

@ -32,6 +32,7 @@ async def passreset(request: Request):
return render_template(request, "passreset.html", context) return render_template(request, "passreset.html", context)
@db.async_retry_deadlock
@router.post("/passreset", response_class=HTMLResponse) @router.post("/passreset", response_class=HTMLResponse)
@handle_form_exceptions @handle_form_exceptions
@requires_guest @requires_guest
@ -260,6 +261,7 @@ async def account_register(
return render_template(request, "register.html", context) return render_template(request, "register.html", context)
@db.async_retry_deadlock
@router.post("/register", response_class=HTMLResponse) @router.post("/register", response_class=HTMLResponse)
@handle_form_exceptions @handle_form_exceptions
@requires_guest @requires_guest
@ -344,10 +346,7 @@ async def account_register_post(
for k in keys: for k in keys:
pk = " ".join(k) pk = " ".join(k)
fprint = get_fingerprint(pk) fprint = get_fingerprint(pk)
with db.begin(): db.create(models.SSHPubKey, User=user, PubKey=pk, Fingerprint=fprint)
db.create(
models.SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fprint
)
# Send a reset key notification to the new user. # Send a reset key notification to the new user.
WelcomeNotification(user.ID).send() WelcomeNotification(user.ID).send()
@ -458,6 +457,8 @@ async def account_edit_post(
update.password, update.password,
] ]
# These update functions are all guarded by retry_deadlock;
# there's no need to guard this route itself.
for f in updates: for f in updates:
f(**args, request=request, user=user, context=context) f(**args, request=request, user=user, context=context)
@ -633,6 +634,7 @@ async def terms_of_service(request: Request):
return render_terms_of_service(request, context, accept_needed) return render_terms_of_service(request, context, accept_needed)
@db.async_retry_deadlock
@router.post("/tos") @router.post("/tos")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth

View file

@ -28,6 +28,11 @@ async def login_get(request: Request, next: str = "/"):
return await login_template(request, next) return await login_template(request, next)
@db.retry_deadlock
def _retry_login(request: Request, user: User, passwd: str, cookie_timeout: int) -> str:
return user.login(request, passwd, cookie_timeout)
@router.post("/login", response_class=HTMLResponse) @router.post("/login", response_class=HTMLResponse)
@handle_form_exceptions @handle_form_exceptions
@requires_guest @requires_guest
@ -48,13 +53,16 @@ async def login_post(
status_code=HTTPStatus.BAD_REQUEST, detail=_("Bad Referer header.") status_code=HTTPStatus.BAD_REQUEST, detail=_("Bad Referer header.")
) )
with db.begin():
user = ( user = (
db.query(User) db.query(User)
.filter(or_(User.Username == user, User.Email == user)) .filter(
or_(
User.Username == user,
User.Email == user,
)
)
.first() .first()
) )
if not user: if not user:
return await login_template(request, next, errors=["Bad username or password."]) return await login_template(request, next, errors=["Bad username or password."])
@ -62,7 +70,7 @@ async def login_post(
return await login_template(request, next, errors=["Account Suspended"]) return await login_template(request, next, errors=["Account Suspended"])
cookie_timeout = cookies.timeout(remember_me) cookie_timeout = cookies.timeout(remember_me)
sid = user.login(request, passwd, cookie_timeout) sid = _retry_login(request, user, passwd, cookie_timeout)
if not sid: if not sid:
return await login_template(request, next, errors=["Bad username or password."]) return await login_template(request, next, errors=["Bad username or password."])
@ -101,12 +109,17 @@ async def login_post(
return response return response
@db.retry_deadlock
def _retry_logout(request: Request) -> None:
request.user.logout(request)
@router.post("/logout") @router.post("/logout")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
async def logout(request: Request, next: str = Form(default="/")): async def logout(request: Request, next: str = Form(default="/")):
if request.user.is_authenticated(): if request.user.is_authenticated():
request.user.logout(request) _retry_logout(request)
# Use 303 since we may be handling a post request, that'll get it # Use 303 since we may be handling a post request, that'll get it
# to redirect to a get request. # to redirect to a get request.

View file

@ -35,6 +35,7 @@ async def favicon(request: Request):
return RedirectResponse("/static/images/favicon.ico") return RedirectResponse("/static/images/favicon.ico")
@db.async_retry_deadlock
@router.post("/language", response_class=RedirectResponse) @router.post("/language", response_class=RedirectResponse)
@handle_form_exceptions @handle_form_exceptions
async def language( async def language(

View file

@ -87,6 +87,7 @@ async def pkgbase_flag_comment(request: Request, name: str):
return render_template(request, "pkgbase/flag-comment.html", context) return render_template(request, "pkgbase/flag-comment.html", context)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/keywords") @router.post("/pkgbase/{name}/keywords")
@handle_form_exceptions @handle_form_exceptions
async def pkgbase_keywords( async def pkgbase_keywords(
@ -139,6 +140,7 @@ async def pkgbase_flag_get(request: Request, name: str):
return render_template(request, "pkgbase/flag.html", context) return render_template(request, "pkgbase/flag.html", context)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/flag") @router.post("/pkgbase/{name}/flag")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -170,6 +172,7 @@ async def pkgbase_flag_post(
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/comments") @router.post("/pkgbase/{name}/comments")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -279,6 +282,7 @@ async def pkgbase_comment_edit(
return render_template(request, "pkgbase/comments/edit.html", context) return render_template(request, "pkgbase/comments/edit.html", context)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/comments/{id}") @router.post("/pkgbase/{name}/comments/{id}")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -324,6 +328,7 @@ async def pkgbase_comment_post(
) )
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/comments/{id}/pin") @router.post("/pkgbase/{name}/comments/{id}/pin")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -362,6 +367,7 @@ async def pkgbase_comment_pin(
return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER) return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/comments/{id}/unpin") @router.post("/pkgbase/{name}/comments/{id}/unpin")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -399,6 +405,7 @@ async def pkgbase_comment_unpin(
return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER) return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/comments/{id}/delete") @router.post("/pkgbase/{name}/comments/{id}/delete")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -440,6 +447,7 @@ async def pkgbase_comment_delete(
return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER) return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/comments/{id}/undelete") @router.post("/pkgbase/{name}/comments/{id}/undelete")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -482,6 +490,7 @@ async def pkgbase_comment_undelete(
return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER) return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/vote") @router.post("/pkgbase/{name}/vote")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -501,6 +510,7 @@ async def pkgbase_vote(request: Request, name: str):
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/unvote") @router.post("/pkgbase/{name}/unvote")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -519,6 +529,7 @@ async def pkgbase_unvote(request: Request, name: str):
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/notify") @router.post("/pkgbase/{name}/notify")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -528,6 +539,7 @@ async def pkgbase_notify(request: Request, name: str):
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/unnotify") @router.post("/pkgbase/{name}/unnotify")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -537,6 +549,7 @@ async def pkgbase_unnotify(request: Request, name: str):
return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse(f"/pkgbase/{name}", status_code=HTTPStatus.SEE_OTHER)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/unflag") @router.post("/pkgbase/{name}/unflag")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -567,6 +580,7 @@ async def pkgbase_disown_get(
return render_template(request, "pkgbase/disown.html", context) return render_template(request, "pkgbase/disown.html", context)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/disown") @router.post("/pkgbase/{name}/disown")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -617,6 +631,7 @@ async def pkgbase_disown_post(
return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER) return RedirectResponse(next, status_code=HTTPStatus.SEE_OTHER)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/adopt") @router.post("/pkgbase/{name}/adopt")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -659,6 +674,7 @@ async def pkgbase_comaintainers(request: Request, name: str) -> Response:
return render_template(request, "pkgbase/comaintainers.html", context) return render_template(request, "pkgbase/comaintainers.html", context)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/comaintainers") @router.post("/pkgbase/{name}/comaintainers")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -715,6 +731,7 @@ async def pkgbase_request(
return render_template(request, "pkgbase/request.html", context) return render_template(request, "pkgbase/request.html", context)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/request") @router.post("/pkgbase/{name}/request")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -817,6 +834,7 @@ async def pkgbase_delete_get(
return render_template(request, "pkgbase/delete.html", context) return render_template(request, "pkgbase/delete.html", context)
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/delete") @router.post("/pkgbase/{name}/delete")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -889,6 +907,7 @@ async def pkgbase_merge_get(
) )
@db.async_retry_deadlock
@router.post("/pkgbase/{name}/merge") @router.post("/pkgbase/{name}/merge")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth

View file

@ -69,6 +69,7 @@ async def request_close(request: Request, id: int):
return render_template(request, "requests/close.html", context) return render_template(request, "requests/close.html", context)
@db.async_retry_deadlock
@router.post("/requests/{id}/close") @router.post("/requests/{id}/close")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth

View file

@ -217,6 +217,7 @@ async def trusted_user_proposal(request: Request, proposal: int):
return render_proposal(request, context, proposal, voteinfo, voters, vote) return render_proposal(request, context, proposal, voteinfo, voters, vote)
@db.async_retry_deadlock
@router.post("/tu/{proposal}") @router.post("/tu/{proposal}")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth
@ -267,13 +268,15 @@ async def trusted_user_proposal_post(
request, context, proposal, voteinfo, voters, vote, status_code=status_code request, context, proposal, voteinfo, voters, vote, status_code=status_code
) )
with db.begin():
if decision in {"Yes", "No", "Abstain"}: if decision in {"Yes", "No", "Abstain"}:
# Increment whichever decision was given to us. # Increment whichever decision was given to us.
setattr(voteinfo, decision, getattr(voteinfo, decision) + 1) setattr(voteinfo, decision, getattr(voteinfo, decision) + 1)
else: else:
return Response("Invalid 'decision' value.", status_code=HTTPStatus.BAD_REQUEST) return Response(
"Invalid 'decision' value.", status_code=HTTPStatus.BAD_REQUEST
)
with db.begin():
vote = db.create(models.TUVote, User=request.user, VoteInfo=voteinfo) vote = db.create(models.TUVote, User=request.user, VoteInfo=voteinfo)
context["error"] = "You've already voted for this proposal." context["error"] = "You've already voted for this proposal."
@ -301,6 +304,7 @@ async def trusted_user_addvote(
return render_template(request, "addvote.html", context) return render_template(request, "addvote.html", context)
@db.async_retry_deadlock
@router.post("/addvote") @router.post("/addvote")
@handle_form_exceptions @handle_form_exceptions
@requires_auth @requires_auth

View file

@ -8,6 +8,7 @@ from aurweb.models.ssh_pub_key import get_fingerprint
from aurweb.util import strtobool from aurweb.util import strtobool
@db.retry_deadlock
def simple( def simple(
U: str = str(), U: str = str(),
E: str = str(), E: str = str(),
@ -42,6 +43,7 @@ def simple(
user.OwnershipNotify = strtobool(ON) user.OwnershipNotify = strtobool(ON)
@db.retry_deadlock
def language( def language(
L: str = str(), L: str = str(),
request: Request = None, request: Request = None,
@ -55,6 +57,7 @@ def language(
context["language"] = L context["language"] = L
@db.retry_deadlock
def timezone( def timezone(
TZ: str = str(), TZ: str = str(),
request: Request = None, request: Request = None,
@ -68,6 +71,7 @@ def timezone(
context["language"] = TZ context["language"] = TZ
@db.retry_deadlock
def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None: def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None:
if not PK: if not PK:
# If no pubkey is provided, wipe out any pubkeys the user # If no pubkey is provided, wipe out any pubkeys the user
@ -101,12 +105,14 @@ def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None:
) )
@db.retry_deadlock
def account_type(T: int = None, user: models.User = None, **kwargs) -> None: def account_type(T: int = None, user: models.User = None, **kwargs) -> None:
if T is not None and (T := int(T)) != user.AccountTypeID: if T is not None and (T := int(T)) != user.AccountTypeID:
with db.begin(): with db.begin():
user.AccountTypeID = T user.AccountTypeID = T
@db.retry_deadlock
def password( def password(
P: str = str(), P: str = str(),
request: Request = None, request: Request = None,

View file

@ -5,6 +5,7 @@ import tempfile
from unittest import mock from unittest import mock
import pytest import pytest
from sqlalchemy.exc import OperationalError
import aurweb.config import aurweb.config
import aurweb.initdb import aurweb.initdb
@ -226,3 +227,22 @@ def test_name_without_pytest_current_test():
with mock.patch.dict("os.environ", {}, clear=True): with mock.patch.dict("os.environ", {}, clear=True):
dbname = aurweb.db.name() dbname = aurweb.db.name()
assert dbname == aurweb.config.get("database", "name") assert dbname == aurweb.config.get("database", "name")
def test_retry_deadlock():
@db.retry_deadlock
def func():
raise OperationalError("Deadlock found", tuple(), "")
with pytest.raises(OperationalError):
func()
@pytest.mark.asyncio
async def test_async_retry_deadlock():
@db.async_retry_deadlock
async def func():
raise OperationalError("Deadlock found", tuple(), "")
with pytest.raises(OperationalError):
await func()