mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
feat(rpc): enforce ratelimiting
New configuration options: - `[ratelimit] cache` - A boolean indicating whether we should use configured cache (1) or database (0) for ratelimiting. Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
parent
6662975005
commit
65240c8343
5 changed files with 280 additions and 2 deletions
110
aurweb/ratelimit.py
Normal file
110
aurweb/ratelimit.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from redis.client import Pipeline
|
||||||
|
|
||||||
|
from aurweb import config, db, logging
|
||||||
|
from aurweb.models import ApiRateLimit
|
||||||
|
from aurweb.redis import redis_connection
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _update_ratelimit_redis(request: Request, pipeline: Pipeline):
|
||||||
|
window_length = config.getint("ratelimit", "window_length")
|
||||||
|
now = int(datetime.utcnow().timestamp())
|
||||||
|
time_to_delete = now - window_length
|
||||||
|
|
||||||
|
host = request.client.host
|
||||||
|
window_key = f"ratelimit-ws:{host}"
|
||||||
|
requests_key = f"ratelimit:{host}"
|
||||||
|
|
||||||
|
pipeline.get(window_key)
|
||||||
|
window = pipeline.execute()[0]
|
||||||
|
|
||||||
|
if not window or int(window.decode()) < time_to_delete:
|
||||||
|
pipeline.set(window_key, now)
|
||||||
|
pipeline.expire(window_key, window_length)
|
||||||
|
|
||||||
|
pipeline.set(requests_key, 1)
|
||||||
|
pipeline.expire(requests_key, window_length)
|
||||||
|
|
||||||
|
pipeline.execute()
|
||||||
|
else:
|
||||||
|
pipeline.incr(requests_key)
|
||||||
|
pipeline.execute()
|
||||||
|
|
||||||
|
|
||||||
|
def _update_ratelimit_db(request: Request):
|
||||||
|
window_length = config.getint("ratelimit", "window_length")
|
||||||
|
now = int(datetime.utcnow().timestamp())
|
||||||
|
time_to_delete = now - window_length
|
||||||
|
|
||||||
|
with db.begin():
|
||||||
|
db.delete(ApiRateLimit, ApiRateLimit.WindowStart < time_to_delete)
|
||||||
|
|
||||||
|
host = request.client.host
|
||||||
|
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
|
||||||
|
with db.begin():
|
||||||
|
if not record:
|
||||||
|
record = db.create(ApiRateLimit,
|
||||||
|
WindowStart=now,
|
||||||
|
IP=host, Requests=1)
|
||||||
|
else:
|
||||||
|
record.Requests += 1
|
||||||
|
|
||||||
|
logger.debug(record.Requests)
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
def update_ratelimit(request: Request, pipeline: Pipeline):
|
||||||
|
""" Update the ratelimit stored in Redis or the database depending
|
||||||
|
on AUR_CONFIG's [options] cache setting.
|
||||||
|
|
||||||
|
This Redis-capable function is slightly different than most. If Redis
|
||||||
|
is not configured to use a real server, this function instead uses
|
||||||
|
the database to persist tracking of a particular host.
|
||||||
|
|
||||||
|
:param request: FastAPI request
|
||||||
|
:param pipeline: redis.client.Pipeline
|
||||||
|
:returns: ApiRateLimit record when Redis cache is not configured, else None
|
||||||
|
"""
|
||||||
|
if config.getboolean("ratelimit", "cache"):
|
||||||
|
return _update_ratelimit_redis(request, pipeline)
|
||||||
|
return _update_ratelimit_db(request)
|
||||||
|
|
||||||
|
|
||||||
|
def check_ratelimit(request: Request):
|
||||||
|
""" Increment and check to see if request has exceeded their rate limit.
|
||||||
|
|
||||||
|
:param request: FastAPI request
|
||||||
|
:returns: True if the request host has exceeded the rate limit else False
|
||||||
|
"""
|
||||||
|
redis = redis_connection()
|
||||||
|
pipeline = redis.pipeline()
|
||||||
|
|
||||||
|
record = update_ratelimit(request, pipeline)
|
||||||
|
|
||||||
|
# Get cache value, else None.
|
||||||
|
host = request.client.host
|
||||||
|
pipeline.get(f"ratelimit:{host}")
|
||||||
|
requests = pipeline.execute()[0]
|
||||||
|
|
||||||
|
# Take into account the split paths. When Redis is used, a
|
||||||
|
# valid cache value will be returned which must be converted
|
||||||
|
# to an int. Otherwise, use the database record returned
|
||||||
|
# by update_ratelimit.
|
||||||
|
if not config.getboolean("ratelimit", "cache"):
|
||||||
|
# If we got nothing from pipeline.get, we did not use
|
||||||
|
# the Redis path of logic: use the DB record's count.
|
||||||
|
requests = record.Requests
|
||||||
|
else:
|
||||||
|
# Otherwise, just case Redis results over to an int.
|
||||||
|
requests = int(requests.decode())
|
||||||
|
|
||||||
|
limit = config.getint("ratelimit", "request_limit")
|
||||||
|
exceeded_ratelimit = requests > limit
|
||||||
|
if exceeded_ratelimit:
|
||||||
|
logger.debug(f"{host} has exceeded the ratelimit.")
|
||||||
|
|
||||||
|
return exceeded_ratelimit
|
|
@ -1,9 +1,11 @@
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from fastapi import APIRouter, Query, Request
|
from fastapi import APIRouter, Query, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from aurweb.ratelimit import check_ratelimit
|
||||||
from aurweb.rpc import RPC
|
from aurweb.rpc import RPC
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -64,6 +66,11 @@ async def rpc(request: Request,
|
||||||
# Create a handle to our RPC class.
|
# Create a handle to our RPC class.
|
||||||
rpc = RPC(version=v, type=type)
|
rpc = RPC(version=v, type=type)
|
||||||
|
|
||||||
|
# If ratelimit was exceeded, return a 429 Too Many Requests.
|
||||||
|
if check_ratelimit(request):
|
||||||
|
return JSONResponse(rpc.error("Rate limit reached"),
|
||||||
|
status_code=int(HTTPStatus.TOO_MANY_REQUESTS))
|
||||||
|
|
||||||
# Prepare list of arguments for input. If 'arg' was given, it'll
|
# Prepare list of arguments for input. If 'arg' was given, it'll
|
||||||
# be a list with one element.
|
# be a list with one element.
|
||||||
arguments = parse_args(request)
|
arguments = parse_args(request)
|
||||||
|
|
|
@ -48,6 +48,11 @@ redis_address = redis://localhost
|
||||||
[ratelimit]
|
[ratelimit]
|
||||||
request_limit = 4000
|
request_limit = 4000
|
||||||
window_length = 86400
|
window_length = 86400
|
||||||
|
; Force-utilize cache for ratelimiting. In FastAPI, forced cache (1)
|
||||||
|
; will cause the ratelimit path to use a real or fake Redis instance
|
||||||
|
; depending on the configured options.cache setting. Otherwise,
|
||||||
|
; cache will be ignored and the database will be used.
|
||||||
|
cache = 1
|
||||||
|
|
||||||
[notifications]
|
[notifications]
|
||||||
notify-cmd = /usr/bin/aurweb-notify
|
notify-cmd = /usr/bin/aurweb-notify
|
||||||
|
|
109
test/test_ratelimit.py
Normal file
109
test/test_ratelimit.py
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from redis.client import Pipeline
|
||||||
|
|
||||||
|
from aurweb import config, db, logging
|
||||||
|
from aurweb.models import ApiRateLimit
|
||||||
|
from aurweb.ratelimit import check_ratelimit
|
||||||
|
from aurweb.redis import redis_connection
|
||||||
|
from aurweb.testing import setup_test_db
|
||||||
|
from aurweb.testing.requests import Request
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup():
|
||||||
|
setup_test_db(ApiRateLimit.__tablename__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline():
|
||||||
|
redis = redis_connection()
|
||||||
|
pipeline = redis.pipeline()
|
||||||
|
|
||||||
|
pipeline.delete("ratelimit-ws:127.0.0.1")
|
||||||
|
pipeline.delete("ratelimit:127.0.0.1")
|
||||||
|
pipeline.execute()
|
||||||
|
|
||||||
|
yield pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def mock_config_getint(section: str, key: str):
|
||||||
|
if key == "request_limit":
|
||||||
|
return 4
|
||||||
|
elif key == "window_length":
|
||||||
|
return 100
|
||||||
|
return config.getint(section, key)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_config_getboolean(return_value: int = 0):
|
||||||
|
def fn(section: str, key: str):
|
||||||
|
if section == "ratelimit" and key == "cache":
|
||||||
|
return return_value
|
||||||
|
return config.getboolean(section, key)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def mock_config_get(return_value: str = "none"):
|
||||||
|
def fn(section: str, key: str):
|
||||||
|
if section == "options" and key == "cache":
|
||||||
|
return return_value
|
||||||
|
return config.get(section, key)
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("aurweb.config.getint", side_effect=mock_config_getint)
|
||||||
|
@mock.patch("aurweb.config.getboolean", side_effect=mock_config_getboolean(1))
|
||||||
|
@mock.patch("aurweb.config.get", side_effect=mock_config_get("none"))
|
||||||
|
def test_ratelimit_redis(get: mock.MagicMock, getboolean: mock.MagicMock,
|
||||||
|
getint: mock.MagicMock, pipeline: Pipeline):
|
||||||
|
""" This test will only cover aurweb.ratelimit's Redis
|
||||||
|
path if a real Redis server is configured. Otherwise,
|
||||||
|
it'll use the database. """
|
||||||
|
|
||||||
|
# We'll need a Request for everything here.
|
||||||
|
request = Request()
|
||||||
|
|
||||||
|
# Run check_ratelimit for our request_limit. These should succeed.
|
||||||
|
for i in range(4):
|
||||||
|
assert not check_ratelimit(request)
|
||||||
|
|
||||||
|
# This check_ratelimit should fail, being the 4001th request.
|
||||||
|
assert check_ratelimit(request)
|
||||||
|
|
||||||
|
# Delete the Redis keys.
|
||||||
|
host = request.client.host
|
||||||
|
pipeline.delete(f"ratelimit-ws:{host}")
|
||||||
|
pipeline.delete(f"ratelimit:{host}")
|
||||||
|
one, two = pipeline.execute()
|
||||||
|
assert one and two
|
||||||
|
|
||||||
|
# Should be good to go again!
|
||||||
|
assert not check_ratelimit(request)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("aurweb.config.getint", side_effect=mock_config_getint)
|
||||||
|
@mock.patch("aurweb.config.getboolean", side_effect=mock_config_getboolean(0))
|
||||||
|
@mock.patch("aurweb.config.get", side_effect=mock_config_get("none"))
|
||||||
|
def test_ratelimit_db(get: mock.MagicMock, getboolean: mock.MagicMock,
|
||||||
|
getint: mock.MagicMock, pipeline: Pipeline):
|
||||||
|
|
||||||
|
# We'll need a Request for everything here.
|
||||||
|
request = Request()
|
||||||
|
|
||||||
|
# Run check_ratelimit for our request_limit. These should succeed.
|
||||||
|
for i in range(4):
|
||||||
|
assert not check_ratelimit(request)
|
||||||
|
|
||||||
|
# This check_ratelimit should fail, being the 4001th request.
|
||||||
|
assert check_ratelimit(request)
|
||||||
|
|
||||||
|
# Delete the ApiRateLimit record.
|
||||||
|
with db.begin():
|
||||||
|
db.delete(ApiRateLimit)
|
||||||
|
|
||||||
|
# Should be good to go again!
|
||||||
|
assert not check_ratelimit(request)
|
|
@ -1,9 +1,13 @@
|
||||||
|
from http import HTTPStatus
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
from redis.client import Pipeline
|
||||||
|
|
||||||
from aurweb import db, scripts
|
from aurweb import config, db, scripts
|
||||||
from aurweb.asgi import app
|
from aurweb.asgi import app
|
||||||
from aurweb.db import begin, create, query
|
from aurweb.db import begin, create, query
|
||||||
from aurweb.models.account_type import AccountType
|
from aurweb.models.account_type import AccountType
|
||||||
|
@ -18,6 +22,7 @@ from aurweb.models.package_relation import PackageRelation
|
||||||
from aurweb.models.package_vote import PackageVote
|
from aurweb.models.package_vote import PackageVote
|
||||||
from aurweb.models.relation_type import RelationType
|
from aurweb.models.relation_type import RelationType
|
||||||
from aurweb.models.user import User
|
from aurweb.models.user import User
|
||||||
|
from aurweb.redis import redis_connection
|
||||||
from aurweb.testing import setup_test_db
|
from aurweb.testing import setup_test_db
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +36,7 @@ def setup():
|
||||||
# Set up tables.
|
# Set up tables.
|
||||||
setup_test_db("Users", "PackageBases", "Packages", "Licenses",
|
setup_test_db("Users", "PackageBases", "Packages", "Licenses",
|
||||||
"PackageDepends", "PackageRelations", "PackageLicenses",
|
"PackageDepends", "PackageRelations", "PackageLicenses",
|
||||||
"PackageKeywords", "PackageVotes")
|
"PackageKeywords", "PackageVotes", "ApiRateLimit")
|
||||||
|
|
||||||
# Create test package details.
|
# Create test package details.
|
||||||
with begin():
|
with begin():
|
||||||
|
@ -178,6 +183,18 @@ def setup():
|
||||||
scripts.popupdate.run_single(conn, pkgbase1)
|
scripts.popupdate.run_single(conn, pkgbase1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline():
|
||||||
|
redis = redis_connection()
|
||||||
|
pipeline = redis.pipeline()
|
||||||
|
|
||||||
|
pipeline.delete("ratelimit-ws:testclient")
|
||||||
|
pipeline.delete("ratelimit:testclient")
|
||||||
|
one, two = pipeline.execute()
|
||||||
|
|
||||||
|
yield pipeline
|
||||||
|
|
||||||
|
|
||||||
def test_rpc_singular_info():
|
def test_rpc_singular_info():
|
||||||
# Define expected response.
|
# Define expected response.
|
||||||
expected_data = {
|
expected_data = {
|
||||||
|
@ -441,3 +458,33 @@ def test_rpc_unimplemented_types():
|
||||||
data = response.json()
|
data = response.json()
|
||||||
expected = f"Request type '{type}' is not yet implemented."
|
expected = f"Request type '{type}' is not yet implemented."
|
||||||
assert data.get("error") == expected
|
assert data.get("error") == expected
|
||||||
|
|
||||||
|
|
||||||
|
def mock_config_getint(section: str, key: str):
|
||||||
|
if key == "request_limit":
|
||||||
|
return 4
|
||||||
|
elif key == "window_length":
|
||||||
|
return 100
|
||||||
|
return config.getint(section, key)
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("aurweb.config.getint", side_effect=mock_config_getint)
|
||||||
|
def test_rpc_ratelimit(getint: mock.MagicMock, pipeline: Pipeline):
|
||||||
|
for i in range(4):
|
||||||
|
# The first 4 requests should be good.
|
||||||
|
response = make_request("/rpc?v=5&type=suggest-pkgbase&arg=big")
|
||||||
|
assert response.status_code == int(HTTPStatus.OK)
|
||||||
|
|
||||||
|
# The fifth request should be banned.
|
||||||
|
response = make_request("/rpc?v=5&type=suggest-pkgbase&arg=big")
|
||||||
|
assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS)
|
||||||
|
|
||||||
|
# Delete the cached records.
|
||||||
|
pipeline.delete("ratelimit-ws:testclient")
|
||||||
|
pipeline.delete("ratelimit:testclient")
|
||||||
|
one, two = pipeline.execute()
|
||||||
|
assert one and two
|
||||||
|
|
||||||
|
# The new first request should be good.
|
||||||
|
response = make_request("/rpc?v=5&type=suggest-pkgbase&arg=big")
|
||||||
|
assert response.status_code == int(HTTPStatus.OK)
|
||||||
|
|
Loading…
Add table
Reference in a new issue