From dc4cc9b604a9085f631c2649909f6767e6f2ce3e Mon Sep 17 00:00:00 2001 From: Kevin Morris Date: Sat, 19 Jun 2021 01:10:53 -0700 Subject: [PATCH] add aurweb.asgi.id_redirect_middleware A new middleware which redirects requests going to '/route?id=some_id' to '/route/some_id'. In the FastAPI application, we'll prefer using restful layouts where possible where resource-based ids are parameters of the request uri: '/route/{resource_id}'. Signed-off-by: Kevin Morris --- aurweb/asgi.py | 22 ++++++++++++++++++++++ test/test_routes.py | 10 ++++++++++ 2 files changed, 32 insertions(+) diff --git a/aurweb/asgi.py b/aurweb/asgi.py index a674fec6..35166c73 100644 --- a/aurweb/asgi.py +++ b/aurweb/asgi.py @@ -2,6 +2,8 @@ import asyncio import http import typing +from urllib.parse import quote_plus + from fastapi import FastAPI, HTTPException, Request from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles @@ -120,3 +122,23 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable): task = asyncio.create_task(call_next(request)) await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED) return task.result() + + +@app.middleware("http") +async def id_redirect_middleware(request: Request, call_next: typing.Callable): + id = request.query_params.get("id") + + if id is not None: + # Preserve query string. + qs = [] + for k, v in request.query_params.items(): + if k != "id": + qs.append(f"{k}={quote_plus(str(v))}") + qs = str() if not qs else '?' + '&'.join(qs) + + path = request.url.path.rstrip('/') + return RedirectResponse(f"{path}/{id}{qs}") + + task = asyncio.create_task(call_next(request)) + await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED) + return task.result() diff --git a/test/test_routes.py b/test/test_routes.py index d67f4a48..a2d1786e 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -148,3 +148,13 @@ def test_nonce_csp(): if not (nonce_verified := (script.get("nonce") == nonce)): break assert nonce_verified is True + + +def test_id_redirect(): + with client as request: + response = request.get("/", params={ + "id": "test", # This param will be rewritten into Location. + "key": "value", # Test that this param persists. + "key2": "value2" # And this one. + }, allow_redirects=False) + assert response.headers.get("location") == "/test?key=value&key2=value2"