diff --git a/aurweb/asgi.py b/aurweb/asgi.py index a674fec6..35166c73 100644 --- a/aurweb/asgi.py +++ b/aurweb/asgi.py @@ -2,6 +2,8 @@ import asyncio import http import typing +from urllib.parse import quote_plus + from fastapi import FastAPI, HTTPException, Request from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles @@ -120,3 +122,23 @@ async def check_terms_of_service(request: Request, call_next: typing.Callable): task = asyncio.create_task(call_next(request)) await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED) return task.result() + + +@app.middleware("http") +async def id_redirect_middleware(request: Request, call_next: typing.Callable): + id = request.query_params.get("id") + + if id is not None: + # Preserve query string. + qs = [] + for k, v in request.query_params.items(): + if k != "id": + qs.append(f"{k}={quote_plus(str(v))}") + qs = str() if not qs else '?' + '&'.join(qs) + + path = request.url.path.rstrip('/') + return RedirectResponse(f"{path}/{id}{qs}") + + task = asyncio.create_task(call_next(request)) + await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED) + return task.result() diff --git a/test/test_routes.py b/test/test_routes.py index d67f4a48..a2d1786e 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -148,3 +148,13 @@ def test_nonce_csp(): if not (nonce_verified := (script.get("nonce") == nonce)): break assert nonce_verified is True + + +def test_id_redirect(): + with client as request: + response = request.get("/", params={ + "id": "test", # This param will be rewritten into Location. + "key": "value", # Test that this param persists. + "key2": "value2" # And this one. + }, allow_redirects=False) + assert response.headers.get("location") == "/test?key=value&key2=value2"