mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
fix(test): Fixes for "TestClient" changes
Seems that client is optional according to the ASGI spec. https://asgi.readthedocs.io/en/latest/specs/www.html With Starlette 0.35 the TestClient connection scope is None for "client". https://github.com/encode/starlette/pull/2377 Signed-off-by: moson <moson@archlinux.org>
This commit is contained in:
parent
22e1577324
commit
2fcd793a58
8 changed files with 29 additions and 16 deletions
|
@ -2,6 +2,7 @@ from fastapi import Request
|
||||||
|
|
||||||
from aurweb import db, schema
|
from aurweb import db, schema
|
||||||
from aurweb.models.declarative import Base
|
from aurweb.models.declarative import Base
|
||||||
|
from aurweb.util import get_client_ip
|
||||||
|
|
||||||
|
|
||||||
class Ban(Base):
|
class Ban(Base):
|
||||||
|
@ -14,6 +15,6 @@ class Ban(Base):
|
||||||
|
|
||||||
|
|
||||||
def is_banned(request: Request):
|
def is_banned(request: Request):
|
||||||
ip = request.client.host
|
ip = get_client_ip(request)
|
||||||
exists = db.query(Ban).filter(Ban.IPAddress == ip).exists()
|
exists = db.query(Ban).filter(Ban.IPAddress == ip).exists()
|
||||||
return db.query(exists).scalar()
|
return db.query(exists).scalar()
|
||||||
|
|
|
@ -122,7 +122,7 @@ class User(Base):
|
||||||
try:
|
try:
|
||||||
with db.begin():
|
with db.begin():
|
||||||
self.LastLogin = now_ts
|
self.LastLogin = now_ts
|
||||||
self.LastLoginIPAddress = request.client.host
|
self.LastLoginIPAddress = util.get_client_ip(request)
|
||||||
if not self.session:
|
if not self.session:
|
||||||
sid = generate_unique_sid()
|
sid = generate_unique_sid()
|
||||||
self.session = db.create(
|
self.session = db.create(
|
||||||
|
|
|
@ -4,6 +4,7 @@ from redis.client import Pipeline
|
||||||
from aurweb import aur_logging, config, db, time
|
from aurweb import aur_logging, config, db, time
|
||||||
from aurweb.aur_redis import redis_connection
|
from aurweb.aur_redis import redis_connection
|
||||||
from aurweb.models import ApiRateLimit
|
from aurweb.models import ApiRateLimit
|
||||||
|
from aurweb.util import get_client_ip
|
||||||
|
|
||||||
logger = aur_logging.get_logger(__name__)
|
logger = aur_logging.get_logger(__name__)
|
||||||
|
|
||||||
|
@ -13,7 +14,7 @@ def _update_ratelimit_redis(request: Request, pipeline: Pipeline):
|
||||||
now = time.utcnow()
|
now = time.utcnow()
|
||||||
time_to_delete = now - window_length
|
time_to_delete = now - window_length
|
||||||
|
|
||||||
host = request.client.host
|
host = get_client_ip(request)
|
||||||
window_key = f"ratelimit-ws:{host}"
|
window_key = f"ratelimit-ws:{host}"
|
||||||
requests_key = f"ratelimit:{host}"
|
requests_key = f"ratelimit:{host}"
|
||||||
|
|
||||||
|
@ -55,7 +56,7 @@ def _update_ratelimit_db(request: Request):
|
||||||
record.Requests += 1
|
record.Requests += 1
|
||||||
return record
|
return record
|
||||||
|
|
||||||
host = request.client.host
|
host = get_client_ip(request)
|
||||||
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
|
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
|
||||||
record = retry_create(record, now, host)
|
record = retry_create(record, now, host)
|
||||||
|
|
||||||
|
@ -92,7 +93,7 @@ def check_ratelimit(request: Request):
|
||||||
record = update_ratelimit(request, pipeline)
|
record = update_ratelimit(request, pipeline)
|
||||||
|
|
||||||
# Get cache value, else None.
|
# Get cache value, else None.
|
||||||
host = request.client.host
|
host = get_client_ip(request)
|
||||||
pipeline.get(f"ratelimit:{host}")
|
pipeline.get(f"ratelimit:{host}")
|
||||||
requests = pipeline.execute()[0]
|
requests = pipeline.execute()[0]
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,9 @@ def open_session(request, conn, user_id):
|
||||||
conn.execute(
|
conn.execute(
|
||||||
Users.update()
|
Users.update()
|
||||||
.where(Users.c.ID == user_id)
|
.where(Users.c.ID == user_id)
|
||||||
.values(LastLogin=int(time.time()), LastLoginIPAddress=request.client.host)
|
.values(
|
||||||
|
LastLogin=int(time.time()), LastLoginIPAddress=util.get_client_ip(request)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return sid
|
return sid
|
||||||
|
@ -110,7 +112,7 @@ async def authenticate(
|
||||||
Receive an OpenID Connect ID token, validate it, then process it to create
|
Receive an OpenID Connect ID token, validate it, then process it to create
|
||||||
an new AUR session.
|
an new AUR session.
|
||||||
"""
|
"""
|
||||||
if is_ip_banned(conn, request.client.host):
|
if is_ip_banned(conn, util.get_client_ip(request)):
|
||||||
_ = get_translator_for_request(request)
|
_ = get_translator_for_request(request)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.FORBIDDEN,
|
status_code=HTTPStatus.FORBIDDEN,
|
||||||
|
|
|
@ -67,7 +67,7 @@ def invalid_password(
|
||||||
|
|
||||||
|
|
||||||
def is_banned(request: Request = None, **kwargs) -> None:
|
def is_banned(request: Request = None, **kwargs) -> None:
|
||||||
host = request.client.host
|
host = util.get_client_ip(request)
|
||||||
exists = db.query(models.Ban, models.Ban.IPAddress == host).exists()
|
exists = db.query(models.Ban, models.Ban.IPAddress == host).exists()
|
||||||
if db.query(exists).scalar():
|
if db.query(exists).scalar():
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
|
|
|
@ -208,3 +208,11 @@ def hash_query(query: Query):
|
||||||
return sha1(
|
return sha1(
|
||||||
str(query.statement.compile(compile_kwargs={"literal_binds": True})).encode()
|
str(query.statement.compile(compile_kwargs={"literal_binds": True})).encode()
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_ip(request: fastapi.Request) -> str:
|
||||||
|
"""
|
||||||
|
Returns the client's IP address for a Request.
|
||||||
|
Falls back to 'no-client' is request.client is None
|
||||||
|
"""
|
||||||
|
return request.client.host if request.client else "no-client"
|
||||||
|
|
|
@ -391,9 +391,10 @@ def test_post_register_error_invalid_captcha(client: TestClient):
|
||||||
|
|
||||||
|
|
||||||
def test_post_register_error_ip_banned(client: TestClient):
|
def test_post_register_error_ip_banned(client: TestClient):
|
||||||
# 'testclient' is used as request.client.host via FastAPI TestClient.
|
# 'no-client' is our fallback value in case request.client is None
|
||||||
|
# which is the case for TestClient
|
||||||
with db.begin():
|
with db.begin():
|
||||||
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
|
create(Ban, IPAddress="no-client", BanTS=datetime.utcnow())
|
||||||
|
|
||||||
with client as request:
|
with client as request:
|
||||||
response = post_register(request)
|
response = post_register(request)
|
||||||
|
|
|
@ -310,10 +310,10 @@ def pipeline():
|
||||||
redis = redis_connection()
|
redis = redis_connection()
|
||||||
pipeline = redis.pipeline()
|
pipeline = redis.pipeline()
|
||||||
|
|
||||||
# The 'testclient' host is used when requesting the app
|
# 'no-client' is our fallback value in case request.client is None
|
||||||
# via fastapi.testclient.TestClient.
|
# which is the case for TestClient
|
||||||
pipeline.delete("ratelimit-ws:testclient")
|
pipeline.delete("ratelimit-ws:no-client")
|
||||||
pipeline.delete("ratelimit:testclient")
|
pipeline.delete("ratelimit:no-client")
|
||||||
pipeline.execute()
|
pipeline.execute()
|
||||||
|
|
||||||
yield pipeline
|
yield pipeline
|
||||||
|
@ -760,8 +760,8 @@ def test_rpc_ratelimit(
|
||||||
assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS)
|
assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS)
|
||||||
|
|
||||||
# Delete the cached records.
|
# Delete the cached records.
|
||||||
pipeline.delete("ratelimit-ws:testclient")
|
pipeline.delete("ratelimit-ws:no-client")
|
||||||
pipeline.delete("ratelimit:testclient")
|
pipeline.delete("ratelimit:no-client")
|
||||||
one, two = pipeline.execute()
|
one, two = pipeline.execute()
|
||||||
assert one and two
|
assert one and two
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue