mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
Merge branch 'db-rework' into pu
This commit is contained in:
commit
91b570ff0d
24 changed files with 233 additions and 166 deletions
|
@ -13,7 +13,7 @@ from starlette.requests import HTTPConnection
|
||||||
|
|
||||||
import aurweb.config
|
import aurweb.config
|
||||||
|
|
||||||
from aurweb import l10n, util
|
from aurweb import db, l10n, util
|
||||||
from aurweb.models import Session, User
|
from aurweb.models import Session, User
|
||||||
from aurweb.models.account_type import ACCOUNT_TYPE_ID
|
from aurweb.models.account_type import ACCOUNT_TYPE_ID
|
||||||
from aurweb.templates import make_variable_context, render_template
|
from aurweb.templates import make_variable_context, render_template
|
||||||
|
@ -98,14 +98,12 @@ class AnonymousUser:
|
||||||
|
|
||||||
class BasicAuthBackend(AuthenticationBackend):
|
class BasicAuthBackend(AuthenticationBackend):
|
||||||
async def authenticate(self, conn: HTTPConnection):
|
async def authenticate(self, conn: HTTPConnection):
|
||||||
from aurweb.db import session
|
|
||||||
|
|
||||||
sid = conn.cookies.get("AURSID")
|
sid = conn.cookies.get("AURSID")
|
||||||
if not sid:
|
if not sid:
|
||||||
return (None, AnonymousUser())
|
return (None, AnonymousUser())
|
||||||
|
|
||||||
now_ts = datetime.utcnow().timestamp()
|
now_ts = datetime.utcnow().timestamp()
|
||||||
record = session.query(Session).filter(
|
record = db.query(Session).filter(
|
||||||
and_(Session.SessionID == sid,
|
and_(Session.SessionID == sid,
|
||||||
Session.LastUpdateTS >= now_ts)).first()
|
Session.LastUpdateTS >= now_ts)).first()
|
||||||
|
|
||||||
|
@ -116,7 +114,7 @@ class BasicAuthBackend(AuthenticationBackend):
|
||||||
# At this point, we cannot have an invalid user if the record
|
# At this point, we cannot have an invalid user if the record
|
||||||
# exists, due to ForeignKey constraints in the schema upheld
|
# exists, due to ForeignKey constraints in the schema upheld
|
||||||
# by mysqlclient.
|
# by mysqlclient.
|
||||||
user = session.query(User).filter(User.ID == record.UsersID).first()
|
user = db.query(User).filter(User.ID == record.UsersID).first()
|
||||||
user.nonce = util.make_nonce()
|
user.nonce = util.make_nonce()
|
||||||
user.authenticated = True
|
user.authenticated = True
|
||||||
|
|
||||||
|
|
89
aurweb/db.py
89
aurweb/db.py
|
@ -2,10 +2,10 @@ import functools
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable, NewType
|
||||||
|
|
||||||
from sqlalchemy import event
|
from sqlalchemy import event
|
||||||
from sqlalchemy.orm import scoped_session
|
from sqlalchemy.orm import Query, scoped_session
|
||||||
|
|
||||||
import aurweb.config
|
import aurweb.config
|
||||||
import aurweb.util
|
import aurweb.util
|
||||||
|
@ -22,6 +22,9 @@ session = None
|
||||||
# Global introspected object memo.
|
# Global introspected object memo.
|
||||||
introspected = dict()
|
introspected = dict()
|
||||||
|
|
||||||
|
# A mocked up type.
|
||||||
|
Base = NewType("aurweb.models.declarative_base.Base", "Base")
|
||||||
|
|
||||||
|
|
||||||
def make_random_value(table: str, column: str):
|
def make_random_value(table: str, column: str):
|
||||||
""" Generate a unique, random value for a string column in a table.
|
""" Generate a unique, random value for a string column in a table.
|
||||||
|
@ -58,55 +61,69 @@ def make_random_value(table: str, column: str):
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def query(model, *args, **kwargs):
|
def get_session():
|
||||||
return session.query(model).filter(*args, **kwargs)
|
""" Return aurweb.db's global session. """
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
def create(model, *args, **kwargs):
|
def refresh(model: Base) -> Base:
|
||||||
instance = model(*args, **kwargs)
|
""" Refresh the session's knowledge of `model`. """
|
||||||
|
get_session().refresh(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def query(Model: Base, *args, **kwargs) -> Query:
|
||||||
|
"""
|
||||||
|
Perform an ORM query against the database session.
|
||||||
|
|
||||||
|
This method also runs Query.filter on the resulting model
|
||||||
|
query with *args and **kwargs.
|
||||||
|
|
||||||
|
:param Model: Declarative ORM class
|
||||||
|
"""
|
||||||
|
return get_session().query(Model).filter(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def create(Model: Base, *args, **kwargs) -> Base:
|
||||||
|
"""
|
||||||
|
Create a record and add() it to the database session.
|
||||||
|
|
||||||
|
:param Model: Declarative ORM class
|
||||||
|
:return: Model instance
|
||||||
|
"""
|
||||||
|
instance = Model(*args, **kwargs)
|
||||||
return add(instance)
|
return add(instance)
|
||||||
|
|
||||||
|
|
||||||
def delete(model, *args, **kwargs):
|
def delete(model: Base) -> None:
|
||||||
instance = session.query(model).filter(*args, **kwargs)
|
"""
|
||||||
for record in instance:
|
Delete a set of records found by Query.filter(*args, **kwargs).
|
||||||
session.delete(record)
|
|
||||||
|
:param Model: Declarative ORM class
|
||||||
|
"""
|
||||||
|
get_session().delete(model)
|
||||||
|
|
||||||
|
|
||||||
def delete_all(iterable: Iterable):
|
def delete_all(iterable: Iterable) -> None:
|
||||||
with begin():
|
""" Delete each instance found in `iterable`. """
|
||||||
for obj in iterable:
|
session_ = get_session()
|
||||||
session.delete(obj)
|
aurweb.util.apply_all(iterable, session_.delete)
|
||||||
|
|
||||||
|
|
||||||
def rollback():
|
def rollback() -> None:
|
||||||
session.rollback()
|
""" Rollback the database session. """
|
||||||
|
get_session().rollback()
|
||||||
|
|
||||||
|
|
||||||
def add(model):
|
def add(model: Base) -> Base:
|
||||||
session.add(model)
|
""" Add `model` to the database session. """
|
||||||
|
get_session().add(model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def begin():
|
def begin():
|
||||||
""" Begin an SQLAlchemy SessionTransaction.
|
""" Begin an SQLAlchemy SessionTransaction. """
|
||||||
|
return get_session().begin()
|
||||||
This context is **required** to perform an modifications to the
|
|
||||||
database.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
with db.begin():
|
|
||||||
object = db.create(...)
|
|
||||||
# On __exit__, db.commit() is run.
|
|
||||||
|
|
||||||
with db.begin():
|
|
||||||
object = db.delete(...)
|
|
||||||
# On __exit__, db.commit() is run.
|
|
||||||
|
|
||||||
:return: A new SessionTransaction based on session
|
|
||||||
"""
|
|
||||||
return session.begin()
|
|
||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_url():
|
def get_sqlalchemy_url():
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from aurweb import schema
|
from aurweb import db, schema
|
||||||
from aurweb.models.declarative import Base
|
from aurweb.models.declarative import Base
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,11 +10,10 @@ class Ban(Base):
|
||||||
__mapper_args__ = {"primary_key": [__table__.c.IPAddress]}
|
__mapper_args__ = {"primary_key": [__table__.c.IPAddress]}
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.IPAddress = kwargs.get("IPAddress")
|
super().__init__(**kwargs)
|
||||||
self.BanTS = kwargs.get("BanTS")
|
|
||||||
|
|
||||||
|
|
||||||
def is_banned(request: Request):
|
def is_banned(request: Request):
|
||||||
from aurweb.db import session
|
|
||||||
ip = request.client.host
|
ip = request.client.host
|
||||||
return session.query(Ban).filter(Ban.IPAddress == ip).first() is not None
|
exists = db.query(Ban).filter(Ban.IPAddress == ip).exists()
|
||||||
|
return db.query(exists).scalar()
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import backref, relationship
|
from sqlalchemy.orm import backref, relationship
|
||||||
|
|
||||||
from aurweb import schema
|
from aurweb import db, schema
|
||||||
from aurweb.models.declarative import Base
|
from aurweb.models.declarative import Base
|
||||||
from aurweb.models.dependency_type import DependencyType as _DependencyType
|
from aurweb.models.dependency_type import DependencyType as _DependencyType
|
||||||
|
from aurweb.models.official_provider import OfficialProvider as _OfficialProvider
|
||||||
from aurweb.models.package import Package as _Package
|
from aurweb.models.package import Package as _Package
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,11 +47,7 @@ class PackageDependency(Base):
|
||||||
params=("NULL"))
|
params=("NULL"))
|
||||||
|
|
||||||
def is_package(self) -> bool:
|
def is_package(self) -> bool:
|
||||||
# TODO: Improve the speed of this query if possible.
|
pkg = db.query(_Package).filter(_Package.Name == self.DepName).exists()
|
||||||
from aurweb import db
|
official = db.query(_OfficialProvider).filter(
|
||||||
from aurweb.models.official_provider import OfficialProvider
|
_OfficialProvider.Name == self.DepName).exists()
|
||||||
from aurweb.models.package import Package
|
return db.query(pkg).scalar() or db.query(official).scalar()
|
||||||
pkg = db.query(Package, Package.Name == self.DepName)
|
|
||||||
official = db.query(OfficialProvider,
|
|
||||||
OfficialProvider.Name == self.DepName)
|
|
||||||
return pkg.scalar() or official.scalar()
|
|
||||||
|
|
|
@ -146,7 +146,7 @@ class User(Base):
|
||||||
self.authenticated = False
|
self.authenticated = False
|
||||||
if self.session:
|
if self.session:
|
||||||
with db.begin():
|
with db.begin():
|
||||||
db.session.delete(self.session)
|
db.delete(self.session)
|
||||||
|
|
||||||
def is_trusted_user(self):
|
def is_trusted_user(self):
|
||||||
return self.AccountType.ID in {
|
return self.AccountType.ID in {
|
||||||
|
|
|
@ -110,18 +110,26 @@ def get_pkg_or_base(
|
||||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
|
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
instance = db.query(cls).filter(cls.Name == name).first()
|
instance = db.query(cls).filter(cls.Name == name).first()
|
||||||
if cls == models.PackageBase and not instance:
|
if not instance:
|
||||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
|
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
return instance
|
return db.refresh(instance)
|
||||||
|
|
||||||
|
|
||||||
def get_pkgbase_comment(
|
def get_pkgbase_comment(pkgbase: models.PackageBase, id: int) \
|
||||||
pkgbase: models.PackageBase, id: int) -> models.PackageComment:
|
-> models.PackageComment:
|
||||||
comment = pkgbase.comments.filter(models.PackageComment.ID == id).first()
|
comment = pkgbase.comments.filter(models.PackageComment.ID == id).first()
|
||||||
if not comment:
|
if not comment:
|
||||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
|
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
|
||||||
return comment
|
return db.refresh(comment)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pkgreq_by_id(id: int):
|
||||||
|
pkgreq = db.query(models.PackageRequest).filter(
|
||||||
|
models.PackageRequest.ID == id).first()
|
||||||
|
if not pkgreq:
|
||||||
|
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
|
||||||
|
return db.refresh(pkgreq)
|
||||||
|
|
||||||
|
|
||||||
@register_filter("out_of_date")
|
@register_filter("out_of_date")
|
||||||
|
|
|
@ -40,8 +40,10 @@ def _update_ratelimit_db(request: Request):
|
||||||
now = int(datetime.utcnow().timestamp())
|
now = int(datetime.utcnow().timestamp())
|
||||||
time_to_delete = now - window_length
|
time_to_delete = now - window_length
|
||||||
|
|
||||||
|
records = db.query(ApiRateLimit).filter(
|
||||||
|
ApiRateLimit.WindowStart < time_to_delete)
|
||||||
with db.begin():
|
with db.begin():
|
||||||
db.delete(ApiRateLimit, ApiRateLimit.WindowStart < time_to_delete)
|
db.delete_all(records)
|
||||||
|
|
||||||
host = request.client.host
|
host = request.client.host
|
||||||
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
|
record = db.query(ApiRateLimit, ApiRateLimit.IP == host).first()
|
||||||
|
|
|
@ -4,7 +4,7 @@ import typing
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
from fastapi import APIRouter, Form, HTTPException, Request
|
from fastapi import APIRouter, Form, Request
|
||||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
from sqlalchemy import and_, func, or_
|
from sqlalchemy import and_, func, or_
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from aurweb.models.account_type import (DEVELOPER, DEVELOPER_ID, TRUSTED_USER, T
|
||||||
from aurweb.models.ssh_pub_key import get_fingerprint
|
from aurweb.models.ssh_pub_key import get_fingerprint
|
||||||
from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification
|
from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification
|
||||||
from aurweb.templates import make_context, make_variable_context, render_template
|
from aurweb.templates import make_context, make_variable_context, render_template
|
||||||
|
from aurweb.users.util import get_user_by_name
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -49,6 +50,7 @@ async def passreset_post(request: Request,
|
||||||
return render_template(request, "passreset.html", context,
|
return render_template(request, "passreset.html", context,
|
||||||
status_code=HTTPStatus.NOT_FOUND)
|
status_code=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
|
db.refresh(user)
|
||||||
if resetkey:
|
if resetkey:
|
||||||
context["resetkey"] = resetkey
|
context["resetkey"] = resetkey
|
||||||
|
|
||||||
|
@ -83,7 +85,7 @@ async def passreset_post(request: Request,
|
||||||
with db.begin():
|
with db.begin():
|
||||||
user.ResetKey = str()
|
user.ResetKey = str()
|
||||||
if user.session:
|
if user.session:
|
||||||
db.session.delete(user.session)
|
db.delete(user.session)
|
||||||
user.update_password(password)
|
user.update_password(password)
|
||||||
|
|
||||||
# Render ?step=complete.
|
# Render ?step=complete.
|
||||||
|
@ -458,15 +460,15 @@ def cannot_edit(request, user):
|
||||||
|
|
||||||
@router.get("/account/{username}/edit", response_class=HTMLResponse)
|
@router.get("/account/{username}/edit", response_class=HTMLResponse)
|
||||||
@auth_required(True, redirect="/account/{username}")
|
@auth_required(True, redirect="/account/{username}")
|
||||||
async def account_edit(request: Request,
|
async def account_edit(request: Request, username: str):
|
||||||
username: str):
|
|
||||||
user = db.query(models.User, models.User.Username == username).first()
|
user = db.query(models.User, models.User.Username == username).first()
|
||||||
|
|
||||||
response = cannot_edit(request, user)
|
response = cannot_edit(request, user)
|
||||||
if response:
|
if response:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
context = await make_variable_context(request, "Accounts")
|
context = await make_variable_context(request, "Accounts")
|
||||||
context["user"] = user
|
context["user"] = db.refresh(user)
|
||||||
|
|
||||||
context = make_account_form_context(context, request, user, dict())
|
context = make_account_form_context(context, request, user, dict())
|
||||||
return render_template(request, "account/edit.html", context)
|
return render_template(request, "account/edit.html", context)
|
||||||
|
@ -497,16 +499,14 @@ async def account_edit_post(request: Request,
|
||||||
ON: bool = Form(default=False), # Owner Notify
|
ON: bool = Form(default=False), # Owner Notify
|
||||||
T: int = Form(default=None),
|
T: int = Form(default=None),
|
||||||
passwd: str = Form(default=str())):
|
passwd: str = Form(default=str())):
|
||||||
from aurweb.db import session
|
user = db.query(models.User).filter(
|
||||||
|
|
||||||
user = session.query(models.User).filter(
|
|
||||||
models.User.Username == username).first()
|
models.User.Username == username).first()
|
||||||
response = cannot_edit(request, user)
|
response = cannot_edit(request, user)
|
||||||
if response:
|
if response:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
context = await make_variable_context(request, "Accounts")
|
context = await make_variable_context(request, "Accounts")
|
||||||
context["user"] = user
|
context["user"] = db.refresh(user)
|
||||||
|
|
||||||
args = dict(await request.form())
|
args = dict(await request.form())
|
||||||
context = make_account_form_context(context, request, user, args)
|
context = make_account_form_context(context, request, user, args)
|
||||||
|
@ -575,7 +575,7 @@ async def account_edit_post(request: Request,
|
||||||
user.ssh_pub_key.Fingerprint = fingerprint
|
user.ssh_pub_key.Fingerprint = fingerprint
|
||||||
elif user.ssh_pub_key:
|
elif user.ssh_pub_key:
|
||||||
# Else, if the user has a public key already, delete it.
|
# Else, if the user has a public key already, delete it.
|
||||||
session.delete(user.ssh_pub_key)
|
db.delete(user.ssh_pub_key)
|
||||||
|
|
||||||
if T and T != user.AccountTypeID:
|
if T and T != user.AccountTypeID:
|
||||||
with db.begin():
|
with db.begin():
|
||||||
|
@ -617,27 +617,16 @@ account_template = (
|
||||||
status_code=HTTPStatus.UNAUTHORIZED)
|
status_code=HTTPStatus.UNAUTHORIZED)
|
||||||
async def account(request: Request, username: str):
|
async def account(request: Request, username: str):
|
||||||
_ = l10n.get_translator_for_request(request)
|
_ = l10n.get_translator_for_request(request)
|
||||||
context = await make_variable_context(request,
|
context = await make_variable_context(
|
||||||
_("Account") + " " + username)
|
request, _("Account") + " " + username)
|
||||||
|
context["user"] = get_user_by_name(username)
|
||||||
user = db.query(models.User, models.User.Username == username).first()
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
|
|
||||||
|
|
||||||
context["user"] = user
|
|
||||||
|
|
||||||
return render_template(request, "account/show.html", context)
|
return render_template(request, "account/show.html", context)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/account/{username}/comments")
|
@router.get("/account/{username}/comments")
|
||||||
@auth_required(redirect="/account/{username}/comments")
|
@auth_required(redirect="/account/{username}/comments")
|
||||||
async def account_comments(request: Request, username: str):
|
async def account_comments(request: Request, username: str):
|
||||||
user = db.query(models.User).filter(
|
user = get_user_by_name(username)
|
||||||
models.User.Username == username
|
|
||||||
).first()
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
|
|
||||||
|
|
||||||
context = make_context(request, "Accounts")
|
context = make_context(request, "Accounts")
|
||||||
context["username"] = username
|
context["username"] = username
|
||||||
context["comments"] = user.package_comments.order_by(
|
context["comments"] = user.package_comments.order_by(
|
||||||
|
@ -662,7 +651,7 @@ async def accounts(request: Request):
|
||||||
account_type.TRUSTED_USER_AND_DEV})
|
account_type.TRUSTED_USER_AND_DEV})
|
||||||
async def accounts_post(request: Request,
|
async def accounts_post(request: Request,
|
||||||
O: int = Form(default=0), # Offset
|
O: int = Form(default=0), # Offset
|
||||||
SB: str = Form(default=str()), # Search By
|
SB: str = Form(default=str()), # Sort By
|
||||||
U: str = Form(default=str()), # Username
|
U: str = Form(default=str()), # Username
|
||||||
T: str = Form(default=str()), # Account Type
|
T: str = Form(default=str()), # Account Type
|
||||||
S: bool = Form(default=False), # Suspended
|
S: bool = Form(default=False), # Suspended
|
||||||
|
@ -705,23 +694,19 @@ async def accounts_post(request: Request,
|
||||||
|
|
||||||
# Populate this list with any additional statements to
|
# Populate this list with any additional statements to
|
||||||
# be ANDed together.
|
# be ANDed together.
|
||||||
statements = []
|
statements = [
|
||||||
if account_type_id is not None:
|
v for k, v in [
|
||||||
statements.append(models.AccountType.ID == account_type_id)
|
(account_type_id is not None, models.AccountType.ID == account_type_id),
|
||||||
if U:
|
(bool(U), models.User.Username.like(f"%{U}%")),
|
||||||
statements.append(models.User.Username.like(f"%{U}%"))
|
(bool(S), models.User.Suspended == S),
|
||||||
if S:
|
(bool(E), models.User.Email.like(f"%{E}%")),
|
||||||
statements.append(models.User.Suspended == S)
|
(bool(R), models.User.RealName.like(f"%{R}%")),
|
||||||
if E:
|
(bool(I), models.User.IRCNick.like(f"%{I}%")),
|
||||||
statements.append(models.User.Email.like(f"%{E}%"))
|
(bool(K), models.User.PGPKey.like(f"%{K}%")),
|
||||||
if R:
|
] if k
|
||||||
statements.append(models.User.RealName.like(f"%{R}%"))
|
]
|
||||||
if I:
|
|
||||||
statements.append(models.User.IRCNick.like(f"%{I}%"))
|
|
||||||
if K:
|
|
||||||
statements.append(models.User.PGPKey.like(f"%{K}%"))
|
|
||||||
|
|
||||||
# Filter the query by combining all statements added above into
|
# Filter the query by coe-mbining all statements added above into
|
||||||
# an AND statement, unless there's just one statement, which
|
# an AND statement, unless there's just one statement, which
|
||||||
# we pass on to filter() as args.
|
# we pass on to filter() as args.
|
||||||
if statements:
|
if statements:
|
||||||
|
@ -729,7 +714,7 @@ async def accounts_post(request: Request,
|
||||||
|
|
||||||
# Finally, order and truncate our users for the current page.
|
# Finally, order and truncate our users for the current page.
|
||||||
users = query.order_by(*order_by).limit(pp).offset(offset)
|
users = query.order_by(*order_by).limit(pp).offset(offset)
|
||||||
context["users"] = users
|
context["users"] = util.apply_all(users, db.refresh)
|
||||||
|
|
||||||
return render_template(request, "account/index.html", context)
|
return render_template(request, "account/index.html", context)
|
||||||
|
|
||||||
|
@ -755,6 +740,9 @@ async def terms_of_service(request: Request):
|
||||||
unaccepted = db.query(models.Term).filter(
|
unaccepted = db.query(models.Term).filter(
|
||||||
~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all()
|
~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all()
|
||||||
|
|
||||||
|
for record in (diffs + unaccepted):
|
||||||
|
db.refresh(record)
|
||||||
|
|
||||||
# Translate the 'Terms of Service' part of our page title.
|
# Translate the 'Terms of Service' part of our page title.
|
||||||
_ = l10n.get_translator_for_request(request)
|
_ = l10n.get_translator_for_request(request)
|
||||||
title = f"AUR {_('Terms of Service')}"
|
title = f"AUR {_('Terms of Service')}"
|
||||||
|
@ -786,18 +774,21 @@ async def terms_of_service_post(request: Request,
|
||||||
# We already did the database filters here, so let's just use
|
# We already did the database filters here, so let's just use
|
||||||
# them instead of reiterating the process in terms_of_service.
|
# them instead of reiterating the process in terms_of_service.
|
||||||
accept_needed = sorted(unaccepted + diffs)
|
accept_needed = sorted(unaccepted + diffs)
|
||||||
return render_terms_of_service(request, context, accept_needed)
|
return render_terms_of_service(
|
||||||
|
request, context, util.apply_all(accept_needed, db.refresh))
|
||||||
|
|
||||||
with db.begin():
|
with db.begin():
|
||||||
# For each term we found, query for the matching accepted term
|
# For each term we found, query for the matching accepted term
|
||||||
# and update its Revision to the term's current Revision.
|
# and update its Revision to the term's current Revision.
|
||||||
for term in diffs:
|
for term in diffs:
|
||||||
|
db.refresh(term)
|
||||||
accepted_term = request.user.accepted_terms.filter(
|
accepted_term = request.user.accepted_terms.filter(
|
||||||
models.AcceptedTerm.TermsID == term.ID).first()
|
models.AcceptedTerm.TermsID == term.ID).first()
|
||||||
accepted_term.Revision = term.Revision
|
accepted_term.Revision = term.Revision
|
||||||
|
|
||||||
# For each term that was never accepted, accept it!
|
# For each term that was never accepted, accept it!
|
||||||
for term in unaccepted:
|
for term in unaccepted:
|
||||||
|
db.refresh(term)
|
||||||
db.create(models.AcceptedTerm, User=request.user,
|
db.create(models.AcceptedTerm, User=request.user,
|
||||||
Term=term, Revision=term.Revision)
|
Term=term, Revision=term.Revision)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List
|
||||||
|
|
||||||
from fastapi import APIRouter, Form, HTTPException, Query, Request, Response
|
from fastapi import APIRouter, Form, HTTPException, Query, Request, Response
|
||||||
from fastapi.responses import JSONResponse, RedirectResponse
|
from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
from sqlalchemy import and_, case
|
from sqlalchemy import case
|
||||||
|
|
||||||
import aurweb.filters
|
import aurweb.filters
|
||||||
import aurweb.packages.util
|
import aurweb.packages.util
|
||||||
|
@ -15,9 +15,9 @@ from aurweb.models.package_request import ACCEPTED_ID, PENDING_ID, REJECTED_ID
|
||||||
from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID
|
from aurweb.models.relation_type import CONFLICTS_ID, PROVIDES_ID, REPLACES_ID
|
||||||
from aurweb.models.request_type import DELETION_ID, MERGE, MERGE_ID
|
from aurweb.models.request_type import DELETION_ID, MERGE, MERGE_ID
|
||||||
from aurweb.packages.search import PackageSearch
|
from aurweb.packages.search import PackageSearch
|
||||||
from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, query_notified, query_voted
|
from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, get_pkgreq_by_id, query_notified, query_voted
|
||||||
from aurweb.scripts import notify, popupdate
|
from aurweb.scripts import notify, popupdate
|
||||||
from aurweb.scripts.rendercomment import update_comment_render
|
from aurweb.scripts.rendercomment import update_comment_render_fastapi
|
||||||
from aurweb.templates import make_context, make_variable_context, render_raw_template, render_template
|
from aurweb.templates import make_context, make_variable_context, render_raw_template, render_template
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -92,7 +92,10 @@ async def packages_get(request: Request, context: Dict[str, Any],
|
||||||
|
|
||||||
# Insert search results into the context.
|
# Insert search results into the context.
|
||||||
results = search.results()
|
results = search.results()
|
||||||
context["packages"] = results.limit(per_page).offset(offset)
|
|
||||||
|
packages = results.limit(per_page).offset(offset)
|
||||||
|
util.apply_all(packages, db.refresh)
|
||||||
|
context["packages"] = packages
|
||||||
context["packages_voted"] = query_voted(
|
context["packages_voted"] = query_voted(
|
||||||
context.get("packages"), request.user)
|
context.get("packages"), request.user)
|
||||||
context["packages_notified"] = query_notified(
|
context["packages_notified"] = query_notified(
|
||||||
|
@ -132,6 +135,7 @@ def create_request_if_missing(requests: List[models.PackageRequest],
|
||||||
ClosedTS=now,
|
ClosedTS=now,
|
||||||
Closer=user)
|
Closer=user)
|
||||||
requests.append(pkgreq)
|
requests.append(pkgreq)
|
||||||
|
return pkgreq
|
||||||
|
|
||||||
|
|
||||||
def delete_package(deleter: models.User, package: models.Package):
|
def delete_package(deleter: models.User, package: models.Package):
|
||||||
|
@ -147,8 +151,9 @@ def delete_package(deleter: models.User, package: models.Package):
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
with db.begin():
|
with db.begin():
|
||||||
create_request_if_missing(
|
pkgreq = create_request_if_missing(
|
||||||
requests, reqtype, deleter, package)
|
requests, reqtype, deleter, package)
|
||||||
|
db.refresh(pkgreq)
|
||||||
|
|
||||||
bases_to_delete.append(package.PackageBase)
|
bases_to_delete.append(package.PackageBase)
|
||||||
|
|
||||||
|
@ -171,7 +176,8 @@ def delete_package(deleter: models.User, package: models.Package):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Perform all the deletions.
|
# Perform all the deletions.
|
||||||
db.delete_all([package])
|
with db.begin():
|
||||||
|
db.delete(package)
|
||||||
db.delete_all(bases_to_delete)
|
db.delete_all(bases_to_delete)
|
||||||
|
|
||||||
# Send out all the notifications.
|
# Send out all the notifications.
|
||||||
|
@ -221,8 +227,7 @@ async def make_single_context(request: Request,
|
||||||
async def package(request: Request, name: str) -> Response:
|
async def package(request: Request, name: str) -> Response:
|
||||||
# Get the Package.
|
# Get the Package.
|
||||||
pkg = get_pkg_or_base(name, models.Package)
|
pkg = get_pkg_or_base(name, models.Package)
|
||||||
pkgbase = (get_pkg_or_base(name, models.PackageBase)
|
pkgbase = pkg.PackageBase
|
||||||
if not pkg else pkg.PackageBase)
|
|
||||||
|
|
||||||
# Add our base information.
|
# Add our base information.
|
||||||
context = await make_single_context(request, pkgbase)
|
context = await make_single_context(request, pkgbase)
|
||||||
|
@ -312,7 +317,7 @@ async def pkgbase_comments_post(
|
||||||
db.create(models.PackageNotification,
|
db.create(models.PackageNotification,
|
||||||
User=request.user,
|
User=request.user,
|
||||||
PackageBase=pkgbase)
|
PackageBase=pkgbase)
|
||||||
update_comment_render(comment.ID)
|
update_comment_render_fastapi(comment)
|
||||||
|
|
||||||
# Redirect to the pkgbase page.
|
# Redirect to the pkgbase page.
|
||||||
return RedirectResponse(f"/pkgbase/{pkgbase.Name}#comment-{comment.ID}",
|
return RedirectResponse(f"/pkgbase/{pkgbase.Name}#comment-{comment.ID}",
|
||||||
|
@ -374,7 +379,7 @@ async def pkgbase_comment_post(
|
||||||
db.create(models.PackageNotification,
|
db.create(models.PackageNotification,
|
||||||
User=request.user,
|
User=request.user,
|
||||||
PackageBase=pkgbase)
|
PackageBase=pkgbase)
|
||||||
update_comment_render(db_comment.ID)
|
update_comment_render_fastapi(db_comment)
|
||||||
|
|
||||||
if not next:
|
if not next:
|
||||||
next = f"/pkgbase/{pkgbase.Name}"
|
next = f"/pkgbase/{pkgbase.Name}"
|
||||||
|
@ -539,7 +544,7 @@ def remove_users(pkgbase, usernames):
|
||||||
conn, comaintainer.User.ID, pkgbase.ID
|
conn, comaintainer.User.ID, pkgbase.ID
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
db.session.delete(comaintainer)
|
db.delete(comaintainer)
|
||||||
|
|
||||||
# Send out notifications if need be.
|
# Send out notifications if need be.
|
||||||
for notify_ in notifications:
|
for notify_ in notifications:
|
||||||
|
@ -679,14 +684,8 @@ async def requests(request: Request,
|
||||||
@router.get("/pkgbase/{name}/request")
|
@router.get("/pkgbase/{name}/request")
|
||||||
@auth_required(True, redirect="/pkgbase/{name}/request")
|
@auth_required(True, redirect="/pkgbase/{name}/request")
|
||||||
async def package_request(request: Request, name: str):
|
async def package_request(request: Request, name: str):
|
||||||
|
pkgbase = get_pkg_or_base(name, models.PackageBase)
|
||||||
context = await make_variable_context(request, "Submit Request")
|
context = await make_variable_context(request, "Submit Request")
|
||||||
|
|
||||||
pkgbase = db.query(models.PackageBase).filter(
|
|
||||||
models.PackageBase.Name == name).first()
|
|
||||||
|
|
||||||
if not pkgbase:
|
|
||||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
|
|
||||||
|
|
||||||
context["pkgbase"] = pkgbase
|
context["pkgbase"] = pkgbase
|
||||||
return render_template(request, "pkgbase/request.html", context)
|
return render_template(request, "pkgbase/request.html", context)
|
||||||
|
|
||||||
|
@ -729,6 +728,7 @@ async def pkgbase_request_post(request: Request, name: str,
|
||||||
]
|
]
|
||||||
return render_template(request, "pkgbase/request.html", context)
|
return render_template(request, "pkgbase/request.html", context)
|
||||||
|
|
||||||
|
db.refresh(target)
|
||||||
if target.ID == pkgbase.ID:
|
if target.ID == pkgbase.ID:
|
||||||
# TODO: This error needs to be translated.
|
# TODO: This error needs to be translated.
|
||||||
context["errors"] = [
|
context["errors"] = [
|
||||||
|
@ -767,8 +767,7 @@ async def pkgbase_request_post(request: Request, name: str,
|
||||||
@router.get("/requests/{id}/close")
|
@router.get("/requests/{id}/close")
|
||||||
@auth_required(True, redirect="/requests/{id}/close")
|
@auth_required(True, redirect="/requests/{id}/close")
|
||||||
async def requests_close(request: Request, id: int):
|
async def requests_close(request: Request, id: int):
|
||||||
pkgreq = db.query(models.PackageRequest).filter(
|
pkgreq = get_pkgreq_by_id(id)
|
||||||
models.PackageRequest.ID == id).first()
|
|
||||||
if not request.user.is_elevated() and request.user != pkgreq.User:
|
if not request.user.is_elevated() and request.user != pkgreq.User:
|
||||||
# Request user doesn't have permission here: redirect to '/'.
|
# Request user doesn't have permission here: redirect to '/'.
|
||||||
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
|
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
|
||||||
|
@ -783,8 +782,7 @@ async def requests_close(request: Request, id: int):
|
||||||
async def requests_close_post(request: Request, id: int,
|
async def requests_close_post(request: Request, id: int,
|
||||||
reason: int = Form(default=0),
|
reason: int = Form(default=0),
|
||||||
comments: str = Form(default=str())):
|
comments: str = Form(default=str())):
|
||||||
pkgreq = db.query(models.PackageRequest).filter(
|
pkgreq = get_pkgreq_by_id(id)
|
||||||
models.PackageRequest.ID == id).first()
|
|
||||||
if not request.user.is_elevated() and request.user != pkgreq.User:
|
if not request.user.is_elevated() and request.user != pkgreq.User:
|
||||||
# Request user doesn't have permission here: redirect to '/'.
|
# Request user doesn't have permission here: redirect to '/'.
|
||||||
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
|
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
|
||||||
|
@ -823,13 +821,17 @@ async def pkgbase_keywords(request: Request, name: str,
|
||||||
keywords = set(keywords.split(" "))
|
keywords = set(keywords.split(" "))
|
||||||
|
|
||||||
# Delete all keywords which are not supplied by the user.
|
# Delete all keywords which are not supplied by the user.
|
||||||
with db.begin():
|
other_keywords = pkgbase.keywords.filter(
|
||||||
db.delete(models.PackageKeyword,
|
~models.PackageKeyword.Keyword.in_(keywords))
|
||||||
and_(models.PackageKeyword.PackageBaseID == pkgbase.ID,
|
other_keyword_strings = [kwd.Keyword for kwd in other_keywords]
|
||||||
~models.PackageKeyword.Keyword.in_(keywords)))
|
|
||||||
|
|
||||||
existing_keywords = set(kwd.Keyword for kwd in pkgbase.keywords.all())
|
existing_keywords = set(
|
||||||
|
kwd.Keyword for kwd in
|
||||||
|
pkgbase.keywords.filter(
|
||||||
|
~models.PackageKeyword.Keyword.in_(other_keyword_strings))
|
||||||
|
)
|
||||||
with db.begin():
|
with db.begin():
|
||||||
|
db.delete_all(other_keywords)
|
||||||
for keyword in keywords.difference(existing_keywords):
|
for keyword in keywords.difference(existing_keywords):
|
||||||
db.create(models.PackageKeyword,
|
db.create(models.PackageKeyword,
|
||||||
PackageBase=pkgbase,
|
PackageBase=pkgbase,
|
||||||
|
@ -940,7 +942,7 @@ def pkgbase_unnotify_instance(request: Request, pkgbase: models.PackageBase):
|
||||||
has_cred = request.user.has_credential("CRED_PKGBASE_NOTIFY")
|
has_cred = request.user.has_credential("CRED_PKGBASE_NOTIFY")
|
||||||
if has_cred and notif:
|
if has_cred and notif:
|
||||||
with db.begin():
|
with db.begin():
|
||||||
db.session.delete(notif)
|
db.delete(notif)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/pkgbase/{name}/unnotify")
|
@router.post("/pkgbase/{name}/unnotify")
|
||||||
|
@ -988,7 +990,7 @@ async def pkgbase_unvote(request: Request, name: str):
|
||||||
has_cred = request.user.has_credential("CRED_PKGBASE_VOTE")
|
has_cred = request.user.has_credential("CRED_PKGBASE_VOTE")
|
||||||
if has_cred and vote:
|
if has_cred and vote:
|
||||||
with db.begin():
|
with db.begin():
|
||||||
db.session.delete(vote)
|
db.delete(vote)
|
||||||
|
|
||||||
# Update NumVotes/Popularity.
|
# Update NumVotes/Popularity.
|
||||||
conn = db.ConnectionExecutor(db.get_engine().raw_connection())
|
conn = db.ConnectionExecutor(db.get_engine().raw_connection())
|
||||||
|
@ -1015,7 +1017,7 @@ def pkgbase_disown_instance(request: Request, pkgbase: models.PackageBase):
|
||||||
if co:
|
if co:
|
||||||
with db.begin():
|
with db.begin():
|
||||||
pkgbase.Maintainer = co.User
|
pkgbase.Maintainer = co.User
|
||||||
db.session.delete(co)
|
db.delete(co)
|
||||||
else:
|
else:
|
||||||
pkgbase.Maintainer = None
|
pkgbase.Maintainer = None
|
||||||
|
|
||||||
|
@ -1463,8 +1465,8 @@ def pkgbase_merge_instance(request: Request, pkgbase: models.PackageBase,
|
||||||
with db.begin():
|
with db.begin():
|
||||||
# Delete pkgbase and its packages now that everything's merged.
|
# Delete pkgbase and its packages now that everything's merged.
|
||||||
for pkg in pkgbase.packages:
|
for pkg in pkgbase.packages:
|
||||||
db.session.delete(pkg)
|
db.delete(pkg)
|
||||||
db.session.delete(pkgbase)
|
db.delete(pkgbase)
|
||||||
|
|
||||||
# Accept merge requests related to this pkgbase and target.
|
# Accept merge requests related to this pkgbase and target.
|
||||||
for pkgreq in requests:
|
for pkgreq in requests:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Callable, Dict, List, NewType
|
||||||
|
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
|
|
||||||
|
@ -25,6 +25,10 @@ REL_TYPES = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DataGenerator = NewType("DataGenerator",
|
||||||
|
Callable[[models.Package], Dict[str, Any]])
|
||||||
|
|
||||||
|
|
||||||
class RPCError(Exception):
|
class RPCError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -188,15 +192,32 @@ class RPC:
|
||||||
self._update_json_relations(package, data)
|
self._update_json_relations(package, data)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _handle_multiinfo_type(self, args: List[str] = [], **kwargs):
|
def _assemble_json_data(self, packages: List[models.Package],
|
||||||
|
data_generator: DataGenerator) \
|
||||||
|
-> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Assemble JSON data out of a list of packages.
|
||||||
|
|
||||||
|
:param packages: A list of Package instances or a Package ORM query
|
||||||
|
:param data_generator: Generator callable of single-Package JSON data
|
||||||
|
"""
|
||||||
|
output = []
|
||||||
|
for pkg in packages:
|
||||||
|
db.refresh(pkg)
|
||||||
|
output.append(data_generator(pkg))
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _handle_multiinfo_type(self, args: List[str] = [], **kwargs) \
|
||||||
|
-> List[Dict[str, Any]]:
|
||||||
self._enforce_args(args)
|
self._enforce_args(args)
|
||||||
args = set(args)
|
args = set(args)
|
||||||
packages = db.query(models.Package).filter(
|
packages = db.query(models.Package).filter(
|
||||||
models.Package.Name.in_(args))
|
models.Package.Name.in_(args))
|
||||||
return [self._get_info_json_data(pkg) for pkg in packages]
|
return self._assemble_json_data(packages, self._get_info_json_data)
|
||||||
|
|
||||||
def _handle_search_type(self, by: str = defaults.RPC_SEARCH_BY,
|
def _handle_search_type(self, by: str = defaults.RPC_SEARCH_BY,
|
||||||
args: List[str] = []):
|
args: List[str] = []) \
|
||||||
|
-> List[Dict[str, Any]]:
|
||||||
# If `by` isn't maintainer and we don't have any args, raise an error.
|
# If `by` isn't maintainer and we don't have any args, raise an error.
|
||||||
# In maintainer's case, return all orphans if there are no args,
|
# In maintainer's case, return all orphans if there are no args,
|
||||||
# so we need args to pass through to the handler without errors.
|
# so we need args to pass through to the handler without errors.
|
||||||
|
@ -212,7 +233,7 @@ class RPC:
|
||||||
|
|
||||||
max_results = config.getint("options", "max_rpc_results")
|
max_results = config.getint("options", "max_rpc_results")
|
||||||
results = search.results().limit(max_results)
|
results = search.results().limit(max_results)
|
||||||
return [self._get_json_data(pkg) for pkg in results]
|
return self._assemble_json_data(results, self._get_json_data)
|
||||||
|
|
||||||
def _handle_msearch_type(self, args: List[str] = [], **kwargs):
|
def _handle_msearch_type(self, args: List[str] = [], **kwargs):
|
||||||
return self._handle_search_type(by="m", args=args)
|
return self._handle_search_type(by="m", args=args)
|
||||||
|
|
|
@ -29,7 +29,7 @@ def run_single(conn, pkgbase):
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
aurweb.db.session.refresh(pkgbase)
|
aurweb.db.refresh(pkgbase)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -129,9 +129,14 @@ def save_rendered_comment(conn, commentid, html):
|
||||||
[html, commentid])
|
[html, commentid])
|
||||||
|
|
||||||
|
|
||||||
def update_comment_render(commentid):
|
def update_comment_render_fastapi(comment):
|
||||||
conn = aurweb.db.Connection()
|
conn = aurweb.db.ConnectionExecutor(
|
||||||
|
aurweb.db.get_engine().raw_connection())
|
||||||
|
update_comment_render(conn, comment.ID)
|
||||||
|
aurweb.db.refresh(comment)
|
||||||
|
|
||||||
|
|
||||||
|
def update_comment_render(conn, commentid):
|
||||||
text, pkgbase = get_comment(conn, commentid)
|
text, pkgbase = get_comment(conn, commentid)
|
||||||
html = markdown.markdown(text, extensions=[
|
html = markdown.markdown(text, extensions=[
|
||||||
'fenced_code',
|
'fenced_code',
|
||||||
|
@ -152,7 +157,8 @@ def update_comment_render(commentid):
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
commentid = int(sys.argv[1])
|
commentid = int(sys.argv[1])
|
||||||
update_comment_render(commentid)
|
conn = aurweb.db.Connection()
|
||||||
|
update_comment_render(conn, commentid)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -19,7 +19,7 @@ def references_graph(table):
|
||||||
"regexp_1": r'(?i)\s+references\s+("|\')?',
|
"regexp_1": r'(?i)\s+references\s+("|\')?',
|
||||||
"regexp_2": r'("|\')?\s*\(',
|
"regexp_2": r'("|\')?\s*\(',
|
||||||
}
|
}
|
||||||
cursor = aurweb.db.session.execute(query, params=params)
|
cursor = aurweb.db.get_session().execute(query, params=params)
|
||||||
return [row[0] for row in cursor.fetchall()]
|
return [row[0] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ def setup_test_db(*args):
|
||||||
db_backend = aurweb.config.get("database", "backend")
|
db_backend = aurweb.config.get("database", "backend")
|
||||||
|
|
||||||
if db_backend != "sqlite": # pragma: no cover
|
if db_backend != "sqlite": # pragma: no cover
|
||||||
aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 0")
|
aurweb.db.get_session().execute("SET FOREIGN_KEY_CHECKS = 0")
|
||||||
else:
|
else:
|
||||||
# We're using sqlite, setup tables to be deleted without violating
|
# We're using sqlite, setup tables to be deleted without violating
|
||||||
# foreign key constraints by graphing references.
|
# foreign key constraints by graphing references.
|
||||||
|
@ -59,10 +59,10 @@ def setup_test_db(*args):
|
||||||
references_graph(table) for table in tables))
|
references_graph(table) for table in tables))
|
||||||
|
|
||||||
for table in tables:
|
for table in tables:
|
||||||
aurweb.db.session.execute(f"DELETE FROM {table}")
|
aurweb.db.get_session().execute(f"DELETE FROM {table}")
|
||||||
|
|
||||||
if db_backend != "sqlite": # pragma: no cover
|
if db_backend != "sqlite": # pragma: no cover
|
||||||
aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 1")
|
aurweb.db.get_session().execute("SET FOREIGN_KEY_CHECKS = 1")
|
||||||
|
|
||||||
# Expunge all objects from SQLAlchemy's IdentityMap.
|
# Expunge all objects from SQLAlchemy's IdentityMap.
|
||||||
aurweb.db.session.expunge_all()
|
aurweb.db.get_session().expunge_all()
|
||||||
|
|
0
aurweb/users/__init__.py
Normal file
0
aurweb/users/__init__.py
Normal file
19
aurweb/users/util.py
Normal file
19
aurweb/users/util.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from aurweb import db
|
||||||
|
from aurweb.models import User
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_by_name(username: str) -> User:
|
||||||
|
"""
|
||||||
|
Query a user by its username.
|
||||||
|
|
||||||
|
:param username: User.Username
|
||||||
|
:return: User instance
|
||||||
|
"""
|
||||||
|
user = db.query(User).filter(User.Username == username).first()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=int(HTTPStatus.NOT_FOUND))
|
||||||
|
return db.refresh(user)
|
|
@ -155,6 +155,7 @@ def get_ssh_fingerprints():
|
||||||
def apply_all(iterable: Iterable, fn: Callable):
|
def apply_all(iterable: Iterable, fn: Callable):
|
||||||
for item in iterable:
|
for item in iterable:
|
||||||
fn(item)
|
fn(item)
|
||||||
|
return iterable
|
||||||
|
|
||||||
|
|
||||||
def sanitize_params(offset: str, per_page: str) -> Tuple[int, int]:
|
def sanitize_params(offset: str, per_page: str) -> Tuple[int, int]:
|
||||||
|
|
|
@ -132,11 +132,11 @@
|
||||||
<tr>
|
<tr>
|
||||||
<th>{{ "Votes" | tr }}:</th>
|
<th>{{ "Votes" | tr }}:</th>
|
||||||
{% if not is_maintainer %}
|
{% if not is_maintainer %}
|
||||||
<td>{{ pkgbase.package_votes.count() }}</td>
|
<td>{{ pkgbase.NumVotes }}</td>
|
||||||
{% else %}
|
{% else %}
|
||||||
<td>
|
<td>
|
||||||
<a href="/pkgbase/{{ pkgbase.Name }}/voters">
|
<a href="/pkgbase/{{ pkgbase.Name }}/voters">
|
||||||
{{ pkgbase.package_votes.count() }}
|
{{ pkgbase.NumVotes }}
|
||||||
</a>
|
</a>
|
||||||
</td>
|
</td>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
|
@ -20,7 +20,7 @@ def setup():
|
||||||
yield account_type
|
yield account_type
|
||||||
|
|
||||||
with begin():
|
with begin():
|
||||||
delete(AccountType, AccountType.ID == account_type.ID)
|
delete(account_type)
|
||||||
|
|
||||||
|
|
||||||
def test_account_type():
|
def test_account_type():
|
||||||
|
@ -50,4 +50,4 @@ def test_user_account_type_relationship():
|
||||||
# This must be deleted here to avoid foreign key issues when
|
# This must be deleted here to avoid foreign key issues when
|
||||||
# deleting the temporary AccountType in the fixture.
|
# deleting the temporary AccountType in the fixture.
|
||||||
with begin():
|
with begin():
|
||||||
delete(User, User.ID == user.ID)
|
delete(user)
|
||||||
|
|
|
@ -279,13 +279,13 @@ def test_connection_execute_paramstyle_unsupported():
|
||||||
|
|
||||||
def test_create_delete():
|
def test_create_delete():
|
||||||
with db.begin():
|
with db.begin():
|
||||||
db.create(AccountType, AccountType="test")
|
account_type = db.create(AccountType, AccountType="test")
|
||||||
|
|
||||||
record = db.query(AccountType, AccountType.AccountType == "test").first()
|
record = db.query(AccountType, AccountType.AccountType == "test").first()
|
||||||
assert record is not None
|
assert record is not None
|
||||||
|
|
||||||
with db.begin():
|
with db.begin():
|
||||||
db.delete(AccountType, AccountType.AccountType == "test")
|
db.delete(account_type)
|
||||||
|
|
||||||
record = db.query(AccountType, AccountType.AccountType == "test").first()
|
record = db.query(AccountType, AccountType.AccountType == "test").first()
|
||||||
assert record is None
|
assert record is None
|
||||||
|
@ -306,7 +306,7 @@ def test_add_commit():
|
||||||
|
|
||||||
# Remove the record.
|
# Remove the record.
|
||||||
with db.begin():
|
with db.begin():
|
||||||
db.delete(AccountType, AccountType.ID == account_type.ID)
|
db.delete(account_type)
|
||||||
|
|
||||||
|
|
||||||
def test_connection_executor_mysql_paramstyle():
|
def test_connection_executor_mysql_paramstyle():
|
||||||
|
|
|
@ -24,7 +24,7 @@ def test_dependency_type_creation():
|
||||||
assert bool(dependency_type.ID)
|
assert bool(dependency_type.ID)
|
||||||
assert dependency_type.Name == "Test Type"
|
assert dependency_type.Name == "Test Type"
|
||||||
with begin():
|
with begin():
|
||||||
delete(DependencyType, DependencyType.ID == dependency_type.ID)
|
delete(dependency_type)
|
||||||
|
|
||||||
|
|
||||||
def test_dependency_type_null_name_uses_default():
|
def test_dependency_type_null_name_uses_default():
|
||||||
|
@ -32,4 +32,4 @@ def test_dependency_type_null_name_uses_default():
|
||||||
dependency_type = create(DependencyType)
|
dependency_type = create(DependencyType)
|
||||||
assert dependency_type.Name == str()
|
assert dependency_type.Name == str()
|
||||||
with begin():
|
with begin():
|
||||||
delete(DependencyType, DependencyType.ID == dependency_type.ID)
|
delete(dependency_type)
|
||||||
|
|
|
@ -2,6 +2,7 @@ from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from aurweb import asgi, db
|
from aurweb import asgi, db
|
||||||
|
@ -98,3 +99,8 @@ def test_query_notified(maintainer: User, package: Package):
|
||||||
query = db.query(Package).filter(Package.ID == package.ID).all()
|
query = db.query(Package).filter(Package.ID == package.ID).all()
|
||||||
query_notified = util.query_notified(query, maintainer)
|
query_notified = util.query_notified(query, maintainer)
|
||||||
assert query_notified[package.PackageBase.ID]
|
assert query_notified[package.PackageBase.ID]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pkgreq_by_id_not_found():
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
util.get_pkgreq_by_id(0)
|
||||||
|
|
|
@ -103,7 +103,7 @@ def test_ratelimit_db(get: mock.MagicMock, getboolean: mock.MagicMock,
|
||||||
|
|
||||||
# Delete the ApiRateLimit record.
|
# Delete the ApiRateLimit record.
|
||||||
with db.begin():
|
with db.begin():
|
||||||
db.delete(ApiRateLimit)
|
db.delete(db.query(ApiRateLimit).first())
|
||||||
|
|
||||||
# Should be good to go again!
|
# Should be good to go again!
|
||||||
assert not check_ratelimit(request)
|
assert not check_ratelimit(request)
|
||||||
|
|
|
@ -18,7 +18,7 @@ def test_relation_type_creation():
|
||||||
assert relation_type.Name == "test-relation"
|
assert relation_type.Name == "test-relation"
|
||||||
|
|
||||||
with db.begin():
|
with db.begin():
|
||||||
db.delete(RelationType, RelationType.ID == relation_type.ID)
|
db.delete(relation_type)
|
||||||
|
|
||||||
|
|
||||||
def test_relation_types():
|
def test_relation_types():
|
||||||
|
|
|
@ -18,7 +18,7 @@ def test_request_type_creation():
|
||||||
assert request_type.Name == "Test Request"
|
assert request_type.Name == "Test Request"
|
||||||
|
|
||||||
with db.begin():
|
with db.begin():
|
||||||
db.delete(RequestType, RequestType.ID == request_type.ID)
|
db.delete(request_type)
|
||||||
|
|
||||||
|
|
||||||
def test_request_type_null_name_returns_empty_string():
|
def test_request_type_null_name_returns_empty_string():
|
||||||
|
@ -29,7 +29,7 @@ def test_request_type_null_name_returns_empty_string():
|
||||||
assert request_type.Name == str()
|
assert request_type.Name == str()
|
||||||
|
|
||||||
with db.begin():
|
with db.begin():
|
||||||
db.delete(RequestType, RequestType.ID == request_type.ID)
|
db.delete(request_type)
|
||||||
|
|
||||||
|
|
||||||
def test_request_type_name_display():
|
def test_request_type_name_display():
|
||||||
|
|
Loading…
Add table
Reference in a new issue