mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
feat(python): handle RuntimeErrors raised through routes
This gets raised when a client closes a connection before receiving a valid response; this is not controllable from our side. Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
parent
bf371c447f
commit
9e7ae5904f
3 changed files with 58 additions and 11 deletions
|
@ -1,4 +1,3 @@
|
||||||
import asyncio
|
|
||||||
import http
|
import http
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
@ -21,7 +20,7 @@ import aurweb.config
|
||||||
import aurweb.logging
|
import aurweb.logging
|
||||||
import aurweb.pkgbase.util as pkgbaseutil
|
import aurweb.pkgbase.util as pkgbaseutil
|
||||||
|
|
||||||
from aurweb import prometheus, util
|
from aurweb import logging, prometheus, util
|
||||||
from aurweb.auth import BasicAuthBackend
|
from aurweb.auth import BasicAuthBackend
|
||||||
from aurweb.db import get_engine, query
|
from aurweb.db import get_engine, query
|
||||||
from aurweb.models import AcceptedTerm, Term
|
from aurweb.models import AcceptedTerm, Term
|
||||||
|
@ -30,6 +29,8 @@ from aurweb.prometheus import instrumentator
|
||||||
from aurweb.routers import APP_ROUTES
|
from aurweb.routers import APP_ROUTES
|
||||||
from aurweb.templates import make_context, render_template
|
from aurweb.templates import make_context, render_template
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
# Setup the FastAPI app.
|
# Setup the FastAPI app.
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -132,9 +133,7 @@ async def add_security_headers(request: Request, call_next: typing.Callable):
|
||||||
RP: Referrer-Policy
|
RP: Referrer-Policy
|
||||||
XFO: X-Frame-Options
|
XFO: X-Frame-Options
|
||||||
"""
|
"""
|
||||||
response = asyncio.create_task(call_next(request))
|
response = await util.error_or_result(call_next, request)
|
||||||
await asyncio.wait({response}, return_when=asyncio.FIRST_COMPLETED)
|
|
||||||
response = response.result()
|
|
||||||
|
|
||||||
# Add CSP header.
|
# Add CSP header.
|
||||||
nonce = request.user.nonce
|
nonce = request.user.nonce
|
||||||
|
@ -174,9 +173,7 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable):
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
"/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
|
"/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
|
||||||
|
|
||||||
task = asyncio.create_task(call_next(request))
|
return await util.error_or_result(call_next, request)
|
||||||
await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED)
|
|
||||||
return task.result()
|
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
|
@ -194,6 +191,4 @@ async def id_redirect_middleware(request: Request, call_next: typing.Callable):
|
||||||
path = request.url.path.rstrip('/')
|
path = request.url.path.rstrip('/')
|
||||||
return RedirectResponse(f"{path}/{id}{qs}")
|
return RedirectResponse(f"{path}/{id}{qs}")
|
||||||
|
|
||||||
task = asyncio.create_task(call_next(request))
|
return await util.error_or_result(call_next, request)
|
||||||
await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED)
|
|
||||||
return task.result()
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import string
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from distutils.util import strtobool as _strtobool
|
from distutils.util import strtobool as _strtobool
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import Any, Callable, Dict, Iterable, Tuple
|
from typing import Any, Callable, Dict, Iterable, Tuple
|
||||||
from urllib.parse import urlencode, urlparse
|
from urllib.parse import urlencode, urlparse
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
@ -15,6 +16,7 @@ import fastapi
|
||||||
import pygit2
|
import pygit2
|
||||||
|
|
||||||
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from jinja2 import pass_context
|
from jinja2 import pass_context
|
||||||
|
|
||||||
import aurweb.config
|
import aurweb.config
|
||||||
|
@ -207,3 +209,25 @@ def git_search(repo: pygit2.Repository, commit_hash: str) -> int:
|
||||||
break
|
break
|
||||||
prefixlen += 1
|
prefixlen += 1
|
||||||
return prefixlen
|
return prefixlen
|
||||||
|
|
||||||
|
|
||||||
|
async def error_or_result(next: Callable, *args, **kwargs) \
|
||||||
|
-> fastapi.Response:
|
||||||
|
"""
|
||||||
|
Try to return a response from `next`.
|
||||||
|
|
||||||
|
If RuntimeError is raised during next(...) execution, return a
|
||||||
|
500 with the exception's error as a JSONResponse.
|
||||||
|
|
||||||
|
:param next: Callable of the next fastapi route callback
|
||||||
|
:param *args: Variable number of arguments passed to the endpoint
|
||||||
|
:param **kwargs: Optional kwargs to pass to the endpoint
|
||||||
|
:return: next(...) retval; if an exc is raised: a 500 response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await next(*args, **kwargs)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
logger.error(f"RuntimeError: {exc}")
|
||||||
|
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
return JSONResponse({"error": str(exc)}, status_code=status_code)
|
||||||
|
return response
|
||||||
|
|
|
@ -1,7 +1,16 @@
|
||||||
|
import json
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from http import HTTPStatus
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from aurweb import filters, util
|
from aurweb import filters, util
|
||||||
|
from aurweb.testing.requests import Request
|
||||||
|
|
||||||
|
|
||||||
def test_timestamp_to_datetime():
|
def test_timestamp_to_datetime():
|
||||||
|
@ -57,3 +66,22 @@ def test_git_search_double_commit():
|
||||||
# Locate the shortest prefix length that matches commit_hash.
|
# Locate the shortest prefix length that matches commit_hash.
|
||||||
prefixlen = util.git_search(repo, commit_hash)
|
prefixlen = util.git_search(repo, commit_hash)
|
||||||
assert prefixlen == 13
|
assert prefixlen == 13
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_or_result():
|
||||||
|
|
||||||
|
async def route(request: fastapi.Request):
|
||||||
|
raise RuntimeError("No response returned.")
|
||||||
|
|
||||||
|
response = await util.error_or_result(route, Request())
|
||||||
|
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
|
||||||
|
data = json.loads(response.body)
|
||||||
|
assert data.get("error") == "No response returned."
|
||||||
|
|
||||||
|
async def good_route(request: fastapi.Request):
|
||||||
|
return JSONResponse()
|
||||||
|
|
||||||
|
response = await util.error_or_result(good_route, Request())
|
||||||
|
assert response.status_code == HTTPStatus.OK
|
||||||
|
|
Loading…
Add table
Reference in a new issue