mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
feat(rpc): enforce ratelimiting
New configuration options: - `[ratelimit] cache` - A boolean indicating whether we should use configured cache (1) or database (0) for ratelimiting. Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
parent
6662975005
commit
65240c8343
5 changed files with 280 additions and 2 deletions
110
aurweb/ratelimit.py
Normal file
110
aurweb/ratelimit.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
from datetime import datetime
|
||||
|
||||
from fastapi import Request
|
||||
from redis.client import Pipeline
|
||||
|
||||
from aurweb import config, db, logging
|
||||
from aurweb.models import ApiRateLimit
|
||||
from aurweb.redis import redis_connection
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _update_ratelimit_redis(request: Request, pipeline: Pipeline):
|
||||
window_length = config.getint("ratelimit", "window_length")
|
||||
now = int(datetime.utcnow().timestamp())
|
||||
time_to_delete = now - window_length
|
||||
|
||||
host = request.client.host
|
||||
window_key = f"ratelimit-ws:{host}"
|
||||
requests_key = f"ratelimit:{host}"
|
||||
|
||||
pipeline.get(window_key)
|
||||
window = pipeline.execute()[0]
|
||||
|
||||
if not window or int(window.decode()) < time_to_delete:
|
||||
pipeline.set(window_key, now)
|
||||
pipeline.expire(window_key, window_length)
|
||||
|
||||
pipeline.set(requests_key, 1)
|
||||
pipeline.expire(requests_key, window_length)
|
||||
|
||||
pipeline.execute()
|
||||
else:
|
||||
pipeline.incr(requests_key)
|
||||
pipeline.execute()
|
||||
|
||||
|
||||
def _update_ratelimit_db(request: Request):
|
||||
window_length = config.getint("ratelimit", "window_length")
|
||||
now = int(datetime.utcnow().timestamp())
|
||||
time_to_delete = now - window_length
|
||||
|
||||
with db.begin():
|
||||
db.delete(ApiRateLimit, ApiRateLimit.WindowStart < time_to_delete)
|
||||
|
||||
host = request.client.host
|
||||
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
|
||||
with db.begin():
|
||||
if not record:
|
||||
record = db.create(ApiRateLimit,
|
||||
WindowStart=now,
|
||||
IP=host, Requests=1)
|
||||
else:
|
||||
record.Requests += 1
|
||||
|
||||
logger.debug(record.Requests)
|
||||
return record
|
||||
|
||||
|
||||
def update_ratelimit(request: Request, pipeline: Pipeline):
|
||||
""" Update the ratelimit stored in Redis or the database depending
|
||||
on AUR_CONFIG's [options] cache setting.
|
||||
|
||||
This Redis-capable function is slightly different than most. If Redis
|
||||
is not configured to use a real server, this function instead uses
|
||||
the database to persist tracking of a particular host.
|
||||
|
||||
:param request: FastAPI request
|
||||
:param pipeline: redis.client.Pipeline
|
||||
:returns: ApiRateLimit record when Redis cache is not configured, else None
|
||||
"""
|
||||
if config.getboolean("ratelimit", "cache"):
|
||||
return _update_ratelimit_redis(request, pipeline)
|
||||
return _update_ratelimit_db(request)
|
||||
|
||||
|
||||
def check_ratelimit(request: Request):
|
||||
""" Increment and check to see if request has exceeded their rate limit.
|
||||
|
||||
:param request: FastAPI request
|
||||
:returns: True if the request host has exceeded the rate limit else False
|
||||
"""
|
||||
redis = redis_connection()
|
||||
pipeline = redis.pipeline()
|
||||
|
||||
record = update_ratelimit(request, pipeline)
|
||||
|
||||
# Get cache value, else None.
|
||||
host = request.client.host
|
||||
pipeline.get(f"ratelimit:{host}")
|
||||
requests = pipeline.execute()[0]
|
||||
|
||||
# Take into account the split paths. When Redis is used, a
|
||||
# valid cache value will be returned which must be converted
|
||||
# to an int. Otherwise, use the database record returned
|
||||
# by update_ratelimit.
|
||||
if not config.getboolean("ratelimit", "cache"):
|
||||
# If we got nothing from pipeline.get, we did not use
|
||||
# the Redis path of logic: use the DB record's count.
|
||||
requests = record.Requests
|
||||
else:
|
||||
# Otherwise, just case Redis results over to an int.
|
||||
requests = int(requests.decode())
|
||||
|
||||
limit = config.getint("ratelimit", "request_limit")
|
||||
exceeded_ratelimit = requests > limit
|
||||
if exceeded_ratelimit:
|
||||
logger.debug(f"{host} has exceeded the ratelimit.")
|
||||
|
||||
return exceeded_ratelimit
|
|
@ -1,9 +1,11 @@
|
|||
from http import HTTPStatus
|
||||
from typing import List, Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
from fastapi import APIRouter, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from aurweb.ratelimit import check_ratelimit
|
||||
from aurweb.rpc import RPC
|
||||
|
||||
router = APIRouter()
|
||||
|
@ -64,6 +66,11 @@ async def rpc(request: Request,
|
|||
# Create a handle to our RPC class.
|
||||
rpc = RPC(version=v, type=type)
|
||||
|
||||
# If ratelimit was exceeded, return a 429 Too Many Requests.
|
||||
if check_ratelimit(request):
|
||||
return JSONResponse(rpc.error("Rate limit reached"),
|
||||
status_code=int(HTTPStatus.TOO_MANY_REQUESTS))
|
||||
|
||||
# Prepare list of arguments for input. If 'arg' was given, it'll
|
||||
# be a list with one element.
|
||||
arguments = parse_args(request)
|
||||
|
|
|
@ -48,6 +48,11 @@ redis_address = redis://localhost
|
|||
[ratelimit]
|
||||
request_limit = 4000
|
||||
window_length = 86400
|
||||
; Force-utilize cache for ratelimiting. In FastAPI, forced cache (1)
|
||||
; will cause the ratelimit path to use a real or fake Redis instance
|
||||
; depending on the configured options.cache setting. Otherwise,
|
||||
; cache will be ignored and the database will be used.
|
||||
cache = 1
|
||||
|
||||
[notifications]
|
||||
notify-cmd = /usr/bin/aurweb-notify
|
||||
|
|
109
test/test_ratelimit.py
Normal file
109
test/test_ratelimit.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from redis.client import Pipeline
|
||||
|
||||
from aurweb import config, db, logging
|
||||
from aurweb.models import ApiRateLimit
|
||||
from aurweb.ratelimit import check_ratelimit
|
||||
from aurweb.redis import redis_connection
|
||||
from aurweb.testing import setup_test_db
|
||||
from aurweb.testing.requests import Request
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup():
|
||||
setup_test_db(ApiRateLimit.__tablename__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline():
|
||||
redis = redis_connection()
|
||||
pipeline = redis.pipeline()
|
||||
|
||||
pipeline.delete("ratelimit-ws:127.0.0.1")
|
||||
pipeline.delete("ratelimit:127.0.0.1")
|
||||
pipeline.execute()
|
||||
|
||||
yield pipeline
|
||||
|
||||
|
||||
def mock_config_getint(section: str, key: str):
|
||||
if key == "request_limit":
|
||||
return 4
|
||||
elif key == "window_length":
|
||||
return 100
|
||||
return config.getint(section, key)
|
||||
|
||||
|
||||
def mock_config_getboolean(return_value: int = 0):
|
||||
def fn(section: str, key: str):
|
||||
if section == "ratelimit" and key == "cache":
|
||||
return return_value
|
||||
return config.getboolean(section, key)
|
||||
return fn
|
||||
|
||||
|
||||
def mock_config_get(return_value: str = "none"):
|
||||
def fn(section: str, key: str):
|
||||
if section == "options" and key == "cache":
|
||||
return return_value
|
||||
return config.get(section, key)
|
||||
return fn
|
||||
|
||||
|
||||
@mock.patch("aurweb.config.getint", side_effect=mock_config_getint)
|
||||
@mock.patch("aurweb.config.getboolean", side_effect=mock_config_getboolean(1))
|
||||
@mock.patch("aurweb.config.get", side_effect=mock_config_get("none"))
|
||||
def test_ratelimit_redis(get: mock.MagicMock, getboolean: mock.MagicMock,
|
||||
getint: mock.MagicMock, pipeline: Pipeline):
|
||||
""" This test will only cover aurweb.ratelimit's Redis
|
||||
path if a real Redis server is configured. Otherwise,
|
||||
it'll use the database. """
|
||||
|
||||
# We'll need a Request for everything here.
|
||||
request = Request()
|
||||
|
||||
# Run check_ratelimit for our request_limit. These should succeed.
|
||||
for i in range(4):
|
||||
assert not check_ratelimit(request)
|
||||
|
||||
# This check_ratelimit should fail, being the 4001th request.
|
||||
assert check_ratelimit(request)
|
||||
|
||||
# Delete the Redis keys.
|
||||
host = request.client.host
|
||||
pipeline.delete(f"ratelimit-ws:{host}")
|
||||
pipeline.delete(f"ratelimit:{host}")
|
||||
one, two = pipeline.execute()
|
||||
assert one and two
|
||||
|
||||
# Should be good to go again!
|
||||
assert not check_ratelimit(request)
|
||||
|
||||
|
||||
@mock.patch("aurweb.config.getint", side_effect=mock_config_getint)
|
||||
@mock.patch("aurweb.config.getboolean", side_effect=mock_config_getboolean(0))
|
||||
@mock.patch("aurweb.config.get", side_effect=mock_config_get("none"))
|
||||
def test_ratelimit_db(get: mock.MagicMock, getboolean: mock.MagicMock,
|
||||
getint: mock.MagicMock, pipeline: Pipeline):
|
||||
|
||||
# We'll need a Request for everything here.
|
||||
request = Request()
|
||||
|
||||
# Run check_ratelimit for our request_limit. These should succeed.
|
||||
for i in range(4):
|
||||
assert not check_ratelimit(request)
|
||||
|
||||
# This check_ratelimit should fail, being the 4001th request.
|
||||
assert check_ratelimit(request)
|
||||
|
||||
# Delete the ApiRateLimit record.
|
||||
with db.begin():
|
||||
db.delete(ApiRateLimit)
|
||||
|
||||
# Should be good to go again!
|
||||
assert not check_ratelimit(request)
|
|
@ -1,9 +1,13 @@
|
|||
from http import HTTPStatus
|
||||
from unittest import mock
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from redis.client import Pipeline
|
||||
|
||||
from aurweb import db, scripts
|
||||
from aurweb import config, db, scripts
|
||||
from aurweb.asgi import app
|
||||
from aurweb.db import begin, create, query
|
||||
from aurweb.models.account_type import AccountType
|
||||
|
@ -18,6 +22,7 @@ from aurweb.models.package_relation import PackageRelation
|
|||
from aurweb.models.package_vote import PackageVote
|
||||
from aurweb.models.relation_type import RelationType
|
||||
from aurweb.models.user import User
|
||||
from aurweb.redis import redis_connection
|
||||
from aurweb.testing import setup_test_db
|
||||
|
||||
|
||||
|
@ -31,7 +36,7 @@ def setup():
|
|||
# Set up tables.
|
||||
setup_test_db("Users", "PackageBases", "Packages", "Licenses",
|
||||
"PackageDepends", "PackageRelations", "PackageLicenses",
|
||||
"PackageKeywords", "PackageVotes")
|
||||
"PackageKeywords", "PackageVotes", "ApiRateLimit")
|
||||
|
||||
# Create test package details.
|
||||
with begin():
|
||||
|
@ -178,6 +183,18 @@ def setup():
|
|||
scripts.popupdate.run_single(conn, pkgbase1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline():
|
||||
redis = redis_connection()
|
||||
pipeline = redis.pipeline()
|
||||
|
||||
pipeline.delete("ratelimit-ws:testclient")
|
||||
pipeline.delete("ratelimit:testclient")
|
||||
one, two = pipeline.execute()
|
||||
|
||||
yield pipeline
|
||||
|
||||
|
||||
def test_rpc_singular_info():
|
||||
# Define expected response.
|
||||
expected_data = {
|
||||
|
@ -441,3 +458,33 @@ def test_rpc_unimplemented_types():
|
|||
data = response.json()
|
||||
expected = f"Request type '{type}' is not yet implemented."
|
||||
assert data.get("error") == expected
|
||||
|
||||
|
||||
def mock_config_getint(section: str, key: str):
|
||||
if key == "request_limit":
|
||||
return 4
|
||||
elif key == "window_length":
|
||||
return 100
|
||||
return config.getint(section, key)
|
||||
|
||||
|
||||
@mock.patch("aurweb.config.getint", side_effect=mock_config_getint)
|
||||
def test_rpc_ratelimit(getint: mock.MagicMock, pipeline: Pipeline):
|
||||
for i in range(4):
|
||||
# The first 4 requests should be good.
|
||||
response = make_request("/rpc?v=5&type=suggest-pkgbase&arg=big")
|
||||
assert response.status_code == int(HTTPStatus.OK)
|
||||
|
||||
# The fifth request should be banned.
|
||||
response = make_request("/rpc?v=5&type=suggest-pkgbase&arg=big")
|
||||
assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS)
|
||||
|
||||
# Delete the cached records.
|
||||
pipeline.delete("ratelimit-ws:testclient")
|
||||
pipeline.delete("ratelimit:testclient")
|
||||
one, two = pipeline.execute()
|
||||
assert one and two
|
||||
|
||||
# The new first request should be good.
|
||||
response = make_request("/rpc?v=5&type=suggest-pkgbase&arg=big")
|
||||
assert response.status_code == int(HTTPStatus.OK)
|
||||
|
|
Loading…
Add table
Reference in a new issue