fix(test): Fixes for "TestClient" changes

Seems that client is optional according to the ASGI spec.
https://asgi.readthedocs.io/en/latest/specs/www.html

With Starlette 0.35 the TestClient connection  scope is None for "client".
https://github.com/encode/starlette/pull/2377

Signed-off-by: moson <moson@archlinux.org>
This commit is contained in:
moson 2024-01-19 16:37:42 +01:00
parent 22e1577324
commit 2fcd793a58
No known key found for this signature in database
GPG key ID: 4A4760AB4EE15296
8 changed files with 29 additions and 16 deletions

View file

@ -2,6 +2,7 @@ from fastapi import Request
from aurweb import db, schema
from aurweb.models.declarative import Base
from aurweb.util import get_client_ip
class Ban(Base):
@ -14,6 +15,6 @@ class Ban(Base):
def is_banned(request: Request):
ip = request.client.host
ip = get_client_ip(request)
exists = db.query(Ban).filter(Ban.IPAddress == ip).exists()
return db.query(exists).scalar()

View file

@ -122,7 +122,7 @@ class User(Base):
try:
with db.begin():
self.LastLogin = now_ts
self.LastLoginIPAddress = request.client.host
self.LastLoginIPAddress = util.get_client_ip(request)
if not self.session:
sid = generate_unique_sid()
self.session = db.create(

View file

@ -4,6 +4,7 @@ from redis.client import Pipeline
from aurweb import aur_logging, config, db, time
from aurweb.aur_redis import redis_connection
from aurweb.models import ApiRateLimit
from aurweb.util import get_client_ip
logger = aur_logging.get_logger(__name__)
@ -13,7 +14,7 @@ def _update_ratelimit_redis(request: Request, pipeline: Pipeline):
now = time.utcnow()
time_to_delete = now - window_length
host = request.client.host
host = get_client_ip(request)
window_key = f"ratelimit-ws:{host}"
requests_key = f"ratelimit:{host}"
@ -55,7 +56,7 @@ def _update_ratelimit_db(request: Request):
record.Requests += 1
return record
host = request.client.host
host = get_client_ip(request)
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
record = retry_create(record, now, host)
@ -92,7 +93,7 @@ def check_ratelimit(request: Request):
record = update_ratelimit(request, pipeline)
# Get cache value, else None.
host = request.client.host
host = get_client_ip(request)
pipeline.get(f"ratelimit:{host}")
requests = pipeline.execute()[0]

View file

@ -80,7 +80,9 @@ def open_session(request, conn, user_id):
conn.execute(
Users.update()
.where(Users.c.ID == user_id)
.values(LastLogin=int(time.time()), LastLoginIPAddress=request.client.host)
.values(
LastLogin=int(time.time()), LastLoginIPAddress=util.get_client_ip(request)
)
)
return sid
@ -110,7 +112,7 @@ async def authenticate(
Receive an OpenID Connect ID token, validate it, then process it to create
an new AUR session.
"""
if is_ip_banned(conn, request.client.host):
if is_ip_banned(conn, util.get_client_ip(request)):
_ = get_translator_for_request(request)
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,

View file

@ -67,7 +67,7 @@ def invalid_password(
def is_banned(request: Request = None, **kwargs) -> None:
host = request.client.host
host = util.get_client_ip(request)
exists = db.query(models.Ban, models.Ban.IPAddress == host).exists()
if db.query(exists).scalar():
raise ValidationError(

View file

@ -208,3 +208,11 @@ def hash_query(query: Query):
return sha1(
str(query.statement.compile(compile_kwargs={"literal_binds": True})).encode()
).hexdigest()
def get_client_ip(request: fastapi.Request) -> str:
"""
Returns the client's IP address for a Request.
Falls back to 'no-client' is request.client is None
"""
return request.client.host if request.client else "no-client"

View file

@ -391,9 +391,10 @@ def test_post_register_error_invalid_captcha(client: TestClient):
def test_post_register_error_ip_banned(client: TestClient):
# 'testclient' is used as request.client.host via FastAPI TestClient.
# 'no-client' is our fallback value in case request.client is None
# which is the case for TestClient
with db.begin():
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
create(Ban, IPAddress="no-client", BanTS=datetime.utcnow())
with client as request:
response = post_register(request)

View file

@ -310,10 +310,10 @@ def pipeline():
redis = redis_connection()
pipeline = redis.pipeline()
# The 'testclient' host is used when requesting the app
# via fastapi.testclient.TestClient.
pipeline.delete("ratelimit-ws:testclient")
pipeline.delete("ratelimit:testclient")
# 'no-client' is our fallback value in case request.client is None
# which is the case for TestClient
pipeline.delete("ratelimit-ws:no-client")
pipeline.delete("ratelimit:no-client")
pipeline.execute()
yield pipeline
@ -760,8 +760,8 @@ def test_rpc_ratelimit(
assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS)
# Delete the cached records.
pipeline.delete("ratelimit-ws:testclient")
pipeline.delete("ratelimit:testclient")
pipeline.delete("ratelimit-ws:no-client")
pipeline.delete("ratelimit:no-client")
one, two = pipeline.execute()
assert one and two