diff --git a/aurweb/rpc.py b/aurweb/rpc.py index 6e2a27fe..70d8c2fd 100644 --- a/aurweb/rpc.py +++ b/aurweb/rpc.py @@ -202,7 +202,12 @@ class RPC: models.User.ID == models.PackageBase.MaintainerUID, isouter=True ).filter(models.Package.Name.in_(args)) - packages = self._entities(packages) + + max_results = config.getint("options", "max_rpc_results") + packages = self._entities(packages).limit(max_results + 1) + + if packages.count() > max_results: + raise RPCError("Too many package results.") ids = {pkg.ID for pkg in packages} @@ -274,12 +279,7 @@ class RPC: ] # Union all subqueries together. - max_results = config.getint("options", "max_rpc_results") - query = subqueries[0].union_all(*subqueries[1:]).limit( - max_results + 1).all() - - if len(query) > max_results: - raise RPCError("Too many package results.") + query = subqueries[0].union_all(*subqueries[1:]).all() # Store our extra information in a class-wise dictionary, # which contains package id -> extra info dict mappings. diff --git a/test/test_rpc.py b/test/test_rpc.py index a67a026e..0d6b2931 100644 --- a/test/test_rpc.py +++ b/test/test_rpc.py @@ -15,6 +15,7 @@ import aurweb.models.relation_type as rt from aurweb import asgi, config, db, rpc, scripts, time from aurweb.models.account_type import USER_ID +from aurweb.models.dependency_type import DEPENDS_ID from aurweb.models.license import License from aurweb.models.package import Package from aurweb.models.package_base import PackageBase @@ -23,6 +24,7 @@ from aurweb.models.package_keyword import PackageKeyword from aurweb.models.package_license import PackageLicense from aurweb.models.package_relation import PackageRelation from aurweb.models.package_vote import PackageVote +from aurweb.models.relation_type import PROVIDES_ID from aurweb.models.user import User from aurweb.redis import redis_connection @@ -814,6 +816,16 @@ def test_rpc_too_many_search_results(client: TestClient, def test_rpc_too_many_info_results(client: TestClient, packages: List[Package]): + # Make many of these packages depend and rely on each other. + # This way, we can test to see that the exceeded limit stays true + # regardless of the number of related records. + with db.begin(): + for i in range(len(packages) - 1): + db.create(PackageDependency, DepTypeID=DEPENDS_ID, + Package=packages[i], DepName=packages[i + 1].Name) + db.create(PackageRelation, RelTypeID=PROVIDES_ID, + Package=packages[i], RelName=packages[i + 1].Name) + config_getint = config.getint def mock_config(section: str, key: str):