diff --git a/aurweb/routers/auth.py b/aurweb/routers/auth.py index 4e6a416a..b8e83c7d 100644 --- a/aurweb/routers/auth.py +++ b/aurweb/routers/auth.py @@ -1,13 +1,14 @@ from datetime import datetime from http import HTTPStatus -from fastapi import APIRouter, Form, Request +from fastapi import APIRouter, Form, HTTPException, Request from fastapi.responses import HTMLResponse, RedirectResponse import aurweb.config from aurweb import cookies from aurweb.auth import auth_required +from aurweb.l10n import get_translator_for_request from aurweb.models import User from aurweb.templates import make_variable_context, render_template @@ -35,6 +36,15 @@ async def login_post(request: Request, user: str = Form(default=str()), passwd: str = Form(default=str()), remember_me: bool = Form(default=False)): + # TODO: Once the Origin header gets broader adoption, this code can be + # slightly simplified to use it. + login_path = aurweb.config.get("options", "aur_location") + "/login" + referer = request.headers.get("Referer") + if not referer or not referer.startswith(login_path): + _ = get_translator_for_request(request) + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, + detail=_("Bad Referer header.")) + from aurweb.db import session user = session.query(User).filter(User.Username == user).first() diff --git a/po/aurweb.pot b/po/aurweb.pot index f4deee70..0b3996a8 100644 --- a/po/aurweb.pot +++ b/po/aurweb.pot @@ -968,6 +968,10 @@ msgstr "" msgid "Package details could not be found." msgstr "" +#: aurweb/routers/auth.py +msgid "Bad Referer header." +msgstr "" + #: aurweb/routers/packages.py msgid "You did not select any packages to be notified about." msgstr "" diff --git a/test/test_auth_routes.py b/test/test_auth_routes.py index 313f9927..39afc6f9 100644 --- a/test/test_auth_routes.py +++ b/test/test_auth_routes.py @@ -18,6 +18,9 @@ from aurweb.testing import setup_test_db # Some test global constants. TEST_USERNAME = "test" TEST_EMAIL = "test@example.org" +TEST_REFERER = { + "referer": aurweb.config.get("options", "aur_location") + "/login", +} # Global mutables. user = client = None @@ -39,6 +42,10 @@ def setup(): client = TestClient(app) + # Necessary for forged login CSRF protection on the login route. Set here + # instead of only on the necessary requests for convenience. + client.headers.update(TEST_REFERER) + def test_login_logout(): post_data = { @@ -92,6 +99,10 @@ def test_secure_login(mock): # Create a local TestClient here since we mocked configuration. client = TestClient(app) + # Necessary for forged login CSRF protection on the login route. Set here + # instead of only on the necessary requests for convenience. + client.headers.update(TEST_REFERER) + # Data used for our upcoming http post request. post_data = { "user": user.Username, @@ -246,3 +257,26 @@ def test_login_incorrect_password(): assert post_data["user"] in content assert post_data["passwd"] not in content assert "checked" not in content + + +def test_login_bad_referer(): + post_data = { + "user": "test", + "passwd": "testPassword", + "next": "/", + } + + # Create new TestClient without a Referer header. + client = TestClient(app) + + with client as request: + response = request.post("/login", data=post_data) + assert "AURSID" not in response.cookies + + BAD_REFERER = { + "referer": aurweb.config.get("options", "aur_location") + ".mal.local", + } + with client as request: + response = request.post("/login", data=post_data, headers=BAD_REFERER) + assert response.status_code == int(HTTPStatus.BAD_REQUEST) + assert "AURSID" not in response.cookies