From 65240c8343eb01401a5ed44185d4201fc8739721 Mon Sep 17 00:00:00 2001 From: Kevin Morris Date: Thu, 21 Oct 2021 10:17:34 -0700 Subject: [PATCH] feat(rpc): enforce ratelimiting New configuration options: - `[ratelimit] cache` - A boolean indicating whether we should use configured cache (1) or database (0) for ratelimiting. Signed-off-by: Kevin Morris --- aurweb/ratelimit.py | 110 +++++++++++++++++++++++++++++++++++++++++ aurweb/routers/rpc.py | 7 +++ conf/config.defaults | 5 ++ test/test_ratelimit.py | 109 ++++++++++++++++++++++++++++++++++++++++ test/test_rpc.py | 51 ++++++++++++++++++- 5 files changed, 280 insertions(+), 2 deletions(-) create mode 100644 aurweb/ratelimit.py create mode 100644 test/test_ratelimit.py diff --git a/aurweb/ratelimit.py b/aurweb/ratelimit.py new file mode 100644 index 00000000..e306f7a7 --- /dev/null +++ b/aurweb/ratelimit.py @@ -0,0 +1,110 @@ +from datetime import datetime + +from fastapi import Request +from redis.client import Pipeline + +from aurweb import config, db, logging +from aurweb.models import ApiRateLimit +from aurweb.redis import redis_connection + +logger = logging.get_logger(__name__) + + +def _update_ratelimit_redis(request: Request, pipeline: Pipeline): + window_length = config.getint("ratelimit", "window_length") + now = int(datetime.utcnow().timestamp()) + time_to_delete = now - window_length + + host = request.client.host + window_key = f"ratelimit-ws:{host}" + requests_key = f"ratelimit:{host}" + + pipeline.get(window_key) + window = pipeline.execute()[0] + + if not window or int(window.decode()) < time_to_delete: + pipeline.set(window_key, now) + pipeline.expire(window_key, window_length) + + pipeline.set(requests_key, 1) + pipeline.expire(requests_key, window_length) + + pipeline.execute() + else: + pipeline.incr(requests_key) + pipeline.execute() + + +def _update_ratelimit_db(request: Request): + window_length = config.getint("ratelimit", "window_length") + now = int(datetime.utcnow().timestamp()) + time_to_delete = now - window_length + + with db.begin(): + db.delete(ApiRateLimit, ApiRateLimit.WindowStart < time_to_delete) + + host = request.client.host + record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first() + with db.begin(): + if not record: + record = db.create(ApiRateLimit, + WindowStart=now, + IP=host, Requests=1) + else: + record.Requests += 1 + + logger.debug(record.Requests) + return record + + +def update_ratelimit(request: Request, pipeline: Pipeline): + """ Update the ratelimit stored in Redis or the database depending + on AUR_CONFIG's [options] cache setting. + + This Redis-capable function is slightly different than most. If Redis + is not configured to use a real server, this function instead uses + the database to persist tracking of a particular host. + + :param request: FastAPI request + :param pipeline: redis.client.Pipeline + :returns: ApiRateLimit record when Redis cache is not configured, else None + """ + if config.getboolean("ratelimit", "cache"): + return _update_ratelimit_redis(request, pipeline) + return _update_ratelimit_db(request) + + +def check_ratelimit(request: Request): + """ Increment and check to see if request has exceeded their rate limit. + + :param request: FastAPI request + :returns: True if the request host has exceeded the rate limit else False + """ + redis = redis_connection() + pipeline = redis.pipeline() + + record = update_ratelimit(request, pipeline) + + # Get cache value, else None. + host = request.client.host + pipeline.get(f"ratelimit:{host}") + requests = pipeline.execute()[0] + + # Take into account the split paths. When Redis is used, a + # valid cache value will be returned which must be converted + # to an int. Otherwise, use the database record returned + # by update_ratelimit. + if not config.getboolean("ratelimit", "cache"): + # If we got nothing from pipeline.get, we did not use + # the Redis path of logic: use the DB record's count. + requests = record.Requests + else: + # Otherwise, just case Redis results over to an int. + requests = int(requests.decode()) + + limit = config.getint("ratelimit", "request_limit") + exceeded_ratelimit = requests > limit + if exceeded_ratelimit: + logger.debug(f"{host} has exceeded the ratelimit.") + + return exceeded_ratelimit diff --git a/aurweb/routers/rpc.py b/aurweb/routers/rpc.py index feb355c8..0616326b 100644 --- a/aurweb/routers/rpc.py +++ b/aurweb/routers/rpc.py @@ -1,9 +1,11 @@ +from http import HTTPStatus from typing import List, Optional from urllib.parse import unquote from fastapi import APIRouter, Query, Request from fastapi.responses import JSONResponse +from aurweb.ratelimit import check_ratelimit from aurweb.rpc import RPC router = APIRouter() @@ -64,6 +66,11 @@ async def rpc(request: Request, # Create a handle to our RPC class. rpc = RPC(version=v, type=type) + # If ratelimit was exceeded, return a 429 Too Many Requests. + if check_ratelimit(request): + return JSONResponse(rpc.error("Rate limit reached"), + status_code=int(HTTPStatus.TOO_MANY_REQUESTS)) + # Prepare list of arguments for input. If 'arg' was given, it'll # be a list with one element. arguments = parse_args(request) diff --git a/conf/config.defaults b/conf/config.defaults index 988859a0..da969343 100644 --- a/conf/config.defaults +++ b/conf/config.defaults @@ -48,6 +48,11 @@ redis_address = redis://localhost [ratelimit] request_limit = 4000 window_length = 86400 +; Force-utilize cache for ratelimiting. In FastAPI, forced cache (1) +; will cause the ratelimit path to use a real or fake Redis instance +; depending on the configured options.cache setting. Otherwise, +; cache will be ignored and the database will be used. +cache = 1 [notifications] notify-cmd = /usr/bin/aurweb-notify diff --git a/test/test_ratelimit.py b/test/test_ratelimit.py new file mode 100644 index 00000000..2634b714 --- /dev/null +++ b/test/test_ratelimit.py @@ -0,0 +1,109 @@ +from unittest import mock + +import pytest + +from redis.client import Pipeline + +from aurweb import config, db, logging +from aurweb.models import ApiRateLimit +from aurweb.ratelimit import check_ratelimit +from aurweb.redis import redis_connection +from aurweb.testing import setup_test_db +from aurweb.testing.requests import Request + +logger = logging.get_logger(__name__) + + +@pytest.fixture(autouse=True) +def setup(): + setup_test_db(ApiRateLimit.__tablename__) + + +@pytest.fixture +def pipeline(): + redis = redis_connection() + pipeline = redis.pipeline() + + pipeline.delete("ratelimit-ws:127.0.0.1") + pipeline.delete("ratelimit:127.0.0.1") + pipeline.execute() + + yield pipeline + + +def mock_config_getint(section: str, key: str): + if key == "request_limit": + return 4 + elif key == "window_length": + return 100 + return config.getint(section, key) + + +def mock_config_getboolean(return_value: int = 0): + def fn(section: str, key: str): + if section == "ratelimit" and key == "cache": + return return_value + return config.getboolean(section, key) + return fn + + +def mock_config_get(return_value: str = "none"): + def fn(section: str, key: str): + if section == "options" and key == "cache": + return return_value + return config.get(section, key) + return fn + + +@mock.patch("aurweb.config.getint", side_effect=mock_config_getint) +@mock.patch("aurweb.config.getboolean", side_effect=mock_config_getboolean(1)) +@mock.patch("aurweb.config.get", side_effect=mock_config_get("none")) +def test_ratelimit_redis(get: mock.MagicMock, getboolean: mock.MagicMock, + getint: mock.MagicMock, pipeline: Pipeline): + """ This test will only cover aurweb.ratelimit's Redis + path if a real Redis server is configured. Otherwise, + it'll use the database. """ + + # We'll need a Request for everything here. + request = Request() + + # Run check_ratelimit for our request_limit. These should succeed. + for i in range(4): + assert not check_ratelimit(request) + + # This check_ratelimit should fail, being the 4001th request. + assert check_ratelimit(request) + + # Delete the Redis keys. + host = request.client.host + pipeline.delete(f"ratelimit-ws:{host}") + pipeline.delete(f"ratelimit:{host}") + one, two = pipeline.execute() + assert one and two + + # Should be good to go again! + assert not check_ratelimit(request) + + +@mock.patch("aurweb.config.getint", side_effect=mock_config_getint) +@mock.patch("aurweb.config.getboolean", side_effect=mock_config_getboolean(0)) +@mock.patch("aurweb.config.get", side_effect=mock_config_get("none")) +def test_ratelimit_db(get: mock.MagicMock, getboolean: mock.MagicMock, + getint: mock.MagicMock, pipeline: Pipeline): + + # We'll need a Request for everything here. + request = Request() + + # Run check_ratelimit for our request_limit. These should succeed. + for i in range(4): + assert not check_ratelimit(request) + + # This check_ratelimit should fail, being the 4001th request. + assert check_ratelimit(request) + + # Delete the ApiRateLimit record. + with db.begin(): + db.delete(ApiRateLimit) + + # Should be good to go again! + assert not check_ratelimit(request) diff --git a/test/test_rpc.py b/test/test_rpc.py index 38cee0eb..9400ee06 100644 --- a/test/test_rpc.py +++ b/test/test_rpc.py @@ -1,9 +1,13 @@ +from http import HTTPStatus +from unittest import mock + import orjson import pytest from fastapi.testclient import TestClient +from redis.client import Pipeline -from aurweb import db, scripts +from aurweb import config, db, scripts from aurweb.asgi import app from aurweb.db import begin, create, query from aurweb.models.account_type import AccountType @@ -18,6 +22,7 @@ from aurweb.models.package_relation import PackageRelation from aurweb.models.package_vote import PackageVote from aurweb.models.relation_type import RelationType from aurweb.models.user import User +from aurweb.redis import redis_connection from aurweb.testing import setup_test_db @@ -31,7 +36,7 @@ def setup(): # Set up tables. setup_test_db("Users", "PackageBases", "Packages", "Licenses", "PackageDepends", "PackageRelations", "PackageLicenses", - "PackageKeywords", "PackageVotes") + "PackageKeywords", "PackageVotes", "ApiRateLimit") # Create test package details. with begin(): @@ -178,6 +183,18 @@ def setup(): scripts.popupdate.run_single(conn, pkgbase1) +@pytest.fixture +def pipeline(): + redis = redis_connection() + pipeline = redis.pipeline() + + pipeline.delete("ratelimit-ws:testclient") + pipeline.delete("ratelimit:testclient") + one, two = pipeline.execute() + + yield pipeline + + def test_rpc_singular_info(): # Define expected response. expected_data = { @@ -441,3 +458,33 @@ def test_rpc_unimplemented_types(): data = response.json() expected = f"Request type '{type}' is not yet implemented." assert data.get("error") == expected + + +def mock_config_getint(section: str, key: str): + if key == "request_limit": + return 4 + elif key == "window_length": + return 100 + return config.getint(section, key) + + +@mock.patch("aurweb.config.getint", side_effect=mock_config_getint) +def test_rpc_ratelimit(getint: mock.MagicMock, pipeline: Pipeline): + for i in range(4): + # The first 4 requests should be good. + response = make_request("/rpc?v=5&type=suggest-pkgbase&arg=big") + assert response.status_code == int(HTTPStatus.OK) + + # The fifth request should be banned. + response = make_request("/rpc?v=5&type=suggest-pkgbase&arg=big") + assert response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS) + + # Delete the cached records. + pipeline.delete("ratelimit-ws:testclient") + pipeline.delete("ratelimit:testclient") + one, two = pipeline.execute() + assert one and two + + # The new first request should be good. + response = make_request("/rpc?v=5&type=suggest-pkgbase&arg=big") + assert response.status_code == int(HTTPStatus.OK)