diff --git a/aurweb/routers/packages.py b/aurweb/routers/packages.py index b0da3bf9..14b91221 100644 --- a/aurweb/routers/packages.py +++ b/aurweb/routers/packages.py @@ -30,8 +30,11 @@ async def packages_get(request: Request, context: Dict[str, Any], context["q"] = dict(request.query_params) # Per page and offset. - per_page = context["PP"] = int(request.query_params.get("PP", 50)) - offset = context["O"] = int(request.query_params.get("O", 0)) + offset, per_page = util.sanitize_params( + request.query_params.get("O", defaults.O), + request.query_params.get("PP", defaults.PP)) + context["O"] = offset + context["PP"] = per_page # Query search by. search_by = context["SeB"] = request.query_params.get("SeB", "nd") diff --git a/aurweb/util.py b/aurweb/util.py index dd7491d3..88142cbc 100644 --- a/aurweb/util.py +++ b/aurweb/util.py @@ -7,7 +7,7 @@ import secrets import string from datetime import datetime -from typing import Any, Callable, Dict, Iterable +from typing import Any, Callable, Dict, Iterable, Tuple from urllib.parse import urlencode, urlparse from zoneinfo import ZoneInfo @@ -18,7 +18,7 @@ from jinja2 import pass_context import aurweb.config -from aurweb import logging +from aurweb import defaults, logging logger = logging.get_logger(__name__) @@ -155,3 +155,17 @@ def get_ssh_fingerprints(): def apply_all(iterable: Iterable, fn: Callable): for item in iterable: fn(item) + + +def sanitize_params(offset: str, per_page: str) -> Tuple[int, int]: + try: + offset = int(offset) + except ValueError: + offset = defaults.O + + try: + per_page = int(per_page) + except ValueError: + per_page = defaults.PP + + return (offset, per_page) diff --git a/test/test_packages_routes.py b/test/test_packages_routes.py index b4a582e3..2ef3f3d8 100644 --- a/test/test_packages_routes.py +++ b/test/test_packages_routes.py @@ -486,15 +486,11 @@ def test_pkgbase(client: TestClient, package: Package): def test_packages(client: TestClient, packages: List[Package]): - """ Test the / packages route with defaults. - - Defaults: - 50 results per page - offset of 0 - """ with client as request: response = request.get("/packages", params={ - "SeB": "X" # "X" isn't valid, defaults to "nd" + "SeB": "X", # "X" isn't valid, defaults to "nd" + "PP": "1 or 1", + "O": "0 or 0" }) assert response.status_code == int(HTTPStatus.OK)