feat(python): handle RuntimeErrors raised through routes

This gets raised when a client closes a connection before receiving
a valid response; this is not controllable from our side.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2022-01-07 18:21:23 -08:00
parent bf371c447f
commit 9e7ae5904f
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
3 changed files with 58 additions and 11 deletions

View file

@ -1,4 +1,3 @@
import asyncio
import http import http
import os import os
import re import re
@ -21,7 +20,7 @@ import aurweb.config
import aurweb.logging import aurweb.logging
import aurweb.pkgbase.util as pkgbaseutil import aurweb.pkgbase.util as pkgbaseutil
from aurweb import prometheus, util from aurweb import logging, prometheus, util
from aurweb.auth import BasicAuthBackend from aurweb.auth import BasicAuthBackend
from aurweb.db import get_engine, query from aurweb.db import get_engine, query
from aurweb.models import AcceptedTerm, Term from aurweb.models import AcceptedTerm, Term
@ -30,6 +29,8 @@ from aurweb.prometheus import instrumentator
from aurweb.routers import APP_ROUTES from aurweb.routers import APP_ROUTES
from aurweb.templates import make_context, render_template from aurweb.templates import make_context, render_template
logger = logging.get_logger(__name__)
# Setup the FastAPI app. # Setup the FastAPI app.
app = FastAPI() app = FastAPI()
@ -132,9 +133,7 @@ async def add_security_headers(request: Request, call_next: typing.Callable):
RP: Referrer-Policy RP: Referrer-Policy
XFO: X-Frame-Options XFO: X-Frame-Options
""" """
response = asyncio.create_task(call_next(request)) response = await util.error_or_result(call_next, request)
await asyncio.wait({response}, return_when=asyncio.FIRST_COMPLETED)
response = response.result()
# Add CSP header. # Add CSP header.
nonce = request.user.nonce nonce = request.user.nonce
@ -174,9 +173,7 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable):
return RedirectResponse( return RedirectResponse(
"/tos", status_code=int(http.HTTPStatus.SEE_OTHER)) "/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
task = asyncio.create_task(call_next(request)) return await util.error_or_result(call_next, request)
await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED)
return task.result()
@app.middleware("http") @app.middleware("http")
@ -194,6 +191,4 @@ async def id_redirect_middleware(request: Request, call_next: typing.Callable):
path = request.url.path.rstrip('/') path = request.url.path.rstrip('/')
return RedirectResponse(f"{path}/{id}{qs}") return RedirectResponse(f"{path}/{id}{qs}")
task = asyncio.create_task(call_next(request)) return await util.error_or_result(call_next, request)
await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED)
return task.result()

View file

@ -7,6 +7,7 @@ import string
from datetime import datetime from datetime import datetime
from distutils.util import strtobool as _strtobool from distutils.util import strtobool as _strtobool
from http import HTTPStatus
from typing import Any, Callable, Dict, Iterable, Tuple from typing import Any, Callable, Dict, Iterable, Tuple
from urllib.parse import urlencode, urlparse from urllib.parse import urlencode, urlparse
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
@ -15,6 +16,7 @@ import fastapi
import pygit2 import pygit2
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
from fastapi.responses import JSONResponse
from jinja2 import pass_context from jinja2 import pass_context
import aurweb.config import aurweb.config
@ -207,3 +209,25 @@ def git_search(repo: pygit2.Repository, commit_hash: str) -> int:
break break
prefixlen += 1 prefixlen += 1
return prefixlen return prefixlen
async def error_or_result(next: Callable, *args, **kwargs) \
-> fastapi.Response:
"""
Try to return a response from `next`.
If RuntimeError is raised during next(...) execution, return a
500 with the exception's error as a JSONResponse.
:param next: Callable of the next fastapi route callback
:param *args: Variable number of arguments passed to the endpoint
:param **kwargs: Optional kwargs to pass to the endpoint
:return: next(...) retval; if an exc is raised: a 500 response
"""
try:
response = await next(*args, **kwargs)
except RuntimeError as exc:
logger.error(f"RuntimeError: {exc}")
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
return JSONResponse({"error": str(exc)}, status_code=status_code)
return response

View file

@ -1,7 +1,16 @@
import json
from datetime import datetime from datetime import datetime
from http import HTTPStatus
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
import fastapi
import pytest
from fastapi.responses import JSONResponse
from aurweb import filters, util from aurweb import filters, util
from aurweb.testing.requests import Request
def test_timestamp_to_datetime(): def test_timestamp_to_datetime():
@ -57,3 +66,22 @@ def test_git_search_double_commit():
# Locate the shortest prefix length that matches commit_hash. # Locate the shortest prefix length that matches commit_hash.
prefixlen = util.git_search(repo, commit_hash) prefixlen = util.git_search(repo, commit_hash)
assert prefixlen == 13 assert prefixlen == 13
@pytest.mark.asyncio
async def test_error_or_result():
async def route(request: fastapi.Request):
raise RuntimeError("No response returned.")
response = await util.error_or_result(route, Request())
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
data = json.loads(response.body)
assert data.get("error") == "No response returned."
async def good_route(request: fastapi.Request):
return JSONResponse()
response = await util.error_or_result(good_route, Request())
assert response.status_code == HTTPStatus.OK