fix(test): Fixes for "TestClient" changes

Seems that client is optional according to the ASGI spec.
https://asgi.readthedocs.io/en/latest/specs/www.html

With Starlette 0.35 the TestClient connection  scope is None for "client".
https://github.com/encode/starlette/pull/2377

Signed-off-by: moson <moson@archlinux.org>
This commit is contained in:
moson 2024-01-19 16:37:42 +01:00
parent 22e1577324
commit 2fcd793a58
No known key found for this signature in database
GPG key ID: 4A4760AB4EE15296
8 changed files with 29 additions and 16 deletions

View file

@ -2,6 +2,7 @@ from fastapi import Request
from aurweb import db, schema from aurweb import db, schema
from aurweb.models.declarative import Base from aurweb.models.declarative import Base
from aurweb.util import get_client_ip
class Ban(Base): class Ban(Base):
@ -14,6 +15,6 @@ class Ban(Base):
def is_banned(request: Request): def is_banned(request: Request):
ip = request.client.host ip = get_client_ip(request)
exists = db.query(Ban).filter(Ban.IPAddress == ip).exists() exists = db.query(Ban).filter(Ban.IPAddress == ip).exists()
return db.query(exists).scalar() return db.query(exists).scalar()

View file

@ -122,7 +122,7 @@ class User(Base):
try: try:
with db.begin(): with db.begin():
self.LastLogin = now_ts self.LastLogin = now_ts
self.LastLoginIPAddress = request.client.host self.LastLoginIPAddress = util.get_client_ip(request)
if not self.session: if not self.session:
sid = generate_unique_sid() sid = generate_unique_sid()
self.session = db.create( self.session = db.create(

View file

@ -4,6 +4,7 @@ from redis.client import Pipeline
from aurweb import aur_logging, config, db, time from aurweb import aur_logging, config, db, time
from aurweb.aur_redis import redis_connection from aurweb.aur_redis import redis_connection
from aurweb.models import ApiRateLimit from aurweb.models import ApiRateLimit
from aurweb.util import get_client_ip
logger = aur_logging.get_logger(__name__) logger = aur_logging.get_logger(__name__)
@ -13,7 +14,7 @@ def _update_ratelimit_redis(request: Request, pipeline: Pipeline):
now = time.utcnow() now = time.utcnow()
time_to_delete = now - window_length time_to_delete = now - window_length
host = request.client.host host = get_client_ip(request)
window_key = f"ratelimit-ws:{host}" window_key = f"ratelimit-ws:{host}"
requests_key = f"ratelimit:{host}" requests_key = f"ratelimit:{host}"
@ -55,7 +56,7 @@ def _update_ratelimit_db(request: Request):
record.Requests += 1 record.Requests += 1
return record return record
host = request.client.host host = get_client_ip(request)
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first() record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
record = retry_create(record, now, host) record = retry_create(record, now, host)
@ -92,7 +93,7 @@ def check_ratelimit(request: Request):
record = update_ratelimit(request, pipeline) record = update_ratelimit(request, pipeline)
# Get cache value, else None. # Get cache value, else None.
host = request.client.host host = get_client_ip(request)
pipeline.get(f"ratelimit:{host}") pipeline.get(f"ratelimit:{host}")
requests = pipeline.execute()[0] requests = pipeline.execute()[0]

View file

@ -80,7 +80,9 @@ def open_session(request, conn, user_id):
conn.execute( conn.execute(
Users.update() Users.update()
.where(Users.c.ID == user_id) .where(Users.c.ID == user_id)
.values(LastLogin=int(time.time()), LastLoginIPAddress=request.client.host) .values(
LastLogin=int(time.time()), LastLoginIPAddress=util.get_client_ip(request)
)
) )
return sid return sid
@ -110,7 +112,7 @@ async def authenticate(
Receive an OpenID Connect ID token, validate it, then process it to create Receive an OpenID Connect ID token, validate it, then process it to create
an new AUR session. an new AUR session.
""" """
if is_ip_banned(conn, request.client.host): if is_ip_banned(conn, util.get_client_ip(request)):
_ = get_translator_for_request(request) _ = get_translator_for_request(request)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, status_code=HTTPStatus.FORBIDDEN,

View file

@ -67,7 +67,7 @@ def invalid_password(
def is_banned(request: Request = None, **kwargs) -> None: def is_banned(request: Request = None, **kwargs) -> None:
host = request.client.host host = util.get_client_ip(request)
exists = db.query(models.Ban, models.Ban.IPAddress == host).exists() exists = db.query(models.Ban, models.Ban.IPAddress == host).exists()
if db.query(exists).scalar(): if db.query(exists).scalar():
raise ValidationError( raise ValidationError(

View file

@ -208,3 +208,11 @@ def hash_query(query: Query):
return sha1( return sha1(
str(query.statement.compile(compile_kwargs={"literal_binds": True})).encode() str(query.statement.compile(compile_kwargs={"literal_binds": True})).encode()
).hexdigest() ).hexdigest()
def get_client_ip(request: fastapi.Request) -> str:
"""
Returns the client's IP address for a Request.
Falls back to 'no-client' is request.client is None
"""
return request.client.host if request.client else "no-client"

View file

@ -391,9 +391,10 @@ def test_post_register_error_invalid_captcha(client: TestClient):
def test_post_register_error_ip_banned(client: TestClient): def test_post_register_error_ip_banned(client: TestClient):
# 'testclient' is used as request.client.host via FastAPI TestClient. # 'no-client' is our fallback value in case request.client is None
# which is the case for TestClient
with db.begin(): with db.begin():
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow()) create(Ban, IPAddress="no-client", BanTS=datetime.utcnow())
with client as request: with client as request:
response = post_register(request) response = post_register(request)

View file

@ -310,10 +310,10 @@ def pipeline():
redis = redis_connection() redis = redis_connection()
pipeline = redis.pipeline() pipeline = redis.pipeline()
# The 'testclient' host is used when requesting the app # 'no-client' is our fallback value in case request.client is None
# via fastapi.testclient.TestClient. # which is the case for TestClient
pipeline.delete("ratelimit-ws:testclient") pipeline.delete("ratelimit-ws:no-client")
pipeline.delete("ratelimit:testclient") pipeline.delete("ratelimit:no-client")
pipeline.execute() pipeline.execute()
yield pipeline yield pipeline
@ -760,8 +760,8 @@ def test_rpc_ratelimit(
assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS) assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS)
# Delete the cached records. # Delete the cached records.
pipeline.delete("ratelimit-ws:testclient") pipeline.delete("ratelimit-ws:no-client")
pipeline.delete("ratelimit:testclient") pipeline.delete("ratelimit:no-client")
one, two = pipeline.execute() one, two = pipeline.execute()
assert one and two assert one and two