aurweb/aurweb/ratelimit.py
Leonidas Spyropoulos 9c0f8f053e
chore: rename logging.py and redis.py to avoid circular imports
Signed-off-by: Leonidas Spyropoulos <artafinde@archlinux.org>
2022-10-22 18:51:38 +01:00

116 lines
3.8 KiB
Python

from fastapi import Request
from redis.client import Pipeline
from aurweb import aur_logging, config, db, time
from aurweb.aur_redis import redis_connection
from aurweb.models import ApiRateLimit
logger = aur_logging.get_logger(__name__)
def _update_ratelimit_redis(request: Request, pipeline: Pipeline):
window_length = config.getint("ratelimit", "window_length")
now = time.utcnow()
time_to_delete = now - window_length
host = request.client.host
window_key = f"ratelimit-ws:{host}"
requests_key = f"ratelimit:{host}"
pipeline.get(window_key)
window = pipeline.execute()[0]
if not window or int(window.decode()) < time_to_delete:
pipeline.set(window_key, now)
pipeline.expire(window_key, window_length)
pipeline.set(requests_key, 1)
pipeline.expire(requests_key, window_length)
pipeline.execute()
else:
pipeline.incr(requests_key)
pipeline.execute()
def _update_ratelimit_db(request: Request):
window_length = config.getint("ratelimit", "window_length")
now = time.utcnow()
time_to_delete = now - window_length
@db.retry_deadlock
def retry_delete(records: list[ApiRateLimit]) -> None:
with db.begin():
db.delete_all(records)
records = db.query(ApiRateLimit).filter(ApiRateLimit.WindowStart < time_to_delete)
retry_delete(records)
@db.retry_deadlock
def retry_create(record: ApiRateLimit, now: int, host: str) -> ApiRateLimit:
with db.begin():
if not record:
record = db.create(ApiRateLimit, WindowStart=now, IP=host, Requests=1)
else:
record.Requests += 1
return record
host = request.client.host
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
record = retry_create(record, now, host)
logger.debug(record.Requests)
return record
def update_ratelimit(request: Request, pipeline: Pipeline):
"""Update the ratelimit stored in Redis or the database depending
on AUR_CONFIG's [options] cache setting.
This Redis-capable function is slightly different than most. If Redis
is not configured to use a real server, this function instead uses
the database to persist tracking of a particular host.
:param request: FastAPI request
:param pipeline: redis.client.Pipeline
:returns: ApiRateLimit record when Redis cache is not configured, else None
"""
if config.getboolean("ratelimit", "cache"):
return _update_ratelimit_redis(request, pipeline)
return _update_ratelimit_db(request)
def check_ratelimit(request: Request):
"""Increment and check to see if request has exceeded their rate limit.
:param request: FastAPI request
:returns: True if the request host has exceeded the rate limit else False
"""
redis = redis_connection()
pipeline = redis.pipeline()
record = update_ratelimit(request, pipeline)
# Get cache value, else None.
host = request.client.host
pipeline.get(f"ratelimit:{host}")
requests = pipeline.execute()[0]
# Take into account the split paths. When Redis is used, a
# valid cache value will be returned which must be converted
# to an int. Otherwise, use the database record returned
# by update_ratelimit.
if not config.getboolean("ratelimit", "cache") or requests is None:
# If we got nothing from pipeline.get, we did not use
# the Redis path of logic: use the DB record's count.
requests = record.Requests
else:
# Otherwise, just case Redis results over to an int.
requests = int(requests.decode())
limit = config.getint("ratelimit", "request_limit")
exceeded_ratelimit = requests > limit
if exceeded_ratelimit:
logger.debug(f"{host} has exceeded the ratelimit.")
return exceeded_ratelimit