[FastAPI] Refactor db modifications

For SQLAlchemy to automatically understand updates from the
external world, it must use an `autocommit=True` in its session.

This change breaks how we were using commit previously, as
`autocommit=True` causes SQLAlchemy to commit when a
SessionTransaction context hits __exit__.

So, a refactoring was required of our tests: All usage of
any `db.{create,delete}` must be called **within** a
SessionTransaction context, created via new `db.begin()`.

From this point forward, we're going to require:

```
with db.begin():
    db.create(...)
    db.delete(...)
    db.session.delete(object)
```

With this, we now get external DB modifications automatically
without reloading or restarting the FastAPI server, which we
absolutely need for production.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2021-09-02 16:26:48 -07:00
parent b52059d437
commit a5943bf2ad
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
37 changed files with 998 additions and 902 deletions

View file

@ -59,20 +59,15 @@ def query(model, *args, **kwargs):
return session.query(model).filter(*args, **kwargs)
def create(model, autocommit: bool = True, *args, **kwargs):
def create(model, *args, **kwargs):
instance = model(*args, **kwargs)
add(instance)
if autocommit is True:
commit()
return instance
return add(instance)
def delete(model, *args, autocommit: bool = True, **kwargs):
def delete(model, *args, **kwargs):
instance = session.query(model).filter(*args, **kwargs)
for record in instance:
session.delete(record)
if autocommit is True:
commit()
def rollback():
@ -84,8 +79,25 @@ def add(model):
return model
def commit():
session.commit()
def begin():
""" Begin an SQLAlchemy SessionTransaction.
This context is **required** to perform an modifications to the
database.
Example:
with db.begin():
object = db.create(...)
# On __exit__, db.commit() is run.
with db.begin():
object = db.delete(...)
# On __exit__, db.commit() is run.
:return: A new SessionTransaction based on session
"""
return session.begin()
def get_sqlalchemy_url():
@ -155,23 +167,23 @@ def get_engine(echo: bool = False):
connect_args=connect_args,
echo=echo)
Session = sessionmaker(autocommit=True, autoflush=False, bind=engine)
session = Session()
if db_backend == "sqlite":
# For SQLite, we need to add some custom functions as
# they are used in the reference graph method.
def regexp(regex, item):
return bool(re.search(regex, str(item)))
@event.listens_for(engine, "begin")
def do_begin(conn):
@event.listens_for(engine, "connect")
def do_begin(conn, record):
create_deterministic_function = functools.partial(
conn.connection.create_function,
conn.create_function,
deterministic=True
)
create_deterministic_function("REGEXP", 2, regexp)
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
return engine

View file

@ -102,7 +102,7 @@ class User(Base):
def login(self, request: Request, password: str, session_time=0):
""" Login and authenticate a request. """
from aurweb.db import session
from aurweb import db
from aurweb.models.session import Session, generate_unique_sid
if not self._login_approved(request):
@ -112,10 +112,7 @@ class User(Base):
if not self.authenticated:
return None
self.LastLogin = now_ts = datetime.utcnow().timestamp()
self.LastLoginIPAddress = request.client.host
session.commit()
now_ts = datetime.utcnow().timestamp()
session_ts = now_ts + (
session_time if session_time
else aurweb.config.getint("options", "login_timeout")
@ -123,11 +120,14 @@ class User(Base):
sid = None
with db.begin():
self.LastLogin = now_ts
self.LastLoginIPAddress = request.client.host
if not self.session:
sid = generate_unique_sid()
self.session = Session(UsersID=self.ID, SessionID=sid,
LastUpdateTS=session_ts)
session.add(self.session)
db.add(self.session)
else:
last_updated = self.session.LastUpdateTS
if last_updated and last_updated < now_ts:
@ -138,8 +138,6 @@ class User(Base):
self.session.LastUpdateTS = session_ts
session.commit()
request.cookies["AURSID"] = self.session.SessionID
return self.session.SessionID
@ -149,13 +147,11 @@ class User(Base):
return aurweb.auth.has_credential(self, cred, approved)
def logout(self, request):
from aurweb.db import session
del request.cookies["AURSID"]
self.authenticated = False
if self.session:
session.delete(self.session)
session.commit()
with db.begin():
db.session.delete(self.session)
def is_trusted_user(self):
return self.AccountType.ID in {

View file

@ -43,8 +43,6 @@ async def passreset_post(request: Request,
resetkey: str = Form(default=None),
password: str = Form(default=None),
confirm: str = Form(default=None)):
from aurweb.db import session
context = await make_variable_context(request, "Password Reset")
# The user parameter being required, we can match against
@ -86,12 +84,11 @@ async def passreset_post(request: Request,
# We got to this point; everything matched up. Update the password
# and remove the ResetKey.
with db.begin():
user.ResetKey = str()
user.update_password(password)
if user.session:
session.delete(user.session)
session.commit()
db.session.delete(user.session)
user.update_password(password)
# Render ?step=complete.
return RedirectResponse(url="/passreset?step=complete",
@ -99,8 +96,8 @@ async def passreset_post(request: Request,
# If we got here, we continue with issuing a resetkey for the user.
resetkey = db.make_random_value(User, User.ResetKey)
with db.begin():
user.ResetKey = resetkey
session.commit()
executor = db.ConnectionExecutor(db.get_engine().raw_connection())
ResetKeyNotification(executor, user.ID).send()
@ -364,8 +361,6 @@ async def account_register_post(request: Request,
ON: bool = Form(default=False),
captcha: str = Form(default=None),
captcha_salt: str = Form(...)):
from aurweb.db import session
context = await make_variable_context(request, "Register")
args = dict(await request.form())
@ -394,11 +389,13 @@ async def account_register_post(request: Request,
AccountType.AccountType == "User").first()
# Create a user given all parameters available.
user = db.create(User, Username=U, Email=E, HideEmail=H, BackupEmail=BE,
with db.begin():
user = db.create(User, Username=U,
Email=E, HideEmail=H, BackupEmail=BE,
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
LangPreference=L, Timezone=TZ, CommentNotify=CN,
UpdateNotify=UN, OwnershipNotify=ON, ResetKey=resetkey,
AccountType=account_type)
UpdateNotify=UN, OwnershipNotify=ON,
ResetKey=resetkey, AccountType=account_type)
# If a PK was given and either one does not exist or the given
# PK mismatches the existing user's SSHPubKey.PubKey.
@ -410,10 +407,10 @@ async def account_register_post(request: Request,
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
with db.begin():
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
session.commit()
# Send a reset key notification to the new user.
executor = db.ConnectionExecutor(db.get_engine().raw_connection())
@ -499,6 +496,7 @@ async def account_edit_post(request: Request,
status_code=int(HTTPStatus.BAD_REQUEST))
# Set all updated fields as needed.
with db.begin():
user.Username = U or user.Username
user.Email = E or user.Email
user.HideEmail = bool(H)
@ -512,20 +510,24 @@ async def account_edit_post(request: Request,
# If we update the language, update the cookie as well.
if L and L != user.LangPreference:
request.cookies["AURLANG"] = L
with db.begin():
user.LangPreference = L
context["language"] = L
# If we update the timezone, also update the cookie.
if TZ and TZ != user.Timezone:
with db.begin():
user.Timezone = TZ
request.cookies["AURTZ"] = TZ
context["timezone"] = TZ
with db.begin():
user.CommentNotify = bool(CN)
user.UpdateNotify = bool(UN)
user.OwnershipNotify = bool(ON)
# If a PK is given, compare it against the target user's PK.
with db.begin():
if PK:
# Get the second token in the public key, which is the actual key.
pubkey = PK.strip().rstrip()
@ -547,15 +549,14 @@ async def account_edit_post(request: Request,
# Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key)
# Commit changes, if any.
session.commit()
if P and not user.valid_password(P):
# Remove the fields we consumed for passwords.
context["P"] = context["C"] = str()
# If a password was given and it doesn't match the user's, update it.
with db.begin():
user.update_password(P)
if user == request.user:
# If the target user is the request user, login with
# the updated password and update AURSID.
@ -731,6 +732,7 @@ async def terms_of_service_post(request: Request,
accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(request, context, accept_needed)
with db.begin():
# For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision.
for term in diffs:
@ -741,11 +743,6 @@ async def terms_of_service_post(request: Request,
# For each term that was never accepted, accept it!
for term in unaccepted:
db.create(AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision,
autocommit=False)
if diffs or unaccepted:
# If we had any terms to update, commit the changes.
db.commit()
Term=term, Revision=term.Revision)
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))

View file

@ -44,8 +44,6 @@ async def language(request: Request,
setting the language on any page, we want to preserve query
parameters across the redirect.
"""
from aurweb.db import session
if next[0] != '/':
return HTMLResponse(b"Invalid 'next' parameter.", status_code=400)
@ -53,8 +51,8 @@ async def language(request: Request,
# If the user is authenticated, update the user's LangPreference.
if request.user.is_authenticated():
with db.begin():
request.user.LangPreference = set_lang
session.commit()
# In any case, set the response's AURLANG cookie that never expires.
response = RedirectResponse(url=f"{next}{query_string}",

View file

@ -214,10 +214,9 @@ async def trusted_user_proposal_post(request: Request,
return Response("Invalid 'decision' value.",
status_code=int(HTTPStatus.BAD_REQUEST))
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo,
autocommit=False)
with db.begin():
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo)
voteinfo.ActiveTUs += 1
db.commit()
context["error"] = "You've already voted for this proposal."
return render_proposal(request, context, proposal, voteinfo, voters, vote)
@ -294,6 +293,7 @@ async def trusted_user_addvote_post(request: Request,
agenda = re.sub(r'<[/]?style.*>', '', agenda)
# Create a new TUVoteInfo (proposal)!
with db.begin():
voteinfo = db.create(TUVoteInfo,
User=user,
Agenda=agenda,

View file

@ -1,6 +1,6 @@
import pytest
from aurweb.db import create, delete, query
from aurweb.db import begin, create, delete, query
from aurweb.models.account_type import AccountType
from aurweb.models.user import User
from aurweb.testing import setup_test_db
@ -14,10 +14,12 @@ def setup():
global account_type
with begin():
account_type = create(AccountType, AccountType="TestUser")
yield account_type
with begin():
delete(AccountType, AccountType.ID == account_type.ID)
@ -38,6 +40,7 @@ def test_account_type():
def test_user_account_type_relationship():
with begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
@ -46,4 +49,5 @@ def test_user_account_type_relationship():
# This must be deleted here to avoid foreign key issues when
# deleting the temporary AccountType in the fixture.
with begin():
delete(User, User.ID == user.ID)

View file

@ -11,9 +11,9 @@ import pytest
from fastapi.testclient import TestClient
from aurweb import captcha
from aurweb import captcha, db
from aurweb.asgi import app
from aurweb.db import commit, create, query
from aurweb.db import create, query
from aurweb.models.accepted_term import AcceptedTerm
from aurweb.models.account_type import DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, AccountType
from aurweb.models.ban import Ban
@ -57,6 +57,8 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
with db.begin():
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test UserZ", Passwd="testPassword",
IRCNick="testZ", AccountType=account_type)
@ -70,9 +72,10 @@ def setup():
@pytest.fixture
def tu_user():
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_AND_DEV_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
yield user
@ -149,11 +152,9 @@ def test_post_passreset_user():
def test_post_passreset_resetkey():
from aurweb.db import session
with db.begin():
user.session = Session(UsersID=user.ID, SessionID="blah",
LastUpdateTS=datetime.utcnow().timestamp())
session.commit()
# Prepare a password reset.
with client as request:
@ -357,6 +358,7 @@ def test_post_register_error_invalid_captcha():
def test_post_register_error_ip_banned():
# 'testclient' is used as request.client.host via FastAPI TestClient.
with db.begin():
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
with client as request:
@ -576,6 +578,7 @@ def test_post_register_error_ssh_pubkey_taken():
# Take the sha256 fingerprint of the ssh public key, create it.
fp = get_fingerprint(pk)
with db.begin():
create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp)
with client as request:
@ -660,13 +663,11 @@ def test_post_account_edit():
def test_post_account_edit_dev():
from aurweb.db import session
# Modify our user to be a "Trusted User & Developer"
name = "Trusted User & Developer"
tu_or_dev = query(AccountType, AccountType.AccountType == name).first()
with db.begin():
user.AccountType = tu_or_dev
session.commit()
request = Request()
sid = user.login(request, "testPassword")
@ -1001,22 +1002,20 @@ def get_rows(html):
def test_post_accounts(tu_user):
# Set a PGPKey.
with db.begin():
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
# Create a few more users.
users = [user]
with db.begin():
for i in range(10):
_user = create(User, Username=f"test_{i}",
Email=f"test_{i}@example.org",
RealName=f"Test #{i}",
Passwd="testPassword",
IRCNick=f"test_#{i}",
autocommit=False)
IRCNick=f"test_#{i}")
users.append(_user)
# Commit everything to the database.
commit()
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
@ -1085,6 +1084,7 @@ def test_post_accounts_account_type(tu_user):
# test the `u` parameter.
account_type = query(AccountType,
AccountType.AccountType == "User").first()
with db.begin():
create(User, Username="test_2",
Email="test_2@example.org",
RealName="Test User 2",
@ -1113,9 +1113,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "User"
# Set our only user to a Trusted User.
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_ID
).first()
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1130,9 +1131,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "Trusted User"
user.AccountType = query(AccountType,
AccountType.ID == DEVELOPER_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == DEVELOPER_ID
).first()
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1147,10 +1149,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "Developer"
user.AccountType = query(AccountType,
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
commit()
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1182,8 +1184,8 @@ def test_post_accounts_status(tu_user):
username, type, status, realname, irc, pgp_key, edit = row
assert status.text.strip() == "Active"
with db.begin():
user.Suspended = True
commit()
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1244,6 +1246,7 @@ def test_post_accounts_sortby(tu_user):
# Create a second user so we can compare sorts.
account_type = query(AccountType,
AccountType.ID == DEVELOPER_ID).first()
with db.begin():
create(User, Username="test2",
Email="test2@example.org",
RealName="Test User 2",
@ -1297,9 +1300,10 @@ def test_post_accounts_sortby(tu_user):
# Test the rows are reversed when ordering by RealName.
assert compare_text_values(4, first_rows, reversed(rows)) is True
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_AND_DEV_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
# Fetch first_rows again with our new AccountType ordering.
with client as request:
@ -1322,8 +1326,8 @@ def test_post_accounts_sortby(tu_user):
def test_post_accounts_pgp_key(tu_user):
with db.begin():
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
commit()
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
@ -1343,15 +1347,14 @@ def test_post_accounts_paged(tu_user):
users = [user]
account_type = query(AccountType,
AccountType.AccountType == "User").first()
with db.begin():
for i in range(150):
_user = create(User, Username=f"test_#{i}",
Email=f"test_#{i}@example.org",
RealName=f"Test User #{i}",
Passwd="testPassword",
AccountType=account_type,
autocommit=False)
AccountType=account_type)
users.append(_user)
commit()
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
@ -1414,6 +1417,7 @@ def test_post_accounts_paged(tu_user):
def test_get_terms_of_service():
with db.begin():
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
@ -1436,6 +1440,7 @@ def test_get_terms_of_service():
response = request.get("/tos", cookies=cookies, allow_redirects=False)
assert response.status_code == int(HTTPStatus.OK)
with db.begin():
accepted_term = create(AcceptedTerm, User=user,
Term=term, Revision=term.Revision)
@ -1445,8 +1450,8 @@ def test_get_terms_of_service():
assert response.status_code == int(HTTPStatus.SEE_OTHER)
# Bump the term's revision.
with db.begin():
term.Revision = 2
commit()
with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1454,8 +1459,8 @@ def test_get_terms_of_service():
# yet been agreed to via AcceptedTerm update.
assert response.status_code == int(HTTPStatus.OK)
with db.begin():
accepted_term.Revision = term.Revision
commit()
with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1471,6 +1476,7 @@ def test_post_terms_of_service():
cookies = {"AURSID": sid} # Auth cookie.
# Create a fresh Term.
with db.begin():
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
@ -1497,8 +1503,8 @@ def test_post_terms_of_service():
assert accepted_term.Term == term
# Update the term to revision 2.
with db.begin():
term.Revision = 2
commit()
# A GET request gives us the new revision to accept.
with client as request:

View file

@ -2,6 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb import db
from aurweb.db import create
from aurweb.models.api_rate_limit import ApiRateLimit
from aurweb.testing import setup_test_db
@ -13,6 +14,7 @@ def setup():
def test_api_rate_key_creation():
with db.begin():
rate = create(ApiRateLimit, IP="127.0.0.1", Requests=10, WindowStart=1)
assert rate.IP == "127.0.0.1"
assert rate.Requests == 10
@ -20,19 +22,20 @@ def test_api_rate_key_creation():
def test_api_rate_key_ip_default():
with db.begin():
api_rate_limit = create(ApiRateLimit, Requests=10, WindowStart=1)
assert api_rate_limit.IP == str()
def test_api_rate_key_null_requests_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
with db.begin():
create(ApiRateLimit, IP="127.0.0.1", WindowStart=1)
session.rollback()
db.rollback()
def test_api_rate_key_null_window_start_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
with db.begin():
create(ApiRateLimit, IP="127.0.0.1", Requests=1)
session.rollback()
db.rollback()

View file

@ -4,6 +4,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb import db
from aurweb.auth import BasicAuthBackend, account_type_required, has_credential
from aurweb.db import create, query
from aurweb.models.account_type import USER, USER_ID, AccountType
@ -23,6 +24,7 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
with db.begin():
user = create(User, Username="test", Email="test@example.com",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
@ -51,14 +53,13 @@ async def test_auth_backend_invalid_sid():
@pytest.mark.asyncio
async def test_auth_backend_invalid_user_id():
from aurweb.db import session
# Create a new session with a fake user id.
now_ts = datetime.utcnow().timestamp()
with pytest.raises(IntegrityError):
with db.begin():
create(Session, UsersID=666, SessionID="realSession",
LastUpdateTS=now_ts + 5)
session.rollback()
db.rollback()
@pytest.mark.asyncio
@ -66,6 +67,7 @@ async def test_basic_auth_backend():
# This time, everything matches up. We expect the user to
# equal the real_user.
now_ts = datetime.utcnow().timestamp()
with db.begin():
create(Session, UsersID=user.ID, SessionID="realSession",
LastUpdateTS=now_ts + 5)
request.cookies["AURSID"] = "realSession"

View file

@ -9,7 +9,7 @@ from fastapi.testclient import TestClient
import aurweb.config
from aurweb.asgi import app
from aurweb.db import create, query
from aurweb.db import begin, create, query
from aurweb.models.account_type import AccountType
from aurweb.models.session import Session
from aurweb.models.user import User
@ -32,6 +32,7 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
with begin():
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test User", Passwd="testPassword",
AccountType=account_type)

View file

@ -6,6 +6,7 @@ import pytest
from sqlalchemy import exc as sa_exc
from aurweb import db
from aurweb.db import create
from aurweb.models.ban import Ban, is_banned
from aurweb.testing import setup_test_db
@ -21,6 +22,7 @@ def setup():
setup_test_db("Bans")
ts = datetime.utcnow() + timedelta(seconds=30)
with db.begin():
ban = create(Ban, IPAddress="127.0.0.1", BanTS=ts)
request = Request()
@ -35,17 +37,17 @@ def test_invalid_ban():
with pytest.raises(sa_exc.IntegrityError):
bad_ban = Ban(BanTS=datetime.utcnow())
session.add(bad_ban)
# We're adding a ban with no primary key; this causes an
# SQLAlchemy warnings when committing to the DB.
# Ignore them.
with warnings.catch_warnings():
warnings.simplefilter("ignore", sa_exc.SAWarning)
session.commit()
with db.begin():
session.add(bad_ban)
# Since we got a transaction failure, we need to rollback.
session.rollback()
db.rollback()
def test_banned():

View file

@ -278,18 +278,15 @@ def test_connection_execute_paramstyle_unsupported():
def test_create_delete():
with db.begin():
db.create(AccountType, AccountType="test")
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is not None
db.delete(AccountType, AccountType.AccountType == "test")
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is None
# Create and delete a record with autocommit=False.
db.create(AccountType, AccountType="test", autocommit=False)
db.commit()
db.delete(AccountType, AccountType.AccountType == "test", autocommit=False)
db.commit()
with db.begin():
db.delete(AccountType, AccountType.AccountType == "test")
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is None
@ -297,8 +294,8 @@ def test_create_delete():
def test_add_commit():
# Use db.add and db.commit to add a temporary record.
account_type = AccountType(AccountType="test")
with db.begin():
db.add(account_type)
db.commit()
# Assert it got created in the DB.
assert bool(account_type.ID)
@ -308,6 +305,7 @@ def test_add_commit():
assert record == account_type
# Remove the record.
with db.begin():
db.delete(AccountType, AccountType.ID == account_type.ID)

View file

@ -1,6 +1,6 @@
import pytest
from aurweb.db import create, delete, query
from aurweb.db import begin, create, delete, query
from aurweb.models.dependency_type import DependencyType
from aurweb.testing import setup_test_db
@ -19,13 +19,17 @@ def test_dependency_types():
def test_dependency_type_creation():
with begin():
dependency_type = create(DependencyType, Name="Test Type")
assert bool(dependency_type.ID)
assert dependency_type.Name == "Test Type"
with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID)
def test_dependency_type_null_name_uses_default():
with begin():
dependency_type = create(DependencyType)
assert dependency_type.Name == str()
with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID)

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create
from aurweb import db
from aurweb.models.group import Group
from aurweb.testing import setup_test_db
@ -13,13 +13,14 @@ def setup():
def test_group_creation():
group = create(Group, Name="Test Group")
with db.begin():
group = db.create(Group, Name="Test Group")
assert bool(group.ID)
assert group.Name == "Test Group"
def test_group_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(Group)
session.rollback()
with db.begin():
db.create(Group)
db.rollback()

View file

@ -38,8 +38,10 @@ def setup():
@pytest.fixture
def user():
yield db.create(User, Username="test", Email="test@example.org",
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
Passwd="testPassword", AccountTypeID=USER_ID)
yield user
@pytest.fixture
@ -68,18 +70,15 @@ def packages(user):
# For i..num_packages, create a package named pkg_{i}.
pkgs = []
now = int(datetime.utcnow().timestamp())
with db.begin():
for i in range(num_packages):
pkgbase = db.create(PackageBase, Name=f"pkg_{i}",
Maintainer=user, Packager=user,
autocommit=False, SubmittedTS=now,
ModifiedTS=now)
pkg = db.create(Package, PackageBase=pkgbase,
Name=pkgbase.Name, autocommit=False)
SubmittedTS=now, ModifiedTS=now)
pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
pkgs.append(pkg)
now += 1
db.commit()
yield pkgs
@ -159,10 +158,11 @@ def test_homepage_updates(redis, packages):
def test_homepage_dashboard(redis, packages, user):
# Create Comaintainer records for all of the packages.
with db.begin():
for pkg in packages:
db.create(PackageComaintainer, PackageBase=pkg.PackageBase,
User=user, Priority=1, autocommit=False)
db.commit()
db.create(PackageComaintainer,
PackageBase=pkg.PackageBase,
User=user, Priority=1)
cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request:
@ -193,6 +193,7 @@ def test_homepage_dashboard_requests(redis, packages, user):
pkg = packages[0]
reqtype = db.query(RequestType, RequestType.ID == DELETION_ID).first()
with db.begin():
pkgreq = db.create(PackageRequest, PackageBase=pkg.PackageBase,
PackageBaseName=pkg.PackageBase.Name,
User=user, Comments=str(),
@ -213,8 +214,8 @@ def test_homepage_dashboard_requests(redis, packages, user):
def test_homepage_dashboard_flagged_packages(redis, packages, user):
# Set the first Package flagged by setting its OutOfDateTS column.
pkg = packages[0]
with db.begin():
pkg.PackageBase.OutOfDateTS = int(datetime.utcnow().timestamp())
db.commit()
cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request:

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create
from aurweb import db
from aurweb.models.license import License
from aurweb.testing import setup_test_db
@ -13,13 +13,14 @@ def setup():
def test_license_creation():
license = create(License, Name="Test License")
with db.begin():
license = db.create(License, Name="Test License")
assert bool(license.ID)
assert license.Name == "Test License"
def test_license_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(License)
session.rollback()
with db.begin():
db.create(License)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create
from aurweb import db
from aurweb.models.official_provider import OfficialProvider
from aurweb.testing import setup_test_db
@ -13,7 +13,8 @@ def setup():
def test_official_provider_creation():
oprovider = create(OfficialProvider,
with db.begin():
oprovider = db.create(OfficialProvider,
Name="some-name",
Repo="some-repo",
Provides="some-provides")
@ -25,13 +26,15 @@ def test_official_provider_creation():
def test_official_provider_cs():
""" Test case sensitivity of the database table. """
oprovider = create(OfficialProvider,
with db.begin():
oprovider = db.create(OfficialProvider,
Name="some-name",
Repo="some-repo",
Provides="some-provides")
assert bool(oprovider.ID)
oprovider_cs = create(OfficialProvider,
with db.begin():
oprovider_cs = db.create(OfficialProvider,
Name="SOME-NAME",
Repo="SOME-REPO",
Provides="SOME-PROVIDES")
@ -49,27 +52,27 @@ def test_official_provider_cs():
def test_official_provider_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(OfficialProvider,
with db.begin():
db.create(OfficialProvider,
Repo="some-repo",
Provides="some-provides")
session.rollback()
db.rollback()
def test_official_provider_null_repo_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(OfficialProvider,
with db.begin():
db.create(OfficialProvider,
Name="some-name",
Provides="some-provides")
session.rollback()
db.rollback()
def test_official_provider_null_provides_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(OfficialProvider,
with db.begin():
db.create(OfficialProvider,
Name="some-name",
Repo="some-repo")
session.rollback()
db.rollback()

View file

@ -3,7 +3,7 @@ import pytest
from sqlalchemy import and_
from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query
from aurweb import db
from aurweb.models.account_type import AccountType
from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase
@ -19,16 +19,18 @@ def setup():
setup_test_db("Packages", "PackageBases", "Users")
account_type = query(AccountType,
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase,
pkgbase = db.create(PackageBase,
Name="beautiful-package",
Maintainer=user)
package = create(Package,
package = db.create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name,
Description="Test description.",
@ -36,8 +38,6 @@ def setup():
def test_package():
from aurweb.db import session
assert pkgbase == package.PackageBase
assert package.Name == "beautiful-package"
assert package.Description == "Test description."
@ -45,33 +45,31 @@ def test_package():
assert package.URL == "https://test.package"
# Update package Version.
with db.begin():
package.Version = "1.2.3"
session.commit()
# Make sure it got updated in the database.
record = query(Package,
record = db.query(Package,
and_(Package.ID == package.ID,
Package.Version == "1.2.3")).first()
assert record is not None
def test_package_null_pkgbase_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(Package,
with db.begin():
db.create(Package,
Name="some-package",
Description="Some description.",
URL="https://some.package")
session.rollback()
db.rollback()
def test_package_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(Package,
with db.begin():
db.create(Package,
PackageBase=pkgbase,
Description="Some description.",
URL="https://some.package")
session.rollback()
db.rollback()

View file

@ -4,7 +4,7 @@ from sqlalchemy.exc import IntegrityError
import aurweb.config
from aurweb.db import create, query
from aurweb import db
from aurweb.models.account_type import AccountType
from aurweb.models.package_base import PackageBase
from aurweb.models.user import User
@ -19,15 +19,17 @@ def setup():
setup_test_db("Users", "PackageBases")
account_type = query(AccountType,
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
def test_package_base():
pkgbase = create(PackageBase,
with db.begin():
pkgbase = db.create(PackageBase,
Name="beautiful-package",
Maintainer=user)
assert pkgbase in user.maintained_bases
@ -38,6 +40,7 @@ def test_package_base():
# Set Popularity to a string, then get it by attribute to
# exercise the string -> float conversion path.
with db.begin():
pkgbase.Popularity = "0.0"
assert pkgbase.Popularity == 0.0
@ -47,22 +50,23 @@ def test_package_base_ci():
if aurweb.config.get("database", "backend") == "sqlite":
return None # SQLite doesn't seem handle this.
from aurweb.db import session
pkgbase = create(PackageBase,
with db.begin():
pkgbase = db.create(PackageBase,
Name="beautiful-package",
Maintainer=user)
assert bool(pkgbase.ID)
with pytest.raises(IntegrityError):
create(PackageBase,
with db.begin():
db.create(PackageBase,
Name="Beautiful-Package",
Maintainer=user)
session.rollback()
db.rollback()
def test_package_base_relationships():
pkgbase = create(PackageBase,
with db.begin():
pkgbase = db.create(PackageBase,
Name="beautiful-package",
Flagger=user,
Maintainer=user,
@ -75,8 +79,7 @@ def test_package_base_relationships():
def test_package_base_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(PackageBase)
session.rollback()
with db.begin():
db.create(PackageBase)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create, rollback
from aurweb import db
from aurweb.models.package_base import PackageBase
from aurweb.models.package_blacklist import PackageBlacklist
from aurweb.models.user import User
@ -17,18 +17,20 @@ def setup():
setup_test_db("PackageBlacklist", "PackageBases", "Users")
user = create(User, Username="test", Email="test@example.org",
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword")
pkgbase = create(PackageBase, Name="test-package", Maintainer=user)
pkgbase = db.create(PackageBase, Name="test-package", Maintainer=user)
def test_package_blacklist_creation():
package_blacklist = create(PackageBlacklist, Name="evil-package")
with db.begin():
package_blacklist = db.create(PackageBlacklist, Name="evil-package")
assert bool(package_blacklist.ID)
assert package_blacklist.Name == "evil-package"
def test_package_blacklist_null_name_raises_exception():
with pytest.raises(IntegrityError):
create(PackageBlacklist)
rollback()
with db.begin():
db.create(PackageBlacklist)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query, rollback
from aurweb.db import begin, create, query, rollback
from aurweb.models.account_type import AccountType
from aurweb.models.package_base import PackageBase
from aurweb.models.package_comment import PackageComment
@ -20,6 +20,7 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
with begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
@ -27,6 +28,7 @@ def setup():
def test_package_comment_creation():
with begin():
package_comment = create(PackageComment,
PackageBase=pkgbase,
User=user,
@ -37,6 +39,7 @@ def test_package_comment_creation():
def test_package_comment_null_package_base_raises_exception():
with pytest.raises(IntegrityError):
with begin():
create(PackageComment, User=user, Comments="Test comment.",
RenderedComment="Test rendered comment.")
rollback()
@ -44,19 +47,23 @@ def test_package_comment_null_package_base_raises_exception():
def test_package_comment_null_user_raises_exception():
with pytest.raises(IntegrityError):
create(PackageComment, PackageBase=pkgbase, Comments="Test comment.",
with begin():
create(PackageComment, PackageBase=pkgbase,
Comments="Test comment.",
RenderedComment="Test rendered comment.")
rollback()
def test_package_comment_null_comments_raises_exception():
with pytest.raises(IntegrityError):
with begin():
create(PackageComment, PackageBase=pkgbase, User=user,
RenderedComment="Test rendered comment.")
rollback()
def test_package_comment_null_renderedcomment_defaults():
with begin():
record = create(PackageComment,
PackageBase=pkgbase,
User=user,

View file

@ -2,7 +2,8 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query
from aurweb import db
from aurweb.db import create, query
from aurweb.models.account_type import AccountType
from aurweb.models.dependency_type import DependencyType
from aurweb.models.package import Package
@ -22,10 +23,11 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
with db.begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase,
Name="test-package",
Maintainer=user)
@ -38,6 +40,8 @@ def setup():
def test_package_dependencies():
depends = query(DependencyType, DependencyType.Name == "depends").first()
with db.begin():
pkgdep = create(PackageDependency, Package=package,
DependencyType=depends,
DepName="test-dep")
@ -49,8 +53,8 @@ def test_package_dependencies():
makedepends = query(DependencyType,
DependencyType.Name == "makedepends").first()
with db.begin():
pkgdep.DependencyType = makedepends
commit()
assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package
assert pkgdep.DependencyType == makedepends
@ -59,8 +63,8 @@ def test_package_dependencies():
checkdepends = query(DependencyType,
DependencyType.Name == "checkdepends").first()
with db.begin():
pkgdep.DependencyType = checkdepends
commit()
assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package
assert pkgdep.DependencyType == checkdepends
@ -69,8 +73,8 @@ def test_package_dependencies():
optdepends = query(DependencyType,
DependencyType.Name == "optdepends").first()
with db.begin():
pkgdep.DependencyType = optdepends
commit()
assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package
assert pkgdep.DependencyType == optdepends
@ -79,6 +83,7 @@ def test_package_dependencies():
assert not pkgdep.is_package()
with db.begin():
base = create(PackageBase, Name=pkgdep.DepName, Maintainer=user)
create(Package, PackageBase=base, Name=pkgdep.DepName)
@ -86,32 +91,29 @@ def test_package_dependencies():
def test_package_dependencies_null_package_raises_exception():
from aurweb.db import session
depends = query(DependencyType, DependencyType.Name == "depends").first()
with pytest.raises(IntegrityError):
with db.begin():
create(PackageDependency,
DependencyType=depends,
DepName="test-dep")
session.rollback()
db.rollback()
def test_package_dependencies_null_dependency_type_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
with db.begin():
create(PackageDependency,
Package=package,
DepName="test-dep")
session.rollback()
db.rollback()
def test_package_dependencies_null_depname_raises_exception():
from aurweb.db import session
depends = query(DependencyType, DependencyType.Name == "depends").first()
with pytest.raises(IntegrityError):
with db.begin():
create(PackageDependency,
Package=package,
DependencyType=depends)
session.rollback()
db.rollback()

View file

@ -2,7 +2,8 @@ import pytest
from sqlalchemy.exc import IntegrityError, OperationalError
from aurweb.db import commit, create, query
from aurweb import db
from aurweb.db import create, query
from aurweb.models.account_type import AccountType
from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase
@ -22,10 +23,11 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
with db.begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase,
Name="test-package",
Maintainer=user)
@ -38,6 +40,8 @@ def setup():
def test_package_relation():
conflicts = query(RelationType, RelationType.Name == "conflicts").first()
with db.begin():
pkgrel = create(PackageRelation, Package=package,
RelationType=conflicts,
RelName="test-relation")
@ -48,8 +52,8 @@ def test_package_relation():
assert pkgrel in package.package_relations
provides = query(RelationType, RelationType.Name == "provides").first()
with db.begin():
pkgrel.RelationType = provides
commit()
assert pkgrel.RelName == "test-relation"
assert pkgrel.Package == package
assert pkgrel.RelationType == provides
@ -57,8 +61,8 @@ def test_package_relation():
assert pkgrel in package.package_relations
replaces = query(RelationType, RelationType.Name == "replaces").first()
with db.begin():
pkgrel.RelationType = replaces
commit()
assert pkgrel.RelName == "test-relation"
assert pkgrel.Package == package
assert pkgrel.RelationType == replaces
@ -67,36 +71,33 @@ def test_package_relation():
def test_package_relation_null_package_raises_exception():
from aurweb.db import session
conflicts = query(RelationType, RelationType.Name == "conflicts").first()
assert conflicts is not None
with pytest.raises(IntegrityError):
with db.begin():
create(PackageRelation,
RelationType=conflicts,
RelName="test-relation")
session.rollback()
db.rollback()
def test_package_relation_null_relation_type_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
with db.begin():
create(PackageRelation,
Package=package,
RelName="test-relation")
session.rollback()
db.rollback()
def test_package_relation_null_relname_raises_exception():
from aurweb.db import session
depends = query(RelationType, RelationType.Name == "conflicts").first()
assert depends is not None
with pytest.raises((OperationalError, IntegrityError)):
with db.begin():
create(PackageRelation,
Package=package,
RelationType=depends)
session.rollback()
db.rollback()

View file

@ -4,7 +4,8 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query, rollback
from aurweb import db
from aurweb.db import create, query, rollback
from aurweb.models.package_base import PackageBase
from aurweb.models.package_request import (ACCEPTED, ACCEPTED_ID, CLOSED, CLOSED_ID, PENDING, PENDING_ID, REJECTED,
REJECTED_ID, PackageRequest)
@ -21,6 +22,7 @@ def setup():
setup_test_db("PackageRequests", "PackageBases", "Users")
with db.begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword")
pkgbase = create(PackageBase, Name="test-package", Maintainer=user)
@ -30,6 +32,7 @@ def test_package_request_creation():
request_type = query(RequestType, RequestType.Name == "merge").first()
assert request_type.Name == "merge"
with db.begin():
package_request = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
@ -54,6 +57,7 @@ def test_package_request_closed():
assert request_type.Name == "merge"
ts = int(datetime.utcnow().timestamp())
with db.begin():
package_request = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
@ -69,6 +73,7 @@ def test_package_request_closed():
def test_package_request_null_request_type_raises_exception():
with pytest.raises(IntegrityError):
with db.begin():
create(PackageRequest, User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
@ -78,8 +83,9 @@ def test_package_request_null_request_type_raises_exception():
def test_package_request_null_user_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
with db.begin():
create(PackageRequest, RequestType=request_type,
PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
rollback()
@ -87,6 +93,7 @@ def test_package_request_null_user_raises_exception():
def test_package_request_null_package_base_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError):
with db.begin():
create(PackageRequest, RequestType=request_type,
User=user, PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
@ -96,6 +103,7 @@ def test_package_request_null_package_base_raises_exception():
def test_package_request_null_package_base_name_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError):
with db.begin():
create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
Comments=str(), ClosureComment=str())
@ -105,8 +113,9 @@ def test_package_request_null_package_base_name_raises_exception():
def test_package_request_null_comments_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
with db.begin():
create(PackageRequest, RequestType=request_type, User=user,
PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
ClosureComment=str())
rollback()
@ -114,8 +123,9 @@ def test_package_request_null_comments_raises_exception():
def test_package_request_null_closure_comment_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
with db.begin():
create(PackageRequest, RequestType=request_type, User=user,
PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
Comments=str())
rollback()
@ -124,6 +134,7 @@ def test_package_request_status_display():
""" Test status_display() based on the Status column value. """
request_type = query(RequestType, RequestType.Name == "merge").first()
with db.begin():
pkgreq = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
@ -131,19 +142,19 @@ def test_package_request_status_display():
Status=PENDING_ID)
assert pkgreq.status_display() == PENDING
with db.begin():
pkgreq.Status = CLOSED_ID
commit()
assert pkgreq.status_display() == CLOSED
with db.begin():
pkgreq.Status = ACCEPTED_ID
commit()
assert pkgreq.status_display() == ACCEPTED
with db.begin():
pkgreq.Status = REJECTED_ID
commit()
assert pkgreq.status_display() == REJECTED
with db.begin():
pkgreq.Status = 124
commit()
with pytest.raises(KeyError):
pkgreq.status_display()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query, rollback
from aurweb.db import begin, create, query, rollback
from aurweb.models.account_type import AccountType
from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase
@ -21,6 +21,7 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
with begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
@ -31,6 +32,7 @@ def setup():
def test_package_source():
with begin():
pkgsource = create(PackageSource, Package=package)
assert pkgsource.Package == package
# By default, PackageSources.Source assigns the string '/dev/null'.
@ -40,5 +42,6 @@ def test_package_source():
def test_package_source_null_package_raises_exception():
with pytest.raises(IntegrityError):
with begin():
create(PackageSource)
rollback()

View file

@ -28,31 +28,25 @@ def package_endpoint(package: Package) -> str:
return f"/packages/{package.Name}"
def create_package(pkgname: str, maintainer: User,
autocommit: bool = True) -> Package:
def create_package(pkgname: str, maintainer: User) -> Package:
pkgbase = db.create(PackageBase,
Name=pkgname,
Maintainer=maintainer,
autocommit=False)
return db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase,
autocommit=autocommit)
Maintainer=maintainer)
return db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
def create_package_dep(package: Package, depname: str,
dep_type_name: str = "depends",
autocommit: bool = True) -> PackageDependency:
dep_type_name: str = "depends") -> PackageDependency:
dep_type = db.query(DependencyType,
DependencyType.Name == dep_type_name).first()
return db.create(PackageDependency,
DependencyType=dep_type,
Package=package,
DepName=depname,
autocommit=autocommit)
DepName=depname)
def create_package_rel(package: Package,
relname: str,
autocommit: bool = True) -> PackageRelation:
relname: str) -> PackageRelation:
rel_type = db.query(RelationType,
RelationType.ID == PROVIDES_ID).first()
return db.create(PackageRelation,
@ -84,31 +78,37 @@ def client() -> TestClient:
def user() -> User:
""" Yield a user. """
account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test",
with db.begin():
user = db.create(User, Username="test",
Email="test@example.org",
Passwd="testPassword",
AccountType=account_type)
yield user
@pytest.fixture
def maintainer() -> User:
""" Yield a specific User used to maintain packages. """
account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test_maintainer",
with db.begin():
maintainer = db.create(User, Username="test_maintainer",
Email="test_maintainer@example.org",
Passwd="testPassword",
AccountType=account_type)
yield maintainer
@pytest.fixture
def package(maintainer: User) -> Package:
""" Yield a Package created by user. """
with db.begin():
pkgbase = db.create(PackageBase,
Name="test-package",
Maintainer=maintainer)
yield db.create(Package,
package = db.create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name)
yield package
def test_package_not_found(client: TestClient):
@ -121,6 +121,7 @@ def test_package_official_not_found(client: TestClient, package: Package):
""" When a Package has a matching OfficialProvider record, it is not
hosted on AUR, but in the official repositories. Getting a package
with this kind of record should return a status code 404. """
with db.begin():
db.create(OfficialProvider,
Name=package.Name,
Repo="core",
@ -157,6 +158,7 @@ def test_package(client: TestClient, package: Package):
def test_package_comments(client: TestClient, user: User, package: Package):
now = (datetime.utcnow().timestamp())
with db.begin():
comment = db.create(PackageComment, PackageBase=package.PackageBase,
User=user, Comments="Test comment", CommentTS=now)
@ -178,6 +180,7 @@ def test_package_comments(client: TestClient, user: User, package: Package):
def test_package_requests_display(client: TestClient, user: User,
package: Package):
type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first()
with db.begin():
db.create(PackageRequest, PackageBase=package.PackageBase,
PackageBaseName=package.PackageBase.Name,
User=user, RequestType=type_,
@ -195,6 +198,7 @@ def test_package_requests_display(client: TestClient, user: User,
assert target.text.strip() == "1 pending request"
type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first()
with db.begin():
db.create(PackageRequest, PackageBase=package.PackageBase,
PackageBaseName=package.PackageBase.Name,
User=user, RequestType=type_,
@ -271,49 +275,42 @@ def test_package_authenticated_maintainer(client: TestClient,
def test_package_dependencies(client: TestClient, maintainer: User,
package: Package):
# Create a normal dependency of type depends.
dep_pkg = create_package("test-dep-1", maintainer, autocommit=False)
dep = create_package_dep(package, dep_pkg.Name, autocommit=False)
with db.begin():
dep_pkg = create_package("test-dep-1", maintainer)
dep = create_package_dep(package, dep_pkg.Name)
dep.DepArch = "x86_64"
# Also, create a makedepends.
make_dep_pkg = create_package("test-dep-2", maintainer, autocommit=False)
make_dep_pkg = create_package("test-dep-2", maintainer)
make_dep = create_package_dep(package, make_dep_pkg.Name,
dep_type_name="makedepends",
autocommit=False)
dep_type_name="makedepends")
# And... a checkdepends!
check_dep_pkg = create_package("test-dep-3", maintainer, autocommit=False)
check_dep_pkg = create_package("test-dep-3", maintainer)
check_dep = create_package_dep(package, check_dep_pkg.Name,
dep_type_name="checkdepends",
autocommit=False)
dep_type_name="checkdepends")
# Geez. Just stop. This is optdepends.
opt_dep_pkg = create_package("test-dep-4", maintainer, autocommit=False)
opt_dep_pkg = create_package("test-dep-4", maintainer)
opt_dep = create_package_dep(package, opt_dep_pkg.Name,
dep_type_name="optdepends",
autocommit=False)
dep_type_name="optdepends")
# Heh. Another optdepends to test one with a description.
opt_desc_dep_pkg = create_package("test-dep-5", maintainer,
autocommit=False)
opt_desc_dep_pkg = create_package("test-dep-5", maintainer)
opt_desc_dep = create_package_dep(package, opt_desc_dep_pkg.Name,
dep_type_name="optdepends",
autocommit=False)
dep_type_name="optdepends")
opt_desc_dep.DepDesc = "Test description."
broken_dep = create_package_dep(package, "test-dep-6",
dep_type_name="depends",
autocommit=False)
dep_type_name="depends")
# Create an official provider record.
db.create(OfficialProvider, Name="test-dep-99",
Repo="core", Provides="test-dep-99",
autocommit=False)
official_dep = create_package_dep(package, "test-dep-99",
autocommit=False)
Repo="core", Provides="test-dep-99")
official_dep = create_package_dep(package, "test-dep-99")
# Also, create a provider who provides our test-dep-99.
provider = create_package("test-provider", maintainer, autocommit=False)
provider = create_package("test-provider", maintainer)
create_package_rel(provider, dep.DepName)
with client as request:
@ -358,6 +355,7 @@ def test_pkgbase_redirect(client: TestClient, package: Package):
def test_pkgbase(client: TestClient, package: Package):
with db.begin():
second = db.create(Package, Name="second-pkg",
PackageBase=package.PackageBase)

View file

@ -26,17 +26,21 @@ def setup():
@pytest.fixture
def maintainer() -> User:
account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test_maintainer",
with db.begin():
maintainer = db.create(User, Username="test_maintainer",
Email="test_maintainer@examepl.org",
Passwd="testPassword",
AccountType=account_type)
yield maintainer
@pytest.fixture
def package(maintainer: User) -> Package:
with db.begin():
pkgbase = db.create(PackageBase, Name="test-pkg",
Packager=maintainer, Maintainer=maintainer)
yield db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
package = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
yield package
@pytest.fixture
@ -45,6 +49,7 @@ def client() -> TestClient:
def test_package_link(client: TestClient, maintainer: User, package: Package):
with db.begin():
db.create(OfficialProvider,
Name=package.Name,
Repo="core",

View file

@ -1,6 +1,6 @@
import pytest
from aurweb.db import create, delete, query
from aurweb import db
from aurweb.models.relation_type import RelationType
from aurweb.testing import setup_test_db
@ -11,22 +11,25 @@ def setup():
def test_relation_type_creation():
relation_type = create(RelationType, Name="test-relation")
with db.begin():
relation_type = db.create(RelationType, Name="test-relation")
assert bool(relation_type.ID)
assert relation_type.Name == "test-relation"
delete(RelationType, RelationType.ID == relation_type.ID)
with db.begin():
db.delete(RelationType, RelationType.ID == relation_type.ID)
def test_relation_types():
conflicts = query(RelationType, RelationType.Name == "conflicts").first()
conflicts = db.query(RelationType, RelationType.Name == "conflicts").first()
assert conflicts is not None
assert conflicts.Name == "conflicts"
provides = query(RelationType, RelationType.Name == "provides").first()
provides = db.query(RelationType, RelationType.Name == "provides").first()
assert provides is not None
assert provides.Name == "provides"
replaces = query(RelationType, RelationType.Name == "replaces").first()
replaces = db.query(RelationType, RelationType.Name == "replaces").first()
assert replaces is not None
assert replaces.Name == "replaces"

View file

@ -1,6 +1,6 @@
import pytest
from aurweb.db import create, delete, query
from aurweb import db
from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID, RequestType
from aurweb.testing import setup_test_db
@ -11,25 +11,33 @@ def setup():
def test_request_type_creation():
request_type = create(RequestType, Name="Test Request")
with db.begin():
request_type = db.create(RequestType, Name="Test Request")
assert bool(request_type.ID)
assert request_type.Name == "Test Request"
delete(RequestType, RequestType.ID == request_type.ID)
with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID)
def test_request_type_null_name_returns_empty_string():
request_type = create(RequestType)
with db.begin():
request_type = db.create(RequestType)
assert bool(request_type.ID)
assert request_type.Name == str()
delete(RequestType, RequestType.ID == request_type.ID)
with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID)
def test_request_type_name_display():
deletion = query(RequestType, RequestType.ID == DELETION_ID).first()
deletion = db.query(RequestType, RequestType.ID == DELETION_ID).first()
assert deletion.name_display() == "Deletion"
orphan = query(RequestType, RequestType.ID == ORPHAN_ID).first()
orphan = db.query(RequestType, RequestType.ID == ORPHAN_ID).first()
assert orphan.name_display() == "Orphan"
merge = query(RequestType, RequestType.ID == MERGE_ID).first()
merge = db.query(RequestType, RequestType.ID == MERGE_ID).first()
assert merge.name_display() == "Merge"

View file

@ -8,8 +8,8 @@ import pytest
from fastapi.testclient import TestClient
from aurweb import db
from aurweb.asgi import app
from aurweb.db import create, query
from aurweb.models.account_type import AccountType
from aurweb.models.user import User
from aurweb.testing import setup_test_db
@ -24,9 +24,11 @@ def setup():
setup_test_db("Users", "Sessions")
account_type = query(AccountType,
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)

View file

@ -49,14 +49,13 @@ def packages(user):
now = int(datetime.utcnow().timestamp())
# Create 101 packages; we limit 100 on RSS feeds.
with db.begin():
for i in range(101):
pkgbase = db.create(
PackageBase, Maintainer=user, Name=f"test-package-{i}",
SubmittedTS=(now + i), ModifiedTS=(now + i), autocommit=False)
pkg = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase,
autocommit=False)
SubmittedTS=(now + i), ModifiedTS=(now + i))
pkg = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
pkgs.append(pkg)
db.commit()
yield pkgs

View file

@ -4,7 +4,7 @@ from unittest import mock
import pytest
from aurweb.db import create, query
from aurweb import db
from aurweb.models.account_type import AccountType
from aurweb.models.session import Session, generate_unique_sid
from aurweb.models.user import User
@ -19,12 +19,15 @@ def setup():
setup_test_db("Users", "Sessions")
account_type = query(AccountType,
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
ResetKey="testReset", Passwd="testPassword",
AccountType=account_type)
session = create(Session, UsersID=user.ID, SessionID="testSession",
with db.begin():
session = db.create(Session, UsersID=user.ID, SessionID="testSession",
LastUpdateTS=datetime.utcnow().timestamp())
@ -35,10 +38,13 @@ def test_session():
def test_session_cs():
""" Test case sensitivity of the database table. """
user2 = create(User, Username="test2", Email="test2@example.org",
with db.begin():
user2 = db.create(User, Username="test2", Email="test2@example.org",
ResetKey="testReset2", Passwd="testPassword",
AccountType=account_type)
session_cs = create(Session, UsersID=user2.ID,
with db.begin():
session_cs = db.create(Session, UsersID=user2.ID,
SessionID="TESTSESSION",
LastUpdateTS=datetime.utcnow().timestamp())
assert session_cs.SessionID == "TESTSESSION"

View file

@ -1,6 +1,6 @@
import pytest
from aurweb.db import create, query
from aurweb import db
from aurweb.models.account_type import AccountType
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
from aurweb.models.user import User
@ -19,16 +19,15 @@ def setup():
setup_test_db("Users", "SSHPubKeys")
account_type = query(AccountType,
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
assert account_type == user.AccountType
assert account_type.ID == user.AccountTypeID
ssh_pub_key = create(SSHPubKey,
with db.begin():
ssh_pub_key = db.create(SSHPubKey,
UserID=user.ID,
Fingerprint="testFingerprint",
PubKey="testPubKey")
@ -43,7 +42,8 @@ def test_ssh_pub_key():
def test_ssh_pub_key_cs():
""" Test case sensitivity of the database table. """
ssh_pub_key_cs = create(SSHPubKey, UserID=user.ID,
with db.begin():
ssh_pub_key_cs = db.create(SSHPubKey, UserID=user.ID,
Fingerprint="TESTFINGERPRINT",
PubKey="TESTPUBKEY")

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create
from aurweb import db
from aurweb.models.term import Term
from aurweb.testing import setup_test_db
@ -18,7 +18,8 @@ def setup():
def test_term_creation():
term = create(Term, Description="Term description",
with db.begin():
term = db.create(Term, Description="Term description",
URL="https://fake_url.io")
assert bool(term.ID)
assert term.Description == "Term description"
@ -27,14 +28,14 @@ def test_term_creation():
def test_term_null_description_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(Term, URL="https://fake_url.io")
session.rollback()
with db.begin():
db.create(Term, URL="https://fake_url.io")
db.rollback()
def test_term_null_url_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(Term, Description="Term description")
session.rollback()
with db.begin():
db.create(Term, Description="Term description")
db.rollback()

View file

@ -90,33 +90,33 @@ def client():
def tu_user():
tu_type = db.query(AccountType,
AccountType.AccountType == "Trusted User").first()
yield db.create(User, Username="test_tu", Email="test_tu@example.org",
with db.begin():
tu_user = db.create(User, Username="test_tu",
Email="test_tu@example.org",
RealName="Test TU", Passwd="testPassword",
AccountType=tu_type)
yield tu_user
@pytest.fixture
def user():
user_type = db.query(AccountType,
AccountType.AccountType == "User").first()
yield db.create(User, Username="test", Email="test@example.org",
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=user_type)
yield user
@pytest.fixture
def proposal(tu_user):
def proposal(user, tu_user):
ts = int(datetime.utcnow().timestamp())
agenda = "Test proposal."
start = ts - 5
end = ts + 1000
user_type = db.query(AccountType,
AccountType.AccountType == "User").first()
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=user_type)
with db.begin():
voteinfo = db.create(TUVoteInfo,
Agenda=agenda, Quorum=0.0,
User=user.Username, Submitter=tu_user,
@ -170,6 +170,7 @@ def test_tu_index(client, tu_user):
("Test agenda 2", ts - 1000, ts - 5) # Not running anymore.
]
vote_records = []
with db.begin():
for vote in votes:
agenda, start, end = vote
vote_records.append(
@ -179,6 +180,7 @@ def test_tu_index(client, tu_user):
Quorum=0.0,
Submitter=tu_user))
with db.begin():
# Vote on an ended proposal.
vote_record = vote_records[1]
vote_record.Yes += 1
@ -255,13 +257,14 @@ def test_tu_index(client, tu_user):
def test_tu_index_table_paging(client, tu_user):
ts = int(datetime.utcnow().timestamp())
with db.begin():
for i in range(25):
# Create 25 current votes.
db.create(TUVoteInfo, Agenda=f"Agenda #{i}",
User=tu_user.Username,
Submitted=(ts - 5), End=(ts + 1000),
Quorum=0.0,
Submitter=tu_user, autocommit=False)
Submitter=tu_user)
for i in range(25):
# Create 25 past votes.
@ -269,8 +272,7 @@ def test_tu_index_table_paging(client, tu_user):
User=tu_user.Username,
Submitted=(ts - 1000), End=(ts - 5),
Quorum=0.0,
Submitter=tu_user, autocommit=False)
db.commit()
Submitter=tu_user)
cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request:
@ -363,13 +365,14 @@ def test_tu_index_table_paging(client, tu_user):
def test_tu_index_sorting(client, tu_user):
ts = int(datetime.utcnow().timestamp())
with db.begin():
for i in range(2):
# Create 'Agenda #1' and 'Agenda #2'.
db.create(TUVoteInfo, Agenda=f"Agenda #{i + 1}",
User=tu_user.Username,
Submitted=(ts + 5), End=(ts + 1000),
Quorum=0.0,
Submitter=tu_user, autocommit=False)
Submitter=tu_user)
# Let's order each vote one day after the other.
# This will allow us to test the sorting nature
@ -432,6 +435,7 @@ def test_tu_index_sorting(client, tu_user):
def test_tu_index_last_votes(client, tu_user, user):
ts = int(datetime.utcnow().timestamp())
with db.begin():
# Create a proposal which has ended.
voteinfo = db.create(TUVoteInfo, Agenda="Test agenda",
User=user.Username,
@ -529,10 +533,10 @@ def test_tu_running_proposal(client, proposal):
assert abstain.attrib["value"] == "Abstain"
# Create a vote.
with db.begin():
db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
voteinfo.ActiveTUs += 1
voteinfo.Yes += 1
db.commit()
# Make another request now that we've voted.
with client as request:
@ -556,8 +560,8 @@ def test_tu_ended_proposal(client, proposal):
tu_user, user, voteinfo = proposal
ts = int(datetime.utcnow().timestamp())
with db.begin():
voteinfo.End = ts - 5 # 5 seconds ago.
db.commit()
# Initiate an authenticated GET request to /tu/{proposal_id}.
proposal_id = voteinfo.ID
@ -635,8 +639,8 @@ def test_tu_proposal_vote_unauthorized(client, proposal):
dev_type = db.query(AccountType,
AccountType.AccountType == "Developer").first()
with db.begin():
tu_user.AccountType = dev_type
db.commit()
cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request:
@ -664,8 +668,8 @@ def test_tu_proposal_vote_cant_self_vote(client, proposal):
tu_user, user, voteinfo = proposal
# Update voteinfo.User.
with db.begin():
voteinfo.User = tu_user.Username
db.commit()
cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request:
@ -692,10 +696,10 @@ def test_tu_proposal_vote_cant_self_vote(client, proposal):
def test_tu_proposal_vote_already_voted(client, proposal):
tu_user, user, voteinfo = proposal
with db.begin():
db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
voteinfo.Yes += 1
voteinfo.ActiveTUs += 1
db.commit()
cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request:

View file

@ -4,7 +4,8 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query, rollback
from aurweb import db
from aurweb.db import create, query, rollback
from aurweb.models.account_type import AccountType
from aurweb.models.tu_voteinfo import TUVoteInfo
from aurweb.models.user import User
@ -21,6 +22,7 @@ def setup():
tu_type = query(AccountType,
AccountType.AccountType == "Trusted User").first()
with db.begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=tu_type)
@ -28,6 +30,7 @@ def setup():
def test_tu_voteinfo_creation():
ts = int(datetime.utcnow().timestamp())
with db.begin():
tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
@ -51,6 +54,7 @@ def test_tu_voteinfo_creation():
def test_tu_voteinfo_is_running():
ts = int(datetime.utcnow().timestamp())
with db.begin():
tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
@ -59,13 +63,14 @@ def test_tu_voteinfo_is_running():
Submitter=user)
assert tu_voteinfo.is_running() is True
with db.begin():
tu_voteinfo.End = ts - 5
commit()
assert tu_voteinfo.is_running() is False
def test_tu_voteinfo_total_votes():
ts = int(datetime.utcnow().timestamp())
with db.begin():
tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
@ -76,7 +81,6 @@ def test_tu_voteinfo_total_votes():
tu_voteinfo.Yes = 1
tu_voteinfo.No = 3
tu_voteinfo.Abstain = 5
commit()
# total_votes() should be the sum of Yes, No and Abstain: 1 + 3 + 5 = 9.
assert tu_voteinfo.total_votes() == 9
@ -84,6 +88,7 @@ def test_tu_voteinfo_total_votes():
def test_tu_voteinfo_null_submitter_raises_exception():
with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
@ -94,6 +99,7 @@ def test_tu_voteinfo_null_submitter_raises_exception():
def test_tu_voteinfo_null_agenda_raises_exception():
with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo,
User=user.Username,
Submitted=0, End=0,
@ -104,6 +110,7 @@ def test_tu_voteinfo_null_agenda_raises_exception():
def test_tu_voteinfo_null_user_raises_exception():
with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo,
Agenda="Blah blah.",
Submitted=0, End=0,
@ -114,6 +121,7 @@ def test_tu_voteinfo_null_user_raises_exception():
def test_tu_voteinfo_null_submitted_raises_exception():
with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
@ -125,6 +133,7 @@ def test_tu_voteinfo_null_submitted_raises_exception():
def test_tu_voteinfo_null_end_raises_exception():
with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
@ -136,6 +145,7 @@ def test_tu_voteinfo_null_end_raises_exception():
def test_tu_voteinfo_null_quorum_raises_exception():
with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,

View file

@ -9,7 +9,7 @@ import pytest
import aurweb.auth
import aurweb.config
from aurweb.db import commit, create, query
from aurweb import db
from aurweb.models.account_type import AccountType
from aurweb.models.ban import Ban
from aurweb.models.package import Package
@ -40,10 +40,11 @@ def setup():
PackageNotification.__tablename__
)
account_type = query(AccountType,
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
@ -70,14 +71,14 @@ def test_user_login_logout():
assert "AURSID" in request.cookies
# Expect that User session relationships work right.
user_session = query(Session,
user_session = db.query(Session,
Session.UsersID == user.ID).first()
assert user_session == user.session
assert user.session.SessionID == sid
assert user.session.User == user
# Search for the user via query API.
result = query(User, User.ID == user.ID).first()
result = db.query(User, User.ID == user.ID).first()
# Compare the result and our original user.
assert result == user
@ -114,7 +115,8 @@ def test_user_login_twice():
def test_user_login_banned():
# Add ban for the next 30 seconds.
banned_timestamp = datetime.utcnow() + timedelta(seconds=30)
create(Ban, IPAddress="127.0.0.1", BanTS=banned_timestamp)
with db.begin():
db.create(Ban, IPAddress="127.0.0.1", BanTS=banned_timestamp)
request = Request()
request.client.host = "127.0.0.1"
@ -122,18 +124,17 @@ def test_user_login_banned():
def test_user_login_suspended():
from aurweb.db import session
with db.begin():
user.Suspended = True
session.commit()
assert not user.login(Request(), "testPassword")
def test_legacy_user_authentication():
from aurweb.db import session
with db.begin():
user.Salt = bcrypt.gensalt().decode()
user.Passwd = hashlib.md5(f"{user.Salt}testPassword".encode()).hexdigest()
session.commit()
user.Passwd = hashlib.md5(
f"{user.Salt}testPassword".encode()
).hexdigest()
assert not user.valid_password("badPassword")
assert user.valid_password("testPassword")
@ -145,7 +146,8 @@ def test_legacy_user_authentication():
def test_user_login_with_outdated_sid():
# Make a session with a LastUpdateTS 5 seconds ago, causing
# user.login to update it with a new sid.
create(Session, UsersID=user.ID, SessionID="stub",
with db.begin():
db.create(Session, UsersID=user.ID, SessionID="stub",
LastUpdateTS=datetime.utcnow().timestamp() - 5)
sid = user.login(Request(), "testPassword")
assert sid and user.is_authenticated()
@ -171,7 +173,8 @@ def test_user_has_credential():
def test_user_ssh_pub_key():
assert user.ssh_pub_key is None
ssh_pub_key = create(SSHPubKey, UserID=user.ID,
with db.begin():
ssh_pub_key = db.create(SSHPubKey, UserID=user.ID,
Fingerprint="testFingerprint",
PubKey="testPubKey")
@ -179,35 +182,33 @@ def test_user_ssh_pub_key():
def test_user_credential_types():
from aurweb.db import session
assert aurweb.auth.user_developer_or_trusted_user(user)
assert not aurweb.auth.trusted_user(user)
assert not aurweb.auth.developer(user)
assert not aurweb.auth.trusted_user_or_dev(user)
trusted_user_type = query(AccountType,
AccountType.AccountType == "Trusted User")\
.first()
trusted_user_type = db.query(AccountType).filter(
AccountType.AccountType == "Trusted User"
).first()
with db.begin():
user.AccountType = trusted_user_type
session.commit()
assert aurweb.auth.trusted_user(user)
assert aurweb.auth.trusted_user_or_dev(user)
developer_type = query(AccountType,
developer_type = db.query(AccountType,
AccountType.AccountType == "Developer").first()
with db.begin():
user.AccountType = developer_type
session.commit()
assert aurweb.auth.developer(user)
assert aurweb.auth.trusted_user_or_dev(user)
type_str = "Trusted User & Developer"
elevated_type = query(AccountType,
elevated_type = db.query(AccountType,
AccountType.AccountType == type_str).first()
with db.begin():
user.AccountType = elevated_type
session.commit()
assert aurweb.auth.trusted_user(user)
assert aurweb.auth.developer(user)
@ -233,53 +234,56 @@ def test_user_as_dict():
def test_user_is_trusted_user():
tu_type = query(AccountType,
tu_type = db.query(AccountType,
AccountType.AccountType == "Trusted User").first()
with db.begin():
user.AccountType = tu_type
commit()
assert user.is_trusted_user() is True
# Do it again with the combined role.
tu_type = query(
tu_type = db.query(
AccountType,
AccountType.AccountType == "Trusted User & Developer").first()
with db.begin():
user.AccountType = tu_type
commit()
assert user.is_trusted_user() is True
def test_user_is_developer():
dev_type = query(AccountType,
dev_type = db.query(AccountType,
AccountType.AccountType == "Developer").first()
with db.begin():
user.AccountType = dev_type
commit()
assert user.is_developer() is True
# Do it again with the combined role.
dev_type = query(
dev_type = db.query(
AccountType,
AccountType.AccountType == "Trusted User & Developer").first()
with db.begin():
user.AccountType = dev_type
commit()
assert user.is_developer() is True
def test_user_voted_for():
now = int(datetime.utcnow().timestamp())
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user)
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
create(PackageVote, PackageBase=pkgbase, User=user, VoteTS=now)
with db.begin():
pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
db.create(PackageVote, PackageBase=pkgbase, User=user, VoteTS=now)
assert user.voted_for(pkg)
def test_user_notified():
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user)
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
create(PackageNotification, PackageBase=pkgbase, User=user)
with db.begin():
pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
db.create(PackageNotification, PackageBase=pkgbase, User=user)
assert user.notified(pkg)
def test_user_packages():
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user)
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
with db.begin():
pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
assert pkg in user.packages()