diff --git a/aurweb/testing/requests.py b/aurweb/testing/requests.py index a8c077db..76f7afca 100644 --- a/aurweb/testing/requests.py +++ b/aurweb/testing/requests.py @@ -1,3 +1,5 @@ +from typing import Dict + import aurweb.config @@ -27,7 +29,13 @@ class URL: class Request: """ A fake Request object which mimics a FastAPI Request for tests. """ client = Client() - cookies = dict() - headers = dict() user = User() url = URL() + + def __init__(self, + method: str = "GET", + headers: Dict[str, str] = dict(), + cookies: Dict[str, str] = dict()) -> "Request": + self.method = method.upper() + self.headers = headers + self.cookies = cookies diff --git a/test/test_auth.py b/test/test_auth.py index b63fb96f..b607a038 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -1,11 +1,13 @@ from datetime import datetime +import fastapi import pytest +from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from aurweb import db -from aurweb.auth import AnonymousUser, BasicAuthBackend, account_type_required +from aurweb.auth import AnonymousUser, BasicAuthBackend, account_type_required, auth_required from aurweb.models.account_type import USER, USER_ID from aurweb.models.session import Session from aurweb.models.user import User @@ -74,6 +76,24 @@ async def test_basic_auth_backend(user: User, backend: BasicAuthBackend): assert result == user +@pytest.mark.asyncio +async def test_auth_required_redirection_bad_referrer(): + # Create a fake route function which can be wrapped by auth_required. + def bad_referrer_route(request: fastapi.Request): + pass + + # Get down to the nitty gritty internal wrapper. + bad_referrer_route = auth_required()(bad_referrer_route) + + # Execute the route with a "./blahblahblah" Referer, which does not + # match aur_location; `./` has been used as a prefix to attempt to + # ensure we're providing a fake referer. + with pytest.raises(HTTPException) as exc: + request = Request(method="POST", headers={"Referer": "./blahblahblah"}) + await bad_referrer_route(request) + assert exc.detail == "Bad Referer header." + + def test_account_type_required(): """ This test merely asserts that a few different paths do not raise exceptions. """