diff --git a/aurweb/models/ban.py b/aurweb/models/ban.py index 0fcb6d2e..d2a7250d 100644 --- a/aurweb/models/ban.py +++ b/aurweb/models/ban.py @@ -2,6 +2,7 @@ from fastapi import Request from aurweb import db, schema from aurweb.models.declarative import Base +from aurweb.util import get_client_ip class Ban(Base): @@ -14,6 +15,6 @@ class Ban(Base): def is_banned(request: Request): - ip = request.client.host + ip = get_client_ip(request) exists = db.query(Ban).filter(Ban.IPAddress == ip).exists() return db.query(exists).scalar() diff --git a/aurweb/models/user.py b/aurweb/models/user.py index b64c1c2e..ee2889d2 100644 --- a/aurweb/models/user.py +++ b/aurweb/models/user.py @@ -122,7 +122,7 @@ class User(Base): try: with db.begin(): self.LastLogin = now_ts - self.LastLoginIPAddress = request.client.host + self.LastLoginIPAddress = util.get_client_ip(request) if not self.session: sid = generate_unique_sid() self.session = db.create( diff --git a/aurweb/ratelimit.py b/aurweb/ratelimit.py index ea191972..060f8dcb 100644 --- a/aurweb/ratelimit.py +++ b/aurweb/ratelimit.py @@ -4,6 +4,7 @@ from redis.client import Pipeline from aurweb import aur_logging, config, db, time from aurweb.aur_redis import redis_connection from aurweb.models import ApiRateLimit +from aurweb.util import get_client_ip logger = aur_logging.get_logger(__name__) @@ -13,7 +14,7 @@ def _update_ratelimit_redis(request: Request, pipeline: Pipeline): now = time.utcnow() time_to_delete = now - window_length - host = request.client.host + host = get_client_ip(request) window_key = f"ratelimit-ws:{host}" requests_key = f"ratelimit:{host}" @@ -55,7 +56,7 @@ def _update_ratelimit_db(request: Request): record.Requests += 1 return record - host = request.client.host + host = get_client_ip(request) record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first() record = retry_create(record, now, host) @@ -92,7 +93,7 @@ def check_ratelimit(request: Request): record = update_ratelimit(request, pipeline) # Get cache value, else None. - host = request.client.host + host = get_client_ip(request) pipeline.get(f"ratelimit:{host}") requests = pipeline.execute()[0] diff --git a/aurweb/routers/sso.py b/aurweb/routers/sso.py index e1356cfb..fb99edd6 100644 --- a/aurweb/routers/sso.py +++ b/aurweb/routers/sso.py @@ -80,7 +80,9 @@ def open_session(request, conn, user_id): conn.execute( Users.update() .where(Users.c.ID == user_id) - .values(LastLogin=int(time.time()), LastLoginIPAddress=request.client.host) + .values( + LastLogin=int(time.time()), LastLoginIPAddress=util.get_client_ip(request) + ) ) return sid @@ -110,7 +112,7 @@ async def authenticate( Receive an OpenID Connect ID token, validate it, then process it to create an new AUR session. """ - if is_ip_banned(conn, request.client.host): + if is_ip_banned(conn, util.get_client_ip(request)): _ = get_translator_for_request(request) raise HTTPException( status_code=HTTPStatus.FORBIDDEN, diff --git a/aurweb/users/validate.py b/aurweb/users/validate.py index e49b0bc1..6a84b3c0 100644 --- a/aurweb/users/validate.py +++ b/aurweb/users/validate.py @@ -67,7 +67,7 @@ def invalid_password( def is_banned(request: Request = None, **kwargs) -> None: - host = request.client.host + host = util.get_client_ip(request) exists = db.query(models.Ban, models.Ban.IPAddress == host).exists() if db.query(exists).scalar(): raise ValidationError( diff --git a/aurweb/util.py b/aurweb/util.py index 3410e4d8..e5948a40 100644 --- a/aurweb/util.py +++ b/aurweb/util.py @@ -208,3 +208,11 @@ def hash_query(query: Query): return sha1( str(query.statement.compile(compile_kwargs={"literal_binds": True})).encode() ).hexdigest() + + +def get_client_ip(request: fastapi.Request) -> str: + """ + Returns the client's IP address for a Request. + Falls back to 'no-client' is request.client is None + """ + return request.client.host if request.client else "no-client" diff --git a/test/test_accounts_routes.py b/test/test_accounts_routes.py index a9cb6f7d..ccf1bc99 100644 --- a/test/test_accounts_routes.py +++ b/test/test_accounts_routes.py @@ -391,9 +391,10 @@ def test_post_register_error_invalid_captcha(client: TestClient): def test_post_register_error_ip_banned(client: TestClient): - # 'testclient' is used as request.client.host via FastAPI TestClient. + # 'no-client' is our fallback value in case request.client is None + # which is the case for TestClient with db.begin(): - create(Ban, IPAddress="testclient", BanTS=datetime.utcnow()) + create(Ban, IPAddress="no-client", BanTS=datetime.utcnow()) with client as request: response = post_register(request) diff --git a/test/test_rpc.py b/test/test_rpc.py index d33578d0..9c969b73 100644 --- a/test/test_rpc.py +++ b/test/test_rpc.py @@ -310,10 +310,10 @@ def pipeline(): redis = redis_connection() pipeline = redis.pipeline() - # The 'testclient' host is used when requesting the app - # via fastapi.testclient.TestClient. - pipeline.delete("ratelimit-ws:testclient") - pipeline.delete("ratelimit:testclient") + # 'no-client' is our fallback value in case request.client is None + # which is the case for TestClient + pipeline.delete("ratelimit-ws:no-client") + pipeline.delete("ratelimit:no-client") pipeline.execute() yield pipeline @@ -760,8 +760,8 @@ def test_rpc_ratelimit( assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS) # Delete the cached records. - pipeline.delete("ratelimit-ws:testclient") - pipeline.delete("ratelimit:testclient") + pipeline.delete("ratelimit-ws:no-client") + pipeline.delete("ratelimit:no-client") one, two = pipeline.execute() assert one and two