diff --git a/aurweb/asgi.py b/aurweb/asgi.py index 95bd5e77..2dd546aa 100644 --- a/aurweb/asgi.py +++ b/aurweb/asgi.py @@ -19,7 +19,9 @@ from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.sessions import SessionMiddleware +import aurweb.captcha # noqa: F401 import aurweb.config +import aurweb.filters # noqa: F401 import aurweb.logging import aurweb.pkgbase.util as pkgbaseutil diff --git a/aurweb/auth/__init__.py b/aurweb/auth/__init__.py index 841110fa..8e7faed1 100644 --- a/aurweb/auth/__init__.py +++ b/aurweb/auth/__init__.py @@ -13,7 +13,7 @@ from starlette.requests import HTTPConnection import aurweb.config -from aurweb import db, l10n, util +from aurweb import db, filters, l10n, util from aurweb.models import Session, User from aurweb.models.account_type import ACCOUNT_TYPE_ID @@ -166,7 +166,7 @@ def _auth_required(auth_goal: bool = True): raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=_("Bad Referer header.")) url = referer[len(aur) - 1:] - url = "/login?" + util.urlencode({"next": url}) + url = "/login?" + filters.urlencode({"next": url}) return RedirectResponse(url, status_code=int(HTTPStatus.SEE_OTHER)) return wrapper diff --git a/aurweb/captcha.py b/aurweb/captcha.py index 529b09e1..34d99e53 100644 --- a/aurweb/captcha.py +++ b/aurweb/captcha.py @@ -5,6 +5,7 @@ from jinja2 import pass_context from aurweb.db import query from aurweb.models import User +from aurweb.templates import register_filter def get_captcha_salts(): @@ -41,6 +42,7 @@ def get_captcha_answer(token): return hashlib.md5((text + "\n").encode()).hexdigest()[:6] +@register_filter("captcha_salt") @pass_context def captcha_salt_filter(context): """ Returns the most recent CAPTCHA salt in the list of salts. """ @@ -48,6 +50,7 @@ def captcha_salt_filter(context): return salts[0] +@register_filter("captcha_cmdline") @pass_context def captcha_cmdline_filter(context, salt): """ Returns a CAPTCHA challenge for a given salt. """ diff --git a/aurweb/filters.py b/aurweb/filters.py index f9f56b5d..9b731501 100644 --- a/aurweb/filters.py +++ b/aurweb/filters.py @@ -1,10 +1,19 @@ -from typing import Any, Dict +import copy +import math +from datetime import datetime +from typing import Any, Dict +from urllib.parse import quote_plus, urlencode +from zoneinfo import ZoneInfo + +import fastapi import paginate from jinja2 import pass_context -from aurweb import config, util +import aurweb.models + +from aurweb import config, l10n from aurweb.templates import register_filter, register_function @@ -30,7 +39,7 @@ def pager_nav(context: Dict[str, Any], def create_url(page: int): nonlocal q offset = max(page * pp - pp, 0) - qs = util.to_qs(util.extend_query(q, ["O", offset])) + qs = to_qs(extend_query(q, ["O", offset])) return f"{prefix}?{qs}" # Use the paginate module to produce our linkage. @@ -58,3 +67,84 @@ def config_getint(section: str, key: str) -> int: @register_function("round") def do_round(f: float) -> int: return round(f) + + +@register_filter("tr") +@pass_context +def tr(context: Dict[str, Any], value: str): + """ A translation filter; example: {{ "Hello" | tr("de") }}. """ + _ = l10n.get_translator_for_request(context.get("request")) + return _(value) + + +@register_filter("tn") +@pass_context +def tn(context: Dict[str, Any], count: int, + singular: str, plural: str) -> str: + """ A singular and plural translation filter. + + Example: + {{ some_integer | tn("singular %d", "plural %d") }} + + :param context: Response context + :param count: The number used to decide singular or plural state + :param singular: The singular translation + :param plural: The plural translation + :return: Translated string + """ + gettext = l10n.get_raw_translator_for_request(context.get("request")) + return gettext.ngettext(singular, plural, count) + + +@register_filter("dt") +def timestamp_to_datetime(timestamp: int): + return datetime.utcfromtimestamp(int(timestamp)) + + +@register_filter("as_timezone") +def as_timezone(dt: datetime, timezone: str): + return dt.astimezone(tz=ZoneInfo(timezone)) + + +@register_filter("extend_query") +def extend_query(query: Dict[str, Any], *additions) -> Dict[str, Any]: + """ Add additional key value pairs to query. """ + q = copy.copy(query) + for k, v in list(additions): + q[k] = v + return q + + +@register_filter("urlencode") +def to_qs(query: Dict[str, Any]) -> str: + return urlencode(query, doseq=True) + + +@register_filter("get_vote") +def get_vote(voteinfo, request: fastapi.Request): + from aurweb.models import TUVote + return voteinfo.tu_votes.filter(TUVote.User == request.user).first() + + +@register_filter("number_format") +def number_format(value: float, places: int): + """ A converter function similar to PHP's number_format. """ + return f"{value:.{places}f}" + + +@register_filter("account_url") +@pass_context +def account_url(context: Dict[str, Any], + user: "aurweb.models.user.User") -> str: + base = aurweb.config.get("options", "aur_location") + return f"{base}/account/{user.Username}" + + +@register_filter("quote_plus") +def _quote_plus(*args, **kwargs) -> str: + return quote_plus(*args, **kwargs) + + +@register_filter("ceil") +def ceil(*args, **kwargs) -> int: + return math.ceil(*args, **kwargs) diff --git a/aurweb/l10n.py b/aurweb/l10n.py index c4938d64..f3bbc1da 100644 --- a/aurweb/l10n.py +++ b/aurweb/l10n.py @@ -1,10 +1,8 @@ import gettext -import typing from collections import OrderedDict from fastapi import Request -from jinja2 import pass_context import aurweb.config @@ -86,28 +84,3 @@ def get_translator_for_request(request: Request): return translator.translate(message, lang) return translate - - -@pass_context -def tr(context: typing.Any, value: str): - """ A translation filter; example: {{ "Hello" | tr("de") }}. """ - _ = get_translator_for_request(context.get("request")) - return _(value) - - -@pass_context -def tn(context: typing.Dict[str, typing.Any], count: int, - singular: str, plural: str) -> str: - """ A singular and plural translation filter. - - Example: - {{ some_integer | tn("singular %d", "plural %d") }} - - :param context: Response context - :param count: The number used to decide singular or plural state - :param singular: The singular translation - :param plural: The plural translation - :return: Translated string - """ - gettext = get_raw_translator_for_request(context.get("request")) - return gettext.ngettext(singular, plural, count) diff --git a/aurweb/routers/rss.py b/aurweb/routers/rss.py index 672a47d6..454ff497 100644 --- a/aurweb/routers/rss.py +++ b/aurweb/routers/rss.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Request from fastapi.responses import Response from feedgen.feed import FeedGenerator -from aurweb import db, util +from aurweb import db, filters from aurweb.models import Package, PackageBase router = APIRouter() @@ -39,8 +39,8 @@ def make_rss_feed(request: Request, packages: list, entry.description(pkg.Description or str()) attr = getattr(pkg.PackageBase, date_attr) - dt = util.timestamp_to_datetime(attr) - dt = util.as_timezone(dt, request.user.Timezone) + dt = filters.timestamp_to_datetime(attr) + dt = filters.as_timezone(dt, request.user.Timezone) entry.pubDate(dt.strftime("%Y-%m-%d %H:%M:%S%z")) entry.source(f"{base}") diff --git a/aurweb/rpc.py b/aurweb/rpc.py index 8757d9f9..90e03a41 100644 --- a/aurweb/rpc.py +++ b/aurweb/rpc.py @@ -8,8 +8,9 @@ from sqlalchemy import and_, literal, orm import aurweb.config as config -from aurweb import db, defaults, models, util +from aurweb import db, defaults, models from aurweb.exceptions import RPCError +from aurweb.filters import number_format from aurweb.packages.search import RPCSearch TYPE_MAPPING = { @@ -124,7 +125,7 @@ class RPC: # Produce RPC API compatible Popularity: If zero, it's an integer # 0, otherwise, it's formatted to the 6th decimal place. pop = package.Popularity - pop = 0 if not pop else float(util.number_format(pop, 6)) + pop = 0 if not pop else float(number_format(pop, 6)) snapshot_uri = config.get("options", "snapshot_uri") return { diff --git a/aurweb/scripts/mkpkglists.py b/aurweb/scripts/mkpkglists.py index 92de7931..d97bc01c 100755 --- a/aurweb/scripts/mkpkglists.py +++ b/aurweb/scripts/mkpkglists.py @@ -31,7 +31,7 @@ from sqlalchemy import literal, orm import aurweb.config -from aurweb import db, logging, models, util +from aurweb import db, filters, logging, models, util from aurweb.benchmark import Benchmark from aurweb.models import Package, PackageBase, User @@ -264,7 +264,7 @@ def _main(): with gzip.open(USERS, "wt") as f: f.writelines([f"{user.Username}\n" for i, user in enumerate(query)]) - seconds = util.number_format(bench.end(), 4) + seconds = filters.number_format(bench.end(), 4) logger.info(f"Completed in {seconds} seconds.") diff --git a/aurweb/scripts/notify.py b/aurweb/scripts/notify.py index 1f875a9c..91593e7f 100755 --- a/aurweb/scripts/notify.py +++ b/aurweb/scripts/notify.py @@ -13,6 +13,7 @@ from sqlalchemy import and_, or_ import aurweb.config import aurweb.db +import aurweb.filters import aurweb.l10n from aurweb import db, l10n, logging @@ -160,7 +161,7 @@ class ServerErrorNotification(Notification): def get_body(self, lang: str) -> str: """ A forcibly English email body. """ - dt = aurweb.util.timestamp_to_datetime(self._utc) + dt = aurweb.filters.timestamp_to_datetime(self._utc) dts = dt.strftime("%Y-%m-%d %H:%M") return (f"Traceback ID: {self._tb_id}\n" f"Location: {aur_location}\n" diff --git a/aurweb/templates.py b/aurweb/templates.py index 74c993f8..82f115d1 100644 --- a/aurweb/templates.py +++ b/aurweb/templates.py @@ -1,23 +1,20 @@ import copy import functools -import math import os import zoneinfo from datetime import datetime from http import HTTPStatus from typing import Callable -from urllib.parse import quote_plus import jinja2 from fastapi import Request from fastapi.responses import HTMLResponse -import aurweb.auth.creds import aurweb.config -from aurweb import captcha, cookies, l10n, time, util +from aurweb import cookies, l10n, time # Prepare jinja2 objects. _loader = jinja2.FileSystemLoader(os.path.join( @@ -25,27 +22,6 @@ _loader = jinja2.FileSystemLoader(os.path.join( _env = jinja2.Environment(loader=_loader, autoescape=True, extensions=["jinja2.ext.i18n"]) -# Add t{r,n} translation filters. -_env.filters["tr"] = l10n.tr -_env.filters["tn"] = l10n.tn - -# Utility filters. -_env.filters["dt"] = util.timestamp_to_datetime -_env.filters["as_timezone"] = util.as_timezone -_env.filters["extend_query"] = util.extend_query -_env.filters["urlencode"] = util.to_qs -_env.filters["quote_plus"] = quote_plus -_env.filters["get_vote"] = util.get_vote -_env.filters["number_format"] = util.number_format -_env.filters["ceil"] = math.ceil - -# Add captcha filters. -_env.filters["captcha_salt"] = captcha.captcha_salt_filter -_env.filters["captcha_cmdline"] = captcha.captcha_cmdline_filter - -# Add account utility filters. -_env.filters["account_url"] = util.account_url - def register_filter(name: str) -> Callable: """ A decorator that can be used to register a filter. @@ -65,8 +41,6 @@ def register_filter(name: str) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) - if name in _env.filters: - raise KeyError(f"Jinja already has a filter named '{name}'") _env.filters[name] = wrapper return wrapper return decorator @@ -88,6 +62,7 @@ def register_function(name: str) -> Callable: def make_context(request: Request, title: str, next: str = None): """ Create a context for a jinja2 TemplateResponse. """ + import aurweb.auth.creds commit_url = aurweb.config.get_with_fallback("devel", "commit_url", None) commit_hash = aurweb.config.get_with_fallback("devel", "commit_hash", None) diff --git a/aurweb/util.py b/aurweb/util.py index bda743fd..776cf516 100644 --- a/aurweb/util.py +++ b/aurweb/util.py @@ -1,5 +1,4 @@ import base64 -import copy import math import re import secrets @@ -8,16 +7,14 @@ import string from datetime import datetime from distutils.util import strtobool as _strtobool from http import HTTPStatus -from typing import Any, Callable, Dict, Iterable, Tuple -from urllib.parse import urlencode, urlparse -from zoneinfo import ZoneInfo +from typing import Callable, Iterable, Tuple +from urllib.parse import urlparse import fastapi import pygit2 from email_validator import EmailNotValidError, EmailUndeliverableError, validate_email from fastapi.responses import JSONResponse -from jinja2 import pass_context import aurweb.config @@ -107,43 +104,6 @@ def valid_ssh_pubkey(pk): return base64.b64encode(base64.b64decode(tokens[1])).decode() == tokens[1] -@pass_context -def account_url(context: Dict[str, Any], - user: "aurweb.models.user.User") -> str: - base = aurweb.config.get("options", "aur_location") - return f"{base}/account/{user.Username}" - - -def timestamp_to_datetime(timestamp: int): - return datetime.utcfromtimestamp(int(timestamp)) - - -def as_timezone(dt: datetime, timezone: str): - return dt.astimezone(tz=ZoneInfo(timezone)) - - -def extend_query(query: Dict[str, Any], *additions) -> Dict[str, Any]: - """ Add additional key value pairs to query. """ - q = copy.copy(query) - for k, v in list(additions): - q[k] = v - return q - - -def to_qs(query: Dict[str, Any]) -> str: - return urlencode(query, doseq=True) - - -def get_vote(voteinfo, request: fastapi.Request): - from aurweb.models import TUVote - return voteinfo.tu_votes.filter(TUVote.User == request.user).first() - - -def number_format(value: float, places: int): - """ A converter function similar to PHP's number_format. """ - return f"{value:.{places}f}" - - def jsonify(obj): """ Perform a conversion on obj if it's needed. """ if isinstance(obj, datetime): diff --git a/test/test_filters.py b/test/test_filters.py new file mode 100644 index 00000000..53d95cdf --- /dev/null +++ b/test/test_filters.py @@ -0,0 +1,36 @@ +from datetime import datetime +from zoneinfo import ZoneInfo + +from aurweb import filters + + +def test_timestamp_to_datetime(): + ts = datetime.utcnow().timestamp() + dt = datetime.utcfromtimestamp(int(ts)) + assert filters.timestamp_to_datetime(ts) == dt + + +def test_as_timezone(): + ts = datetime.utcnow().timestamp() + dt = filters.timestamp_to_datetime(ts) + assert filters.as_timezone(dt, "UTC") == dt.astimezone(tz=ZoneInfo("UTC")) + + +def test_number_format(): + assert filters.number_format(0.222, 2) == "0.22" + assert filters.number_format(0.226, 2) == "0.23" + + +def test_extend_query(): + """ Test extension of a query via extend_query. """ + query = {"a": "b"} + extended = filters.extend_query(query, ("a", "c"), ("b", "d")) + assert extended.get("a") == "c" + assert extended.get("b") == "d" + + +def test_to_qs(): + """ Test conversion from a query dictionary to a query string. """ + query = {"a": "b", "c": [1, 2, 3]} + qs = filters.to_qs(query) + assert qs == "a=b&c=1&c=2&c=3" diff --git a/test/test_l10n.py b/test/test_l10n.py index 1c2ae95a..c24c5f55 100644 --- a/test/test_l10n.py +++ b/test/test_l10n.py @@ -1,5 +1,5 @@ """ Test our l10n module. """ -from aurweb import l10n +from aurweb import filters, l10n from aurweb.testing.requests import Request @@ -43,8 +43,10 @@ def test_tn_filter(): request.cookies["AURLANG"] = "en" context = {"language": "en", "request": request} - translated = l10n.tn(context, 1, "%d package found.", "%d packages found.") + translated = filters.tn(context, 1, "%d package found.", + "%d packages found.") assert translated == "%d package found." - translated = l10n.tn(context, 2, "%d package found.", "%d packages found.") + translated = filters.tn(context, 2, "%d package found.", + "%d packages found.") assert translated == "%d packages found." diff --git a/test/test_templates.py b/test/test_templates.py index 6104c126..7d393b61 100644 --- a/test/test_templates.py +++ b/test/test_templates.py @@ -8,6 +8,8 @@ import pytest import aurweb.filters # noqa: F401 from aurweb import config, db, templates +from aurweb.filters import as_timezone, number_format +from aurweb.filters import timestamp_to_datetime as to_dt from aurweb.models import Package, PackageBase, User from aurweb.models.account_type import USER_ID from aurweb.models.license import License @@ -17,8 +19,6 @@ from aurweb.models.relation_type import PROVIDES_ID, REPLACES_ID from aurweb.templates import base_template, make_context, register_filter, register_function from aurweb.testing.html import parse_root from aurweb.testing.requests import Request -from aurweb.util import as_timezone, number_format -from aurweb.util import timestamp_to_datetime as to_dt GIT_CLONE_URI_ANON = "anon_%s" GIT_CLONE_URI_PRIV = "priv_%s" @@ -79,15 +79,6 @@ def create_license(pkg: Package, license_name: str) -> PackageLicense: return pkglic -def test_register_filter_exists_key_error(): - """ Most instances of register_filter are tested through module - imports or template renders, so we only test failures here. """ - with pytest.raises(KeyError): - @register_filter("func") - def some_func(): - pass - - def test_register_function_exists_key_error(): """ Most instances of register_filter are tested through module imports or template renders, so we only test failures here. """ diff --git a/test/test_trusted_user_routes.py b/test/test_trusted_user_routes.py index b050af22..ae5ad418 100644 --- a/test/test_trusted_user_routes.py +++ b/test/test_trusted_user_routes.py @@ -10,7 +10,7 @@ import pytest from fastapi.testclient import TestClient -from aurweb import config, db, util +from aurweb import config, db, filters from aurweb.models.account_type import DEVELOPER_ID, AccountType from aurweb.models.tu_vote import TUVote from aurweb.models.tu_voteinfo import TUVoteInfo @@ -130,7 +130,7 @@ def test_tu_index_guest(client): response = request.get("/tu", allow_redirects=False, headers=headers) assert response.status_code == int(HTTPStatus.SEE_OTHER) - params = util.urlencode({"next": "/tu"}) + params = filters.urlencode({"next": "/tu"}) assert response.headers.get("location") == f"/login?{params}" diff --git a/test/test_util.py b/test/test_util.py index 91a0f475..51d978fb 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -1,8 +1,6 @@ import json -from datetime import datetime from http import HTTPStatus -from zoneinfo import ZoneInfo import fastapi import pytest @@ -13,38 +11,6 @@ from aurweb import filters, util from aurweb.testing.requests import Request -def test_timestamp_to_datetime(): - ts = datetime.utcnow().timestamp() - dt = datetime.utcfromtimestamp(int(ts)) - assert util.timestamp_to_datetime(ts) == dt - - -def test_as_timezone(): - ts = datetime.utcnow().timestamp() - dt = util.timestamp_to_datetime(ts) - assert util.as_timezone(dt, "UTC") == dt.astimezone(tz=ZoneInfo("UTC")) - - -def test_number_format(): - assert util.number_format(0.222, 2) == "0.22" - assert util.number_format(0.226, 2) == "0.23" - - -def test_extend_query(): - """ Test extension of a query via extend_query. """ - query = {"a": "b"} - extended = util.extend_query(query, ("a", "c"), ("b", "d")) - assert extended.get("a") == "c" - assert extended.get("b") == "d" - - -def test_to_qs(): - """ Test conversion from a query dictionary to a query string. """ - query = {"a": "b", "c": [1, 2, 3]} - qs = util.to_qs(query) - assert qs == "a=b&c=1&c=2&c=3" - - def test_round(): assert filters.do_round(1.3) == 1 assert filters.do_round(1.5) == 2