From 9e7ae5904f45fd22bed6d871256080458fde966c Mon Sep 17 00:00:00 2001 From: Kevin Morris Date: Fri, 7 Jan 2022 18:21:23 -0800 Subject: [PATCH] feat(python): handle RuntimeErrors raised through routes This gets raised when a client closes a connection before receiving a valid response; this is not controllable from our side. Signed-off-by: Kevin Morris --- aurweb/asgi.py | 17 ++++++----------- aurweb/util.py | 24 ++++++++++++++++++++++++ test/test_util.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 11 deletions(-) diff --git a/aurweb/asgi.py b/aurweb/asgi.py index b55eada1..cbb84f8a 100644 --- a/aurweb/asgi.py +++ b/aurweb/asgi.py @@ -1,4 +1,3 @@ -import asyncio import http import os import re @@ -21,7 +20,7 @@ import aurweb.config import aurweb.logging import aurweb.pkgbase.util as pkgbaseutil -from aurweb import prometheus, util +from aurweb import logging, prometheus, util from aurweb.auth import BasicAuthBackend from aurweb.db import get_engine, query from aurweb.models import AcceptedTerm, Term @@ -30,6 +29,8 @@ from aurweb.prometheus import instrumentator from aurweb.routers import APP_ROUTES from aurweb.templates import make_context, render_template +logger = logging.get_logger(__name__) + # Setup the FastAPI app. app = FastAPI() @@ -132,9 +133,7 @@ async def add_security_headers(request: Request, call_next: typing.Callable): RP: Referrer-Policy XFO: X-Frame-Options """ - response = asyncio.create_task(call_next(request)) - await asyncio.wait({response}, return_when=asyncio.FIRST_COMPLETED) - response = response.result() + response = await util.error_or_result(call_next, request) # Add CSP header. nonce = request.user.nonce @@ -174,9 +173,7 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable): return RedirectResponse( "/tos", status_code=int(http.HTTPStatus.SEE_OTHER)) - task = asyncio.create_task(call_next(request)) - await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED) - return task.result() + return await util.error_or_result(call_next, request) @app.middleware("http") @@ -194,6 +191,4 @@ async def id_redirect_middleware(request: Request, call_next: typing.Callable): path = request.url.path.rstrip('/') return RedirectResponse(f"{path}/{id}{qs}") - task = asyncio.create_task(call_next(request)) - await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED) - return task.result() + return await util.error_or_result(call_next, request) diff --git a/aurweb/util.py b/aurweb/util.py index 10a7953b..0eb2671e 100644 --- a/aurweb/util.py +++ b/aurweb/util.py @@ -7,6 +7,7 @@ import string from datetime import datetime from distutils.util import strtobool as _strtobool +from http import HTTPStatus from typing import Any, Callable, Dict, Iterable, Tuple from urllib.parse import urlencode, urlparse from zoneinfo import ZoneInfo @@ -15,6 +16,7 @@ import fastapi import pygit2 from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email +from fastapi.responses import JSONResponse from jinja2 import pass_context import aurweb.config @@ -207,3 +209,25 @@ def git_search(repo: pygit2.Repository, commit_hash: str) -> int: break prefixlen += 1 return prefixlen + + +async def error_or_result(next: Callable, *args, **kwargs) \ + -> fastapi.Response: + """ + Try to return a response from `next`. + + If RuntimeError is raised during next(...) execution, return a + 500 with the exception's error as a JSONResponse. + + :param next: Callable of the next fastapi route callback + :param *args: Variable number of arguments passed to the endpoint + :param **kwargs: Optional kwargs to pass to the endpoint + :return: next(...) retval; if an exc is raised: a 500 response + """ + try: + response = await next(*args, **kwargs) + except RuntimeError as exc: + logger.error(f"RuntimeError: {exc}") + status_code = HTTPStatus.INTERNAL_SERVER_ERROR + return JSONResponse({"error": str(exc)}, status_code=status_code) + return response diff --git a/test/test_util.py b/test/test_util.py index 2529ed1f..41876fbf 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -1,7 +1,16 @@ +import json + from datetime import datetime +from http import HTTPStatus from zoneinfo import ZoneInfo +import fastapi +import pytest + +from fastapi.responses import JSONResponse + from aurweb import filters, util +from aurweb.testing.requests import Request def test_timestamp_to_datetime(): @@ -57,3 +66,22 @@ def test_git_search_double_commit(): # Locate the shortest prefix length that matches commit_hash. prefixlen = util.git_search(repo, commit_hash) assert prefixlen == 13 + + +@pytest.mark.asyncio +async def test_error_or_result(): + + async def route(request: fastapi.Request): + raise RuntimeError("No response returned.") + + response = await util.error_or_result(route, Request()) + assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + + data = json.loads(response.body) + assert data.get("error") == "No response returned." + + async def good_route(request: fastapi.Request): + return JSONResponse() + + response = await util.error_or_result(good_route, Request()) + assert response.status_code == HTTPStatus.OK