[FastAPI] Refactor db modifications

For SQLAlchemy to automatically understand updates from the
external world, it must use an `autocommit=True` in its session.

This change breaks how we were using commit previously, as
`autocommit=True` causes SQLAlchemy to commit when a
SessionTransaction context hits __exit__.

So, a refactoring was required of our tests: All usage of
any `db.{create,delete}` must be called **within** a
SessionTransaction context, created via new `db.begin()`.

From this point forward, we're going to require:

```
with db.begin():
    db.create(...)
    db.delete(...)
    db.session.delete(object)
```

With this, we now get external DB modifications automatically
without reloading or restarting the FastAPI server, which we
absolutely need for production.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2021-09-02 16:26:48 -07:00
parent b52059d437
commit a5943bf2ad
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
37 changed files with 998 additions and 902 deletions

View file

@ -59,20 +59,15 @@ def query(model, *args, **kwargs):
return session.query(model).filter(*args, **kwargs) return session.query(model).filter(*args, **kwargs)
def create(model, autocommit: bool = True, *args, **kwargs): def create(model, *args, **kwargs):
instance = model(*args, **kwargs) instance = model(*args, **kwargs)
add(instance) return add(instance)
if autocommit is True:
commit()
return instance
def delete(model, *args, autocommit: bool = True, **kwargs): def delete(model, *args, **kwargs):
instance = session.query(model).filter(*args, **kwargs) instance = session.query(model).filter(*args, **kwargs)
for record in instance: for record in instance:
session.delete(record) session.delete(record)
if autocommit is True:
commit()
def rollback(): def rollback():
@ -84,8 +79,25 @@ def add(model):
return model return model
def commit(): def begin():
session.commit() """ Begin an SQLAlchemy SessionTransaction.
This context is **required** to perform an modifications to the
database.
Example:
with db.begin():
object = db.create(...)
# On __exit__, db.commit() is run.
with db.begin():
object = db.delete(...)
# On __exit__, db.commit() is run.
:return: A new SessionTransaction based on session
"""
return session.begin()
def get_sqlalchemy_url(): def get_sqlalchemy_url():
@ -155,23 +167,23 @@ def get_engine(echo: bool = False):
connect_args=connect_args, connect_args=connect_args,
echo=echo) echo=echo)
Session = sessionmaker(autocommit=True, autoflush=False, bind=engine)
session = Session()
if db_backend == "sqlite": if db_backend == "sqlite":
# For SQLite, we need to add some custom functions as # For SQLite, we need to add some custom functions as
# they are used in the reference graph method. # they are used in the reference graph method.
def regexp(regex, item): def regexp(regex, item):
return bool(re.search(regex, str(item))) return bool(re.search(regex, str(item)))
@event.listens_for(engine, "begin") @event.listens_for(engine, "connect")
def do_begin(conn): def do_begin(conn, record):
create_deterministic_function = functools.partial( create_deterministic_function = functools.partial(
conn.connection.create_function, conn.create_function,
deterministic=True deterministic=True
) )
create_deterministic_function("REGEXP", 2, regexp) create_deterministic_function("REGEXP", 2, regexp)
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
return engine return engine

View file

@ -102,7 +102,7 @@ class User(Base):
def login(self, request: Request, password: str, session_time=0): def login(self, request: Request, password: str, session_time=0):
""" Login and authenticate a request. """ """ Login and authenticate a request. """
from aurweb.db import session from aurweb import db
from aurweb.models.session import Session, generate_unique_sid from aurweb.models.session import Session, generate_unique_sid
if not self._login_approved(request): if not self._login_approved(request):
@ -112,10 +112,7 @@ class User(Base):
if not self.authenticated: if not self.authenticated:
return None return None
self.LastLogin = now_ts = datetime.utcnow().timestamp() now_ts = datetime.utcnow().timestamp()
self.LastLoginIPAddress = request.client.host
session.commit()
session_ts = now_ts + ( session_ts = now_ts + (
session_time if session_time session_time if session_time
else aurweb.config.getint("options", "login_timeout") else aurweb.config.getint("options", "login_timeout")
@ -123,11 +120,14 @@ class User(Base):
sid = None sid = None
with db.begin():
self.LastLogin = now_ts
self.LastLoginIPAddress = request.client.host
if not self.session: if not self.session:
sid = generate_unique_sid() sid = generate_unique_sid()
self.session = Session(UsersID=self.ID, SessionID=sid, self.session = Session(UsersID=self.ID, SessionID=sid,
LastUpdateTS=session_ts) LastUpdateTS=session_ts)
session.add(self.session) db.add(self.session)
else: else:
last_updated = self.session.LastUpdateTS last_updated = self.session.LastUpdateTS
if last_updated and last_updated < now_ts: if last_updated and last_updated < now_ts:
@ -138,8 +138,6 @@ class User(Base):
self.session.LastUpdateTS = session_ts self.session.LastUpdateTS = session_ts
session.commit()
request.cookies["AURSID"] = self.session.SessionID request.cookies["AURSID"] = self.session.SessionID
return self.session.SessionID return self.session.SessionID
@ -149,13 +147,11 @@ class User(Base):
return aurweb.auth.has_credential(self, cred, approved) return aurweb.auth.has_credential(self, cred, approved)
def logout(self, request): def logout(self, request):
from aurweb.db import session
del request.cookies["AURSID"] del request.cookies["AURSID"]
self.authenticated = False self.authenticated = False
if self.session: if self.session:
session.delete(self.session) with db.begin():
session.commit() db.session.delete(self.session)
def is_trusted_user(self): def is_trusted_user(self):
return self.AccountType.ID in { return self.AccountType.ID in {

View file

@ -43,8 +43,6 @@ async def passreset_post(request: Request,
resetkey: str = Form(default=None), resetkey: str = Form(default=None),
password: str = Form(default=None), password: str = Form(default=None),
confirm: str = Form(default=None)): confirm: str = Form(default=None)):
from aurweb.db import session
context = await make_variable_context(request, "Password Reset") context = await make_variable_context(request, "Password Reset")
# The user parameter being required, we can match against # The user parameter being required, we can match against
@ -86,12 +84,11 @@ async def passreset_post(request: Request,
# We got to this point; everything matched up. Update the password # We got to this point; everything matched up. Update the password
# and remove the ResetKey. # and remove the ResetKey.
with db.begin():
user.ResetKey = str() user.ResetKey = str()
user.update_password(password)
if user.session: if user.session:
session.delete(user.session) db.session.delete(user.session)
session.commit() user.update_password(password)
# Render ?step=complete. # Render ?step=complete.
return RedirectResponse(url="/passreset?step=complete", return RedirectResponse(url="/passreset?step=complete",
@ -99,8 +96,8 @@ async def passreset_post(request: Request,
# If we got here, we continue with issuing a resetkey for the user. # If we got here, we continue with issuing a resetkey for the user.
resetkey = db.make_random_value(User, User.ResetKey) resetkey = db.make_random_value(User, User.ResetKey)
with db.begin():
user.ResetKey = resetkey user.ResetKey = resetkey
session.commit()
executor = db.ConnectionExecutor(db.get_engine().raw_connection()) executor = db.ConnectionExecutor(db.get_engine().raw_connection())
ResetKeyNotification(executor, user.ID).send() ResetKeyNotification(executor, user.ID).send()
@ -364,8 +361,6 @@ async def account_register_post(request: Request,
ON: bool = Form(default=False), ON: bool = Form(default=False),
captcha: str = Form(default=None), captcha: str = Form(default=None),
captcha_salt: str = Form(...)): captcha_salt: str = Form(...)):
from aurweb.db import session
context = await make_variable_context(request, "Register") context = await make_variable_context(request, "Register")
args = dict(await request.form()) args = dict(await request.form())
@ -394,11 +389,13 @@ async def account_register_post(request: Request,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
# Create a user given all parameters available. # Create a user given all parameters available.
user = db.create(User, Username=U, Email=E, HideEmail=H, BackupEmail=BE, with db.begin():
user = db.create(User, Username=U,
Email=E, HideEmail=H, BackupEmail=BE,
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K, RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
LangPreference=L, Timezone=TZ, CommentNotify=CN, LangPreference=L, Timezone=TZ, CommentNotify=CN,
UpdateNotify=UN, OwnershipNotify=ON, ResetKey=resetkey, UpdateNotify=UN, OwnershipNotify=ON,
AccountType=account_type) ResetKey=resetkey, AccountType=account_type)
# If a PK was given and either one does not exist or the given # If a PK was given and either one does not exist or the given
# PK mismatches the existing user's SSHPubKey.PubKey. # PK mismatches the existing user's SSHPubKey.PubKey.
@ -410,10 +407,10 @@ async def account_register_post(request: Request,
# Remove the host part. # Remove the host part.
pubkey = parts[0] + " " + parts[1] pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey) fingerprint = get_fingerprint(pubkey)
with db.begin():
user.ssh_pub_key = SSHPubKey(UserID=user.ID, user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey, PubKey=pubkey,
Fingerprint=fingerprint) Fingerprint=fingerprint)
session.commit()
# Send a reset key notification to the new user. # Send a reset key notification to the new user.
executor = db.ConnectionExecutor(db.get_engine().raw_connection()) executor = db.ConnectionExecutor(db.get_engine().raw_connection())
@ -499,6 +496,7 @@ async def account_edit_post(request: Request,
status_code=int(HTTPStatus.BAD_REQUEST)) status_code=int(HTTPStatus.BAD_REQUEST))
# Set all updated fields as needed. # Set all updated fields as needed.
with db.begin():
user.Username = U or user.Username user.Username = U or user.Username
user.Email = E or user.Email user.Email = E or user.Email
user.HideEmail = bool(H) user.HideEmail = bool(H)
@ -512,20 +510,24 @@ async def account_edit_post(request: Request,
# If we update the language, update the cookie as well. # If we update the language, update the cookie as well.
if L and L != user.LangPreference: if L and L != user.LangPreference:
request.cookies["AURLANG"] = L request.cookies["AURLANG"] = L
with db.begin():
user.LangPreference = L user.LangPreference = L
context["language"] = L context["language"] = L
# If we update the timezone, also update the cookie. # If we update the timezone, also update the cookie.
if TZ and TZ != user.Timezone: if TZ and TZ != user.Timezone:
with db.begin():
user.Timezone = TZ user.Timezone = TZ
request.cookies["AURTZ"] = TZ request.cookies["AURTZ"] = TZ
context["timezone"] = TZ context["timezone"] = TZ
with db.begin():
user.CommentNotify = bool(CN) user.CommentNotify = bool(CN)
user.UpdateNotify = bool(UN) user.UpdateNotify = bool(UN)
user.OwnershipNotify = bool(ON) user.OwnershipNotify = bool(ON)
# If a PK is given, compare it against the target user's PK. # If a PK is given, compare it against the target user's PK.
with db.begin():
if PK: if PK:
# Get the second token in the public key, which is the actual key. # Get the second token in the public key, which is the actual key.
pubkey = PK.strip().rstrip() pubkey = PK.strip().rstrip()
@ -547,15 +549,14 @@ async def account_edit_post(request: Request,
# Else, if the user has a public key already, delete it. # Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key) session.delete(user.ssh_pub_key)
# Commit changes, if any.
session.commit()
if P and not user.valid_password(P): if P and not user.valid_password(P):
# Remove the fields we consumed for passwords. # Remove the fields we consumed for passwords.
context["P"] = context["C"] = str() context["P"] = context["C"] = str()
# If a password was given and it doesn't match the user's, update it. # If a password was given and it doesn't match the user's, update it.
with db.begin():
user.update_password(P) user.update_password(P)
if user == request.user: if user == request.user:
# If the target user is the request user, login with # If the target user is the request user, login with
# the updated password and update AURSID. # the updated password and update AURSID.
@ -731,6 +732,7 @@ async def terms_of_service_post(request: Request,
accept_needed = sorted(unaccepted + diffs) accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(request, context, accept_needed) return render_terms_of_service(request, context, accept_needed)
with db.begin():
# For each term we found, query for the matching accepted term # For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision. # and update its Revision to the term's current Revision.
for term in diffs: for term in diffs:
@ -741,11 +743,6 @@ async def terms_of_service_post(request: Request,
# For each term that was never accepted, accept it! # For each term that was never accepted, accept it!
for term in unaccepted: for term in unaccepted:
db.create(AcceptedTerm, User=request.user, db.create(AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision, Term=term, Revision=term.Revision)
autocommit=False)
if diffs or unaccepted:
# If we had any terms to update, commit the changes.
db.commit()
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER)) return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))

View file

@ -44,8 +44,6 @@ async def language(request: Request,
setting the language on any page, we want to preserve query setting the language on any page, we want to preserve query
parameters across the redirect. parameters across the redirect.
""" """
from aurweb.db import session
if next[0] != '/': if next[0] != '/':
return HTMLResponse(b"Invalid 'next' parameter.", status_code=400) return HTMLResponse(b"Invalid 'next' parameter.", status_code=400)
@ -53,8 +51,8 @@ async def language(request: Request,
# If the user is authenticated, update the user's LangPreference. # If the user is authenticated, update the user's LangPreference.
if request.user.is_authenticated(): if request.user.is_authenticated():
with db.begin():
request.user.LangPreference = set_lang request.user.LangPreference = set_lang
session.commit()
# In any case, set the response's AURLANG cookie that never expires. # In any case, set the response's AURLANG cookie that never expires.
response = RedirectResponse(url=f"{next}{query_string}", response = RedirectResponse(url=f"{next}{query_string}",

View file

@ -214,10 +214,9 @@ async def trusted_user_proposal_post(request: Request,
return Response("Invalid 'decision' value.", return Response("Invalid 'decision' value.",
status_code=int(HTTPStatus.BAD_REQUEST)) status_code=int(HTTPStatus.BAD_REQUEST))
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo, with db.begin():
autocommit=False) vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo)
voteinfo.ActiveTUs += 1 voteinfo.ActiveTUs += 1
db.commit()
context["error"] = "You've already voted for this proposal." context["error"] = "You've already voted for this proposal."
return render_proposal(request, context, proposal, voteinfo, voters, vote) return render_proposal(request, context, proposal, voteinfo, voters, vote)
@ -294,6 +293,7 @@ async def trusted_user_addvote_post(request: Request,
agenda = re.sub(r'<[/]?style.*>', '', agenda) agenda = re.sub(r'<[/]?style.*>', '', agenda)
# Create a new TUVoteInfo (proposal)! # Create a new TUVoteInfo (proposal)!
with db.begin():
voteinfo = db.create(TUVoteInfo, voteinfo = db.create(TUVoteInfo,
User=user, User=user,
Agenda=agenda, Agenda=agenda,

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aurweb.db import create, delete, query from aurweb.db import begin, create, delete, query
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.user import User from aurweb.models.user import User
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -14,10 +14,12 @@ def setup():
global account_type global account_type
with begin():
account_type = create(AccountType, AccountType="TestUser") account_type = create(AccountType, AccountType="TestUser")
yield account_type yield account_type
with begin():
delete(AccountType, AccountType.ID == account_type.ID) delete(AccountType, AccountType.ID == account_type.ID)
@ -38,6 +40,7 @@ def test_account_type():
def test_user_account_type_relationship(): def test_user_account_type_relationship():
with begin():
user = create(User, Username="test", Email="test@example.org", user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
@ -46,4 +49,5 @@ def test_user_account_type_relationship():
# This must be deleted here to avoid foreign key issues when # This must be deleted here to avoid foreign key issues when
# deleting the temporary AccountType in the fixture. # deleting the temporary AccountType in the fixture.
with begin():
delete(User, User.ID == user.ID) delete(User, User.ID == user.ID)

View file

@ -11,9 +11,9 @@ import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from aurweb import captcha from aurweb import captcha, db
from aurweb.asgi import app from aurweb.asgi import app
from aurweb.db import commit, create, query from aurweb.db import create, query
from aurweb.models.accepted_term import AcceptedTerm from aurweb.models.accepted_term import AcceptedTerm
from aurweb.models.account_type import DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, AccountType from aurweb.models.account_type import DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, AccountType
from aurweb.models.ban import Ban from aurweb.models.ban import Ban
@ -57,6 +57,8 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
with db.begin():
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL, user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test UserZ", Passwd="testPassword", RealName="Test UserZ", Passwd="testPassword",
IRCNick="testZ", AccountType=account_type) IRCNick="testZ", AccountType=account_type)
@ -70,9 +72,10 @@ def setup():
@pytest.fixture @pytest.fixture
def tu_user(): def tu_user():
user.AccountType = query(AccountType, with db.begin():
AccountType.ID == TRUSTED_USER_AND_DEV_ID).first() user.AccountType = query(AccountType).filter(
commit() AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
yield user yield user
@ -149,11 +152,9 @@ def test_post_passreset_user():
def test_post_passreset_resetkey(): def test_post_passreset_resetkey():
from aurweb.db import session with db.begin():
user.session = Session(UsersID=user.ID, SessionID="blah", user.session = Session(UsersID=user.ID, SessionID="blah",
LastUpdateTS=datetime.utcnow().timestamp()) LastUpdateTS=datetime.utcnow().timestamp())
session.commit()
# Prepare a password reset. # Prepare a password reset.
with client as request: with client as request:
@ -357,6 +358,7 @@ def test_post_register_error_invalid_captcha():
def test_post_register_error_ip_banned(): def test_post_register_error_ip_banned():
# 'testclient' is used as request.client.host via FastAPI TestClient. # 'testclient' is used as request.client.host via FastAPI TestClient.
with db.begin():
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow()) create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
with client as request: with client as request:
@ -576,6 +578,7 @@ def test_post_register_error_ssh_pubkey_taken():
# Take the sha256 fingerprint of the ssh public key, create it. # Take the sha256 fingerprint of the ssh public key, create it.
fp = get_fingerprint(pk) fp = get_fingerprint(pk)
with db.begin():
create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp) create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp)
with client as request: with client as request:
@ -660,13 +663,11 @@ def test_post_account_edit():
def test_post_account_edit_dev(): def test_post_account_edit_dev():
from aurweb.db import session
# Modify our user to be a "Trusted User & Developer" # Modify our user to be a "Trusted User & Developer"
name = "Trusted User & Developer" name = "Trusted User & Developer"
tu_or_dev = query(AccountType, AccountType.AccountType == name).first() tu_or_dev = query(AccountType, AccountType.AccountType == name).first()
with db.begin():
user.AccountType = tu_or_dev user.AccountType = tu_or_dev
session.commit()
request = Request() request = Request()
sid = user.login(request, "testPassword") sid = user.login(request, "testPassword")
@ -1001,22 +1002,20 @@ def get_rows(html):
def test_post_accounts(tu_user): def test_post_accounts(tu_user):
# Set a PGPKey. # Set a PGPKey.
with db.begin():
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30" user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
# Create a few more users. # Create a few more users.
users = [user] users = [user]
with db.begin():
for i in range(10): for i in range(10):
_user = create(User, Username=f"test_{i}", _user = create(User, Username=f"test_{i}",
Email=f"test_{i}@example.org", Email=f"test_{i}@example.org",
RealName=f"Test #{i}", RealName=f"Test #{i}",
Passwd="testPassword", Passwd="testPassword",
IRCNick=f"test_#{i}", IRCNick=f"test_#{i}")
autocommit=False)
users.append(_user) users.append(_user)
# Commit everything to the database.
commit()
sid = user.login(Request(), "testPassword") sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid} cookies = {"AURSID": sid}
@ -1085,6 +1084,7 @@ def test_post_accounts_account_type(tu_user):
# test the `u` parameter. # test the `u` parameter.
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
with db.begin():
create(User, Username="test_2", create(User, Username="test_2",
Email="test_2@example.org", Email="test_2@example.org",
RealName="Test User 2", RealName="Test User 2",
@ -1113,9 +1113,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "User" assert type.text.strip() == "User"
# Set our only user to a Trusted User. # Set our only user to a Trusted User.
user.AccountType = query(AccountType, with db.begin():
AccountType.ID == TRUSTED_USER_ID).first() user.AccountType = query(AccountType).filter(
commit() AccountType.ID == TRUSTED_USER_ID
).first()
with client as request: with client as request:
response = request.post("/accounts/", cookies=cookies, response = request.post("/accounts/", cookies=cookies,
@ -1130,9 +1131,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "Trusted User" assert type.text.strip() == "Trusted User"
user.AccountType = query(AccountType, with db.begin():
AccountType.ID == DEVELOPER_ID).first() user.AccountType = query(AccountType).filter(
commit() AccountType.ID == DEVELOPER_ID
).first()
with client as request: with client as request:
response = request.post("/accounts/", cookies=cookies, response = request.post("/accounts/", cookies=cookies,
@ -1147,10 +1149,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "Developer" assert type.text.strip() == "Developer"
user.AccountType = query(AccountType, with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first() ).first()
commit()
with client as request: with client as request:
response = request.post("/accounts/", cookies=cookies, response = request.post("/accounts/", cookies=cookies,
@ -1182,8 +1184,8 @@ def test_post_accounts_status(tu_user):
username, type, status, realname, irc, pgp_key, edit = row username, type, status, realname, irc, pgp_key, edit = row
assert status.text.strip() == "Active" assert status.text.strip() == "Active"
with db.begin():
user.Suspended = True user.Suspended = True
commit()
with client as request: with client as request:
response = request.post("/accounts/", cookies=cookies, response = request.post("/accounts/", cookies=cookies,
@ -1244,6 +1246,7 @@ def test_post_accounts_sortby(tu_user):
# Create a second user so we can compare sorts. # Create a second user so we can compare sorts.
account_type = query(AccountType, account_type = query(AccountType,
AccountType.ID == DEVELOPER_ID).first() AccountType.ID == DEVELOPER_ID).first()
with db.begin():
create(User, Username="test2", create(User, Username="test2",
Email="test2@example.org", Email="test2@example.org",
RealName="Test User 2", RealName="Test User 2",
@ -1297,9 +1300,10 @@ def test_post_accounts_sortby(tu_user):
# Test the rows are reversed when ordering by RealName. # Test the rows are reversed when ordering by RealName.
assert compare_text_values(4, first_rows, reversed(rows)) is True assert compare_text_values(4, first_rows, reversed(rows)) is True
user.AccountType = query(AccountType, with db.begin():
AccountType.ID == TRUSTED_USER_AND_DEV_ID).first() user.AccountType = query(AccountType).filter(
commit() AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
# Fetch first_rows again with our new AccountType ordering. # Fetch first_rows again with our new AccountType ordering.
with client as request: with client as request:
@ -1322,8 +1326,8 @@ def test_post_accounts_sortby(tu_user):
def test_post_accounts_pgp_key(tu_user): def test_post_accounts_pgp_key(tu_user):
with db.begin():
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30" user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
commit()
sid = user.login(Request(), "testPassword") sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid} cookies = {"AURSID": sid}
@ -1343,15 +1347,14 @@ def test_post_accounts_paged(tu_user):
users = [user] users = [user]
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
with db.begin():
for i in range(150): for i in range(150):
_user = create(User, Username=f"test_#{i}", _user = create(User, Username=f"test_#{i}",
Email=f"test_#{i}@example.org", Email=f"test_#{i}@example.org",
RealName=f"Test User #{i}", RealName=f"Test User #{i}",
Passwd="testPassword", Passwd="testPassword",
AccountType=account_type, AccountType=account_type)
autocommit=False)
users.append(_user) users.append(_user)
commit()
sid = user.login(Request(), "testPassword") sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid} cookies = {"AURSID": sid}
@ -1414,6 +1417,7 @@ def test_post_accounts_paged(tu_user):
def test_get_terms_of_service(): def test_get_terms_of_service():
with db.begin():
term = create(Term, Description="Test term.", term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1) URL="http://localhost", Revision=1)
@ -1436,6 +1440,7 @@ def test_get_terms_of_service():
response = request.get("/tos", cookies=cookies, allow_redirects=False) response = request.get("/tos", cookies=cookies, allow_redirects=False)
assert response.status_code == int(HTTPStatus.OK) assert response.status_code == int(HTTPStatus.OK)
with db.begin():
accepted_term = create(AcceptedTerm, User=user, accepted_term = create(AcceptedTerm, User=user,
Term=term, Revision=term.Revision) Term=term, Revision=term.Revision)
@ -1445,8 +1450,8 @@ def test_get_terms_of_service():
assert response.status_code == int(HTTPStatus.SEE_OTHER) assert response.status_code == int(HTTPStatus.SEE_OTHER)
# Bump the term's revision. # Bump the term's revision.
with db.begin():
term.Revision = 2 term.Revision = 2
commit()
with client as request: with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False) response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1454,8 +1459,8 @@ def test_get_terms_of_service():
# yet been agreed to via AcceptedTerm update. # yet been agreed to via AcceptedTerm update.
assert response.status_code == int(HTTPStatus.OK) assert response.status_code == int(HTTPStatus.OK)
with db.begin():
accepted_term.Revision = term.Revision accepted_term.Revision = term.Revision
commit()
with client as request: with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False) response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1471,6 +1476,7 @@ def test_post_terms_of_service():
cookies = {"AURSID": sid} # Auth cookie. cookies = {"AURSID": sid} # Auth cookie.
# Create a fresh Term. # Create a fresh Term.
with db.begin():
term = create(Term, Description="Test term.", term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1) URL="http://localhost", Revision=1)
@ -1497,8 +1503,8 @@ def test_post_terms_of_service():
assert accepted_term.Term == term assert accepted_term.Term == term
# Update the term to revision 2. # Update the term to revision 2.
with db.begin():
term.Revision = 2 term.Revision = 2
commit()
# A GET request gives us the new revision to accept. # A GET request gives us the new revision to accept.
with client as request: with client as request:

View file

@ -2,6 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb import db
from aurweb.db import create from aurweb.db import create
from aurweb.models.api_rate_limit import ApiRateLimit from aurweb.models.api_rate_limit import ApiRateLimit
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -13,6 +14,7 @@ def setup():
def test_api_rate_key_creation(): def test_api_rate_key_creation():
with db.begin():
rate = create(ApiRateLimit, IP="127.0.0.1", Requests=10, WindowStart=1) rate = create(ApiRateLimit, IP="127.0.0.1", Requests=10, WindowStart=1)
assert rate.IP == "127.0.0.1" assert rate.IP == "127.0.0.1"
assert rate.Requests == 10 assert rate.Requests == 10
@ -20,19 +22,20 @@ def test_api_rate_key_creation():
def test_api_rate_key_ip_default(): def test_api_rate_key_ip_default():
with db.begin():
api_rate_limit = create(ApiRateLimit, Requests=10, WindowStart=1) api_rate_limit = create(ApiRateLimit, Requests=10, WindowStart=1)
assert api_rate_limit.IP == str() assert api_rate_limit.IP == str()
def test_api_rate_key_null_requests_raises_exception(): def test_api_rate_key_null_requests_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(ApiRateLimit, IP="127.0.0.1", WindowStart=1) create(ApiRateLimit, IP="127.0.0.1", WindowStart=1)
session.rollback() db.rollback()
def test_api_rate_key_null_window_start_raises_exception(): def test_api_rate_key_null_window_start_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(ApiRateLimit, IP="127.0.0.1", Requests=1) create(ApiRateLimit, IP="127.0.0.1", Requests=1)
session.rollback() db.rollback()

View file

@ -4,6 +4,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb import db
from aurweb.auth import BasicAuthBackend, account_type_required, has_credential from aurweb.auth import BasicAuthBackend, account_type_required, has_credential
from aurweb.db import create, query from aurweb.db import create, query
from aurweb.models.account_type import USER, USER_ID, AccountType from aurweb.models.account_type import USER, USER_ID, AccountType
@ -23,6 +24,7 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
with db.begin():
user = create(User, Username="test", Email="test@example.com", user = create(User, Username="test", Email="test@example.com",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
@ -51,14 +53,13 @@ async def test_auth_backend_invalid_sid():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_auth_backend_invalid_user_id(): async def test_auth_backend_invalid_user_id():
from aurweb.db import session
# Create a new session with a fake user id. # Create a new session with a fake user id.
now_ts = datetime.utcnow().timestamp() now_ts = datetime.utcnow().timestamp()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(Session, UsersID=666, SessionID="realSession", create(Session, UsersID=666, SessionID="realSession",
LastUpdateTS=now_ts + 5) LastUpdateTS=now_ts + 5)
session.rollback() db.rollback()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -66,6 +67,7 @@ async def test_basic_auth_backend():
# This time, everything matches up. We expect the user to # This time, everything matches up. We expect the user to
# equal the real_user. # equal the real_user.
now_ts = datetime.utcnow().timestamp() now_ts = datetime.utcnow().timestamp()
with db.begin():
create(Session, UsersID=user.ID, SessionID="realSession", create(Session, UsersID=user.ID, SessionID="realSession",
LastUpdateTS=now_ts + 5) LastUpdateTS=now_ts + 5)
request.cookies["AURSID"] = "realSession" request.cookies["AURSID"] = "realSession"

View file

@ -9,7 +9,7 @@ from fastapi.testclient import TestClient
import aurweb.config import aurweb.config
from aurweb.asgi import app from aurweb.asgi import app
from aurweb.db import create, query from aurweb.db import begin, create, query
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.session import Session from aurweb.models.session import Session
from aurweb.models.user import User from aurweb.models.user import User
@ -32,6 +32,7 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
with begin():
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL, user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)

View file

@ -6,6 +6,7 @@ import pytest
from sqlalchemy import exc as sa_exc from sqlalchemy import exc as sa_exc
from aurweb import db
from aurweb.db import create from aurweb.db import create
from aurweb.models.ban import Ban, is_banned from aurweb.models.ban import Ban, is_banned
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -21,6 +22,7 @@ def setup():
setup_test_db("Bans") setup_test_db("Bans")
ts = datetime.utcnow() + timedelta(seconds=30) ts = datetime.utcnow() + timedelta(seconds=30)
with db.begin():
ban = create(Ban, IPAddress="127.0.0.1", BanTS=ts) ban = create(Ban, IPAddress="127.0.0.1", BanTS=ts)
request = Request() request = Request()
@ -35,17 +37,17 @@ def test_invalid_ban():
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError):
bad_ban = Ban(BanTS=datetime.utcnow()) bad_ban = Ban(BanTS=datetime.utcnow())
session.add(bad_ban)
# We're adding a ban with no primary key; this causes an # We're adding a ban with no primary key; this causes an
# SQLAlchemy warnings when committing to the DB. # SQLAlchemy warnings when committing to the DB.
# Ignore them. # Ignore them.
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", sa_exc.SAWarning) warnings.simplefilter("ignore", sa_exc.SAWarning)
session.commit() with db.begin():
session.add(bad_ban)
# Since we got a transaction failure, we need to rollback. # Since we got a transaction failure, we need to rollback.
session.rollback() db.rollback()
def test_banned(): def test_banned():

View file

@ -278,18 +278,15 @@ def test_connection_execute_paramstyle_unsupported():
def test_create_delete(): def test_create_delete():
with db.begin():
db.create(AccountType, AccountType="test") db.create(AccountType, AccountType="test")
record = db.query(AccountType, AccountType.AccountType == "test").first() record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is not None assert record is not None
db.delete(AccountType, AccountType.AccountType == "test")
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is None
# Create and delete a record with autocommit=False. with db.begin():
db.create(AccountType, AccountType="test", autocommit=False) db.delete(AccountType, AccountType.AccountType == "test")
db.commit()
db.delete(AccountType, AccountType.AccountType == "test", autocommit=False)
db.commit()
record = db.query(AccountType, AccountType.AccountType == "test").first() record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is None assert record is None
@ -297,8 +294,8 @@ def test_create_delete():
def test_add_commit(): def test_add_commit():
# Use db.add and db.commit to add a temporary record. # Use db.add and db.commit to add a temporary record.
account_type = AccountType(AccountType="test") account_type = AccountType(AccountType="test")
with db.begin():
db.add(account_type) db.add(account_type)
db.commit()
# Assert it got created in the DB. # Assert it got created in the DB.
assert bool(account_type.ID) assert bool(account_type.ID)
@ -308,6 +305,7 @@ def test_add_commit():
assert record == account_type assert record == account_type
# Remove the record. # Remove the record.
with db.begin():
db.delete(AccountType, AccountType.ID == account_type.ID) db.delete(AccountType, AccountType.ID == account_type.ID)

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aurweb.db import create, delete, query from aurweb.db import begin, create, delete, query
from aurweb.models.dependency_type import DependencyType from aurweb.models.dependency_type import DependencyType
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -19,13 +19,17 @@ def test_dependency_types():
def test_dependency_type_creation(): def test_dependency_type_creation():
with begin():
dependency_type = create(DependencyType, Name="Test Type") dependency_type = create(DependencyType, Name="Test Type")
assert bool(dependency_type.ID) assert bool(dependency_type.ID)
assert dependency_type.Name == "Test Type" assert dependency_type.Name == "Test Type"
with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID) delete(DependencyType, DependencyType.ID == dependency_type.ID)
def test_dependency_type_null_name_uses_default(): def test_dependency_type_null_name_uses_default():
with begin():
dependency_type = create(DependencyType) dependency_type = create(DependencyType)
assert dependency_type.Name == str() assert dependency_type.Name == str()
with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID) delete(DependencyType, DependencyType.ID == dependency_type.ID)

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create from aurweb import db
from aurweb.models.group import Group from aurweb.models.group import Group
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -13,13 +13,14 @@ def setup():
def test_group_creation(): def test_group_creation():
group = create(Group, Name="Test Group") with db.begin():
group = db.create(Group, Name="Test Group")
assert bool(group.ID) assert bool(group.ID)
assert group.Name == "Test Group" assert group.Name == "Test Group"
def test_group_null_name_raises_exception(): def test_group_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(Group) with db.begin():
session.rollback() db.create(Group)
db.rollback()

View file

@ -38,8 +38,10 @@ def setup():
@pytest.fixture @pytest.fixture
def user(): def user():
yield db.create(User, Username="test", Email="test@example.org", with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
Passwd="testPassword", AccountTypeID=USER_ID) Passwd="testPassword", AccountTypeID=USER_ID)
yield user
@pytest.fixture @pytest.fixture
@ -68,18 +70,15 @@ def packages(user):
# For i..num_packages, create a package named pkg_{i}. # For i..num_packages, create a package named pkg_{i}.
pkgs = [] pkgs = []
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
with db.begin():
for i in range(num_packages): for i in range(num_packages):
pkgbase = db.create(PackageBase, Name=f"pkg_{i}", pkgbase = db.create(PackageBase, Name=f"pkg_{i}",
Maintainer=user, Packager=user, Maintainer=user, Packager=user,
autocommit=False, SubmittedTS=now, SubmittedTS=now, ModifiedTS=now)
ModifiedTS=now) pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
pkg = db.create(Package, PackageBase=pkgbase,
Name=pkgbase.Name, autocommit=False)
pkgs.append(pkg) pkgs.append(pkg)
now += 1 now += 1
db.commit()
yield pkgs yield pkgs
@ -159,10 +158,11 @@ def test_homepage_updates(redis, packages):
def test_homepage_dashboard(redis, packages, user): def test_homepage_dashboard(redis, packages, user):
# Create Comaintainer records for all of the packages. # Create Comaintainer records for all of the packages.
with db.begin():
for pkg in packages: for pkg in packages:
db.create(PackageComaintainer, PackageBase=pkg.PackageBase, db.create(PackageComaintainer,
User=user, Priority=1, autocommit=False) PackageBase=pkg.PackageBase,
db.commit() User=user, Priority=1)
cookies = {"AURSID": user.login(Request(), "testPassword")} cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request: with client as request:
@ -193,6 +193,7 @@ def test_homepage_dashboard_requests(redis, packages, user):
pkg = packages[0] pkg = packages[0]
reqtype = db.query(RequestType, RequestType.ID == DELETION_ID).first() reqtype = db.query(RequestType, RequestType.ID == DELETION_ID).first()
with db.begin():
pkgreq = db.create(PackageRequest, PackageBase=pkg.PackageBase, pkgreq = db.create(PackageRequest, PackageBase=pkg.PackageBase,
PackageBaseName=pkg.PackageBase.Name, PackageBaseName=pkg.PackageBase.Name,
User=user, Comments=str(), User=user, Comments=str(),
@ -213,8 +214,8 @@ def test_homepage_dashboard_requests(redis, packages, user):
def test_homepage_dashboard_flagged_packages(redis, packages, user): def test_homepage_dashboard_flagged_packages(redis, packages, user):
# Set the first Package flagged by setting its OutOfDateTS column. # Set the first Package flagged by setting its OutOfDateTS column.
pkg = packages[0] pkg = packages[0]
with db.begin():
pkg.PackageBase.OutOfDateTS = int(datetime.utcnow().timestamp()) pkg.PackageBase.OutOfDateTS = int(datetime.utcnow().timestamp())
db.commit()
cookies = {"AURSID": user.login(Request(), "testPassword")} cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request: with client as request:

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create from aurweb import db
from aurweb.models.license import License from aurweb.models.license import License
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -13,13 +13,14 @@ def setup():
def test_license_creation(): def test_license_creation():
license = create(License, Name="Test License") with db.begin():
license = db.create(License, Name="Test License")
assert bool(license.ID) assert bool(license.ID)
assert license.Name == "Test License" assert license.Name == "Test License"
def test_license_null_name_raises_exception(): def test_license_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(License) with db.begin():
session.rollback() db.create(License)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create from aurweb import db
from aurweb.models.official_provider import OfficialProvider from aurweb.models.official_provider import OfficialProvider
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -13,7 +13,8 @@ def setup():
def test_official_provider_creation(): def test_official_provider_creation():
oprovider = create(OfficialProvider, with db.begin():
oprovider = db.create(OfficialProvider,
Name="some-name", Name="some-name",
Repo="some-repo", Repo="some-repo",
Provides="some-provides") Provides="some-provides")
@ -25,13 +26,15 @@ def test_official_provider_creation():
def test_official_provider_cs(): def test_official_provider_cs():
""" Test case sensitivity of the database table. """ """ Test case sensitivity of the database table. """
oprovider = create(OfficialProvider, with db.begin():
oprovider = db.create(OfficialProvider,
Name="some-name", Name="some-name",
Repo="some-repo", Repo="some-repo",
Provides="some-provides") Provides="some-provides")
assert bool(oprovider.ID) assert bool(oprovider.ID)
oprovider_cs = create(OfficialProvider, with db.begin():
oprovider_cs = db.create(OfficialProvider,
Name="SOME-NAME", Name="SOME-NAME",
Repo="SOME-REPO", Repo="SOME-REPO",
Provides="SOME-PROVIDES") Provides="SOME-PROVIDES")
@ -49,27 +52,27 @@ def test_official_provider_cs():
def test_official_provider_null_name_raises_exception(): def test_official_provider_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(OfficialProvider, with db.begin():
db.create(OfficialProvider,
Repo="some-repo", Repo="some-repo",
Provides="some-provides") Provides="some-provides")
session.rollback() db.rollback()
def test_official_provider_null_repo_raises_exception(): def test_official_provider_null_repo_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(OfficialProvider, with db.begin():
db.create(OfficialProvider,
Name="some-name", Name="some-name",
Provides="some-provides") Provides="some-provides")
session.rollback() db.rollback()
def test_official_provider_null_provides_raises_exception(): def test_official_provider_null_provides_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(OfficialProvider, with db.begin():
db.create(OfficialProvider,
Name="some-name", Name="some-name",
Repo="some-repo") Repo="some-repo")
session.rollback() db.rollback()

View file

@ -3,7 +3,7 @@ import pytest
from sqlalchemy import and_ from sqlalchemy import and_
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query from aurweb import db
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.package import Package from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
@ -19,16 +19,18 @@ def setup():
setup_test_db("Packages", "PackageBases", "Users") setup_test_db("Packages", "PackageBases", "Users")
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
pkgbase = create(PackageBase, pkgbase = db.create(PackageBase,
Name="beautiful-package", Name="beautiful-package",
Maintainer=user) Maintainer=user)
package = create(Package, package = db.create(Package,
PackageBase=pkgbase, PackageBase=pkgbase,
Name=pkgbase.Name, Name=pkgbase.Name,
Description="Test description.", Description="Test description.",
@ -36,8 +38,6 @@ def setup():
def test_package(): def test_package():
from aurweb.db import session
assert pkgbase == package.PackageBase assert pkgbase == package.PackageBase
assert package.Name == "beautiful-package" assert package.Name == "beautiful-package"
assert package.Description == "Test description." assert package.Description == "Test description."
@ -45,33 +45,31 @@ def test_package():
assert package.URL == "https://test.package" assert package.URL == "https://test.package"
# Update package Version. # Update package Version.
with db.begin():
package.Version = "1.2.3" package.Version = "1.2.3"
session.commit()
# Make sure it got updated in the database. # Make sure it got updated in the database.
record = query(Package, record = db.query(Package,
and_(Package.ID == package.ID, and_(Package.ID == package.ID,
Package.Version == "1.2.3")).first() Package.Version == "1.2.3")).first()
assert record is not None assert record is not None
def test_package_null_pkgbase_raises_exception(): def test_package_null_pkgbase_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(Package, with db.begin():
db.create(Package,
Name="some-package", Name="some-package",
Description="Some description.", Description="Some description.",
URL="https://some.package") URL="https://some.package")
session.rollback() db.rollback()
def test_package_null_name_raises_exception(): def test_package_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(Package, with db.begin():
db.create(Package,
PackageBase=pkgbase, PackageBase=pkgbase,
Description="Some description.", Description="Some description.",
URL="https://some.package") URL="https://some.package")
session.rollback() db.rollback()

View file

@ -4,7 +4,7 @@ from sqlalchemy.exc import IntegrityError
import aurweb.config import aurweb.config
from aurweb.db import create, query from aurweb import db
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
from aurweb.models.user import User from aurweb.models.user import User
@ -19,15 +19,17 @@ def setup():
setup_test_db("Users", "PackageBases") setup_test_db("Users", "PackageBases")
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org", with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
def test_package_base(): def test_package_base():
pkgbase = create(PackageBase, with db.begin():
pkgbase = db.create(PackageBase,
Name="beautiful-package", Name="beautiful-package",
Maintainer=user) Maintainer=user)
assert pkgbase in user.maintained_bases assert pkgbase in user.maintained_bases
@ -38,6 +40,7 @@ def test_package_base():
# Set Popularity to a string, then get it by attribute to # Set Popularity to a string, then get it by attribute to
# exercise the string -> float conversion path. # exercise the string -> float conversion path.
with db.begin():
pkgbase.Popularity = "0.0" pkgbase.Popularity = "0.0"
assert pkgbase.Popularity == 0.0 assert pkgbase.Popularity == 0.0
@ -47,22 +50,23 @@ def test_package_base_ci():
if aurweb.config.get("database", "backend") == "sqlite": if aurweb.config.get("database", "backend") == "sqlite":
return None # SQLite doesn't seem handle this. return None # SQLite doesn't seem handle this.
from aurweb.db import session with db.begin():
pkgbase = db.create(PackageBase,
pkgbase = create(PackageBase,
Name="beautiful-package", Name="beautiful-package",
Maintainer=user) Maintainer=user)
assert bool(pkgbase.ID) assert bool(pkgbase.ID)
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageBase, with db.begin():
db.create(PackageBase,
Name="Beautiful-Package", Name="Beautiful-Package",
Maintainer=user) Maintainer=user)
session.rollback() db.rollback()
def test_package_base_relationships(): def test_package_base_relationships():
pkgbase = create(PackageBase, with db.begin():
pkgbase = db.create(PackageBase,
Name="beautiful-package", Name="beautiful-package",
Flagger=user, Flagger=user,
Maintainer=user, Maintainer=user,
@ -75,8 +79,7 @@ def test_package_base_relationships():
def test_package_base_null_name_raises_exception(): def test_package_base_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageBase) with db.begin():
session.rollback() db.create(PackageBase)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create, rollback from aurweb import db
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
from aurweb.models.package_blacklist import PackageBlacklist from aurweb.models.package_blacklist import PackageBlacklist
from aurweb.models.user import User from aurweb.models.user import User
@ -17,18 +17,20 @@ def setup():
setup_test_db("PackageBlacklist", "PackageBases", "Users") setup_test_db("PackageBlacklist", "PackageBases", "Users")
user = create(User, Username="test", Email="test@example.org", user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword") RealName="Test User", Passwd="testPassword")
pkgbase = create(PackageBase, Name="test-package", Maintainer=user) pkgbase = db.create(PackageBase, Name="test-package", Maintainer=user)
def test_package_blacklist_creation(): def test_package_blacklist_creation():
package_blacklist = create(PackageBlacklist, Name="evil-package") with db.begin():
package_blacklist = db.create(PackageBlacklist, Name="evil-package")
assert bool(package_blacklist.ID) assert bool(package_blacklist.ID)
assert package_blacklist.Name == "evil-package" assert package_blacklist.Name == "evil-package"
def test_package_blacklist_null_name_raises_exception(): def test_package_blacklist_null_name_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageBlacklist) with db.begin():
rollback() db.create(PackageBlacklist)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query, rollback from aurweb.db import begin, create, query, rollback
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
from aurweb.models.package_comment import PackageComment from aurweb.models.package_comment import PackageComment
@ -20,6 +20,7 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
with begin():
user = create(User, Username="test", Email="test@example.org", user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
@ -27,6 +28,7 @@ def setup():
def test_package_comment_creation(): def test_package_comment_creation():
with begin():
package_comment = create(PackageComment, package_comment = create(PackageComment,
PackageBase=pkgbase, PackageBase=pkgbase,
User=user, User=user,
@ -37,6 +39,7 @@ def test_package_comment_creation():
def test_package_comment_null_package_base_raises_exception(): def test_package_comment_null_package_base_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with begin():
create(PackageComment, User=user, Comments="Test comment.", create(PackageComment, User=user, Comments="Test comment.",
RenderedComment="Test rendered comment.") RenderedComment="Test rendered comment.")
rollback() rollback()
@ -44,19 +47,23 @@ def test_package_comment_null_package_base_raises_exception():
def test_package_comment_null_user_raises_exception(): def test_package_comment_null_user_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageComment, PackageBase=pkgbase, Comments="Test comment.", with begin():
create(PackageComment, PackageBase=pkgbase,
Comments="Test comment.",
RenderedComment="Test rendered comment.") RenderedComment="Test rendered comment.")
rollback() rollback()
def test_package_comment_null_comments_raises_exception(): def test_package_comment_null_comments_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with begin():
create(PackageComment, PackageBase=pkgbase, User=user, create(PackageComment, PackageBase=pkgbase, User=user,
RenderedComment="Test rendered comment.") RenderedComment="Test rendered comment.")
rollback() rollback()
def test_package_comment_null_renderedcomment_defaults(): def test_package_comment_null_renderedcomment_defaults():
with begin():
record = create(PackageComment, record = create(PackageComment,
PackageBase=pkgbase, PackageBase=pkgbase,
User=user, User=user,

View file

@ -2,7 +2,8 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query from aurweb import db
from aurweb.db import create, query
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.dependency_type import DependencyType from aurweb.models.dependency_type import DependencyType
from aurweb.models.package import Package from aurweb.models.package import Package
@ -22,10 +23,11 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
with db.begin():
user = create(User, Username="test", Email="test@example.org", user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
pkgbase = create(PackageBase, pkgbase = create(PackageBase,
Name="test-package", Name="test-package",
Maintainer=user) Maintainer=user)
@ -38,6 +40,8 @@ def setup():
def test_package_dependencies(): def test_package_dependencies():
depends = query(DependencyType, DependencyType.Name == "depends").first() depends = query(DependencyType, DependencyType.Name == "depends").first()
with db.begin():
pkgdep = create(PackageDependency, Package=package, pkgdep = create(PackageDependency, Package=package,
DependencyType=depends, DependencyType=depends,
DepName="test-dep") DepName="test-dep")
@ -49,8 +53,8 @@ def test_package_dependencies():
makedepends = query(DependencyType, makedepends = query(DependencyType,
DependencyType.Name == "makedepends").first() DependencyType.Name == "makedepends").first()
with db.begin():
pkgdep.DependencyType = makedepends pkgdep.DependencyType = makedepends
commit()
assert pkgdep.DepName == "test-dep" assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package assert pkgdep.Package == package
assert pkgdep.DependencyType == makedepends assert pkgdep.DependencyType == makedepends
@ -59,8 +63,8 @@ def test_package_dependencies():
checkdepends = query(DependencyType, checkdepends = query(DependencyType,
DependencyType.Name == "checkdepends").first() DependencyType.Name == "checkdepends").first()
with db.begin():
pkgdep.DependencyType = checkdepends pkgdep.DependencyType = checkdepends
commit()
assert pkgdep.DepName == "test-dep" assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package assert pkgdep.Package == package
assert pkgdep.DependencyType == checkdepends assert pkgdep.DependencyType == checkdepends
@ -69,8 +73,8 @@ def test_package_dependencies():
optdepends = query(DependencyType, optdepends = query(DependencyType,
DependencyType.Name == "optdepends").first() DependencyType.Name == "optdepends").first()
with db.begin():
pkgdep.DependencyType = optdepends pkgdep.DependencyType = optdepends
commit()
assert pkgdep.DepName == "test-dep" assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package assert pkgdep.Package == package
assert pkgdep.DependencyType == optdepends assert pkgdep.DependencyType == optdepends
@ -79,6 +83,7 @@ def test_package_dependencies():
assert not pkgdep.is_package() assert not pkgdep.is_package()
with db.begin():
base = create(PackageBase, Name=pkgdep.DepName, Maintainer=user) base = create(PackageBase, Name=pkgdep.DepName, Maintainer=user)
create(Package, PackageBase=base, Name=pkgdep.DepName) create(Package, PackageBase=base, Name=pkgdep.DepName)
@ -86,32 +91,29 @@ def test_package_dependencies():
def test_package_dependencies_null_package_raises_exception(): def test_package_dependencies_null_package_raises_exception():
from aurweb.db import session
depends = query(DependencyType, DependencyType.Name == "depends").first() depends = query(DependencyType, DependencyType.Name == "depends").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(PackageDependency, create(PackageDependency,
DependencyType=depends, DependencyType=depends,
DepName="test-dep") DepName="test-dep")
session.rollback() db.rollback()
def test_package_dependencies_null_dependency_type_raises_exception(): def test_package_dependencies_null_dependency_type_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(PackageDependency, create(PackageDependency,
Package=package, Package=package,
DepName="test-dep") DepName="test-dep")
session.rollback() db.rollback()
def test_package_dependencies_null_depname_raises_exception(): def test_package_dependencies_null_depname_raises_exception():
from aurweb.db import session
depends = query(DependencyType, DependencyType.Name == "depends").first() depends = query(DependencyType, DependencyType.Name == "depends").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(PackageDependency, create(PackageDependency,
Package=package, Package=package,
DependencyType=depends) DependencyType=depends)
session.rollback() db.rollback()

View file

@ -2,7 +2,8 @@ import pytest
from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.exc import IntegrityError, OperationalError
from aurweb.db import commit, create, query from aurweb import db
from aurweb.db import create, query
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.package import Package from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
@ -22,10 +23,11 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
with db.begin():
user = create(User, Username="test", Email="test@example.org", user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
pkgbase = create(PackageBase, pkgbase = create(PackageBase,
Name="test-package", Name="test-package",
Maintainer=user) Maintainer=user)
@ -38,6 +40,8 @@ def setup():
def test_package_relation(): def test_package_relation():
conflicts = query(RelationType, RelationType.Name == "conflicts").first() conflicts = query(RelationType, RelationType.Name == "conflicts").first()
with db.begin():
pkgrel = create(PackageRelation, Package=package, pkgrel = create(PackageRelation, Package=package,
RelationType=conflicts, RelationType=conflicts,
RelName="test-relation") RelName="test-relation")
@ -48,8 +52,8 @@ def test_package_relation():
assert pkgrel in package.package_relations assert pkgrel in package.package_relations
provides = query(RelationType, RelationType.Name == "provides").first() provides = query(RelationType, RelationType.Name == "provides").first()
with db.begin():
pkgrel.RelationType = provides pkgrel.RelationType = provides
commit()
assert pkgrel.RelName == "test-relation" assert pkgrel.RelName == "test-relation"
assert pkgrel.Package == package assert pkgrel.Package == package
assert pkgrel.RelationType == provides assert pkgrel.RelationType == provides
@ -57,8 +61,8 @@ def test_package_relation():
assert pkgrel in package.package_relations assert pkgrel in package.package_relations
replaces = query(RelationType, RelationType.Name == "replaces").first() replaces = query(RelationType, RelationType.Name == "replaces").first()
with db.begin():
pkgrel.RelationType = replaces pkgrel.RelationType = replaces
commit()
assert pkgrel.RelName == "test-relation" assert pkgrel.RelName == "test-relation"
assert pkgrel.Package == package assert pkgrel.Package == package
assert pkgrel.RelationType == replaces assert pkgrel.RelationType == replaces
@ -67,36 +71,33 @@ def test_package_relation():
def test_package_relation_null_package_raises_exception(): def test_package_relation_null_package_raises_exception():
from aurweb.db import session
conflicts = query(RelationType, RelationType.Name == "conflicts").first() conflicts = query(RelationType, RelationType.Name == "conflicts").first()
assert conflicts is not None assert conflicts is not None
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(PackageRelation, create(PackageRelation,
RelationType=conflicts, RelationType=conflicts,
RelName="test-relation") RelName="test-relation")
session.rollback() db.rollback()
def test_package_relation_null_relation_type_raises_exception(): def test_package_relation_null_relation_type_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(PackageRelation, create(PackageRelation,
Package=package, Package=package,
RelName="test-relation") RelName="test-relation")
session.rollback() db.rollback()
def test_package_relation_null_relname_raises_exception(): def test_package_relation_null_relname_raises_exception():
from aurweb.db import session
depends = query(RelationType, RelationType.Name == "conflicts").first() depends = query(RelationType, RelationType.Name == "conflicts").first()
assert depends is not None assert depends is not None
with pytest.raises((OperationalError, IntegrityError)): with pytest.raises((OperationalError, IntegrityError)):
with db.begin():
create(PackageRelation, create(PackageRelation,
Package=package, Package=package,
RelationType=depends) RelationType=depends)
session.rollback() db.rollback()

View file

@ -4,7 +4,8 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query, rollback from aurweb import db
from aurweb.db import create, query, rollback
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
from aurweb.models.package_request import (ACCEPTED, ACCEPTED_ID, CLOSED, CLOSED_ID, PENDING, PENDING_ID, REJECTED, from aurweb.models.package_request import (ACCEPTED, ACCEPTED_ID, CLOSED, CLOSED_ID, PENDING, PENDING_ID, REJECTED,
REJECTED_ID, PackageRequest) REJECTED_ID, PackageRequest)
@ -21,6 +22,7 @@ def setup():
setup_test_db("PackageRequests", "PackageBases", "Users") setup_test_db("PackageRequests", "PackageBases", "Users")
with db.begin():
user = create(User, Username="test", Email="test@example.org", user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword") RealName="Test User", Passwd="testPassword")
pkgbase = create(PackageBase, Name="test-package", Maintainer=user) pkgbase = create(PackageBase, Name="test-package", Maintainer=user)
@ -30,6 +32,7 @@ def test_package_request_creation():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
assert request_type.Name == "merge" assert request_type.Name == "merge"
with db.begin():
package_request = create(PackageRequest, RequestType=request_type, package_request = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase, User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name, PackageBaseName=pkgbase.Name,
@ -54,6 +57,7 @@ def test_package_request_closed():
assert request_type.Name == "merge" assert request_type.Name == "merge"
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
with db.begin():
package_request = create(PackageRequest, RequestType=request_type, package_request = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase, User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name, PackageBaseName=pkgbase.Name,
@ -69,6 +73,7 @@ def test_package_request_closed():
def test_package_request_null_request_type_raises_exception(): def test_package_request_null_request_type_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(PackageRequest, User=user, PackageBase=pkgbase, create(PackageRequest, User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name, PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str()) Comments=str(), ClosureComment=str())
@ -78,8 +83,9 @@ def test_package_request_null_request_type_raises_exception():
def test_package_request_null_user_raises_exception(): def test_package_request_null_user_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type, PackageBase=pkgbase, with db.begin():
PackageBaseName=pkgbase.Name, create(PackageRequest, RequestType=request_type,
PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str()) Comments=str(), ClosureComment=str())
rollback() rollback()
@ -87,6 +93,7 @@ def test_package_request_null_user_raises_exception():
def test_package_request_null_package_base_raises_exception(): def test_package_request_null_package_base_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(PackageRequest, RequestType=request_type, create(PackageRequest, RequestType=request_type,
User=user, PackageBaseName=pkgbase.Name, User=user, PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str()) Comments=str(), ClosureComment=str())
@ -96,6 +103,7 @@ def test_package_request_null_package_base_raises_exception():
def test_package_request_null_package_base_name_raises_exception(): def test_package_request_null_package_base_name_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(PackageRequest, RequestType=request_type, create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase, User=user, PackageBase=pkgbase,
Comments=str(), ClosureComment=str()) Comments=str(), ClosureComment=str())
@ -105,8 +113,9 @@ def test_package_request_null_package_base_name_raises_exception():
def test_package_request_null_comments_raises_exception(): def test_package_request_null_comments_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type, with db.begin():
User=user, PackageBase=pkgbase, PackageBaseName=pkgbase.Name, create(PackageRequest, RequestType=request_type, User=user,
PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
ClosureComment=str()) ClosureComment=str())
rollback() rollback()
@ -114,8 +123,9 @@ def test_package_request_null_comments_raises_exception():
def test_package_request_null_closure_comment_raises_exception(): def test_package_request_null_closure_comment_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type, with db.begin():
User=user, PackageBase=pkgbase, PackageBaseName=pkgbase.Name, create(PackageRequest, RequestType=request_type, User=user,
PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
Comments=str()) Comments=str())
rollback() rollback()
@ -124,6 +134,7 @@ def test_package_request_status_display():
""" Test status_display() based on the Status column value. """ """ Test status_display() based on the Status column value. """
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
with db.begin():
pkgreq = create(PackageRequest, RequestType=request_type, pkgreq = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase, User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name, PackageBaseName=pkgbase.Name,
@ -131,19 +142,19 @@ def test_package_request_status_display():
Status=PENDING_ID) Status=PENDING_ID)
assert pkgreq.status_display() == PENDING assert pkgreq.status_display() == PENDING
with db.begin():
pkgreq.Status = CLOSED_ID pkgreq.Status = CLOSED_ID
commit()
assert pkgreq.status_display() == CLOSED assert pkgreq.status_display() == CLOSED
with db.begin():
pkgreq.Status = ACCEPTED_ID pkgreq.Status = ACCEPTED_ID
commit()
assert pkgreq.status_display() == ACCEPTED assert pkgreq.status_display() == ACCEPTED
with db.begin():
pkgreq.Status = REJECTED_ID pkgreq.Status = REJECTED_ID
commit()
assert pkgreq.status_display() == REJECTED assert pkgreq.status_display() == REJECTED
with db.begin():
pkgreq.Status = 124 pkgreq.Status = 124
commit()
with pytest.raises(KeyError): with pytest.raises(KeyError):
pkgreq.status_display() pkgreq.status_display()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query, rollback from aurweb.db import begin, create, query, rollback
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.package import Package from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
@ -21,6 +21,7 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
with begin():
user = create(User, Username="test", Email="test@example.org", user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
@ -31,6 +32,7 @@ def setup():
def test_package_source(): def test_package_source():
with begin():
pkgsource = create(PackageSource, Package=package) pkgsource = create(PackageSource, Package=package)
assert pkgsource.Package == package assert pkgsource.Package == package
# By default, PackageSources.Source assigns the string '/dev/null'. # By default, PackageSources.Source assigns the string '/dev/null'.
@ -40,5 +42,6 @@ def test_package_source():
def test_package_source_null_package_raises_exception(): def test_package_source_null_package_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with begin():
create(PackageSource) create(PackageSource)
rollback() rollback()

View file

@ -28,31 +28,25 @@ def package_endpoint(package: Package) -> str:
return f"/packages/{package.Name}" return f"/packages/{package.Name}"
def create_package(pkgname: str, maintainer: User, def create_package(pkgname: str, maintainer: User) -> Package:
autocommit: bool = True) -> Package:
pkgbase = db.create(PackageBase, pkgbase = db.create(PackageBase,
Name=pkgname, Name=pkgname,
Maintainer=maintainer, Maintainer=maintainer)
autocommit=False) return db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
return db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase,
autocommit=autocommit)
def create_package_dep(package: Package, depname: str, def create_package_dep(package: Package, depname: str,
dep_type_name: str = "depends", dep_type_name: str = "depends") -> PackageDependency:
autocommit: bool = True) -> PackageDependency:
dep_type = db.query(DependencyType, dep_type = db.query(DependencyType,
DependencyType.Name == dep_type_name).first() DependencyType.Name == dep_type_name).first()
return db.create(PackageDependency, return db.create(PackageDependency,
DependencyType=dep_type, DependencyType=dep_type,
Package=package, Package=package,
DepName=depname, DepName=depname)
autocommit=autocommit)
def create_package_rel(package: Package, def create_package_rel(package: Package,
relname: str, relname: str) -> PackageRelation:
autocommit: bool = True) -> PackageRelation:
rel_type = db.query(RelationType, rel_type = db.query(RelationType,
RelationType.ID == PROVIDES_ID).first() RelationType.ID == PROVIDES_ID).first()
return db.create(PackageRelation, return db.create(PackageRelation,
@ -84,31 +78,37 @@ def client() -> TestClient:
def user() -> User: def user() -> User:
""" Yield a user. """ """ Yield a user. """
account_type = db.query(AccountType, AccountType.ID == USER_ID).first() account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test", with db.begin():
user = db.create(User, Username="test",
Email="test@example.org", Email="test@example.org",
Passwd="testPassword", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
yield user
@pytest.fixture @pytest.fixture
def maintainer() -> User: def maintainer() -> User:
""" Yield a specific User used to maintain packages. """ """ Yield a specific User used to maintain packages. """
account_type = db.query(AccountType, AccountType.ID == USER_ID).first() account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test_maintainer", with db.begin():
maintainer = db.create(User, Username="test_maintainer",
Email="test_maintainer@example.org", Email="test_maintainer@example.org",
Passwd="testPassword", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
yield maintainer
@pytest.fixture @pytest.fixture
def package(maintainer: User) -> Package: def package(maintainer: User) -> Package:
""" Yield a Package created by user. """ """ Yield a Package created by user. """
with db.begin():
pkgbase = db.create(PackageBase, pkgbase = db.create(PackageBase,
Name="test-package", Name="test-package",
Maintainer=maintainer) Maintainer=maintainer)
yield db.create(Package, package = db.create(Package,
PackageBase=pkgbase, PackageBase=pkgbase,
Name=pkgbase.Name) Name=pkgbase.Name)
yield package
def test_package_not_found(client: TestClient): def test_package_not_found(client: TestClient):
@ -121,6 +121,7 @@ def test_package_official_not_found(client: TestClient, package: Package):
""" When a Package has a matching OfficialProvider record, it is not """ When a Package has a matching OfficialProvider record, it is not
hosted on AUR, but in the official repositories. Getting a package hosted on AUR, but in the official repositories. Getting a package
with this kind of record should return a status code 404. """ with this kind of record should return a status code 404. """
with db.begin():
db.create(OfficialProvider, db.create(OfficialProvider,
Name=package.Name, Name=package.Name,
Repo="core", Repo="core",
@ -157,6 +158,7 @@ def test_package(client: TestClient, package: Package):
def test_package_comments(client: TestClient, user: User, package: Package): def test_package_comments(client: TestClient, user: User, package: Package):
now = (datetime.utcnow().timestamp()) now = (datetime.utcnow().timestamp())
with db.begin():
comment = db.create(PackageComment, PackageBase=package.PackageBase, comment = db.create(PackageComment, PackageBase=package.PackageBase,
User=user, Comments="Test comment", CommentTS=now) User=user, Comments="Test comment", CommentTS=now)
@ -178,6 +180,7 @@ def test_package_comments(client: TestClient, user: User, package: Package):
def test_package_requests_display(client: TestClient, user: User, def test_package_requests_display(client: TestClient, user: User,
package: Package): package: Package):
type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first() type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first()
with db.begin():
db.create(PackageRequest, PackageBase=package.PackageBase, db.create(PackageRequest, PackageBase=package.PackageBase,
PackageBaseName=package.PackageBase.Name, PackageBaseName=package.PackageBase.Name,
User=user, RequestType=type_, User=user, RequestType=type_,
@ -195,6 +198,7 @@ def test_package_requests_display(client: TestClient, user: User,
assert target.text.strip() == "1 pending request" assert target.text.strip() == "1 pending request"
type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first() type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first()
with db.begin():
db.create(PackageRequest, PackageBase=package.PackageBase, db.create(PackageRequest, PackageBase=package.PackageBase,
PackageBaseName=package.PackageBase.Name, PackageBaseName=package.PackageBase.Name,
User=user, RequestType=type_, User=user, RequestType=type_,
@ -271,49 +275,42 @@ def test_package_authenticated_maintainer(client: TestClient,
def test_package_dependencies(client: TestClient, maintainer: User, def test_package_dependencies(client: TestClient, maintainer: User,
package: Package): package: Package):
# Create a normal dependency of type depends. # Create a normal dependency of type depends.
dep_pkg = create_package("test-dep-1", maintainer, autocommit=False) with db.begin():
dep = create_package_dep(package, dep_pkg.Name, autocommit=False) dep_pkg = create_package("test-dep-1", maintainer)
dep = create_package_dep(package, dep_pkg.Name)
dep.DepArch = "x86_64" dep.DepArch = "x86_64"
# Also, create a makedepends. # Also, create a makedepends.
make_dep_pkg = create_package("test-dep-2", maintainer, autocommit=False) make_dep_pkg = create_package("test-dep-2", maintainer)
make_dep = create_package_dep(package, make_dep_pkg.Name, make_dep = create_package_dep(package, make_dep_pkg.Name,
dep_type_name="makedepends", dep_type_name="makedepends")
autocommit=False)
# And... a checkdepends! # And... a checkdepends!
check_dep_pkg = create_package("test-dep-3", maintainer, autocommit=False) check_dep_pkg = create_package("test-dep-3", maintainer)
check_dep = create_package_dep(package, check_dep_pkg.Name, check_dep = create_package_dep(package, check_dep_pkg.Name,
dep_type_name="checkdepends", dep_type_name="checkdepends")
autocommit=False)
# Geez. Just stop. This is optdepends. # Geez. Just stop. This is optdepends.
opt_dep_pkg = create_package("test-dep-4", maintainer, autocommit=False) opt_dep_pkg = create_package("test-dep-4", maintainer)
opt_dep = create_package_dep(package, opt_dep_pkg.Name, opt_dep = create_package_dep(package, opt_dep_pkg.Name,
dep_type_name="optdepends", dep_type_name="optdepends")
autocommit=False)
# Heh. Another optdepends to test one with a description. # Heh. Another optdepends to test one with a description.
opt_desc_dep_pkg = create_package("test-dep-5", maintainer, opt_desc_dep_pkg = create_package("test-dep-5", maintainer)
autocommit=False)
opt_desc_dep = create_package_dep(package, opt_desc_dep_pkg.Name, opt_desc_dep = create_package_dep(package, opt_desc_dep_pkg.Name,
dep_type_name="optdepends", dep_type_name="optdepends")
autocommit=False)
opt_desc_dep.DepDesc = "Test description." opt_desc_dep.DepDesc = "Test description."
broken_dep = create_package_dep(package, "test-dep-6", broken_dep = create_package_dep(package, "test-dep-6",
dep_type_name="depends", dep_type_name="depends")
autocommit=False)
# Create an official provider record. # Create an official provider record.
db.create(OfficialProvider, Name="test-dep-99", db.create(OfficialProvider, Name="test-dep-99",
Repo="core", Provides="test-dep-99", Repo="core", Provides="test-dep-99")
autocommit=False) official_dep = create_package_dep(package, "test-dep-99")
official_dep = create_package_dep(package, "test-dep-99",
autocommit=False)
# Also, create a provider who provides our test-dep-99. # Also, create a provider who provides our test-dep-99.
provider = create_package("test-provider", maintainer, autocommit=False) provider = create_package("test-provider", maintainer)
create_package_rel(provider, dep.DepName) create_package_rel(provider, dep.DepName)
with client as request: with client as request:
@ -358,6 +355,7 @@ def test_pkgbase_redirect(client: TestClient, package: Package):
def test_pkgbase(client: TestClient, package: Package): def test_pkgbase(client: TestClient, package: Package):
with db.begin():
second = db.create(Package, Name="second-pkg", second = db.create(Package, Name="second-pkg",
PackageBase=package.PackageBase) PackageBase=package.PackageBase)

View file

@ -26,17 +26,21 @@ def setup():
@pytest.fixture @pytest.fixture
def maintainer() -> User: def maintainer() -> User:
account_type = db.query(AccountType, AccountType.ID == USER_ID).first() account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test_maintainer", with db.begin():
maintainer = db.create(User, Username="test_maintainer",
Email="test_maintainer@examepl.org", Email="test_maintainer@examepl.org",
Passwd="testPassword", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
yield maintainer
@pytest.fixture @pytest.fixture
def package(maintainer: User) -> Package: def package(maintainer: User) -> Package:
with db.begin():
pkgbase = db.create(PackageBase, Name="test-pkg", pkgbase = db.create(PackageBase, Name="test-pkg",
Packager=maintainer, Maintainer=maintainer) Packager=maintainer, Maintainer=maintainer)
yield db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase) package = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
yield package
@pytest.fixture @pytest.fixture
@ -45,6 +49,7 @@ def client() -> TestClient:
def test_package_link(client: TestClient, maintainer: User, package: Package): def test_package_link(client: TestClient, maintainer: User, package: Package):
with db.begin():
db.create(OfficialProvider, db.create(OfficialProvider,
Name=package.Name, Name=package.Name,
Repo="core", Repo="core",

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aurweb.db import create, delete, query from aurweb import db
from aurweb.models.relation_type import RelationType from aurweb.models.relation_type import RelationType
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -11,22 +11,25 @@ def setup():
def test_relation_type_creation(): def test_relation_type_creation():
relation_type = create(RelationType, Name="test-relation") with db.begin():
relation_type = db.create(RelationType, Name="test-relation")
assert bool(relation_type.ID) assert bool(relation_type.ID)
assert relation_type.Name == "test-relation" assert relation_type.Name == "test-relation"
delete(RelationType, RelationType.ID == relation_type.ID) with db.begin():
db.delete(RelationType, RelationType.ID == relation_type.ID)
def test_relation_types(): def test_relation_types():
conflicts = query(RelationType, RelationType.Name == "conflicts").first() conflicts = db.query(RelationType, RelationType.Name == "conflicts").first()
assert conflicts is not None assert conflicts is not None
assert conflicts.Name == "conflicts" assert conflicts.Name == "conflicts"
provides = query(RelationType, RelationType.Name == "provides").first() provides = db.query(RelationType, RelationType.Name == "provides").first()
assert provides is not None assert provides is not None
assert provides.Name == "provides" assert provides.Name == "provides"
replaces = query(RelationType, RelationType.Name == "replaces").first() replaces = db.query(RelationType, RelationType.Name == "replaces").first()
assert replaces is not None assert replaces is not None
assert replaces.Name == "replaces" assert replaces.Name == "replaces"

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aurweb.db import create, delete, query from aurweb import db
from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID, RequestType from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID, RequestType
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -11,25 +11,33 @@ def setup():
def test_request_type_creation(): def test_request_type_creation():
request_type = create(RequestType, Name="Test Request") with db.begin():
request_type = db.create(RequestType, Name="Test Request")
assert bool(request_type.ID) assert bool(request_type.ID)
assert request_type.Name == "Test Request" assert request_type.Name == "Test Request"
delete(RequestType, RequestType.ID == request_type.ID)
with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID)
def test_request_type_null_name_returns_empty_string(): def test_request_type_null_name_returns_empty_string():
request_type = create(RequestType) with db.begin():
request_type = db.create(RequestType)
assert bool(request_type.ID) assert bool(request_type.ID)
assert request_type.Name == str() assert request_type.Name == str()
delete(RequestType, RequestType.ID == request_type.ID)
with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID)
def test_request_type_name_display(): def test_request_type_name_display():
deletion = query(RequestType, RequestType.ID == DELETION_ID).first() deletion = db.query(RequestType, RequestType.ID == DELETION_ID).first()
assert deletion.name_display() == "Deletion" assert deletion.name_display() == "Deletion"
orphan = query(RequestType, RequestType.ID == ORPHAN_ID).first() orphan = db.query(RequestType, RequestType.ID == ORPHAN_ID).first()
assert orphan.name_display() == "Orphan" assert orphan.name_display() == "Orphan"
merge = query(RequestType, RequestType.ID == MERGE_ID).first() merge = db.query(RequestType, RequestType.ID == MERGE_ID).first()
assert merge.name_display() == "Merge" assert merge.name_display() == "Merge"

View file

@ -8,8 +8,8 @@ import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from aurweb import db
from aurweb.asgi import app from aurweb.asgi import app
from aurweb.db import create, query
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.user import User from aurweb.models.user import User
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -24,9 +24,11 @@ def setup():
setup_test_db("Users", "Sessions") setup_test_db("Users", "Sessions")
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)

View file

@ -49,14 +49,13 @@ def packages(user):
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
# Create 101 packages; we limit 100 on RSS feeds. # Create 101 packages; we limit 100 on RSS feeds.
with db.begin():
for i in range(101): for i in range(101):
pkgbase = db.create( pkgbase = db.create(
PackageBase, Maintainer=user, Name=f"test-package-{i}", PackageBase, Maintainer=user, Name=f"test-package-{i}",
SubmittedTS=(now + i), ModifiedTS=(now + i), autocommit=False) SubmittedTS=(now + i), ModifiedTS=(now + i))
pkg = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase, pkg = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
autocommit=False)
pkgs.append(pkg) pkgs.append(pkg)
db.commit()
yield pkgs yield pkgs

View file

@ -4,7 +4,7 @@ from unittest import mock
import pytest import pytest
from aurweb.db import create, query from aurweb import db
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.session import Session, generate_unique_sid from aurweb.models.session import Session, generate_unique_sid
from aurweb.models.user import User from aurweb.models.user import User
@ -19,12 +19,15 @@ def setup():
setup_test_db("Users", "Sessions") setup_test_db("Users", "Sessions")
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org", with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
ResetKey="testReset", Passwd="testPassword", ResetKey="testReset", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
session = create(Session, UsersID=user.ID, SessionID="testSession",
with db.begin():
session = db.create(Session, UsersID=user.ID, SessionID="testSession",
LastUpdateTS=datetime.utcnow().timestamp()) LastUpdateTS=datetime.utcnow().timestamp())
@ -35,10 +38,13 @@ def test_session():
def test_session_cs(): def test_session_cs():
""" Test case sensitivity of the database table. """ """ Test case sensitivity of the database table. """
user2 = create(User, Username="test2", Email="test2@example.org", with db.begin():
user2 = db.create(User, Username="test2", Email="test2@example.org",
ResetKey="testReset2", Passwd="testPassword", ResetKey="testReset2", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
session_cs = create(Session, UsersID=user2.ID,
with db.begin():
session_cs = db.create(Session, UsersID=user2.ID,
SessionID="TESTSESSION", SessionID="TESTSESSION",
LastUpdateTS=datetime.utcnow().timestamp()) LastUpdateTS=datetime.utcnow().timestamp())
assert session_cs.SessionID == "TESTSESSION" assert session_cs.SessionID == "TESTSESSION"

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aurweb.db import create, query from aurweb import db
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
from aurweb.models.user import User from aurweb.models.user import User
@ -19,16 +19,15 @@ def setup():
setup_test_db("Users", "SSHPubKeys") setup_test_db("Users", "SSHPubKeys")
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org", with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
assert account_type == user.AccountType with db.begin():
assert account_type.ID == user.AccountTypeID ssh_pub_key = db.create(SSHPubKey,
ssh_pub_key = create(SSHPubKey,
UserID=user.ID, UserID=user.ID,
Fingerprint="testFingerprint", Fingerprint="testFingerprint",
PubKey="testPubKey") PubKey="testPubKey")
@ -43,7 +42,8 @@ def test_ssh_pub_key():
def test_ssh_pub_key_cs(): def test_ssh_pub_key_cs():
""" Test case sensitivity of the database table. """ """ Test case sensitivity of the database table. """
ssh_pub_key_cs = create(SSHPubKey, UserID=user.ID, with db.begin():
ssh_pub_key_cs = db.create(SSHPubKey, UserID=user.ID,
Fingerprint="TESTFINGERPRINT", Fingerprint="TESTFINGERPRINT",
PubKey="TESTPUBKEY") PubKey="TESTPUBKEY")

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create from aurweb import db
from aurweb.models.term import Term from aurweb.models.term import Term
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -18,7 +18,8 @@ def setup():
def test_term_creation(): def test_term_creation():
term = create(Term, Description="Term description", with db.begin():
term = db.create(Term, Description="Term description",
URL="https://fake_url.io") URL="https://fake_url.io")
assert bool(term.ID) assert bool(term.ID)
assert term.Description == "Term description" assert term.Description == "Term description"
@ -27,14 +28,14 @@ def test_term_creation():
def test_term_null_description_raises_exception(): def test_term_null_description_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(Term, URL="https://fake_url.io") with db.begin():
session.rollback() db.create(Term, URL="https://fake_url.io")
db.rollback()
def test_term_null_url_raises_exception(): def test_term_null_url_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(Term, Description="Term description") with db.begin():
session.rollback() db.create(Term, Description="Term description")
db.rollback()

View file

@ -90,33 +90,33 @@ def client():
def tu_user(): def tu_user():
tu_type = db.query(AccountType, tu_type = db.query(AccountType,
AccountType.AccountType == "Trusted User").first() AccountType.AccountType == "Trusted User").first()
yield db.create(User, Username="test_tu", Email="test_tu@example.org", with db.begin():
tu_user = db.create(User, Username="test_tu",
Email="test_tu@example.org",
RealName="Test TU", Passwd="testPassword", RealName="Test TU", Passwd="testPassword",
AccountType=tu_type) AccountType=tu_type)
yield tu_user
@pytest.fixture @pytest.fixture
def user(): def user():
user_type = db.query(AccountType, user_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
yield db.create(User, Username="test", Email="test@example.org", with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=user_type) AccountType=user_type)
yield user
@pytest.fixture @pytest.fixture
def proposal(tu_user): def proposal(user, tu_user):
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
agenda = "Test proposal." agenda = "Test proposal."
start = ts - 5 start = ts - 5
end = ts + 1000 end = ts + 1000
user_type = db.query(AccountType, with db.begin():
AccountType.AccountType == "User").first()
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=user_type)
voteinfo = db.create(TUVoteInfo, voteinfo = db.create(TUVoteInfo,
Agenda=agenda, Quorum=0.0, Agenda=agenda, Quorum=0.0,
User=user.Username, Submitter=tu_user, User=user.Username, Submitter=tu_user,
@ -170,6 +170,7 @@ def test_tu_index(client, tu_user):
("Test agenda 2", ts - 1000, ts - 5) # Not running anymore. ("Test agenda 2", ts - 1000, ts - 5) # Not running anymore.
] ]
vote_records = [] vote_records = []
with db.begin():
for vote in votes: for vote in votes:
agenda, start, end = vote agenda, start, end = vote
vote_records.append( vote_records.append(
@ -179,6 +180,7 @@ def test_tu_index(client, tu_user):
Quorum=0.0, Quorum=0.0,
Submitter=tu_user)) Submitter=tu_user))
with db.begin():
# Vote on an ended proposal. # Vote on an ended proposal.
vote_record = vote_records[1] vote_record = vote_records[1]
vote_record.Yes += 1 vote_record.Yes += 1
@ -255,13 +257,14 @@ def test_tu_index(client, tu_user):
def test_tu_index_table_paging(client, tu_user): def test_tu_index_table_paging(client, tu_user):
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
with db.begin():
for i in range(25): for i in range(25):
# Create 25 current votes. # Create 25 current votes.
db.create(TUVoteInfo, Agenda=f"Agenda #{i}", db.create(TUVoteInfo, Agenda=f"Agenda #{i}",
User=tu_user.Username, User=tu_user.Username,
Submitted=(ts - 5), End=(ts + 1000), Submitted=(ts - 5), End=(ts + 1000),
Quorum=0.0, Quorum=0.0,
Submitter=tu_user, autocommit=False) Submitter=tu_user)
for i in range(25): for i in range(25):
# Create 25 past votes. # Create 25 past votes.
@ -269,8 +272,7 @@ def test_tu_index_table_paging(client, tu_user):
User=tu_user.Username, User=tu_user.Username,
Submitted=(ts - 1000), End=(ts - 5), Submitted=(ts - 1000), End=(ts - 5),
Quorum=0.0, Quorum=0.0,
Submitter=tu_user, autocommit=False) Submitter=tu_user)
db.commit()
cookies = {"AURSID": tu_user.login(Request(), "testPassword")} cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request: with client as request:
@ -363,13 +365,14 @@ def test_tu_index_table_paging(client, tu_user):
def test_tu_index_sorting(client, tu_user): def test_tu_index_sorting(client, tu_user):
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
with db.begin():
for i in range(2): for i in range(2):
# Create 'Agenda #1' and 'Agenda #2'. # Create 'Agenda #1' and 'Agenda #2'.
db.create(TUVoteInfo, Agenda=f"Agenda #{i + 1}", db.create(TUVoteInfo, Agenda=f"Agenda #{i + 1}",
User=tu_user.Username, User=tu_user.Username,
Submitted=(ts + 5), End=(ts + 1000), Submitted=(ts + 5), End=(ts + 1000),
Quorum=0.0, Quorum=0.0,
Submitter=tu_user, autocommit=False) Submitter=tu_user)
# Let's order each vote one day after the other. # Let's order each vote one day after the other.
# This will allow us to test the sorting nature # This will allow us to test the sorting nature
@ -432,6 +435,7 @@ def test_tu_index_sorting(client, tu_user):
def test_tu_index_last_votes(client, tu_user, user): def test_tu_index_last_votes(client, tu_user, user):
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
with db.begin():
# Create a proposal which has ended. # Create a proposal which has ended.
voteinfo = db.create(TUVoteInfo, Agenda="Test agenda", voteinfo = db.create(TUVoteInfo, Agenda="Test agenda",
User=user.Username, User=user.Username,
@ -529,10 +533,10 @@ def test_tu_running_proposal(client, proposal):
assert abstain.attrib["value"] == "Abstain" assert abstain.attrib["value"] == "Abstain"
# Create a vote. # Create a vote.
with db.begin():
db.create(TUVote, VoteInfo=voteinfo, User=tu_user) db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
voteinfo.ActiveTUs += 1 voteinfo.ActiveTUs += 1
voteinfo.Yes += 1 voteinfo.Yes += 1
db.commit()
# Make another request now that we've voted. # Make another request now that we've voted.
with client as request: with client as request:
@ -556,8 +560,8 @@ def test_tu_ended_proposal(client, proposal):
tu_user, user, voteinfo = proposal tu_user, user, voteinfo = proposal
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
with db.begin():
voteinfo.End = ts - 5 # 5 seconds ago. voteinfo.End = ts - 5 # 5 seconds ago.
db.commit()
# Initiate an authenticated GET request to /tu/{proposal_id}. # Initiate an authenticated GET request to /tu/{proposal_id}.
proposal_id = voteinfo.ID proposal_id = voteinfo.ID
@ -635,8 +639,8 @@ def test_tu_proposal_vote_unauthorized(client, proposal):
dev_type = db.query(AccountType, dev_type = db.query(AccountType,
AccountType.AccountType == "Developer").first() AccountType.AccountType == "Developer").first()
with db.begin():
tu_user.AccountType = dev_type tu_user.AccountType = dev_type
db.commit()
cookies = {"AURSID": tu_user.login(Request(), "testPassword")} cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request: with client as request:
@ -664,8 +668,8 @@ def test_tu_proposal_vote_cant_self_vote(client, proposal):
tu_user, user, voteinfo = proposal tu_user, user, voteinfo = proposal
# Update voteinfo.User. # Update voteinfo.User.
with db.begin():
voteinfo.User = tu_user.Username voteinfo.User = tu_user.Username
db.commit()
cookies = {"AURSID": tu_user.login(Request(), "testPassword")} cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request: with client as request:
@ -692,10 +696,10 @@ def test_tu_proposal_vote_cant_self_vote(client, proposal):
def test_tu_proposal_vote_already_voted(client, proposal): def test_tu_proposal_vote_already_voted(client, proposal):
tu_user, user, voteinfo = proposal tu_user, user, voteinfo = proposal
with db.begin():
db.create(TUVote, VoteInfo=voteinfo, User=tu_user) db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
voteinfo.Yes += 1 voteinfo.Yes += 1
voteinfo.ActiveTUs += 1 voteinfo.ActiveTUs += 1
db.commit()
cookies = {"AURSID": tu_user.login(Request(), "testPassword")} cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request: with client as request:

View file

@ -4,7 +4,8 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query, rollback from aurweb import db
from aurweb.db import create, query, rollback
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.tu_voteinfo import TUVoteInfo from aurweb.models.tu_voteinfo import TUVoteInfo
from aurweb.models.user import User from aurweb.models.user import User
@ -21,6 +22,7 @@ def setup():
tu_type = query(AccountType, tu_type = query(AccountType,
AccountType.AccountType == "Trusted User").first() AccountType.AccountType == "Trusted User").first()
with db.begin():
user = create(User, Username="test", Email="test@example.org", user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=tu_type) AccountType=tu_type)
@ -28,6 +30,7 @@ def setup():
def test_tu_voteinfo_creation(): def test_tu_voteinfo_creation():
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
with db.begin():
tu_voteinfo = create(TUVoteInfo, tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.", Agenda="Blah blah.",
User=user.Username, User=user.Username,
@ -51,6 +54,7 @@ def test_tu_voteinfo_creation():
def test_tu_voteinfo_is_running(): def test_tu_voteinfo_is_running():
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
with db.begin():
tu_voteinfo = create(TUVoteInfo, tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.", Agenda="Blah blah.",
User=user.Username, User=user.Username,
@ -59,13 +63,14 @@ def test_tu_voteinfo_is_running():
Submitter=user) Submitter=user)
assert tu_voteinfo.is_running() is True assert tu_voteinfo.is_running() is True
with db.begin():
tu_voteinfo.End = ts - 5 tu_voteinfo.End = ts - 5
commit()
assert tu_voteinfo.is_running() is False assert tu_voteinfo.is_running() is False
def test_tu_voteinfo_total_votes(): def test_tu_voteinfo_total_votes():
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
with db.begin():
tu_voteinfo = create(TUVoteInfo, tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.", Agenda="Blah blah.",
User=user.Username, User=user.Username,
@ -76,7 +81,6 @@ def test_tu_voteinfo_total_votes():
tu_voteinfo.Yes = 1 tu_voteinfo.Yes = 1
tu_voteinfo.No = 3 tu_voteinfo.No = 3
tu_voteinfo.Abstain = 5 tu_voteinfo.Abstain = 5
commit()
# total_votes() should be the sum of Yes, No and Abstain: 1 + 3 + 5 = 9. # total_votes() should be the sum of Yes, No and Abstain: 1 + 3 + 5 = 9.
assert tu_voteinfo.total_votes() == 9 assert tu_voteinfo.total_votes() == 9
@ -84,6 +88,7 @@ def test_tu_voteinfo_total_votes():
def test_tu_voteinfo_null_submitter_raises_exception(): def test_tu_voteinfo_null_submitter_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo, create(TUVoteInfo,
Agenda="Blah blah.", Agenda="Blah blah.",
User=user.Username, User=user.Username,
@ -94,6 +99,7 @@ def test_tu_voteinfo_null_submitter_raises_exception():
def test_tu_voteinfo_null_agenda_raises_exception(): def test_tu_voteinfo_null_agenda_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo, create(TUVoteInfo,
User=user.Username, User=user.Username,
Submitted=0, End=0, Submitted=0, End=0,
@ -104,6 +110,7 @@ def test_tu_voteinfo_null_agenda_raises_exception():
def test_tu_voteinfo_null_user_raises_exception(): def test_tu_voteinfo_null_user_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo, create(TUVoteInfo,
Agenda="Blah blah.", Agenda="Blah blah.",
Submitted=0, End=0, Submitted=0, End=0,
@ -114,6 +121,7 @@ def test_tu_voteinfo_null_user_raises_exception():
def test_tu_voteinfo_null_submitted_raises_exception(): def test_tu_voteinfo_null_submitted_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo, create(TUVoteInfo,
Agenda="Blah blah.", Agenda="Blah blah.",
User=user.Username, User=user.Username,
@ -125,6 +133,7 @@ def test_tu_voteinfo_null_submitted_raises_exception():
def test_tu_voteinfo_null_end_raises_exception(): def test_tu_voteinfo_null_end_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo, create(TUVoteInfo,
Agenda="Blah blah.", Agenda="Blah blah.",
User=user.Username, User=user.Username,
@ -136,6 +145,7 @@ def test_tu_voteinfo_null_end_raises_exception():
def test_tu_voteinfo_null_quorum_raises_exception(): def test_tu_voteinfo_null_quorum_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
with db.begin():
create(TUVoteInfo, create(TUVoteInfo,
Agenda="Blah blah.", Agenda="Blah blah.",
User=user.Username, User=user.Username,

View file

@ -9,7 +9,7 @@ import pytest
import aurweb.auth import aurweb.auth
import aurweb.config import aurweb.config
from aurweb.db import commit, create, query from aurweb import db
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.ban import Ban from aurweb.models.ban import Ban
from aurweb.models.package import Package from aurweb.models.package import Package
@ -40,10 +40,11 @@ def setup():
PackageNotification.__tablename__ PackageNotification.__tablename__
) )
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org", with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", RealName="Test User", Passwd="testPassword",
AccountType=account_type) AccountType=account_type)
@ -70,14 +71,14 @@ def test_user_login_logout():
assert "AURSID" in request.cookies assert "AURSID" in request.cookies
# Expect that User session relationships work right. # Expect that User session relationships work right.
user_session = query(Session, user_session = db.query(Session,
Session.UsersID == user.ID).first() Session.UsersID == user.ID).first()
assert user_session == user.session assert user_session == user.session
assert user.session.SessionID == sid assert user.session.SessionID == sid
assert user.session.User == user assert user.session.User == user
# Search for the user via query API. # Search for the user via query API.
result = query(User, User.ID == user.ID).first() result = db.query(User, User.ID == user.ID).first()
# Compare the result and our original user. # Compare the result and our original user.
assert result == user assert result == user
@ -114,7 +115,8 @@ def test_user_login_twice():
def test_user_login_banned(): def test_user_login_banned():
# Add ban for the next 30 seconds. # Add ban for the next 30 seconds.
banned_timestamp = datetime.utcnow() + timedelta(seconds=30) banned_timestamp = datetime.utcnow() + timedelta(seconds=30)
create(Ban, IPAddress="127.0.0.1", BanTS=banned_timestamp) with db.begin():
db.create(Ban, IPAddress="127.0.0.1", BanTS=banned_timestamp)
request = Request() request = Request()
request.client.host = "127.0.0.1" request.client.host = "127.0.0.1"
@ -122,18 +124,17 @@ def test_user_login_banned():
def test_user_login_suspended(): def test_user_login_suspended():
from aurweb.db import session with db.begin():
user.Suspended = True user.Suspended = True
session.commit()
assert not user.login(Request(), "testPassword") assert not user.login(Request(), "testPassword")
def test_legacy_user_authentication(): def test_legacy_user_authentication():
from aurweb.db import session with db.begin():
user.Salt = bcrypt.gensalt().decode() user.Salt = bcrypt.gensalt().decode()
user.Passwd = hashlib.md5(f"{user.Salt}testPassword".encode()).hexdigest() user.Passwd = hashlib.md5(
session.commit() f"{user.Salt}testPassword".encode()
).hexdigest()
assert not user.valid_password("badPassword") assert not user.valid_password("badPassword")
assert user.valid_password("testPassword") assert user.valid_password("testPassword")
@ -145,7 +146,8 @@ def test_legacy_user_authentication():
def test_user_login_with_outdated_sid(): def test_user_login_with_outdated_sid():
# Make a session with a LastUpdateTS 5 seconds ago, causing # Make a session with a LastUpdateTS 5 seconds ago, causing
# user.login to update it with a new sid. # user.login to update it with a new sid.
create(Session, UsersID=user.ID, SessionID="stub", with db.begin():
db.create(Session, UsersID=user.ID, SessionID="stub",
LastUpdateTS=datetime.utcnow().timestamp() - 5) LastUpdateTS=datetime.utcnow().timestamp() - 5)
sid = user.login(Request(), "testPassword") sid = user.login(Request(), "testPassword")
assert sid and user.is_authenticated() assert sid and user.is_authenticated()
@ -171,7 +173,8 @@ def test_user_has_credential():
def test_user_ssh_pub_key(): def test_user_ssh_pub_key():
assert user.ssh_pub_key is None assert user.ssh_pub_key is None
ssh_pub_key = create(SSHPubKey, UserID=user.ID, with db.begin():
ssh_pub_key = db.create(SSHPubKey, UserID=user.ID,
Fingerprint="testFingerprint", Fingerprint="testFingerprint",
PubKey="testPubKey") PubKey="testPubKey")
@ -179,35 +182,33 @@ def test_user_ssh_pub_key():
def test_user_credential_types(): def test_user_credential_types():
from aurweb.db import session
assert aurweb.auth.user_developer_or_trusted_user(user) assert aurweb.auth.user_developer_or_trusted_user(user)
assert not aurweb.auth.trusted_user(user) assert not aurweb.auth.trusted_user(user)
assert not aurweb.auth.developer(user) assert not aurweb.auth.developer(user)
assert not aurweb.auth.trusted_user_or_dev(user) assert not aurweb.auth.trusted_user_or_dev(user)
trusted_user_type = query(AccountType, trusted_user_type = db.query(AccountType).filter(
AccountType.AccountType == "Trusted User")\ AccountType.AccountType == "Trusted User"
.first() ).first()
with db.begin():
user.AccountType = trusted_user_type user.AccountType = trusted_user_type
session.commit()
assert aurweb.auth.trusted_user(user) assert aurweb.auth.trusted_user(user)
assert aurweb.auth.trusted_user_or_dev(user) assert aurweb.auth.trusted_user_or_dev(user)
developer_type = query(AccountType, developer_type = db.query(AccountType,
AccountType.AccountType == "Developer").first() AccountType.AccountType == "Developer").first()
with db.begin():
user.AccountType = developer_type user.AccountType = developer_type
session.commit()
assert aurweb.auth.developer(user) assert aurweb.auth.developer(user)
assert aurweb.auth.trusted_user_or_dev(user) assert aurweb.auth.trusted_user_or_dev(user)
type_str = "Trusted User & Developer" type_str = "Trusted User & Developer"
elevated_type = query(AccountType, elevated_type = db.query(AccountType,
AccountType.AccountType == type_str).first() AccountType.AccountType == type_str).first()
with db.begin():
user.AccountType = elevated_type user.AccountType = elevated_type
session.commit()
assert aurweb.auth.trusted_user(user) assert aurweb.auth.trusted_user(user)
assert aurweb.auth.developer(user) assert aurweb.auth.developer(user)
@ -233,53 +234,56 @@ def test_user_as_dict():
def test_user_is_trusted_user(): def test_user_is_trusted_user():
tu_type = query(AccountType, tu_type = db.query(AccountType,
AccountType.AccountType == "Trusted User").first() AccountType.AccountType == "Trusted User").first()
with db.begin():
user.AccountType = tu_type user.AccountType = tu_type
commit()
assert user.is_trusted_user() is True assert user.is_trusted_user() is True
# Do it again with the combined role. # Do it again with the combined role.
tu_type = query( tu_type = db.query(
AccountType, AccountType,
AccountType.AccountType == "Trusted User & Developer").first() AccountType.AccountType == "Trusted User & Developer").first()
with db.begin():
user.AccountType = tu_type user.AccountType = tu_type
commit()
assert user.is_trusted_user() is True assert user.is_trusted_user() is True
def test_user_is_developer(): def test_user_is_developer():
dev_type = query(AccountType, dev_type = db.query(AccountType,
AccountType.AccountType == "Developer").first() AccountType.AccountType == "Developer").first()
with db.begin():
user.AccountType = dev_type user.AccountType = dev_type
commit()
assert user.is_developer() is True assert user.is_developer() is True
# Do it again with the combined role. # Do it again with the combined role.
dev_type = query( dev_type = db.query(
AccountType, AccountType,
AccountType.AccountType == "Trusted User & Developer").first() AccountType.AccountType == "Trusted User & Developer").first()
with db.begin():
user.AccountType = dev_type user.AccountType = dev_type
commit()
assert user.is_developer() is True assert user.is_developer() is True
def test_user_voted_for(): def test_user_voted_for():
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user) with db.begin():
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name) pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
create(PackageVote, PackageBase=pkgbase, User=user, VoteTS=now) pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
db.create(PackageVote, PackageBase=pkgbase, User=user, VoteTS=now)
assert user.voted_for(pkg) assert user.voted_for(pkg)
def test_user_notified(): def test_user_notified():
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user) with db.begin():
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name) pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
create(PackageNotification, PackageBase=pkgbase, User=user) pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
db.create(PackageNotification, PackageBase=pkgbase, User=user)
assert user.notified(pkg) assert user.notified(pkg)
def test_user_packages(): def test_user_packages():
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user) with db.begin():
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name) pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
assert pkg in user.packages() assert pkg in user.packages()