mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
fix(test): Fixes for "TestClient" changes
Seems that client is optional according to the ASGI spec. https://asgi.readthedocs.io/en/latest/specs/www.html With Starlette 0.35 the TestClient connection scope is None for "client". https://github.com/encode/starlette/pull/2377 Signed-off-by: moson <moson@archlinux.org>
This commit is contained in:
parent
22e1577324
commit
2fcd793a58
8 changed files with 29 additions and 16 deletions
|
@ -2,6 +2,7 @@ from fastapi import Request
|
|||
|
||||
from aurweb import db, schema
|
||||
from aurweb.models.declarative import Base
|
||||
from aurweb.util import get_client_ip
|
||||
|
||||
|
||||
class Ban(Base):
|
||||
|
@ -14,6 +15,6 @@ class Ban(Base):
|
|||
|
||||
|
||||
def is_banned(request: Request):
|
||||
ip = request.client.host
|
||||
ip = get_client_ip(request)
|
||||
exists = db.query(Ban).filter(Ban.IPAddress == ip).exists()
|
||||
return db.query(exists).scalar()
|
||||
|
|
|
@ -122,7 +122,7 @@ class User(Base):
|
|||
try:
|
||||
with db.begin():
|
||||
self.LastLogin = now_ts
|
||||
self.LastLoginIPAddress = request.client.host
|
||||
self.LastLoginIPAddress = util.get_client_ip(request)
|
||||
if not self.session:
|
||||
sid = generate_unique_sid()
|
||||
self.session = db.create(
|
||||
|
|
|
@ -4,6 +4,7 @@ from redis.client import Pipeline
|
|||
from aurweb import aur_logging, config, db, time
|
||||
from aurweb.aur_redis import redis_connection
|
||||
from aurweb.models import ApiRateLimit
|
||||
from aurweb.util import get_client_ip
|
||||
|
||||
logger = aur_logging.get_logger(__name__)
|
||||
|
||||
|
@ -13,7 +14,7 @@ def _update_ratelimit_redis(request: Request, pipeline: Pipeline):
|
|||
now = time.utcnow()
|
||||
time_to_delete = now - window_length
|
||||
|
||||
host = request.client.host
|
||||
host = get_client_ip(request)
|
||||
window_key = f"ratelimit-ws:{host}"
|
||||
requests_key = f"ratelimit:{host}"
|
||||
|
||||
|
@ -55,7 +56,7 @@ def _update_ratelimit_db(request: Request):
|
|||
record.Requests += 1
|
||||
return record
|
||||
|
||||
host = request.client.host
|
||||
host = get_client_ip(request)
|
||||
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
|
||||
record = retry_create(record, now, host)
|
||||
|
||||
|
@ -92,7 +93,7 @@ def check_ratelimit(request: Request):
|
|||
record = update_ratelimit(request, pipeline)
|
||||
|
||||
# Get cache value, else None.
|
||||
host = request.client.host
|
||||
host = get_client_ip(request)
|
||||
pipeline.get(f"ratelimit:{host}")
|
||||
requests = pipeline.execute()[0]
|
||||
|
||||
|
|
|
@ -80,7 +80,9 @@ def open_session(request, conn, user_id):
|
|||
conn.execute(
|
||||
Users.update()
|
||||
.where(Users.c.ID == user_id)
|
||||
.values(LastLogin=int(time.time()), LastLoginIPAddress=request.client.host)
|
||||
.values(
|
||||
LastLogin=int(time.time()), LastLoginIPAddress=util.get_client_ip(request)
|
||||
)
|
||||
)
|
||||
|
||||
return sid
|
||||
|
@ -110,7 +112,7 @@ async def authenticate(
|
|||
Receive an OpenID Connect ID token, validate it, then process it to create
|
||||
an new AUR session.
|
||||
"""
|
||||
if is_ip_banned(conn, request.client.host):
|
||||
if is_ip_banned(conn, util.get_client_ip(request)):
|
||||
_ = get_translator_for_request(request)
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
|
|
|
@ -67,7 +67,7 @@ def invalid_password(
|
|||
|
||||
|
||||
def is_banned(request: Request = None, **kwargs) -> None:
|
||||
host = request.client.host
|
||||
host = util.get_client_ip(request)
|
||||
exists = db.query(models.Ban, models.Ban.IPAddress == host).exists()
|
||||
if db.query(exists).scalar():
|
||||
raise ValidationError(
|
||||
|
|
|
@ -208,3 +208,11 @@ def hash_query(query: Query):
|
|||
return sha1(
|
||||
str(query.statement.compile(compile_kwargs={"literal_binds": True})).encode()
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def get_client_ip(request: fastapi.Request) -> str:
|
||||
"""
|
||||
Returns the client's IP address for a Request.
|
||||
Falls back to 'no-client' is request.client is None
|
||||
"""
|
||||
return request.client.host if request.client else "no-client"
|
||||
|
|
|
@ -391,9 +391,10 @@ def test_post_register_error_invalid_captcha(client: TestClient):
|
|||
|
||||
|
||||
def test_post_register_error_ip_banned(client: TestClient):
|
||||
# 'testclient' is used as request.client.host via FastAPI TestClient.
|
||||
# 'no-client' is our fallback value in case request.client is None
|
||||
# which is the case for TestClient
|
||||
with db.begin():
|
||||
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
|
||||
create(Ban, IPAddress="no-client", BanTS=datetime.utcnow())
|
||||
|
||||
with client as request:
|
||||
response = post_register(request)
|
||||
|
|
|
@ -310,10 +310,10 @@ def pipeline():
|
|||
redis = redis_connection()
|
||||
pipeline = redis.pipeline()
|
||||
|
||||
# The 'testclient' host is used when requesting the app
|
||||
# via fastapi.testclient.TestClient.
|
||||
pipeline.delete("ratelimit-ws:testclient")
|
||||
pipeline.delete("ratelimit:testclient")
|
||||
# 'no-client' is our fallback value in case request.client is None
|
||||
# which is the case for TestClient
|
||||
pipeline.delete("ratelimit-ws:no-client")
|
||||
pipeline.delete("ratelimit:no-client")
|
||||
pipeline.execute()
|
||||
|
||||
yield pipeline
|
||||
|
@ -760,8 +760,8 @@ def test_rpc_ratelimit(
|
|||
assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS)
|
||||
|
||||
# Delete the cached records.
|
||||
pipeline.delete("ratelimit-ws:testclient")
|
||||
pipeline.delete("ratelimit:testclient")
|
||||
pipeline.delete("ratelimit-ws:no-client")
|
||||
pipeline.delete("ratelimit:no-client")
|
||||
one, two = pipeline.execute()
|
||||
assert one and two
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue