diff --git a/aurweb/cache.py b/aurweb/cache.py index 1572e2fc..56bb45b7 100644 --- a/aurweb/cache.py +++ b/aurweb/cache.py @@ -1,21 +1,43 @@ -from redis import Redis +import pickle + from sqlalchemy import orm +from aurweb import config +from aurweb.aur_redis import redis_connection -async def db_count_cache( - redis: Redis, key: str, query: orm.Query, expire: int = None -) -> int: +_redis = redis_connection() + + +async def db_count_cache(key: str, query: orm.Query, expire: int = None) -> int: """Store and retrieve a query.count() via redis cache. - :param redis: Redis handle :param key: Redis key :param query: SQLAlchemy ORM query :param expire: Optional expiration in seconds :return: query.count() """ - result = redis.get(key) + result = _redis.get(key) if result is None: - redis.set(key, (result := int(query.count()))) + _redis.set(key, (result := int(query.count()))) if expire: - redis.expire(key, expire) + _redis.expire(key, expire) return int(result) + + +async def db_query_cache(key: str, query: orm.Query, expire: int = None): + """Store and retrieve query results via redis cache. + + :param key: Redis key + :param query: SQLAlchemy ORM query + :param expire: Optional expiration in seconds + :return: query.all() + """ + result = _redis.get(key) + if result is None: + if _redis.dbsize() > config.getint("cache", "max_search_entries", 50000): + return query.all() + _redis.set(key, (result := pickle.dumps(query.all())), ex=expire) + if expire: + _redis.expire(key, expire) + + return pickle.loads(result) diff --git a/aurweb/routers/html.py b/aurweb/routers/html.py index 38303837..fc9f3519 100644 --- a/aurweb/routers/html.py +++ b/aurweb/routers/html.py @@ -89,22 +89,20 @@ async def index(request: Request): bases = db.query(models.PackageBase) - redis = aurweb.aur_redis.redis_connection() - cache_expire = 300 # Five minutes. - + cache_expire = aurweb.config.getint("cache", "expiry_time") # Package statistics. context["package_count"] = await db_count_cache( - redis, "package_count", bases, expire=cache_expire + "package_count", bases, expire=cache_expire ) query = bases.filter(models.PackageBase.MaintainerUID.is_(None)) context["orphan_count"] = await db_count_cache( - redis, "orphan_count", query, expire=cache_expire + "orphan_count", query, expire=cache_expire ) query = db.query(models.User) context["user_count"] = await db_count_cache( - redis, "user_count", query, expire=cache_expire + "user_count", query, expire=cache_expire ) query = query.filter( @@ -114,7 +112,7 @@ async def index(request: Request): ) ) context["trusted_user_count"] = await db_count_cache( - redis, "trusted_user_count", query, expire=cache_expire + "trusted_user_count", query, expire=cache_expire ) # Current timestamp. @@ -130,26 +128,26 @@ async def index(request: Request): query = bases.filter(models.PackageBase.SubmittedTS >= seven_days_ago) context["seven_days_old_added"] = await db_count_cache( - redis, "seven_days_old_added", query, expire=cache_expire + "seven_days_old_added", query, expire=cache_expire ) query = updated.filter(models.PackageBase.ModifiedTS >= seven_days_ago) context["seven_days_old_updated"] = await db_count_cache( - redis, "seven_days_old_updated", query, expire=cache_expire + "seven_days_old_updated", query, expire=cache_expire ) year = seven_days * 52 # Fifty two weeks worth: one year. year_ago = now - year query = updated.filter(models.PackageBase.ModifiedTS >= year_ago) context["year_old_updated"] = await db_count_cache( - redis, "year_old_updated", query, expire=cache_expire + "year_old_updated", query, expire=cache_expire ) query = bases.filter( models.PackageBase.ModifiedTS - models.PackageBase.SubmittedTS < 3600 ) context["never_updated"] = await db_count_cache( - redis, "never_updated", query, expire=cache_expire + "never_updated", query, expire=cache_expire ) # Get the 15 most recently updated packages. diff --git a/aurweb/routers/packages.py b/aurweb/routers/packages.py index 83bfe6e2..779efb4b 100644 --- a/aurweb/routers/packages.py +++ b/aurweb/routers/packages.py @@ -7,6 +7,7 @@ from fastapi import APIRouter, Form, Query, Request, Response import aurweb.filters # noqa: F401 from aurweb import aur_logging, config, db, defaults, models, util from aurweb.auth import creds, requires_auth +from aurweb.cache import db_count_cache, db_query_cache from aurweb.exceptions import InvariantError, handle_form_exceptions from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID from aurweb.packages import util as pkgutil @@ -14,6 +15,7 @@ from aurweb.packages.search import PackageSearch from aurweb.packages.util import get_pkg_or_base from aurweb.pkgbase import actions as pkgbase_actions, util as pkgbaseutil from aurweb.templates import make_context, make_variable_context, render_template +from aurweb.util import hash_query logger = aur_logging.get_logger(__name__) router = APIRouter() @@ -87,7 +89,11 @@ async def packages_get( # Collect search result count here; we've applied our keywords. # Including more query operations below, like ordering, will # increase the amount of time required to collect a count. - num_packages = search.count() + # we use redis for caching the results of the query + cache_expire = config.getint("cache", "expiry_time") + num_packages = await db_count_cache( + hash_query(search.query), search.query, cache_expire + ) # Apply user-specified sort column and ordering. search.sort_by(sort_by, sort_order) @@ -108,7 +114,12 @@ async def packages_get( models.PackageNotification.PackageBaseID.label("Notify"), ) - packages = results.limit(per_page).offset(offset) + # paging + results = results.limit(per_page).offset(offset) + + # we use redis for caching the results of the query + packages = await db_query_cache(hash_query(results), results, cache_expire) + context["packages"] = packages context["packages_count"] = num_packages diff --git a/aurweb/util.py b/aurweb/util.py index d80b0311..7050b482 100644 --- a/aurweb/util.py +++ b/aurweb/util.py @@ -4,6 +4,7 @@ import secrets import shlex import string from datetime import datetime +from hashlib import sha1 from http import HTTPStatus from subprocess import PIPE, Popen from typing import Callable, Iterable, Tuple, Union @@ -13,6 +14,7 @@ import fastapi import pygit2 from email_validator import EmailSyntaxError, validate_email from fastapi.responses import JSONResponse +from sqlalchemy.orm import Query import aurweb.config from aurweb import aur_logging, defaults @@ -200,3 +202,9 @@ def shell_exec(cmdline: str, cwd: str) -> Tuple[int, str, str]: proc = Popen(args, cwd=cwd, stdout=PIPE, stderr=PIPE) out, err = proc.communicate() return proc.returncode, out.decode().strip(), err.decode().strip() + + +def hash_query(query: Query): + return sha1( + str(query.statement.compile(compile_kwargs={"literal_binds": True})).encode() + ).hexdigest() diff --git a/conf/config.defaults b/conf/config.defaults index c059444d..4e2415ed 100644 --- a/conf/config.defaults +++ b/conf/config.defaults @@ -165,3 +165,9 @@ commit_url = https://gitlab.archlinux.org/archlinux/aurweb/-/commits/%s ; voted on based on `now + range_start <= End <= now + range_end`. range_start = 500 range_end = 172800 + +[cache] +; maximum number of keys/entries (for search results) in our redis cache, default is 50000 +max_search_entries = 50000 +; number of seconds after a cache entry expires, default is 3 minutes +expiry_time = 180 diff --git a/test/test_cache.py b/test/test_cache.py index 83a9755a..e19fa6a2 100644 --- a/test/test_cache.py +++ b/test/test_cache.py @@ -1,6 +1,8 @@ +from unittest import mock + import pytest -from aurweb import cache, db +from aurweb import cache, config, db from aurweb.models.account_type import USER_ID from aurweb.models.user import User @@ -10,68 +12,85 @@ def setup(db_test): return -class StubRedis: - """A class which acts as a RedisConnection without using Redis.""" - - cache = dict() - expires = dict() - - def get(self, key, *args): - if "key" not in self.cache: - self.cache[key] = None - return self.cache[key] - - def set(self, key, *args): - self.cache[key] = list(args)[0] - - def expire(self, key, *args): - self.expires[key] = list(args)[0] - - async def execute(self, command, key, *args): - f = getattr(self, command) - return f(key, *args) - - @pytest.fixture -def redis(): - yield StubRedis() +def user() -> User: + with db.begin(): + user = db.create( + User, + Username="test", + Email="test@example.org", + RealName="Test User", + Passwd="testPassword", + AccountTypeID=USER_ID, + ) + yield user + + +@pytest.fixture(autouse=True) +def clear_fakeredis_cache(): + cache._redis.flushall() @pytest.mark.asyncio -async def test_db_count_cache(redis): - db.create( - User, - Username="user1", - Email="user1@example.org", - Passwd="testPassword", - AccountTypeID=USER_ID, - ) - +async def test_db_count_cache(user): query = db.query(User) - # Now, perform several checks that db_count_cache matches query.count(). - # We have no cached value yet. - assert await cache.db_count_cache(redis, "key1", query) == query.count() + assert cache._redis.get("key1") is None + + # Add to cache + assert await cache.db_count_cache("key1", query) == query.count() # It's cached now. - assert await cache.db_count_cache(redis, "key1", query) == query.count() + assert cache._redis.get("key1") is not None + + # It does not expire + assert cache._redis.ttl("key1") == -1 + + # Cache a query with an expire. + value = await cache.db_count_cache("key2", query, 100) + assert value == query.count() + + assert cache._redis.ttl("key2") == 100 @pytest.mark.asyncio -async def test_db_count_cache_expires(redis): - db.create( - User, - Username="user1", - Email="user1@example.org", - Passwd="testPassword", - AccountTypeID=USER_ID, - ) - +async def test_db_query_cache(user): query = db.query(User) - # Cache a query with an expire. - value = await cache.db_count_cache(redis, "key1", query, 100) - assert value == query.count() + # We have no cached value yet. + assert cache._redis.get("key1") is None - assert redis.expires["key1"] == 100 + # Add to cache + await cache.db_query_cache("key1", query) + + # It's cached now. + assert cache._redis.get("key1") is not None + + # Modify our user and make sure we got a cached value + user.Username = "changed" + cached = await cache.db_query_cache("key1", query) + assert cached[0].Username != query.all()[0].Username + + # It does not expire + assert cache._redis.ttl("key1") == -1 + + # Cache a query with an expire. + value = await cache.db_query_cache("key2", query, 100) + assert len(value) == query.count() + assert value[0].Username == query.all()[0].Username + + assert cache._redis.ttl("key2") == 100 + + # Test "max_search_entries" options + def mock_max_search_entries(section: str, key: str, fallback: int) -> str: + if section == "cache" and key == "max_search_entries": + return 1 + return config.getint(section, key) + + with mock.patch("aurweb.config.getint", side_effect=mock_max_search_entries): + # Try to add another entry (we already have 2) + await cache.db_query_cache("key3", query) + + # Make sure it was not added because it exceeds our max. + assert cache._redis.get("key3") is None diff --git a/test/test_packages_routes.py b/test/test_packages_routes.py index 93dc404a..fb12e65e 100644 --- a/test/test_packages_routes.py +++ b/test/test_packages_routes.py @@ -5,7 +5,7 @@ from unittest import mock import pytest from fastapi.testclient import TestClient -from aurweb import asgi, config, db, time +from aurweb import asgi, cache, config, db, time from aurweb.filters import datetime_display from aurweb.models import License, PackageLicense from aurweb.models.account_type import USER_ID, AccountType @@ -63,6 +63,11 @@ def setup(db_test): return +@pytest.fixture(autouse=True) +def clear_fakeredis_cache(): + cache._redis.flushall() + + @pytest.fixture def client() -> TestClient: """Yield a FastAPI TestClient.""" @@ -815,6 +820,8 @@ def test_packages_search_by_keywords(client: TestClient, packages: list[Package] # And request packages with that keyword, we should get 1 result. with client as request: + # clear fakeredis cache + cache._redis.flushall() response = request.get("/packages", params={"SeB": "k", "K": "testKeyword"}) assert response.status_code == int(HTTPStatus.OK) @@ -870,6 +877,8 @@ def test_packages_search_by_maintainer( # This time, we should get `package` returned, since it's now an orphan. with client as request: + # clear fakeredis cache + cache._redis.flushall() response = request.get("/packages", params={"SeB": "m"}) assert response.status_code == int(HTTPStatus.OK) root = parse_root(response.text) @@ -902,6 +911,8 @@ def test_packages_search_by_comaintainer( # Then test that it's returned by our search. with client as request: + # clear fakeredis cache + cache._redis.flushall() response = request.get( "/packages", params={"SeB": "c", "K": maintainer.Username} ) diff --git a/test/test_util.py b/test/test_util.py index a138d912..042b9ad9 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -5,7 +5,8 @@ import fastapi import pytest from fastapi.responses import JSONResponse -from aurweb import filters, util +from aurweb import db, filters, util +from aurweb.models.user import User from aurweb.testing.requests import Request @@ -146,3 +147,26 @@ def assert_multiple_keys(pks): assert key1 == k1[1] assert pfx2 == k2[0] assert key2 == k2[1] + + +def test_hash_query(): + # No conditions + query = db.query(User) + assert util.hash_query(query) == "75e76026b7d576536e745ec22892cf8f5d7b5d62" + + # With where clause + query = db.query(User).filter(User.Username == "bla") + assert util.hash_query(query) == "4dca710f33b1344c27ec6a3c266970f4fa6a8a00" + + # With where clause and sorting + query = db.query(User).filter(User.Username == "bla").order_by(User.Username) + assert util.hash_query(query) == "ee2c7846fede430776e140f8dfe1d83cd21d2eed" + + # With where clause, sorting and specific columns + query = ( + db.query(User) + .filter(User.Username == "bla") + .order_by(User.Username) + .with_entities(User.Username) + ) + assert util.hash_query(query) == "c1db751be61443d266cf643005eee7a884dac103"