feat(python): handle RuntimeErrors raised through routes

This gets raised when a client closes a connection before receiving
a valid response; this is not controllable from our side.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2022-01-07 18:21:23 -08:00
parent bf371c447f
commit 9e7ae5904f
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
3 changed files with 58 additions and 11 deletions

View file

@ -1,4 +1,3 @@
import asyncio
import http
import os
import re
@ -21,7 +20,7 @@ import aurweb.config
import aurweb.logging
import aurweb.pkgbase.util as pkgbaseutil
from aurweb import prometheus, util
from aurweb import logging, prometheus, util
from aurweb.auth import BasicAuthBackend
from aurweb.db import get_engine, query
from aurweb.models import AcceptedTerm, Term
@ -30,6 +29,8 @@ from aurweb.prometheus import instrumentator
from aurweb.routers import APP_ROUTES
from aurweb.templates import make_context, render_template
logger = logging.get_logger(__name__)
# Setup the FastAPI app.
app = FastAPI()
@ -132,9 +133,7 @@ async def add_security_headers(request: Request, call_next: typing.Callable):
RP: Referrer-Policy
XFO: X-Frame-Options
"""
response = asyncio.create_task(call_next(request))
await asyncio.wait({response}, return_when=asyncio.FIRST_COMPLETED)
response = response.result()
response = await util.error_or_result(call_next, request)
# Add CSP header.
nonce = request.user.nonce
@ -174,9 +173,7 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable):
return RedirectResponse(
"/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
task = asyncio.create_task(call_next(request))
await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED)
return task.result()
return await util.error_or_result(call_next, request)
@app.middleware("http")
@ -194,6 +191,4 @@ async def id_redirect_middleware(request: Request, call_next: typing.Callable):
path = request.url.path.rstrip('/')
return RedirectResponse(f"{path}/{id}{qs}")
task = asyncio.create_task(call_next(request))
await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED)
return task.result()
return await util.error_or_result(call_next, request)

View file

@ -7,6 +7,7 @@ import string
from datetime import datetime
from distutils.util import strtobool as _strtobool
from http import HTTPStatus
from typing import Any, Callable, Dict, Iterable, Tuple
from urllib.parse import urlencode, urlparse
from zoneinfo import ZoneInfo
@ -15,6 +16,7 @@ import fastapi
import pygit2
from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email
from fastapi.responses import JSONResponse
from jinja2 import pass_context
import aurweb.config
@ -207,3 +209,25 @@ def git_search(repo: pygit2.Repository, commit_hash: str) -> int:
break
prefixlen += 1
return prefixlen
async def error_or_result(next: Callable, *args, **kwargs) \
-> fastapi.Response:
"""
Try to return a response from `next`.
If RuntimeError is raised during next(...) execution, return a
500 with the exception's error as a JSONResponse.
:param next: Callable of the next fastapi route callback
:param *args: Variable number of arguments passed to the endpoint
:param **kwargs: Optional kwargs to pass to the endpoint
:return: next(...) retval; if an exc is raised: a 500 response
"""
try:
response = await next(*args, **kwargs)
except RuntimeError as exc:
logger.error(f"RuntimeError: {exc}")
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
return JSONResponse({"error": str(exc)}, status_code=status_code)
return response

View file

@ -1,7 +1,16 @@
import json
from datetime import datetime
from http import HTTPStatus
from zoneinfo import ZoneInfo
import fastapi
import pytest
from fastapi.responses import JSONResponse
from aurweb import filters, util
from aurweb.testing.requests import Request
def test_timestamp_to_datetime():
@ -57,3 +66,22 @@ def test_git_search_double_commit():
# Locate the shortest prefix length that matches commit_hash.
prefixlen = util.git_search(repo, commit_hash)
assert prefixlen == 13
@pytest.mark.asyncio
async def test_error_or_result():
async def route(request: fastapi.Request):
raise RuntimeError("No response returned.")
response = await util.error_or_result(route, Request())
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
data = json.loads(response.body)
assert data.get("error") == "No response returned."
async def good_route(request: fastapi.Request):
return JSONResponse()
response = await util.error_or_result(good_route, Request())
assert response.status_code == HTTPStatus.OK