diff --git a/aurweb/auth/__init__.py b/aurweb/auth/__init__.py index b6dd6e3f..5f55e2fb 100644 --- a/aurweb/auth/__init__.py +++ b/aurweb/auth/__init__.py @@ -7,7 +7,6 @@ import fastapi from fastapi import HTTPException from fastapi.responses import RedirectResponse -from sqlalchemy import and_ from starlette.authentication import AuthCredentials, AuthenticationBackend from starlette.requests import HTTPConnection @@ -97,18 +96,27 @@ class AnonymousUser: class BasicAuthBackend(AuthenticationBackend): async def authenticate(self, conn: HTTPConnection): + unauthenticated = (None, AnonymousUser()) sid = conn.cookies.get("AURSID") if not sid: - return (None, AnonymousUser()) + return unauthenticated - now_ts = datetime.utcnow().timestamp() - record = db.query(Session).filter( - and_(Session.SessionID == sid, - Session.LastUpdateTS >= now_ts)).first() + timeout = aurweb.config.getint("options", "login_timeout") + remembered = ("AURREMEMBER" in conn.cookies + and bool(conn.cookies.get("AURREMEMBER"))) + if remembered: + timeout = aurweb.config.getint("options", + "persistent_cookie_timeout") # If no session with sid and a LastUpdateTS now or later exists. + now_ts = int(datetime.utcnow().timestamp()) + record = db.query(Session).filter(Session.SessionID == sid).first() if not record: - return (None, AnonymousUser()) + return unauthenticated + elif record.LastUpdateTS < (now_ts - timeout): + with db.begin(): + db.delete_all([record]) + return unauthenticated # At this point, we cannot have an invalid user if the record # exists, due to ForeignKey constraints in the schema upheld diff --git a/aurweb/models/user.py b/aurweb/models/user.py index d0bdea30..5ead606e 100644 --- a/aurweb/models/user.py +++ b/aurweb/models/user.py @@ -123,10 +123,6 @@ class User(Base): for i in range(tries): exc = None now_ts = datetime.utcnow().timestamp() - session_ts = now_ts + ( - session_time if session_time - else aurweb.config.getint("options", "login_timeout") - ) try: with db.begin(): self.LastLogin = now_ts @@ -135,12 +131,12 @@ class User(Base): sid = generate_unique_sid() self.session = db.create(Session, User=self, SessionID=sid, - LastUpdateTS=session_ts) + LastUpdateTS=now_ts) else: last_updated = self.session.LastUpdateTS if last_updated and last_updated < now_ts: self.session.SessionID = generate_unique_sid() - self.session.LastUpdateTS = session_ts + self.session.LastUpdateTS = now_ts break except IntegrityError as exc_: exc = exc_ diff --git a/aurweb/routers/auth.py b/aurweb/routers/auth.py index 74763667..8815c896 100644 --- a/aurweb/routers/auth.py +++ b/aurweb/routers/auth.py @@ -73,6 +73,10 @@ async def login_post(request: Request, response.set_cookie("AURLANG", user.LangPreference, secure=secure, httponly=secure, samesite=cookies.samesite()) + response.set_cookie("AURREMEMBER", remember_me, + expires=expires_at, + secure=secure, httponly=secure, + samesite=cookies.samesite()) return response diff --git a/test/test_auth.py b/test/test_auth.py index b607a038..0094aa25 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -6,7 +6,7 @@ import pytest from fastapi import HTTPException from sqlalchemy.exc import IntegrityError -from aurweb import db +from aurweb import config, db from aurweb.auth import AnonymousUser, BasicAuthBackend, account_type_required, auth_required from aurweb.models.account_type import USER, USER_ID from aurweb.models.session import Session @@ -76,6 +76,28 @@ async def test_basic_auth_backend(user: User, backend: BasicAuthBackend): assert result == user +@pytest.mark.asyncio +async def test_expired_session(backend: BasicAuthBackend, user: User): + """ Login, expire the session manually, then authenticate. """ + # First, build a Request with a logged in user. + request = Request() + request.user = user + sid = request.user.login(Request(), "testPassword") + request.cookies["AURSID"] = sid + + # Set Session.LastUpdateTS to 20 seconds expired. + timeout = config.getint("options", "login_timeout") + now_ts = int(datetime.utcnow().timestamp()) + with db.begin(): + request.user.session.LastUpdateTS = now_ts - timeout - 20 + + # Run through authentication backend and get the session + # deleted due to its expiration. + await backend.authenticate(request) + session = db.query(Session).filter(Session.SessionID == sid).first() + assert session is None + + @pytest.mark.asyncio async def test_auth_required_redirection_bad_referrer(): # Create a fake route function which can be wrapped by auth_required. diff --git a/test/test_auth_routes.py b/test/test_auth_routes.py index 3ae8a56c..f3e2a011 100644 --- a/test/test_auth_routes.py +++ b/test/test_auth_routes.py @@ -13,7 +13,6 @@ from aurweb.asgi import app from aurweb.models.account_type import USER_ID from aurweb.models.session import Session from aurweb.models.user import User -from aurweb.testing.requests import Request # Some test global constants. TEST_USERNAME = "test" @@ -136,12 +135,11 @@ def test_secure_login(getboolean: bool, client: TestClient, user: User): def test_authenticated_login(client: TestClient, user: User): post_data = { - "user": "test", + "user": user.Username, "passwd": "testPassword", "next": "/" } - cookies = {"AURSID": user.login(Request(), "testPassword")} with client as request: # Try to login. response = request.post("/login", data=post_data, @@ -153,7 +151,7 @@ def test_authenticated_login(client: TestClient, user: User): # when requesting GET /login as an authenticated user. # Now, let's verify that we receive 403 Forbidden when we # try to get /login as an authenticated user. - response = request.get("/login", cookies=cookies, + response = request.get("/login", cookies=response.cookies, allow_redirects=False) assert response.status_code == int(HTTPStatus.OK) assert "Logged-in as: test" in response.text @@ -200,14 +198,12 @@ def test_login_remember_me(client: TestClient, user: User): cookie_timeout = aurweb.config.getint( "options", "persistent_cookie_timeout") - expected_ts = datetime.utcnow().timestamp() + cookie_timeout - + now_ts = int(datetime.utcnow().timestamp()) session = db.query(Session).filter(Session.UsersID == user.ID).first() - # Expect that LastUpdateTS was within 5 seconds of the expected_ts, - # which is equal to the current timestamp + persistent_cookie_timeout. - assert session.LastUpdateTS > expected_ts - 5 - assert session.LastUpdateTS < expected_ts + 5 + # Expect that LastUpdateTS is not past the cookie timeout + # for a remembered session. + assert session.LastUpdateTS > (now_ts - cookie_timeout) def test_login_incorrect_password_remember_me(client: TestClient, user: User):