[FastAPI] Refactor db modifications

For SQLAlchemy to automatically understand updates from the
external world, it must use an `autocommit=True` in its session.

This change breaks how we were using commit previously, as
`autocommit=True` causes SQLAlchemy to commit when a
SessionTransaction context hits __exit__.

So, a refactoring was required of our tests: All usage of
any `db.{create,delete}` must be called **within** a
SessionTransaction context, created via new `db.begin()`.

From this point forward, we're going to require:

```
with db.begin():
    db.create(...)
    db.delete(...)
    db.session.delete(object)
```

With this, we now get external DB modifications automatically
without reloading or restarting the FastAPI server, which we
absolutely need for production.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2021-09-02 16:26:48 -07:00
parent b52059d437
commit a5943bf2ad
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
37 changed files with 998 additions and 902 deletions

View file

@ -59,20 +59,15 @@ def query(model, *args, **kwargs):
return session.query(model).filter(*args, **kwargs) return session.query(model).filter(*args, **kwargs)
def create(model, autocommit: bool = True, *args, **kwargs): def create(model, *args, **kwargs):
instance = model(*args, **kwargs) instance = model(*args, **kwargs)
add(instance) return add(instance)
if autocommit is True:
commit()
return instance
def delete(model, *args, autocommit: bool = True, **kwargs): def delete(model, *args, **kwargs):
instance = session.query(model).filter(*args, **kwargs) instance = session.query(model).filter(*args, **kwargs)
for record in instance: for record in instance:
session.delete(record) session.delete(record)
if autocommit is True:
commit()
def rollback(): def rollback():
@ -84,8 +79,25 @@ def add(model):
return model return model
def commit(): def begin():
session.commit() """ Begin an SQLAlchemy SessionTransaction.
This context is **required** to perform an modifications to the
database.
Example:
with db.begin():
object = db.create(...)
# On __exit__, db.commit() is run.
with db.begin():
object = db.delete(...)
# On __exit__, db.commit() is run.
:return: A new SessionTransaction based on session
"""
return session.begin()
def get_sqlalchemy_url(): def get_sqlalchemy_url():
@ -155,23 +167,23 @@ def get_engine(echo: bool = False):
connect_args=connect_args, connect_args=connect_args,
echo=echo) echo=echo)
Session = sessionmaker(autocommit=True, autoflush=False, bind=engine)
session = Session()
if db_backend == "sqlite": if db_backend == "sqlite":
# For SQLite, we need to add some custom functions as # For SQLite, we need to add some custom functions as
# they are used in the reference graph method. # they are used in the reference graph method.
def regexp(regex, item): def regexp(regex, item):
return bool(re.search(regex, str(item))) return bool(re.search(regex, str(item)))
@event.listens_for(engine, "begin") @event.listens_for(engine, "connect")
def do_begin(conn): def do_begin(conn, record):
create_deterministic_function = functools.partial( create_deterministic_function = functools.partial(
conn.connection.create_function, conn.create_function,
deterministic=True deterministic=True
) )
create_deterministic_function("REGEXP", 2, regexp) create_deterministic_function("REGEXP", 2, regexp)
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
return engine return engine

View file

@ -102,7 +102,7 @@ class User(Base):
def login(self, request: Request, password: str, session_time=0): def login(self, request: Request, password: str, session_time=0):
""" Login and authenticate a request. """ """ Login and authenticate a request. """
from aurweb.db import session from aurweb import db
from aurweb.models.session import Session, generate_unique_sid from aurweb.models.session import Session, generate_unique_sid
if not self._login_approved(request): if not self._login_approved(request):
@ -112,10 +112,7 @@ class User(Base):
if not self.authenticated: if not self.authenticated:
return None return None
self.LastLogin = now_ts = datetime.utcnow().timestamp() now_ts = datetime.utcnow().timestamp()
self.LastLoginIPAddress = request.client.host
session.commit()
session_ts = now_ts + ( session_ts = now_ts + (
session_time if session_time session_time if session_time
else aurweb.config.getint("options", "login_timeout") else aurweb.config.getint("options", "login_timeout")
@ -123,22 +120,23 @@ class User(Base):
sid = None sid = None
if not self.session: with db.begin():
sid = generate_unique_sid() self.LastLogin = now_ts
self.session = Session(UsersID=self.ID, SessionID=sid, self.LastLoginIPAddress = request.client.host
LastUpdateTS=session_ts) if not self.session:
session.add(self.session) sid = generate_unique_sid()
else: self.session = Session(UsersID=self.ID, SessionID=sid,
last_updated = self.session.LastUpdateTS LastUpdateTS=session_ts)
if last_updated and last_updated < now_ts: db.add(self.session)
self.session.SessionID = sid = generate_unique_sid()
else: else:
# Session is still valid; retrieve the current SID. last_updated = self.session.LastUpdateTS
sid = self.session.SessionID if last_updated and last_updated < now_ts:
self.session.SessionID = sid = generate_unique_sid()
else:
# Session is still valid; retrieve the current SID.
sid = self.session.SessionID
self.session.LastUpdateTS = session_ts self.session.LastUpdateTS = session_ts
session.commit()
request.cookies["AURSID"] = self.session.SessionID request.cookies["AURSID"] = self.session.SessionID
return self.session.SessionID return self.session.SessionID
@ -149,13 +147,11 @@ class User(Base):
return aurweb.auth.has_credential(self, cred, approved) return aurweb.auth.has_credential(self, cred, approved)
def logout(self, request): def logout(self, request):
from aurweb.db import session
del request.cookies["AURSID"] del request.cookies["AURSID"]
self.authenticated = False self.authenticated = False
if self.session: if self.session:
session.delete(self.session) with db.begin():
session.commit() db.session.delete(self.session)
def is_trusted_user(self): def is_trusted_user(self):
return self.AccountType.ID in { return self.AccountType.ID in {

View file

@ -43,8 +43,6 @@ async def passreset_post(request: Request,
resetkey: str = Form(default=None), resetkey: str = Form(default=None),
password: str = Form(default=None), password: str = Form(default=None),
confirm: str = Form(default=None)): confirm: str = Form(default=None)):
from aurweb.db import session
context = await make_variable_context(request, "Password Reset") context = await make_variable_context(request, "Password Reset")
# The user parameter being required, we can match against # The user parameter being required, we can match against
@ -86,12 +84,11 @@ async def passreset_post(request: Request,
# We got to this point; everything matched up. Update the password # We got to this point; everything matched up. Update the password
# and remove the ResetKey. # and remove the ResetKey.
user.ResetKey = str() with db.begin():
user.update_password(password) user.ResetKey = str()
if user.session:
if user.session: db.session.delete(user.session)
session.delete(user.session) user.update_password(password)
session.commit()
# Render ?step=complete. # Render ?step=complete.
return RedirectResponse(url="/passreset?step=complete", return RedirectResponse(url="/passreset?step=complete",
@ -99,8 +96,8 @@ async def passreset_post(request: Request,
# If we got here, we continue with issuing a resetkey for the user. # If we got here, we continue with issuing a resetkey for the user.
resetkey = db.make_random_value(User, User.ResetKey) resetkey = db.make_random_value(User, User.ResetKey)
user.ResetKey = resetkey with db.begin():
session.commit() user.ResetKey = resetkey
executor = db.ConnectionExecutor(db.get_engine().raw_connection()) executor = db.ConnectionExecutor(db.get_engine().raw_connection())
ResetKeyNotification(executor, user.ID).send() ResetKeyNotification(executor, user.ID).send()
@ -364,8 +361,6 @@ async def account_register_post(request: Request,
ON: bool = Form(default=False), ON: bool = Form(default=False),
captcha: str = Form(default=None), captcha: str = Form(default=None),
captcha_salt: str = Form(...)): captcha_salt: str = Form(...)):
from aurweb.db import session
context = await make_variable_context(request, "Register") context = await make_variable_context(request, "Register")
args = dict(await request.form()) args = dict(await request.form())
@ -394,11 +389,13 @@ async def account_register_post(request: Request,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
# Create a user given all parameters available. # Create a user given all parameters available.
user = db.create(User, Username=U, Email=E, HideEmail=H, BackupEmail=BE, with db.begin():
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K, user = db.create(User, Username=U,
LangPreference=L, Timezone=TZ, CommentNotify=CN, Email=E, HideEmail=H, BackupEmail=BE,
UpdateNotify=UN, OwnershipNotify=ON, ResetKey=resetkey, RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
AccountType=account_type) LangPreference=L, Timezone=TZ, CommentNotify=CN,
UpdateNotify=UN, OwnershipNotify=ON,
ResetKey=resetkey, AccountType=account_type)
# If a PK was given and either one does not exist or the given # If a PK was given and either one does not exist or the given
# PK mismatches the existing user's SSHPubKey.PubKey. # PK mismatches the existing user's SSHPubKey.PubKey.
@ -410,10 +407,10 @@ async def account_register_post(request: Request,
# Remove the host part. # Remove the host part.
pubkey = parts[0] + " " + parts[1] pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey) fingerprint = get_fingerprint(pubkey)
user.ssh_pub_key = SSHPubKey(UserID=user.ID, with db.begin():
PubKey=pubkey, user.ssh_pub_key = SSHPubKey(UserID=user.ID,
Fingerprint=fingerprint) PubKey=pubkey,
session.commit() Fingerprint=fingerprint)
# Send a reset key notification to the new user. # Send a reset key notification to the new user.
executor = db.ConnectionExecutor(db.get_engine().raw_connection()) executor = db.ConnectionExecutor(db.get_engine().raw_connection())
@ -499,63 +496,67 @@ async def account_edit_post(request: Request,
status_code=int(HTTPStatus.BAD_REQUEST)) status_code=int(HTTPStatus.BAD_REQUEST))
# Set all updated fields as needed. # Set all updated fields as needed.
user.Username = U or user.Username with db.begin():
user.Email = E or user.Email user.Username = U or user.Username
user.HideEmail = bool(H) user.Email = E or user.Email
user.BackupEmail = BE or user.BackupEmail user.HideEmail = bool(H)
user.RealName = R or user.RealName user.BackupEmail = BE or user.BackupEmail
user.Homepage = HP or user.Homepage user.RealName = R or user.RealName
user.IRCNick = I or user.IRCNick user.Homepage = HP or user.Homepage
user.PGPKey = K or user.PGPKey user.IRCNick = I or user.IRCNick
user.InactivityTS = datetime.utcnow().timestamp() if J else 0 user.PGPKey = K or user.PGPKey
user.InactivityTS = datetime.utcnow().timestamp() if J else 0
# If we update the language, update the cookie as well. # If we update the language, update the cookie as well.
if L and L != user.LangPreference: if L and L != user.LangPreference:
request.cookies["AURLANG"] = L request.cookies["AURLANG"] = L
user.LangPreference = L with db.begin():
user.LangPreference = L
context["language"] = L context["language"] = L
# If we update the timezone, also update the cookie. # If we update the timezone, also update the cookie.
if TZ and TZ != user.Timezone: if TZ and TZ != user.Timezone:
user.Timezone = TZ with db.begin():
user.Timezone = TZ
request.cookies["AURTZ"] = TZ request.cookies["AURTZ"] = TZ
context["timezone"] = TZ context["timezone"] = TZ
user.CommentNotify = bool(CN) with db.begin():
user.UpdateNotify = bool(UN) user.CommentNotify = bool(CN)
user.OwnershipNotify = bool(ON) user.UpdateNotify = bool(UN)
user.OwnershipNotify = bool(ON)
# If a PK is given, compare it against the target user's PK. # If a PK is given, compare it against the target user's PK.
if PK: with db.begin():
# Get the second token in the public key, which is the actual key. if PK:
pubkey = PK.strip().rstrip() # Get the second token in the public key, which is the actual key.
parts = pubkey.split(" ") pubkey = PK.strip().rstrip()
if len(parts) == 3: parts = pubkey.split(" ")
# Remove the host part. if len(parts) == 3:
pubkey = parts[0] + " " + parts[1] # Remove the host part.
fingerprint = get_fingerprint(pubkey) pubkey = parts[0] + " " + parts[1]
if not user.ssh_pub_key: fingerprint = get_fingerprint(pubkey)
# No public key exists, create one. if not user.ssh_pub_key:
user.ssh_pub_key = SSHPubKey(UserID=user.ID, # No public key exists, create one.
PubKey=pubkey, user.ssh_pub_key = SSHPubKey(UserID=user.ID,
Fingerprint=fingerprint) PubKey=pubkey,
elif user.ssh_pub_key.PubKey != pubkey: Fingerprint=fingerprint)
# A public key already exists, update it. elif user.ssh_pub_key.PubKey != pubkey:
user.ssh_pub_key.PubKey = pubkey # A public key already exists, update it.
user.ssh_pub_key.Fingerprint = fingerprint user.ssh_pub_key.PubKey = pubkey
elif user.ssh_pub_key: user.ssh_pub_key.Fingerprint = fingerprint
# Else, if the user has a public key already, delete it. elif user.ssh_pub_key:
session.delete(user.ssh_pub_key) # Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key)
# Commit changes, if any.
session.commit()
if P and not user.valid_password(P): if P and not user.valid_password(P):
# Remove the fields we consumed for passwords. # Remove the fields we consumed for passwords.
context["P"] = context["C"] = str() context["P"] = context["C"] = str()
# If a password was given and it doesn't match the user's, update it. # If a password was given and it doesn't match the user's, update it.
user.update_password(P) with db.begin():
user.update_password(P)
if user == request.user: if user == request.user:
# If the target user is the request user, login with # If the target user is the request user, login with
# the updated password and update AURSID. # the updated password and update AURSID.
@ -731,21 +732,17 @@ async def terms_of_service_post(request: Request,
accept_needed = sorted(unaccepted + diffs) accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(request, context, accept_needed) return render_terms_of_service(request, context, accept_needed)
# For each term we found, query for the matching accepted term with db.begin():
# and update its Revision to the term's current Revision. # For each term we found, query for the matching accepted term
for term in diffs: # and update its Revision to the term's current Revision.
accepted_term = request.user.accepted_terms.filter( for term in diffs:
AcceptedTerm.TermsID == term.ID).first() accepted_term = request.user.accepted_terms.filter(
accepted_term.Revision = term.Revision AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it! # For each term that was never accepted, accept it!
for term in unaccepted: for term in unaccepted:
db.create(AcceptedTerm, User=request.user, db.create(AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision, Term=term, Revision=term.Revision)
autocommit=False)
if diffs or unaccepted:
# If we had any terms to update, commit the changes.
db.commit()
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER)) return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))

View file

@ -44,8 +44,6 @@ async def language(request: Request,
setting the language on any page, we want to preserve query setting the language on any page, we want to preserve query
parameters across the redirect. parameters across the redirect.
""" """
from aurweb.db import session
if next[0] != '/': if next[0] != '/':
return HTMLResponse(b"Invalid 'next' parameter.", status_code=400) return HTMLResponse(b"Invalid 'next' parameter.", status_code=400)
@ -53,8 +51,8 @@ async def language(request: Request,
# If the user is authenticated, update the user's LangPreference. # If the user is authenticated, update the user's LangPreference.
if request.user.is_authenticated(): if request.user.is_authenticated():
request.user.LangPreference = set_lang with db.begin():
session.commit() request.user.LangPreference = set_lang
# In any case, set the response's AURLANG cookie that never expires. # In any case, set the response's AURLANG cookie that never expires.
response = RedirectResponse(url=f"{next}{query_string}", response = RedirectResponse(url=f"{next}{query_string}",

View file

@ -214,10 +214,9 @@ async def trusted_user_proposal_post(request: Request,
return Response("Invalid 'decision' value.", return Response("Invalid 'decision' value.",
status_code=int(HTTPStatus.BAD_REQUEST)) status_code=int(HTTPStatus.BAD_REQUEST))
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo, with db.begin():
autocommit=False) vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo)
voteinfo.ActiveTUs += 1 voteinfo.ActiveTUs += 1
db.commit()
context["error"] = "You've already voted for this proposal." context["error"] = "You've already voted for this proposal."
return render_proposal(request, context, proposal, voteinfo, voters, vote) return render_proposal(request, context, proposal, voteinfo, voters, vote)
@ -294,12 +293,13 @@ async def trusted_user_addvote_post(request: Request,
agenda = re.sub(r'<[/]?style.*>', '', agenda) agenda = re.sub(r'<[/]?style.*>', '', agenda)
# Create a new TUVoteInfo (proposal)! # Create a new TUVoteInfo (proposal)!
voteinfo = db.create(TUVoteInfo, with db.begin():
User=user, voteinfo = db.create(TUVoteInfo,
Agenda=agenda, User=user,
Submitted=timestamp, End=timestamp + duration, Agenda=agenda,
Quorum=quorum, Submitted=timestamp, End=timestamp + duration,
Submitter=request.user) Quorum=quorum,
Submitter=request.user)
# Redirect to the new proposal. # Redirect to the new proposal.
return RedirectResponse(f"/tu/{voteinfo.ID}", return RedirectResponse(f"/tu/{voteinfo.ID}",

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aurweb.db import create, delete, query from aurweb.db import begin, create, delete, query
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.user import User from aurweb.models.user import User
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -14,11 +14,13 @@ def setup():
global account_type global account_type
account_type = create(AccountType, AccountType="TestUser") with begin():
account_type = create(AccountType, AccountType="TestUser")
yield account_type yield account_type
delete(AccountType, AccountType.ID == account_type.ID) with begin():
delete(AccountType, AccountType.ID == account_type.ID)
def test_account_type(): def test_account_type():
@ -38,12 +40,14 @@ def test_account_type():
def test_user_account_type_relationship(): def test_user_account_type_relationship():
user = create(User, Username="test", Email="test@example.org", with begin():
RealName="Test User", Passwd="testPassword", user = create(User, Username="test", Email="test@example.org",
AccountType=account_type) RealName="Test User", Passwd="testPassword",
AccountType=account_type)
assert user.AccountType == account_type assert user.AccountType == account_type
# This must be deleted here to avoid foreign key issues when # This must be deleted here to avoid foreign key issues when
# deleting the temporary AccountType in the fixture. # deleting the temporary AccountType in the fixture.
delete(User, User.ID == user.ID) with begin():
delete(User, User.ID == user.ID)

View file

@ -11,9 +11,9 @@ import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from aurweb import captcha from aurweb import captcha, db
from aurweb.asgi import app from aurweb.asgi import app
from aurweb.db import commit, create, query from aurweb.db import create, query
from aurweb.models.accepted_term import AcceptedTerm from aurweb.models.accepted_term import AcceptedTerm
from aurweb.models.account_type import DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, AccountType from aurweb.models.account_type import DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, AccountType
from aurweb.models.ban import Ban from aurweb.models.ban import Ban
@ -57,9 +57,11 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test UserZ", Passwd="testPassword", with db.begin():
IRCNick="testZ", AccountType=account_type) user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test UserZ", Passwd="testPassword",
IRCNick="testZ", AccountType=account_type)
yield user yield user
@ -70,9 +72,10 @@ def setup():
@pytest.fixture @pytest.fixture
def tu_user(): def tu_user():
user.AccountType = query(AccountType, with db.begin():
AccountType.ID == TRUSTED_USER_AND_DEV_ID).first() user.AccountType = query(AccountType).filter(
commit() AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
yield user yield user
@ -149,11 +152,9 @@ def test_post_passreset_user():
def test_post_passreset_resetkey(): def test_post_passreset_resetkey():
from aurweb.db import session with db.begin():
user.session = Session(UsersID=user.ID, SessionID="blah",
user.session = Session(UsersID=user.ID, SessionID="blah", LastUpdateTS=datetime.utcnow().timestamp())
LastUpdateTS=datetime.utcnow().timestamp())
session.commit()
# Prepare a password reset. # Prepare a password reset.
with client as request: with client as request:
@ -357,7 +358,8 @@ def test_post_register_error_invalid_captcha():
def test_post_register_error_ip_banned(): def test_post_register_error_ip_banned():
# 'testclient' is used as request.client.host via FastAPI TestClient. # 'testclient' is used as request.client.host via FastAPI TestClient.
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow()) with db.begin():
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
with client as request: with client as request:
response = post_register(request) response = post_register(request)
@ -576,7 +578,8 @@ def test_post_register_error_ssh_pubkey_taken():
# Take the sha256 fingerprint of the ssh public key, create it. # Take the sha256 fingerprint of the ssh public key, create it.
fp = get_fingerprint(pk) fp = get_fingerprint(pk)
create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp) with db.begin():
create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp)
with client as request: with client as request:
response = post_register(request, PK=pk) response = post_register(request, PK=pk)
@ -660,13 +663,11 @@ def test_post_account_edit():
def test_post_account_edit_dev(): def test_post_account_edit_dev():
from aurweb.db import session
# Modify our user to be a "Trusted User & Developer" # Modify our user to be a "Trusted User & Developer"
name = "Trusted User & Developer" name = "Trusted User & Developer"
tu_or_dev = query(AccountType, AccountType.AccountType == name).first() tu_or_dev = query(AccountType, AccountType.AccountType == name).first()
user.AccountType = tu_or_dev with db.begin():
session.commit() user.AccountType = tu_or_dev
request = Request() request = Request()
sid = user.login(request, "testPassword") sid = user.login(request, "testPassword")
@ -1001,21 +1002,19 @@ def get_rows(html):
def test_post_accounts(tu_user): def test_post_accounts(tu_user):
# Set a PGPKey. # Set a PGPKey.
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30" with db.begin():
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
# Create a few more users. # Create a few more users.
users = [user] users = [user]
for i in range(10): with db.begin():
_user = create(User, Username=f"test_{i}", for i in range(10):
Email=f"test_{i}@example.org", _user = create(User, Username=f"test_{i}",
RealName=f"Test #{i}", Email=f"test_{i}@example.org",
Passwd="testPassword", RealName=f"Test #{i}",
IRCNick=f"test_#{i}", Passwd="testPassword",
autocommit=False) IRCNick=f"test_#{i}")
users.append(_user) users.append(_user)
# Commit everything to the database.
commit()
sid = user.login(Request(), "testPassword") sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid} cookies = {"AURSID": sid}
@ -1085,11 +1084,12 @@ def test_post_accounts_account_type(tu_user):
# test the `u` parameter. # test the `u` parameter.
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
create(User, Username="test_2", with db.begin():
Email="test_2@example.org", create(User, Username="test_2",
RealName="Test User 2", Email="test_2@example.org",
Passwd="testPassword", RealName="Test User 2",
AccountType=account_type) Passwd="testPassword",
AccountType=account_type)
# Expect no entries; we marked our only user as a User type. # Expect no entries; we marked our only user as a User type.
with client as request: with client as request:
@ -1113,9 +1113,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "User" assert type.text.strip() == "User"
# Set our only user to a Trusted User. # Set our only user to a Trusted User.
user.AccountType = query(AccountType, with db.begin():
AccountType.ID == TRUSTED_USER_ID).first() user.AccountType = query(AccountType).filter(
commit() AccountType.ID == TRUSTED_USER_ID
).first()
with client as request: with client as request:
response = request.post("/accounts/", cookies=cookies, response = request.post("/accounts/", cookies=cookies,
@ -1130,9 +1131,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "Trusted User" assert type.text.strip() == "Trusted User"
user.AccountType = query(AccountType, with db.begin():
AccountType.ID == DEVELOPER_ID).first() user.AccountType = query(AccountType).filter(
commit() AccountType.ID == DEVELOPER_ID
).first()
with client as request: with client as request:
response = request.post("/accounts/", cookies=cookies, response = request.post("/accounts/", cookies=cookies,
@ -1147,10 +1149,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "Developer" assert type.text.strip() == "Developer"
user.AccountType = query(AccountType, with db.begin():
AccountType.ID == TRUSTED_USER_AND_DEV_ID user.AccountType = query(AccountType).filter(
).first() AccountType.ID == TRUSTED_USER_AND_DEV_ID
commit() ).first()
with client as request: with client as request:
response = request.post("/accounts/", cookies=cookies, response = request.post("/accounts/", cookies=cookies,
@ -1182,8 +1184,8 @@ def test_post_accounts_status(tu_user):
username, type, status, realname, irc, pgp_key, edit = row username, type, status, realname, irc, pgp_key, edit = row
assert status.text.strip() == "Active" assert status.text.strip() == "Active"
user.Suspended = True with db.begin():
commit() user.Suspended = True
with client as request: with client as request:
response = request.post("/accounts/", cookies=cookies, response = request.post("/accounts/", cookies=cookies,
@ -1244,12 +1246,13 @@ def test_post_accounts_sortby(tu_user):
# Create a second user so we can compare sorts. # Create a second user so we can compare sorts.
account_type = query(AccountType, account_type = query(AccountType,
AccountType.ID == DEVELOPER_ID).first() AccountType.ID == DEVELOPER_ID).first()
create(User, Username="test2", with db.begin():
Email="test2@example.org", create(User, Username="test2",
RealName="Test User 2", Email="test2@example.org",
Passwd="testPassword", RealName="Test User 2",
IRCNick="test2", Passwd="testPassword",
AccountType=account_type) IRCNick="test2",
AccountType=account_type)
sid = user.login(Request(), "testPassword") sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid} cookies = {"AURSID": sid}
@ -1297,9 +1300,10 @@ def test_post_accounts_sortby(tu_user):
# Test the rows are reversed when ordering by RealName. # Test the rows are reversed when ordering by RealName.
assert compare_text_values(4, first_rows, reversed(rows)) is True assert compare_text_values(4, first_rows, reversed(rows)) is True
user.AccountType = query(AccountType, with db.begin():
AccountType.ID == TRUSTED_USER_AND_DEV_ID).first() user.AccountType = query(AccountType).filter(
commit() AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
# Fetch first_rows again with our new AccountType ordering. # Fetch first_rows again with our new AccountType ordering.
with client as request: with client as request:
@ -1322,8 +1326,8 @@ def test_post_accounts_sortby(tu_user):
def test_post_accounts_pgp_key(tu_user): def test_post_accounts_pgp_key(tu_user):
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30" with db.begin():
commit() user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
sid = user.login(Request(), "testPassword") sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid} cookies = {"AURSID": sid}
@ -1343,15 +1347,14 @@ def test_post_accounts_paged(tu_user):
users = [user] users = [user]
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
for i in range(150): with db.begin():
_user = create(User, Username=f"test_#{i}", for i in range(150):
Email=f"test_#{i}@example.org", _user = create(User, Username=f"test_#{i}",
RealName=f"Test User #{i}", Email=f"test_#{i}@example.org",
Passwd="testPassword", RealName=f"Test User #{i}",
AccountType=account_type, Passwd="testPassword",
autocommit=False) AccountType=account_type)
users.append(_user) users.append(_user)
commit()
sid = user.login(Request(), "testPassword") sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid} cookies = {"AURSID": sid}
@ -1414,8 +1417,9 @@ def test_post_accounts_paged(tu_user):
def test_get_terms_of_service(): def test_get_terms_of_service():
term = create(Term, Description="Test term.", with db.begin():
URL="http://localhost", Revision=1) term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
with client as request: with client as request:
response = request.get("/tos", allow_redirects=False) response = request.get("/tos", allow_redirects=False)
@ -1436,8 +1440,9 @@ def test_get_terms_of_service():
response = request.get("/tos", cookies=cookies, allow_redirects=False) response = request.get("/tos", cookies=cookies, allow_redirects=False)
assert response.status_code == int(HTTPStatus.OK) assert response.status_code == int(HTTPStatus.OK)
accepted_term = create(AcceptedTerm, User=user, with db.begin():
Term=term, Revision=term.Revision) accepted_term = create(AcceptedTerm, User=user,
Term=term, Revision=term.Revision)
with client as request: with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False) response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1445,8 +1450,8 @@ def test_get_terms_of_service():
assert response.status_code == int(HTTPStatus.SEE_OTHER) assert response.status_code == int(HTTPStatus.SEE_OTHER)
# Bump the term's revision. # Bump the term's revision.
term.Revision = 2 with db.begin():
commit() term.Revision = 2
with client as request: with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False) response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1454,8 +1459,8 @@ def test_get_terms_of_service():
# yet been agreed to via AcceptedTerm update. # yet been agreed to via AcceptedTerm update.
assert response.status_code == int(HTTPStatus.OK) assert response.status_code == int(HTTPStatus.OK)
accepted_term.Revision = term.Revision with db.begin():
commit() accepted_term.Revision = term.Revision
with client as request: with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False) response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1471,8 +1476,9 @@ def test_post_terms_of_service():
cookies = {"AURSID": sid} # Auth cookie. cookies = {"AURSID": sid} # Auth cookie.
# Create a fresh Term. # Create a fresh Term.
term = create(Term, Description="Test term.", with db.begin():
URL="http://localhost", Revision=1) term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
# Test that the term we just created is listed. # Test that the term we just created is listed.
with client as request: with client as request:
@ -1497,8 +1503,8 @@ def test_post_terms_of_service():
assert accepted_term.Term == term assert accepted_term.Term == term
# Update the term to revision 2. # Update the term to revision 2.
term.Revision = 2 with db.begin():
commit() term.Revision = 2
# A GET request gives us the new revision to accept. # A GET request gives us the new revision to accept.
with client as request: with client as request:

View file

@ -2,6 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb import db
from aurweb.db import create from aurweb.db import create
from aurweb.models.api_rate_limit import ApiRateLimit from aurweb.models.api_rate_limit import ApiRateLimit
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -13,26 +14,28 @@ def setup():
def test_api_rate_key_creation(): def test_api_rate_key_creation():
rate = create(ApiRateLimit, IP="127.0.0.1", Requests=10, WindowStart=1) with db.begin():
rate = create(ApiRateLimit, IP="127.0.0.1", Requests=10, WindowStart=1)
assert rate.IP == "127.0.0.1" assert rate.IP == "127.0.0.1"
assert rate.Requests == 10 assert rate.Requests == 10
assert rate.WindowStart == 1 assert rate.WindowStart == 1
def test_api_rate_key_ip_default(): def test_api_rate_key_ip_default():
api_rate_limit = create(ApiRateLimit, Requests=10, WindowStart=1) with db.begin():
api_rate_limit = create(ApiRateLimit, Requests=10, WindowStart=1)
assert api_rate_limit.IP == str() assert api_rate_limit.IP == str()
def test_api_rate_key_null_requests_raises_exception(): def test_api_rate_key_null_requests_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(ApiRateLimit, IP="127.0.0.1", WindowStart=1) with db.begin():
session.rollback() create(ApiRateLimit, IP="127.0.0.1", WindowStart=1)
db.rollback()
def test_api_rate_key_null_window_start_raises_exception(): def test_api_rate_key_null_window_start_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(ApiRateLimit, IP="127.0.0.1", Requests=1) with db.begin():
session.rollback() create(ApiRateLimit, IP="127.0.0.1", Requests=1)
db.rollback()

View file

@ -4,6 +4,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb import db
from aurweb.auth import BasicAuthBackend, account_type_required, has_credential from aurweb.auth import BasicAuthBackend, account_type_required, has_credential
from aurweb.db import create, query from aurweb.db import create, query
from aurweb.models.account_type import USER, USER_ID, AccountType from aurweb.models.account_type import USER, USER_ID, AccountType
@ -23,9 +24,10 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.com", with db.begin():
RealName="Test User", Passwd="testPassword", user = create(User, Username="test", Email="test@example.com",
AccountType=account_type) RealName="Test User", Passwd="testPassword",
AccountType=account_type)
backend = BasicAuthBackend() backend = BasicAuthBackend()
request = Request() request = Request()
@ -51,14 +53,13 @@ async def test_auth_backend_invalid_sid():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_auth_backend_invalid_user_id(): async def test_auth_backend_invalid_user_id():
from aurweb.db import session
# Create a new session with a fake user id. # Create a new session with a fake user id.
now_ts = datetime.utcnow().timestamp() now_ts = datetime.utcnow().timestamp()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(Session, UsersID=666, SessionID="realSession", with db.begin():
LastUpdateTS=now_ts + 5) create(Session, UsersID=666, SessionID="realSession",
session.rollback() LastUpdateTS=now_ts + 5)
db.rollback()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -66,8 +67,9 @@ async def test_basic_auth_backend():
# This time, everything matches up. We expect the user to # This time, everything matches up. We expect the user to
# equal the real_user. # equal the real_user.
now_ts = datetime.utcnow().timestamp() now_ts = datetime.utcnow().timestamp()
create(Session, UsersID=user.ID, SessionID="realSession", with db.begin():
LastUpdateTS=now_ts + 5) create(Session, UsersID=user.ID, SessionID="realSession",
LastUpdateTS=now_ts + 5)
request.cookies["AURSID"] = "realSession" request.cookies["AURSID"] = "realSession"
_, result = await backend.authenticate(request) _, result = await backend.authenticate(request)
assert result == user assert result == user

View file

@ -9,7 +9,7 @@ from fastapi.testclient import TestClient
import aurweb.config import aurweb.config
from aurweb.asgi import app from aurweb.asgi import app
from aurweb.db import create, query from aurweb.db import begin, create, query
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.session import Session from aurweb.models.session import Session
from aurweb.models.user import User from aurweb.models.user import User
@ -32,9 +32,10 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL, with begin():
RealName="Test User", Passwd="testPassword", user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
AccountType=account_type) RealName="Test User", Passwd="testPassword",
AccountType=account_type)
client = TestClient(app) client = TestClient(app)

View file

@ -6,6 +6,7 @@ import pytest
from sqlalchemy import exc as sa_exc from sqlalchemy import exc as sa_exc
from aurweb import db
from aurweb.db import create from aurweb.db import create
from aurweb.models.ban import Ban, is_banned from aurweb.models.ban import Ban, is_banned
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -21,7 +22,8 @@ def setup():
setup_test_db("Bans") setup_test_db("Bans")
ts = datetime.utcnow() + timedelta(seconds=30) ts = datetime.utcnow() + timedelta(seconds=30)
ban = create(Ban, IPAddress="127.0.0.1", BanTS=ts) with db.begin():
ban = create(Ban, IPAddress="127.0.0.1", BanTS=ts)
request = Request() request = Request()
@ -35,17 +37,17 @@ def test_invalid_ban():
with pytest.raises(sa_exc.IntegrityError): with pytest.raises(sa_exc.IntegrityError):
bad_ban = Ban(BanTS=datetime.utcnow()) bad_ban = Ban(BanTS=datetime.utcnow())
session.add(bad_ban)
# We're adding a ban with no primary key; this causes an # We're adding a ban with no primary key; this causes an
# SQLAlchemy warnings when committing to the DB. # SQLAlchemy warnings when committing to the DB.
# Ignore them. # Ignore them.
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", sa_exc.SAWarning) warnings.simplefilter("ignore", sa_exc.SAWarning)
session.commit() with db.begin():
session.add(bad_ban)
# Since we got a transaction failure, we need to rollback. # Since we got a transaction failure, we need to rollback.
session.rollback() db.rollback()
def test_banned(): def test_banned():

View file

@ -278,18 +278,15 @@ def test_connection_execute_paramstyle_unsupported():
def test_create_delete(): def test_create_delete():
db.create(AccountType, AccountType="test") with db.begin():
db.create(AccountType, AccountType="test")
record = db.query(AccountType, AccountType.AccountType == "test").first() record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is not None assert record is not None
db.delete(AccountType, AccountType.AccountType == "test")
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is None
# Create and delete a record with autocommit=False. with db.begin():
db.create(AccountType, AccountType="test", autocommit=False) db.delete(AccountType, AccountType.AccountType == "test")
db.commit()
db.delete(AccountType, AccountType.AccountType == "test", autocommit=False)
db.commit()
record = db.query(AccountType, AccountType.AccountType == "test").first() record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is None assert record is None
@ -297,8 +294,8 @@ def test_create_delete():
def test_add_commit(): def test_add_commit():
# Use db.add and db.commit to add a temporary record. # Use db.add and db.commit to add a temporary record.
account_type = AccountType(AccountType="test") account_type = AccountType(AccountType="test")
db.add(account_type) with db.begin():
db.commit() db.add(account_type)
# Assert it got created in the DB. # Assert it got created in the DB.
assert bool(account_type.ID) assert bool(account_type.ID)
@ -308,7 +305,8 @@ def test_add_commit():
assert record == account_type assert record == account_type
# Remove the record. # Remove the record.
db.delete(AccountType, AccountType.ID == account_type.ID) with db.begin():
db.delete(AccountType, AccountType.ID == account_type.ID)
def test_connection_executor_mysql_paramstyle(): def test_connection_executor_mysql_paramstyle():

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aurweb.db import create, delete, query from aurweb.db import begin, create, delete, query
from aurweb.models.dependency_type import DependencyType from aurweb.models.dependency_type import DependencyType
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -19,13 +19,17 @@ def test_dependency_types():
def test_dependency_type_creation(): def test_dependency_type_creation():
dependency_type = create(DependencyType, Name="Test Type") with begin():
dependency_type = create(DependencyType, Name="Test Type")
assert bool(dependency_type.ID) assert bool(dependency_type.ID)
assert dependency_type.Name == "Test Type" assert dependency_type.Name == "Test Type"
delete(DependencyType, DependencyType.ID == dependency_type.ID) with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID)
def test_dependency_type_null_name_uses_default(): def test_dependency_type_null_name_uses_default():
dependency_type = create(DependencyType) with begin():
dependency_type = create(DependencyType)
assert dependency_type.Name == str() assert dependency_type.Name == str()
delete(DependencyType, DependencyType.ID == dependency_type.ID) with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID)

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create from aurweb import db
from aurweb.models.group import Group from aurweb.models.group import Group
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -13,13 +13,14 @@ def setup():
def test_group_creation(): def test_group_creation():
group = create(Group, Name="Test Group") with db.begin():
group = db.create(Group, Name="Test Group")
assert bool(group.ID) assert bool(group.ID)
assert group.Name == "Test Group" assert group.Name == "Test Group"
def test_group_null_name_raises_exception(): def test_group_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(Group) with db.begin():
session.rollback() db.create(Group)
db.rollback()

View file

@ -38,8 +38,10 @@ def setup():
@pytest.fixture @pytest.fixture
def user(): def user():
yield db.create(User, Username="test", Email="test@example.org", with db.begin():
Passwd="testPassword", AccountTypeID=USER_ID) user = db.create(User, Username="test", Email="test@example.org",
Passwd="testPassword", AccountTypeID=USER_ID)
yield user
@pytest.fixture @pytest.fixture
@ -68,17 +70,14 @@ def packages(user):
# For i..num_packages, create a package named pkg_{i}. # For i..num_packages, create a package named pkg_{i}.
pkgs = [] pkgs = []
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
for i in range(num_packages): with db.begin():
pkgbase = db.create(PackageBase, Name=f"pkg_{i}", for i in range(num_packages):
Maintainer=user, Packager=user, pkgbase = db.create(PackageBase, Name=f"pkg_{i}",
autocommit=False, SubmittedTS=now, Maintainer=user, Packager=user,
ModifiedTS=now) SubmittedTS=now, ModifiedTS=now)
pkg = db.create(Package, PackageBase=pkgbase, pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
Name=pkgbase.Name, autocommit=False) pkgs.append(pkg)
pkgs.append(pkg) now += 1
now += 1
db.commit()
yield pkgs yield pkgs
@ -159,10 +158,11 @@ def test_homepage_updates(redis, packages):
def test_homepage_dashboard(redis, packages, user): def test_homepage_dashboard(redis, packages, user):
# Create Comaintainer records for all of the packages. # Create Comaintainer records for all of the packages.
for pkg in packages: with db.begin():
db.create(PackageComaintainer, PackageBase=pkg.PackageBase, for pkg in packages:
User=user, Priority=1, autocommit=False) db.create(PackageComaintainer,
db.commit() PackageBase=pkg.PackageBase,
User=user, Priority=1)
cookies = {"AURSID": user.login(Request(), "testPassword")} cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request: with client as request:
@ -193,11 +193,12 @@ def test_homepage_dashboard_requests(redis, packages, user):
pkg = packages[0] pkg = packages[0]
reqtype = db.query(RequestType, RequestType.ID == DELETION_ID).first() reqtype = db.query(RequestType, RequestType.ID == DELETION_ID).first()
pkgreq = db.create(PackageRequest, PackageBase=pkg.PackageBase, with db.begin():
PackageBaseName=pkg.PackageBase.Name, pkgreq = db.create(PackageRequest, PackageBase=pkg.PackageBase,
User=user, Comments=str(), PackageBaseName=pkg.PackageBase.Name,
ClosureComment=str(), RequestTS=now, User=user, Comments=str(),
RequestType=reqtype) ClosureComment=str(), RequestTS=now,
RequestType=reqtype)
cookies = {"AURSID": user.login(Request(), "testPassword")} cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request: with client as request:
@ -213,8 +214,8 @@ def test_homepage_dashboard_requests(redis, packages, user):
def test_homepage_dashboard_flagged_packages(redis, packages, user): def test_homepage_dashboard_flagged_packages(redis, packages, user):
# Set the first Package flagged by setting its OutOfDateTS column. # Set the first Package flagged by setting its OutOfDateTS column.
pkg = packages[0] pkg = packages[0]
pkg.PackageBase.OutOfDateTS = int(datetime.utcnow().timestamp()) with db.begin():
db.commit() pkg.PackageBase.OutOfDateTS = int(datetime.utcnow().timestamp())
cookies = {"AURSID": user.login(Request(), "testPassword")} cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request: with client as request:

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create from aurweb import db
from aurweb.models.license import License from aurweb.models.license import License
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -13,13 +13,14 @@ def setup():
def test_license_creation(): def test_license_creation():
license = create(License, Name="Test License") with db.begin():
license = db.create(License, Name="Test License")
assert bool(license.ID) assert bool(license.ID)
assert license.Name == "Test License" assert license.Name == "Test License"
def test_license_null_name_raises_exception(): def test_license_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(License) with db.begin():
session.rollback() db.create(License)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create from aurweb import db
from aurweb.models.official_provider import OfficialProvider from aurweb.models.official_provider import OfficialProvider
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -13,10 +13,11 @@ def setup():
def test_official_provider_creation(): def test_official_provider_creation():
oprovider = create(OfficialProvider, with db.begin():
Name="some-name", oprovider = db.create(OfficialProvider,
Repo="some-repo", Name="some-name",
Provides="some-provides") Repo="some-repo",
Provides="some-provides")
assert bool(oprovider.ID) assert bool(oprovider.ID)
assert oprovider.Name == "some-name" assert oprovider.Name == "some-name"
assert oprovider.Repo == "some-repo" assert oprovider.Repo == "some-repo"
@ -25,16 +26,18 @@ def test_official_provider_creation():
def test_official_provider_cs(): def test_official_provider_cs():
""" Test case sensitivity of the database table. """ """ Test case sensitivity of the database table. """
oprovider = create(OfficialProvider, with db.begin():
Name="some-name", oprovider = db.create(OfficialProvider,
Repo="some-repo", Name="some-name",
Provides="some-provides") Repo="some-repo",
Provides="some-provides")
assert bool(oprovider.ID) assert bool(oprovider.ID)
oprovider_cs = create(OfficialProvider, with db.begin():
Name="SOME-NAME", oprovider_cs = db.create(OfficialProvider,
Repo="SOME-REPO", Name="SOME-NAME",
Provides="SOME-PROVIDES") Repo="SOME-REPO",
Provides="SOME-PROVIDES")
assert bool(oprovider_cs.ID) assert bool(oprovider_cs.ID)
assert oprovider.ID != oprovider_cs.ID assert oprovider.ID != oprovider_cs.ID
@ -49,27 +52,27 @@ def test_official_provider_cs():
def test_official_provider_null_name_raises_exception(): def test_official_provider_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(OfficialProvider, with db.begin():
Repo="some-repo", db.create(OfficialProvider,
Provides="some-provides") Repo="some-repo",
session.rollback() Provides="some-provides")
db.rollback()
def test_official_provider_null_repo_raises_exception(): def test_official_provider_null_repo_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(OfficialProvider, with db.begin():
Name="some-name", db.create(OfficialProvider,
Provides="some-provides") Name="some-name",
session.rollback() Provides="some-provides")
db.rollback()
def test_official_provider_null_provides_raises_exception(): def test_official_provider_null_provides_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(OfficialProvider, with db.begin():
Name="some-name", db.create(OfficialProvider,
Repo="some-repo") Name="some-name",
session.rollback() Repo="some-repo")
db.rollback()

View file

@ -3,7 +3,7 @@ import pytest
from sqlalchemy import and_ from sqlalchemy import and_
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query from aurweb import db
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.package import Package from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
@ -19,25 +19,25 @@ def setup():
setup_test_db("Packages", "PackageBases", "Users") setup_test_db("Packages", "PackageBases", "Users")
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase, with db.begin():
Name="beautiful-package", user = db.create(User, Username="test", Email="test@example.org",
Maintainer=user) RealName="Test User", Passwd="testPassword",
package = create(Package, AccountType=account_type)
PackageBase=pkgbase,
Name=pkgbase.Name, pkgbase = db.create(PackageBase,
Description="Test description.", Name="beautiful-package",
URL="https://test.package") Maintainer=user)
package = db.create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name,
Description="Test description.",
URL="https://test.package")
def test_package(): def test_package():
from aurweb.db import session
assert pkgbase == package.PackageBase assert pkgbase == package.PackageBase
assert package.Name == "beautiful-package" assert package.Name == "beautiful-package"
assert package.Description == "Test description." assert package.Description == "Test description."
@ -45,33 +45,31 @@ def test_package():
assert package.URL == "https://test.package" assert package.URL == "https://test.package"
# Update package Version. # Update package Version.
package.Version = "1.2.3" with db.begin():
session.commit() package.Version = "1.2.3"
# Make sure it got updated in the database. # Make sure it got updated in the database.
record = query(Package, record = db.query(Package,
and_(Package.ID == package.ID, and_(Package.ID == package.ID,
Package.Version == "1.2.3")).first() Package.Version == "1.2.3")).first()
assert record is not None assert record is not None
def test_package_null_pkgbase_raises_exception(): def test_package_null_pkgbase_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(Package, with db.begin():
Name="some-package", db.create(Package,
Description="Some description.", Name="some-package",
URL="https://some.package") Description="Some description.",
session.rollback() URL="https://some.package")
db.rollback()
def test_package_null_name_raises_exception(): def test_package_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(Package, with db.begin():
PackageBase=pkgbase, db.create(Package,
Description="Some description.", PackageBase=pkgbase,
URL="https://some.package") Description="Some description.",
session.rollback() URL="https://some.package")
db.rollback()

View file

@ -4,7 +4,7 @@ from sqlalchemy.exc import IntegrityError
import aurweb.config import aurweb.config
from aurweb.db import create, query from aurweb import db
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
from aurweb.models.user import User from aurweb.models.user import User
@ -19,17 +19,19 @@ def setup():
setup_test_db("Users", "PackageBases") setup_test_db("Users", "PackageBases")
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org", with db.begin():
RealName="Test User", Passwd="testPassword", user = db.create(User, Username="test", Email="test@example.org",
AccountType=account_type) RealName="Test User", Passwd="testPassword",
AccountType=account_type)
def test_package_base(): def test_package_base():
pkgbase = create(PackageBase, with db.begin():
Name="beautiful-package", pkgbase = db.create(PackageBase,
Maintainer=user) Name="beautiful-package",
Maintainer=user)
assert pkgbase in user.maintained_bases assert pkgbase in user.maintained_bases
assert not pkgbase.OutOfDateTS assert not pkgbase.OutOfDateTS
@ -38,7 +40,8 @@ def test_package_base():
# Set Popularity to a string, then get it by attribute to # Set Popularity to a string, then get it by attribute to
# exercise the string -> float conversion path. # exercise the string -> float conversion path.
pkgbase.Popularity = "0.0" with db.begin():
pkgbase.Popularity = "0.0"
assert pkgbase.Popularity == 0.0 assert pkgbase.Popularity == 0.0
@ -47,27 +50,28 @@ def test_package_base_ci():
if aurweb.config.get("database", "backend") == "sqlite": if aurweb.config.get("database", "backend") == "sqlite":
return None # SQLite doesn't seem handle this. return None # SQLite doesn't seem handle this.
from aurweb.db import session with db.begin():
pkgbase = db.create(PackageBase,
pkgbase = create(PackageBase, Name="beautiful-package",
Name="beautiful-package", Maintainer=user)
Maintainer=user)
assert bool(pkgbase.ID) assert bool(pkgbase.ID)
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageBase, with db.begin():
Name="Beautiful-Package", db.create(PackageBase,
Maintainer=user) Name="Beautiful-Package",
session.rollback() Maintainer=user)
db.rollback()
def test_package_base_relationships(): def test_package_base_relationships():
pkgbase = create(PackageBase, with db.begin():
Name="beautiful-package", pkgbase = db.create(PackageBase,
Flagger=user, Name="beautiful-package",
Maintainer=user, Flagger=user,
Submitter=user, Maintainer=user,
Packager=user) Submitter=user,
Packager=user)
assert pkgbase in user.flagged_bases assert pkgbase in user.flagged_bases
assert pkgbase in user.maintained_bases assert pkgbase in user.maintained_bases
assert pkgbase in user.submitted_bases assert pkgbase in user.submitted_bases
@ -75,8 +79,7 @@ def test_package_base_relationships():
def test_package_base_null_name_raises_exception(): def test_package_base_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageBase) with db.begin():
session.rollback() db.create(PackageBase)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create, rollback from aurweb import db
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
from aurweb.models.package_blacklist import PackageBlacklist from aurweb.models.package_blacklist import PackageBlacklist
from aurweb.models.user import User from aurweb.models.user import User
@ -17,18 +17,20 @@ def setup():
setup_test_db("PackageBlacklist", "PackageBases", "Users") setup_test_db("PackageBlacklist", "PackageBases", "Users")
user = create(User, Username="test", Email="test@example.org", user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword") RealName="Test User", Passwd="testPassword")
pkgbase = create(PackageBase, Name="test-package", Maintainer=user) pkgbase = db.create(PackageBase, Name="test-package", Maintainer=user)
def test_package_blacklist_creation(): def test_package_blacklist_creation():
package_blacklist = create(PackageBlacklist, Name="evil-package") with db.begin():
package_blacklist = db.create(PackageBlacklist, Name="evil-package")
assert bool(package_blacklist.ID) assert bool(package_blacklist.ID)
assert package_blacklist.Name == "evil-package" assert package_blacklist.Name == "evil-package"
def test_package_blacklist_null_name_raises_exception(): def test_package_blacklist_null_name_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageBlacklist) with db.begin():
rollback() db.create(PackageBlacklist)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query, rollback from aurweb.db import begin, create, query, rollback
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
from aurweb.models.package_comment import PackageComment from aurweb.models.package_comment import PackageComment
@ -20,45 +20,52 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org", with begin():
RealName="Test User", Passwd="testPassword", user = create(User, Username="test", Email="test@example.org",
AccountType=account_type) RealName="Test User", Passwd="testPassword",
pkgbase = create(PackageBase, Name="test-package", Maintainer=user) AccountType=account_type)
pkgbase = create(PackageBase, Name="test-package", Maintainer=user)
def test_package_comment_creation(): def test_package_comment_creation():
package_comment = create(PackageComment, with begin():
PackageBase=pkgbase, package_comment = create(PackageComment,
User=user, PackageBase=pkgbase,
Comments="Test comment.", User=user,
RenderedComment="Test rendered comment.") Comments="Test comment.",
RenderedComment="Test rendered comment.")
assert bool(package_comment.ID) assert bool(package_comment.ID)
def test_package_comment_null_package_base_raises_exception(): def test_package_comment_null_package_base_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageComment, User=user, Comments="Test comment.", with begin():
RenderedComment="Test rendered comment.") create(PackageComment, User=user, Comments="Test comment.",
RenderedComment="Test rendered comment.")
rollback() rollback()
def test_package_comment_null_user_raises_exception(): def test_package_comment_null_user_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageComment, PackageBase=pkgbase, Comments="Test comment.", with begin():
RenderedComment="Test rendered comment.") create(PackageComment, PackageBase=pkgbase,
Comments="Test comment.",
RenderedComment="Test rendered comment.")
rollback() rollback()
def test_package_comment_null_comments_raises_exception(): def test_package_comment_null_comments_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageComment, PackageBase=pkgbase, User=user, with begin():
RenderedComment="Test rendered comment.") create(PackageComment, PackageBase=pkgbase, User=user,
RenderedComment="Test rendered comment.")
rollback() rollback()
def test_package_comment_null_renderedcomment_defaults(): def test_package_comment_null_renderedcomment_defaults():
record = create(PackageComment, with begin():
PackageBase=pkgbase, record = create(PackageComment,
User=user, PackageBase=pkgbase,
Comments="Test comment.") User=user,
Comments="Test comment.")
assert record.RenderedComment == str() assert record.RenderedComment == str()

View file

@ -2,7 +2,8 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query from aurweb import db
from aurweb.db import create, query
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.dependency_type import DependencyType from aurweb.models.dependency_type import DependencyType
from aurweb.models.package import Package from aurweb.models.package import Package
@ -22,25 +23,28 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase, with db.begin():
Name="test-package", user = create(User, Username="test", Email="test@example.org",
Maintainer=user) RealName="Test User", Passwd="testPassword",
package = create(Package, AccountType=account_type)
PackageBase=pkgbase, pkgbase = create(PackageBase,
Name=pkgbase.Name, Name="test-package",
Description="Test description.", Maintainer=user)
URL="https://test.package") package = create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name,
Description="Test description.",
URL="https://test.package")
def test_package_dependencies(): def test_package_dependencies():
depends = query(DependencyType, DependencyType.Name == "depends").first() depends = query(DependencyType, DependencyType.Name == "depends").first()
pkgdep = create(PackageDependency, Package=package,
DependencyType=depends, with db.begin():
DepName="test-dep") pkgdep = create(PackageDependency, Package=package,
DependencyType=depends,
DepName="test-dep")
assert pkgdep.DepName == "test-dep" assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package assert pkgdep.Package == package
assert pkgdep.DependencyType == depends assert pkgdep.DependencyType == depends
@ -49,8 +53,8 @@ def test_package_dependencies():
makedepends = query(DependencyType, makedepends = query(DependencyType,
DependencyType.Name == "makedepends").first() DependencyType.Name == "makedepends").first()
pkgdep.DependencyType = makedepends with db.begin():
commit() pkgdep.DependencyType = makedepends
assert pkgdep.DepName == "test-dep" assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package assert pkgdep.Package == package
assert pkgdep.DependencyType == makedepends assert pkgdep.DependencyType == makedepends
@ -59,8 +63,8 @@ def test_package_dependencies():
checkdepends = query(DependencyType, checkdepends = query(DependencyType,
DependencyType.Name == "checkdepends").first() DependencyType.Name == "checkdepends").first()
pkgdep.DependencyType = checkdepends with db.begin():
commit() pkgdep.DependencyType = checkdepends
assert pkgdep.DepName == "test-dep" assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package assert pkgdep.Package == package
assert pkgdep.DependencyType == checkdepends assert pkgdep.DependencyType == checkdepends
@ -69,8 +73,8 @@ def test_package_dependencies():
optdepends = query(DependencyType, optdepends = query(DependencyType,
DependencyType.Name == "optdepends").first() DependencyType.Name == "optdepends").first()
pkgdep.DependencyType = optdepends with db.begin():
commit() pkgdep.DependencyType = optdepends
assert pkgdep.DepName == "test-dep" assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package assert pkgdep.Package == package
assert pkgdep.DependencyType == optdepends assert pkgdep.DependencyType == optdepends
@ -79,39 +83,37 @@ def test_package_dependencies():
assert not pkgdep.is_package() assert not pkgdep.is_package()
base = create(PackageBase, Name=pkgdep.DepName, Maintainer=user) with db.begin():
create(Package, PackageBase=base, Name=pkgdep.DepName) base = create(PackageBase, Name=pkgdep.DepName, Maintainer=user)
create(Package, PackageBase=base, Name=pkgdep.DepName)
assert pkgdep.is_package() assert pkgdep.is_package()
def test_package_dependencies_null_package_raises_exception(): def test_package_dependencies_null_package_raises_exception():
from aurweb.db import session
depends = query(DependencyType, DependencyType.Name == "depends").first() depends = query(DependencyType, DependencyType.Name == "depends").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageDependency, with db.begin():
DependencyType=depends, create(PackageDependency,
DepName="test-dep") DependencyType=depends,
session.rollback() DepName="test-dep")
db.rollback()
def test_package_dependencies_null_dependency_type_raises_exception(): def test_package_dependencies_null_dependency_type_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageDependency, with db.begin():
Package=package, create(PackageDependency,
DepName="test-dep") Package=package,
session.rollback() DepName="test-dep")
db.rollback()
def test_package_dependencies_null_depname_raises_exception(): def test_package_dependencies_null_depname_raises_exception():
from aurweb.db import session
depends = query(DependencyType, DependencyType.Name == "depends").first() depends = query(DependencyType, DependencyType.Name == "depends").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageDependency, with db.begin():
Package=package, create(PackageDependency,
DependencyType=depends) Package=package,
session.rollback() DependencyType=depends)
db.rollback()

View file

@ -2,7 +2,8 @@ import pytest
from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.exc import IntegrityError, OperationalError
from aurweb.db import commit, create, query from aurweb import db
from aurweb.db import create, query
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.package import Package from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
@ -22,25 +23,28 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase, with db.begin():
Name="test-package", user = create(User, Username="test", Email="test@example.org",
Maintainer=user) RealName="Test User", Passwd="testPassword",
package = create(Package, AccountType=account_type)
PackageBase=pkgbase, pkgbase = create(PackageBase,
Name=pkgbase.Name, Name="test-package",
Description="Test description.", Maintainer=user)
URL="https://test.package") package = create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name,
Description="Test description.",
URL="https://test.package")
def test_package_relation(): def test_package_relation():
conflicts = query(RelationType, RelationType.Name == "conflicts").first() conflicts = query(RelationType, RelationType.Name == "conflicts").first()
pkgrel = create(PackageRelation, Package=package,
RelationType=conflicts, with db.begin():
RelName="test-relation") pkgrel = create(PackageRelation, Package=package,
RelationType=conflicts,
RelName="test-relation")
assert pkgrel.RelName == "test-relation" assert pkgrel.RelName == "test-relation"
assert pkgrel.Package == package assert pkgrel.Package == package
assert pkgrel.RelationType == conflicts assert pkgrel.RelationType == conflicts
@ -48,8 +52,8 @@ def test_package_relation():
assert pkgrel in package.package_relations assert pkgrel in package.package_relations
provides = query(RelationType, RelationType.Name == "provides").first() provides = query(RelationType, RelationType.Name == "provides").first()
pkgrel.RelationType = provides with db.begin():
commit() pkgrel.RelationType = provides
assert pkgrel.RelName == "test-relation" assert pkgrel.RelName == "test-relation"
assert pkgrel.Package == package assert pkgrel.Package == package
assert pkgrel.RelationType == provides assert pkgrel.RelationType == provides
@ -57,8 +61,8 @@ def test_package_relation():
assert pkgrel in package.package_relations assert pkgrel in package.package_relations
replaces = query(RelationType, RelationType.Name == "replaces").first() replaces = query(RelationType, RelationType.Name == "replaces").first()
pkgrel.RelationType = replaces with db.begin():
commit() pkgrel.RelationType = replaces
assert pkgrel.RelName == "test-relation" assert pkgrel.RelName == "test-relation"
assert pkgrel.Package == package assert pkgrel.Package == package
assert pkgrel.RelationType == replaces assert pkgrel.RelationType == replaces
@ -67,36 +71,33 @@ def test_package_relation():
def test_package_relation_null_package_raises_exception(): def test_package_relation_null_package_raises_exception():
from aurweb.db import session
conflicts = query(RelationType, RelationType.Name == "conflicts").first() conflicts = query(RelationType, RelationType.Name == "conflicts").first()
assert conflicts is not None assert conflicts is not None
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageRelation, with db.begin():
RelationType=conflicts, create(PackageRelation,
RelName="test-relation") RelationType=conflicts,
session.rollback() RelName="test-relation")
db.rollback()
def test_package_relation_null_relation_type_raises_exception(): def test_package_relation_null_relation_type_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageRelation, with db.begin():
Package=package, create(PackageRelation,
RelName="test-relation") Package=package,
session.rollback() RelName="test-relation")
db.rollback()
def test_package_relation_null_relname_raises_exception(): def test_package_relation_null_relname_raises_exception():
from aurweb.db import session
depends = query(RelationType, RelationType.Name == "conflicts").first() depends = query(RelationType, RelationType.Name == "conflicts").first()
assert depends is not None assert depends is not None
with pytest.raises((OperationalError, IntegrityError)): with pytest.raises((OperationalError, IntegrityError)):
create(PackageRelation, with db.begin():
Package=package, create(PackageRelation,
RelationType=depends) Package=package,
session.rollback() RelationType=depends)
db.rollback()

View file

@ -4,7 +4,8 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query, rollback from aurweb import db
from aurweb.db import create, query, rollback
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
from aurweb.models.package_request import (ACCEPTED, ACCEPTED_ID, CLOSED, CLOSED_ID, PENDING, PENDING_ID, REJECTED, from aurweb.models.package_request import (ACCEPTED, ACCEPTED_ID, CLOSED, CLOSED_ID, PENDING, PENDING_ID, REJECTED,
REJECTED_ID, PackageRequest) REJECTED_ID, PackageRequest)
@ -21,19 +22,21 @@ def setup():
setup_test_db("PackageRequests", "PackageBases", "Users") setup_test_db("PackageRequests", "PackageBases", "Users")
user = create(User, Username="test", Email="test@example.org", with db.begin():
RealName="Test User", Passwd="testPassword") user = create(User, Username="test", Email="test@example.org",
pkgbase = create(PackageBase, Name="test-package", Maintainer=user) RealName="Test User", Passwd="testPassword")
pkgbase = create(PackageBase, Name="test-package", Maintainer=user)
def test_package_request_creation(): def test_package_request_creation():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
assert request_type.Name == "merge" assert request_type.Name == "merge"
package_request = create(PackageRequest, RequestType=request_type, with db.begin():
User=user, PackageBase=pkgbase, package_request = create(PackageRequest, RequestType=request_type,
PackageBaseName=pkgbase.Name, User=user, PackageBase=pkgbase,
Comments=str(), ClosureComment=str()) PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
assert bool(package_request.ID) assert bool(package_request.ID)
assert package_request.RequestType == request_type assert package_request.RequestType == request_type
@ -54,11 +57,12 @@ def test_package_request_closed():
assert request_type.Name == "merge" assert request_type.Name == "merge"
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
package_request = create(PackageRequest, RequestType=request_type, with db.begin():
User=user, PackageBase=pkgbase, package_request = create(PackageRequest, RequestType=request_type,
PackageBaseName=pkgbase.Name, User=user, PackageBase=pkgbase,
Closer=user, ClosedTS=ts, PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str()) Closer=user, ClosedTS=ts,
Comments=str(), ClosureComment=str())
assert package_request.Closer == user assert package_request.Closer == user
assert package_request.ClosedTS == ts assert package_request.ClosedTS == ts
@ -69,54 +73,60 @@ def test_package_request_closed():
def test_package_request_null_request_type_raises_exception(): def test_package_request_null_request_type_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageRequest, User=user, PackageBase=pkgbase, with db.begin():
PackageBaseName=pkgbase.Name, create(PackageRequest, User=user, PackageBase=pkgbase,
Comments=str(), ClosureComment=str()) PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
rollback() rollback()
def test_package_request_null_user_raises_exception(): def test_package_request_null_user_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type, PackageBase=pkgbase, with db.begin():
PackageBaseName=pkgbase.Name, create(PackageRequest, RequestType=request_type,
Comments=str(), ClosureComment=str()) PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
rollback() rollback()
def test_package_request_null_package_base_raises_exception(): def test_package_request_null_package_base_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type, with db.begin():
User=user, PackageBaseName=pkgbase.Name, create(PackageRequest, RequestType=request_type,
Comments=str(), ClosureComment=str()) User=user, PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
rollback() rollback()
def test_package_request_null_package_base_name_raises_exception(): def test_package_request_null_package_base_name_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type, with db.begin():
User=user, PackageBase=pkgbase, create(PackageRequest, RequestType=request_type,
Comments=str(), ClosureComment=str()) User=user, PackageBase=pkgbase,
Comments=str(), ClosureComment=str())
rollback() rollback()
def test_package_request_null_comments_raises_exception(): def test_package_request_null_comments_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type, with db.begin():
User=user, PackageBase=pkgbase, PackageBaseName=pkgbase.Name, create(PackageRequest, RequestType=request_type, User=user,
ClosureComment=str()) PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
ClosureComment=str())
rollback() rollback()
def test_package_request_null_closure_comment_raises_exception(): def test_package_request_null_closure_comment_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type, with db.begin():
User=user, PackageBase=pkgbase, PackageBaseName=pkgbase.Name, create(PackageRequest, RequestType=request_type, User=user,
Comments=str()) PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
Comments=str())
rollback() rollback()
@ -124,26 +134,27 @@ def test_package_request_status_display():
""" Test status_display() based on the Status column value. """ """ Test status_display() based on the Status column value. """
request_type = query(RequestType, RequestType.Name == "merge").first() request_type = query(RequestType, RequestType.Name == "merge").first()
pkgreq = create(PackageRequest, RequestType=request_type, with db.begin():
User=user, PackageBase=pkgbase, pkgreq = create(PackageRequest, RequestType=request_type,
PackageBaseName=pkgbase.Name, User=user, PackageBase=pkgbase,
Comments=str(), ClosureComment=str(), PackageBaseName=pkgbase.Name,
Status=PENDING_ID) Comments=str(), ClosureComment=str(),
Status=PENDING_ID)
assert pkgreq.status_display() == PENDING assert pkgreq.status_display() == PENDING
pkgreq.Status = CLOSED_ID with db.begin():
commit() pkgreq.Status = CLOSED_ID
assert pkgreq.status_display() == CLOSED assert pkgreq.status_display() == CLOSED
pkgreq.Status = ACCEPTED_ID with db.begin():
commit() pkgreq.Status = ACCEPTED_ID
assert pkgreq.status_display() == ACCEPTED assert pkgreq.status_display() == ACCEPTED
pkgreq.Status = REJECTED_ID with db.begin():
commit() pkgreq.Status = REJECTED_ID
assert pkgreq.status_display() == REJECTED assert pkgreq.status_display() == REJECTED
pkgreq.Status = 124 with db.begin():
commit() pkgreq.Status = 124
with pytest.raises(KeyError): with pytest.raises(KeyError):
pkgreq.status_display() pkgreq.status_display()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query, rollback from aurweb.db import begin, create, query, rollback
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.package import Package from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase from aurweb.models.package_base import PackageBase
@ -21,17 +21,19 @@ def setup():
account_type = query(AccountType, account_type = query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org", with begin():
RealName="Test User", Passwd="testPassword", user = create(User, Username="test", Email="test@example.org",
AccountType=account_type) RealName="Test User", Passwd="testPassword",
pkgbase = create(PackageBase, AccountType=account_type)
Name="test-package", pkgbase = create(PackageBase,
Maintainer=user) Name="test-package",
package = create(Package, PackageBase=pkgbase, Name="test-package") Maintainer=user)
package = create(Package, PackageBase=pkgbase, Name="test-package")
def test_package_source(): def test_package_source():
pkgsource = create(PackageSource, Package=package) with begin():
pkgsource = create(PackageSource, Package=package)
assert pkgsource.Package == package assert pkgsource.Package == package
# By default, PackageSources.Source assigns the string '/dev/null'. # By default, PackageSources.Source assigns the string '/dev/null'.
assert pkgsource.Source == "/dev/null" assert pkgsource.Source == "/dev/null"
@ -40,5 +42,6 @@ def test_package_source():
def test_package_source_null_package_raises_exception(): def test_package_source_null_package_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(PackageSource) with begin():
create(PackageSource)
rollback() rollback()

View file

@ -28,31 +28,25 @@ def package_endpoint(package: Package) -> str:
return f"/packages/{package.Name}" return f"/packages/{package.Name}"
def create_package(pkgname: str, maintainer: User, def create_package(pkgname: str, maintainer: User) -> Package:
autocommit: bool = True) -> Package:
pkgbase = db.create(PackageBase, pkgbase = db.create(PackageBase,
Name=pkgname, Name=pkgname,
Maintainer=maintainer, Maintainer=maintainer)
autocommit=False) return db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
return db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase,
autocommit=autocommit)
def create_package_dep(package: Package, depname: str, def create_package_dep(package: Package, depname: str,
dep_type_name: str = "depends", dep_type_name: str = "depends") -> PackageDependency:
autocommit: bool = True) -> PackageDependency:
dep_type = db.query(DependencyType, dep_type = db.query(DependencyType,
DependencyType.Name == dep_type_name).first() DependencyType.Name == dep_type_name).first()
return db.create(PackageDependency, return db.create(PackageDependency,
DependencyType=dep_type, DependencyType=dep_type,
Package=package, Package=package,
DepName=depname, DepName=depname)
autocommit=autocommit)
def create_package_rel(package: Package, def create_package_rel(package: Package,
relname: str, relname: str) -> PackageRelation:
autocommit: bool = True) -> PackageRelation:
rel_type = db.query(RelationType, rel_type = db.query(RelationType,
RelationType.ID == PROVIDES_ID).first() RelationType.ID == PROVIDES_ID).first()
return db.create(PackageRelation, return db.create(PackageRelation,
@ -84,31 +78,37 @@ def client() -> TestClient:
def user() -> User: def user() -> User:
""" Yield a user. """ """ Yield a user. """
account_type = db.query(AccountType, AccountType.ID == USER_ID).first() account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test", with db.begin():
Email="test@example.org", user = db.create(User, Username="test",
Passwd="testPassword", Email="test@example.org",
AccountType=account_type) Passwd="testPassword",
AccountType=account_type)
yield user
@pytest.fixture @pytest.fixture
def maintainer() -> User: def maintainer() -> User:
""" Yield a specific User used to maintain packages. """ """ Yield a specific User used to maintain packages. """
account_type = db.query(AccountType, AccountType.ID == USER_ID).first() account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test_maintainer", with db.begin():
Email="test_maintainer@example.org", maintainer = db.create(User, Username="test_maintainer",
Passwd="testPassword", Email="test_maintainer@example.org",
AccountType=account_type) Passwd="testPassword",
AccountType=account_type)
yield maintainer
@pytest.fixture @pytest.fixture
def package(maintainer: User) -> Package: def package(maintainer: User) -> Package:
""" Yield a Package created by user. """ """ Yield a Package created by user. """
pkgbase = db.create(PackageBase, with db.begin():
Name="test-package", pkgbase = db.create(PackageBase,
Maintainer=maintainer) Name="test-package",
yield db.create(Package, Maintainer=maintainer)
PackageBase=pkgbase, package = db.create(Package,
Name=pkgbase.Name) PackageBase=pkgbase,
Name=pkgbase.Name)
yield package
def test_package_not_found(client: TestClient): def test_package_not_found(client: TestClient):
@ -121,10 +121,11 @@ def test_package_official_not_found(client: TestClient, package: Package):
""" When a Package has a matching OfficialProvider record, it is not """ When a Package has a matching OfficialProvider record, it is not
hosted on AUR, but in the official repositories. Getting a package hosted on AUR, but in the official repositories. Getting a package
with this kind of record should return a status code 404. """ with this kind of record should return a status code 404. """
db.create(OfficialProvider, with db.begin():
Name=package.Name, db.create(OfficialProvider,
Repo="core", Name=package.Name,
Provides=package.Name) Repo="core",
Provides=package.Name)
with client as request: with client as request:
resp = request.get(package_endpoint(package)) resp = request.get(package_endpoint(package))
@ -157,8 +158,9 @@ def test_package(client: TestClient, package: Package):
def test_package_comments(client: TestClient, user: User, package: Package): def test_package_comments(client: TestClient, user: User, package: Package):
now = (datetime.utcnow().timestamp()) now = (datetime.utcnow().timestamp())
comment = db.create(PackageComment, PackageBase=package.PackageBase, with db.begin():
User=user, Comments="Test comment", CommentTS=now) comment = db.create(PackageComment, PackageBase=package.PackageBase,
User=user, Comments="Test comment", CommentTS=now)
cookies = {"AURSID": user.login(Request(), "testPassword")} cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request: with client as request:
@ -178,11 +180,12 @@ def test_package_comments(client: TestClient, user: User, package: Package):
def test_package_requests_display(client: TestClient, user: User, def test_package_requests_display(client: TestClient, user: User,
package: Package): package: Package):
type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first() type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first()
db.create(PackageRequest, PackageBase=package.PackageBase, with db.begin():
PackageBaseName=package.PackageBase.Name, db.create(PackageRequest, PackageBase=package.PackageBase,
User=user, RequestType=type_, PackageBaseName=package.PackageBase.Name,
Comments="Test comment.", User=user, RequestType=type_,
ClosureComment=str()) Comments="Test comment.",
ClosureComment=str())
# Test that a single request displays "1 pending request". # Test that a single request displays "1 pending request".
with client as request: with client as request:
@ -195,11 +198,12 @@ def test_package_requests_display(client: TestClient, user: User,
assert target.text.strip() == "1 pending request" assert target.text.strip() == "1 pending request"
type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first() type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first()
db.create(PackageRequest, PackageBase=package.PackageBase, with db.begin():
PackageBaseName=package.PackageBase.Name, db.create(PackageRequest, PackageBase=package.PackageBase,
User=user, RequestType=type_, PackageBaseName=package.PackageBase.Name,
Comments="Test comment2.", User=user, RequestType=type_,
ClosureComment=str()) Comments="Test comment2.",
ClosureComment=str())
# Test that a two requests display "2 pending requests". # Test that a two requests display "2 pending requests".
with client as request: with client as request:
@ -271,50 +275,43 @@ def test_package_authenticated_maintainer(client: TestClient,
def test_package_dependencies(client: TestClient, maintainer: User, def test_package_dependencies(client: TestClient, maintainer: User,
package: Package): package: Package):
# Create a normal dependency of type depends. # Create a normal dependency of type depends.
dep_pkg = create_package("test-dep-1", maintainer, autocommit=False) with db.begin():
dep = create_package_dep(package, dep_pkg.Name, autocommit=False) dep_pkg = create_package("test-dep-1", maintainer)
dep.DepArch = "x86_64" dep = create_package_dep(package, dep_pkg.Name)
dep.DepArch = "x86_64"
# Also, create a makedepends. # Also, create a makedepends.
make_dep_pkg = create_package("test-dep-2", maintainer, autocommit=False) make_dep_pkg = create_package("test-dep-2", maintainer)
make_dep = create_package_dep(package, make_dep_pkg.Name, make_dep = create_package_dep(package, make_dep_pkg.Name,
dep_type_name="makedepends", dep_type_name="makedepends")
autocommit=False)
# And... a checkdepends! # And... a checkdepends!
check_dep_pkg = create_package("test-dep-3", maintainer, autocommit=False) check_dep_pkg = create_package("test-dep-3", maintainer)
check_dep = create_package_dep(package, check_dep_pkg.Name, check_dep = create_package_dep(package, check_dep_pkg.Name,
dep_type_name="checkdepends", dep_type_name="checkdepends")
autocommit=False)
# Geez. Just stop. This is optdepends. # Geez. Just stop. This is optdepends.
opt_dep_pkg = create_package("test-dep-4", maintainer, autocommit=False) opt_dep_pkg = create_package("test-dep-4", maintainer)
opt_dep = create_package_dep(package, opt_dep_pkg.Name, opt_dep = create_package_dep(package, opt_dep_pkg.Name,
dep_type_name="optdepends", dep_type_name="optdepends")
autocommit=False)
# Heh. Another optdepends to test one with a description. # Heh. Another optdepends to test one with a description.
opt_desc_dep_pkg = create_package("test-dep-5", maintainer, opt_desc_dep_pkg = create_package("test-dep-5", maintainer)
autocommit=False) opt_desc_dep = create_package_dep(package, opt_desc_dep_pkg.Name,
opt_desc_dep = create_package_dep(package, opt_desc_dep_pkg.Name, dep_type_name="optdepends")
dep_type_name="optdepends", opt_desc_dep.DepDesc = "Test description."
autocommit=False)
opt_desc_dep.DepDesc = "Test description."
broken_dep = create_package_dep(package, "test-dep-6", broken_dep = create_package_dep(package, "test-dep-6",
dep_type_name="depends", dep_type_name="depends")
autocommit=False)
# Create an official provider record. # Create an official provider record.
db.create(OfficialProvider, Name="test-dep-99", db.create(OfficialProvider, Name="test-dep-99",
Repo="core", Provides="test-dep-99", Repo="core", Provides="test-dep-99")
autocommit=False) official_dep = create_package_dep(package, "test-dep-99")
official_dep = create_package_dep(package, "test-dep-99",
autocommit=False)
# Also, create a provider who provides our test-dep-99. # Also, create a provider who provides our test-dep-99.
provider = create_package("test-provider", maintainer, autocommit=False) provider = create_package("test-provider", maintainer)
create_package_rel(provider, dep.DepName) create_package_rel(provider, dep.DepName)
with client as request: with client as request:
resp = request.get(package_endpoint(package)) resp = request.get(package_endpoint(package))
@ -358,8 +355,9 @@ def test_pkgbase_redirect(client: TestClient, package: Package):
def test_pkgbase(client: TestClient, package: Package): def test_pkgbase(client: TestClient, package: Package):
second = db.create(Package, Name="second-pkg", with db.begin():
PackageBase=package.PackageBase) second = db.create(Package, Name="second-pkg",
PackageBase=package.PackageBase)
expected = [package.Name, second.Name] expected = [package.Name, second.Name]
with client as request: with client as request:

View file

@ -26,17 +26,21 @@ def setup():
@pytest.fixture @pytest.fixture
def maintainer() -> User: def maintainer() -> User:
account_type = db.query(AccountType, AccountType.ID == USER_ID).first() account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test_maintainer", with db.begin():
Email="test_maintainer@examepl.org", maintainer = db.create(User, Username="test_maintainer",
Passwd="testPassword", Email="test_maintainer@examepl.org",
AccountType=account_type) Passwd="testPassword",
AccountType=account_type)
yield maintainer
@pytest.fixture @pytest.fixture
def package(maintainer: User) -> Package: def package(maintainer: User) -> Package:
pkgbase = db.create(PackageBase, Name="test-pkg", with db.begin():
Packager=maintainer, Maintainer=maintainer) pkgbase = db.create(PackageBase, Name="test-pkg",
yield db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase) Packager=maintainer, Maintainer=maintainer)
package = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
yield package
@pytest.fixture @pytest.fixture
@ -45,10 +49,11 @@ def client() -> TestClient:
def test_package_link(client: TestClient, maintainer: User, package: Package): def test_package_link(client: TestClient, maintainer: User, package: Package):
db.create(OfficialProvider, with db.begin():
Name=package.Name, db.create(OfficialProvider,
Repo="core", Name=package.Name,
Provides=package.Name) Repo="core",
Provides=package.Name)
expected = f"{OFFICIAL_BASE}/packages/?q={package.Name}" expected = f"{OFFICIAL_BASE}/packages/?q={package.Name}"
assert util.package_link(package) == expected assert util.package_link(package) == expected

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aurweb.db import create, delete, query from aurweb import db
from aurweb.models.relation_type import RelationType from aurweb.models.relation_type import RelationType
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -11,22 +11,25 @@ def setup():
def test_relation_type_creation(): def test_relation_type_creation():
relation_type = create(RelationType, Name="test-relation") with db.begin():
relation_type = db.create(RelationType, Name="test-relation")
assert bool(relation_type.ID) assert bool(relation_type.ID)
assert relation_type.Name == "test-relation" assert relation_type.Name == "test-relation"
delete(RelationType, RelationType.ID == relation_type.ID) with db.begin():
db.delete(RelationType, RelationType.ID == relation_type.ID)
def test_relation_types(): def test_relation_types():
conflicts = query(RelationType, RelationType.Name == "conflicts").first() conflicts = db.query(RelationType, RelationType.Name == "conflicts").first()
assert conflicts is not None assert conflicts is not None
assert conflicts.Name == "conflicts" assert conflicts.Name == "conflicts"
provides = query(RelationType, RelationType.Name == "provides").first() provides = db.query(RelationType, RelationType.Name == "provides").first()
assert provides is not None assert provides is not None
assert provides.Name == "provides" assert provides.Name == "provides"
replaces = query(RelationType, RelationType.Name == "replaces").first() replaces = db.query(RelationType, RelationType.Name == "replaces").first()
assert replaces is not None assert replaces is not None
assert replaces.Name == "replaces" assert replaces.Name == "replaces"

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aurweb.db import create, delete, query from aurweb import db
from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID, RequestType from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID, RequestType
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -11,25 +11,33 @@ def setup():
def test_request_type_creation(): def test_request_type_creation():
request_type = create(RequestType, Name="Test Request") with db.begin():
request_type = db.create(RequestType, Name="Test Request")
assert bool(request_type.ID) assert bool(request_type.ID)
assert request_type.Name == "Test Request" assert request_type.Name == "Test Request"
delete(RequestType, RequestType.ID == request_type.ID)
with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID)
def test_request_type_null_name_returns_empty_string(): def test_request_type_null_name_returns_empty_string():
request_type = create(RequestType) with db.begin():
request_type = db.create(RequestType)
assert bool(request_type.ID) assert bool(request_type.ID)
assert request_type.Name == str() assert request_type.Name == str()
delete(RequestType, RequestType.ID == request_type.ID)
with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID)
def test_request_type_name_display(): def test_request_type_name_display():
deletion = query(RequestType, RequestType.ID == DELETION_ID).first() deletion = db.query(RequestType, RequestType.ID == DELETION_ID).first()
assert deletion.name_display() == "Deletion" assert deletion.name_display() == "Deletion"
orphan = query(RequestType, RequestType.ID == ORPHAN_ID).first() orphan = db.query(RequestType, RequestType.ID == ORPHAN_ID).first()
assert orphan.name_display() == "Orphan" assert orphan.name_display() == "Orphan"
merge = query(RequestType, RequestType.ID == MERGE_ID).first() merge = db.query(RequestType, RequestType.ID == MERGE_ID).first()
assert merge.name_display() == "Merge" assert merge.name_display() == "Merge"

View file

@ -8,8 +8,8 @@ import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from aurweb import db
from aurweb.asgi import app from aurweb.asgi import app
from aurweb.db import create, query
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.user import User from aurweb.models.user import User
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -24,11 +24,13 @@ def setup():
setup_test_db("Users", "Sessions") setup_test_db("Users", "Sessions")
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword", with db.begin():
AccountType=account_type) user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
client = TestClient(app) client = TestClient(app)

View file

@ -49,14 +49,13 @@ def packages(user):
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
# Create 101 packages; we limit 100 on RSS feeds. # Create 101 packages; we limit 100 on RSS feeds.
for i in range(101): with db.begin():
pkgbase = db.create( for i in range(101):
PackageBase, Maintainer=user, Name=f"test-package-{i}", pkgbase = db.create(
SubmittedTS=(now + i), ModifiedTS=(now + i), autocommit=False) PackageBase, Maintainer=user, Name=f"test-package-{i}",
pkg = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase, SubmittedTS=(now + i), ModifiedTS=(now + i))
autocommit=False) pkg = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
pkgs.append(pkg) pkgs.append(pkg)
db.commit()
yield pkgs yield pkgs

View file

@ -4,7 +4,7 @@ from unittest import mock
import pytest import pytest
from aurweb.db import create, query from aurweb import db
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.session import Session, generate_unique_sid from aurweb.models.session import Session, generate_unique_sid
from aurweb.models.user import User from aurweb.models.user import User
@ -19,13 +19,16 @@ def setup():
setup_test_db("Users", "Sessions") setup_test_db("Users", "Sessions")
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org", with db.begin():
ResetKey="testReset", Passwd="testPassword", user = db.create(User, Username="test", Email="test@example.org",
AccountType=account_type) ResetKey="testReset", Passwd="testPassword",
session = create(Session, UsersID=user.ID, SessionID="testSession", AccountType=account_type)
LastUpdateTS=datetime.utcnow().timestamp())
with db.begin():
session = db.create(Session, UsersID=user.ID, SessionID="testSession",
LastUpdateTS=datetime.utcnow().timestamp())
def test_session(): def test_session():
@ -35,12 +38,15 @@ def test_session():
def test_session_cs(): def test_session_cs():
""" Test case sensitivity of the database table. """ """ Test case sensitivity of the database table. """
user2 = create(User, Username="test2", Email="test2@example.org", with db.begin():
ResetKey="testReset2", Passwd="testPassword", user2 = db.create(User, Username="test2", Email="test2@example.org",
AccountType=account_type) ResetKey="testReset2", Passwd="testPassword",
session_cs = create(Session, UsersID=user2.ID, AccountType=account_type)
SessionID="TESTSESSION",
LastUpdateTS=datetime.utcnow().timestamp()) with db.begin():
session_cs = db.create(Session, UsersID=user2.ID,
SessionID="TESTSESSION",
LastUpdateTS=datetime.utcnow().timestamp())
assert session_cs.SessionID == "TESTSESSION" assert session_cs.SessionID == "TESTSESSION"
assert session.SessionID == "testSession" assert session.SessionID == "testSession"

View file

@ -1,6 +1,6 @@
import pytest import pytest
from aurweb.db import create, query from aurweb import db
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
from aurweb.models.user import User from aurweb.models.user import User
@ -19,19 +19,18 @@ def setup():
setup_test_db("Users", "SSHPubKeys") setup_test_db("Users", "SSHPubKeys")
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org", with db.begin():
RealName="Test User", Passwd="testPassword", user = db.create(User, Username="test", Email="test@example.org",
AccountType=account_type) RealName="Test User", Passwd="testPassword",
AccountType=account_type)
assert account_type == user.AccountType with db.begin():
assert account_type.ID == user.AccountTypeID ssh_pub_key = db.create(SSHPubKey,
UserID=user.ID,
ssh_pub_key = create(SSHPubKey, Fingerprint="testFingerprint",
UserID=user.ID, PubKey="testPubKey")
Fingerprint="testFingerprint",
PubKey="testPubKey")
def test_ssh_pub_key(): def test_ssh_pub_key():
@ -43,9 +42,10 @@ def test_ssh_pub_key():
def test_ssh_pub_key_cs(): def test_ssh_pub_key_cs():
""" Test case sensitivity of the database table. """ """ Test case sensitivity of the database table. """
ssh_pub_key_cs = create(SSHPubKey, UserID=user.ID, with db.begin():
Fingerprint="TESTFINGERPRINT", ssh_pub_key_cs = db.create(SSHPubKey, UserID=user.ID,
PubKey="TESTPUBKEY") Fingerprint="TESTFINGERPRINT",
PubKey="TESTPUBKEY")
assert ssh_pub_key_cs.Fingerprint == "TESTFINGERPRINT" assert ssh_pub_key_cs.Fingerprint == "TESTFINGERPRINT"
assert ssh_pub_key_cs.PubKey == "TESTPUBKEY" assert ssh_pub_key_cs.PubKey == "TESTPUBKEY"

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import create from aurweb import db
from aurweb.models.term import Term from aurweb.models.term import Term
from aurweb.testing import setup_test_db from aurweb.testing import setup_test_db
@ -18,8 +18,9 @@ def setup():
def test_term_creation(): def test_term_creation():
term = create(Term, Description="Term description", with db.begin():
URL="https://fake_url.io") term = db.create(Term, Description="Term description",
URL="https://fake_url.io")
assert bool(term.ID) assert bool(term.ID)
assert term.Description == "Term description" assert term.Description == "Term description"
assert term.URL == "https://fake_url.io" assert term.URL == "https://fake_url.io"
@ -27,14 +28,14 @@ def test_term_creation():
def test_term_null_description_raises_exception(): def test_term_null_description_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(Term, URL="https://fake_url.io") with db.begin():
session.rollback() db.create(Term, URL="https://fake_url.io")
db.rollback()
def test_term_null_url_raises_exception(): def test_term_null_url_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(Term, Description="Term description") with db.begin():
session.rollback() db.create(Term, Description="Term description")
db.rollback()

View file

@ -90,37 +90,37 @@ def client():
def tu_user(): def tu_user():
tu_type = db.query(AccountType, tu_type = db.query(AccountType,
AccountType.AccountType == "Trusted User").first() AccountType.AccountType == "Trusted User").first()
yield db.create(User, Username="test_tu", Email="test_tu@example.org", with db.begin():
RealName="Test TU", Passwd="testPassword", tu_user = db.create(User, Username="test_tu",
AccountType=tu_type) Email="test_tu@example.org",
RealName="Test TU", Passwd="testPassword",
AccountType=tu_type)
yield tu_user
@pytest.fixture @pytest.fixture
def user(): def user():
user_type = db.query(AccountType, user_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
yield db.create(User, Username="test", Email="test@example.org", with db.begin():
RealName="Test User", Passwd="testPassword", user = db.create(User, Username="test", Email="test@example.org",
AccountType=user_type) RealName="Test User", Passwd="testPassword",
AccountType=user_type)
yield user
@pytest.fixture @pytest.fixture
def proposal(tu_user): def proposal(user, tu_user):
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
agenda = "Test proposal." agenda = "Test proposal."
start = ts - 5 start = ts - 5
end = ts + 1000 end = ts + 1000
user_type = db.query(AccountType, with db.begin():
AccountType.AccountType == "User").first() voteinfo = db.create(TUVoteInfo,
user = db.create(User, Username="test", Email="test@example.org", Agenda=agenda, Quorum=0.0,
RealName="Test User", Passwd="testPassword", User=user.Username, Submitter=tu_user,
AccountType=user_type) Submitted=start, End=end)
voteinfo = db.create(TUVoteInfo,
Agenda=agenda, Quorum=0.0,
User=user.Username, Submitter=tu_user,
Submitted=start, End=end)
yield (tu_user, user, voteinfo) yield (tu_user, user, voteinfo)
@ -170,20 +170,22 @@ def test_tu_index(client, tu_user):
("Test agenda 2", ts - 1000, ts - 5) # Not running anymore. ("Test agenda 2", ts - 1000, ts - 5) # Not running anymore.
] ]
vote_records = [] vote_records = []
for vote in votes: with db.begin():
agenda, start, end = vote for vote in votes:
vote_records.append( agenda, start, end = vote
db.create(TUVoteInfo, Agenda=agenda, vote_records.append(
User=tu_user.Username, db.create(TUVoteInfo, Agenda=agenda,
Submitted=start, End=end, User=tu_user.Username,
Quorum=0.0, Submitted=start, End=end,
Submitter=tu_user)) Quorum=0.0,
Submitter=tu_user))
# Vote on an ended proposal. with db.begin():
vote_record = vote_records[1] # Vote on an ended proposal.
vote_record.Yes += 1 vote_record = vote_records[1]
vote_record.ActiveTUs += 1 vote_record.Yes += 1
db.create(TUVote, VoteInfo=vote_record, User=tu_user) vote_record.ActiveTUs += 1
db.create(TUVote, VoteInfo=vote_record, User=tu_user)
cookies = {"AURSID": tu_user.login(Request(), "testPassword")} cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request: with client as request:
@ -255,22 +257,22 @@ def test_tu_index(client, tu_user):
def test_tu_index_table_paging(client, tu_user): def test_tu_index_table_paging(client, tu_user):
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
for i in range(25): with db.begin():
# Create 25 current votes. for i in range(25):
db.create(TUVoteInfo, Agenda=f"Agenda #{i}", # Create 25 current votes.
User=tu_user.Username, db.create(TUVoteInfo, Agenda=f"Agenda #{i}",
Submitted=(ts - 5), End=(ts + 1000), User=tu_user.Username,
Quorum=0.0, Submitted=(ts - 5), End=(ts + 1000),
Submitter=tu_user, autocommit=False) Quorum=0.0,
Submitter=tu_user)
for i in range(25): for i in range(25):
# Create 25 past votes. # Create 25 past votes.
db.create(TUVoteInfo, Agenda=f"Agenda #{25 + i}", db.create(TUVoteInfo, Agenda=f"Agenda #{25 + i}",
User=tu_user.Username, User=tu_user.Username,
Submitted=(ts - 1000), End=(ts - 5), Submitted=(ts - 1000), End=(ts - 5),
Quorum=0.0, Quorum=0.0,
Submitter=tu_user, autocommit=False) Submitter=tu_user)
db.commit()
cookies = {"AURSID": tu_user.login(Request(), "testPassword")} cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request: with client as request:
@ -363,18 +365,19 @@ def test_tu_index_table_paging(client, tu_user):
def test_tu_index_sorting(client, tu_user): def test_tu_index_sorting(client, tu_user):
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
for i in range(2): with db.begin():
# Create 'Agenda #1' and 'Agenda #2'. for i in range(2):
db.create(TUVoteInfo, Agenda=f"Agenda #{i + 1}", # Create 'Agenda #1' and 'Agenda #2'.
User=tu_user.Username, db.create(TUVoteInfo, Agenda=f"Agenda #{i + 1}",
Submitted=(ts + 5), End=(ts + 1000), User=tu_user.Username,
Quorum=0.0, Submitted=(ts + 5), End=(ts + 1000),
Submitter=tu_user, autocommit=False) Quorum=0.0,
Submitter=tu_user)
# Let's order each vote one day after the other. # Let's order each vote one day after the other.
# This will allow us to test the sorting nature # This will allow us to test the sorting nature
# of the tables. # of the tables.
ts += 86405 ts += 86405
# Make a default request to /tu. # Make a default request to /tu.
cookies = {"AURSID": tu_user.login(Request(), "testPassword")} cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
@ -432,18 +435,19 @@ def test_tu_index_sorting(client, tu_user):
def test_tu_index_last_votes(client, tu_user, user): def test_tu_index_last_votes(client, tu_user, user):
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
# Create a proposal which has ended. with db.begin():
voteinfo = db.create(TUVoteInfo, Agenda="Test agenda", # Create a proposal which has ended.
User=user.Username, voteinfo = db.create(TUVoteInfo, Agenda="Test agenda",
Submitted=(ts - 1000), User=user.Username,
End=(ts - 5), Submitted=(ts - 1000),
Yes=1, End=(ts - 5),
ActiveTUs=1, Yes=1,
Quorum=0.0, ActiveTUs=1,
Submitter=tu_user) Quorum=0.0,
Submitter=tu_user)
# Create a vote on it from tu_user. # Create a vote on it from tu_user.
db.create(TUVote, VoteInfo=voteinfo, User=tu_user) db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
# Now, check that tu_user got populated in the .last-votes table. # Now, check that tu_user got populated in the .last-votes table.
cookies = {"AURSID": tu_user.login(Request(), "testPassword")} cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
@ -529,10 +533,10 @@ def test_tu_running_proposal(client, proposal):
assert abstain.attrib["value"] == "Abstain" assert abstain.attrib["value"] == "Abstain"
# Create a vote. # Create a vote.
db.create(TUVote, VoteInfo=voteinfo, User=tu_user) with db.begin():
voteinfo.ActiveTUs += 1 db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
voteinfo.Yes += 1 voteinfo.ActiveTUs += 1
db.commit() voteinfo.Yes += 1
# Make another request now that we've voted. # Make another request now that we've voted.
with client as request: with client as request:
@ -556,8 +560,8 @@ def test_tu_ended_proposal(client, proposal):
tu_user, user, voteinfo = proposal tu_user, user, voteinfo = proposal
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
voteinfo.End = ts - 5 # 5 seconds ago. with db.begin():
db.commit() voteinfo.End = ts - 5 # 5 seconds ago.
# Initiate an authenticated GET request to /tu/{proposal_id}. # Initiate an authenticated GET request to /tu/{proposal_id}.
proposal_id = voteinfo.ID proposal_id = voteinfo.ID
@ -635,8 +639,8 @@ def test_tu_proposal_vote_unauthorized(client, proposal):
dev_type = db.query(AccountType, dev_type = db.query(AccountType,
AccountType.AccountType == "Developer").first() AccountType.AccountType == "Developer").first()
tu_user.AccountType = dev_type with db.begin():
db.commit() tu_user.AccountType = dev_type
cookies = {"AURSID": tu_user.login(Request(), "testPassword")} cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request: with client as request:
@ -664,8 +668,8 @@ def test_tu_proposal_vote_cant_self_vote(client, proposal):
tu_user, user, voteinfo = proposal tu_user, user, voteinfo = proposal
# Update voteinfo.User. # Update voteinfo.User.
voteinfo.User = tu_user.Username with db.begin():
db.commit() voteinfo.User = tu_user.Username
cookies = {"AURSID": tu_user.login(Request(), "testPassword")} cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request: with client as request:
@ -692,10 +696,10 @@ def test_tu_proposal_vote_cant_self_vote(client, proposal):
def test_tu_proposal_vote_already_voted(client, proposal): def test_tu_proposal_vote_already_voted(client, proposal):
tu_user, user, voteinfo = proposal tu_user, user, voteinfo = proposal
db.create(TUVote, VoteInfo=voteinfo, User=tu_user) with db.begin():
voteinfo.Yes += 1 db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
voteinfo.ActiveTUs += 1 voteinfo.Yes += 1
db.commit() voteinfo.ActiveTUs += 1
cookies = {"AURSID": tu_user.login(Request(), "testPassword")} cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request: with client as request:

View file

@ -4,7 +4,8 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query, rollback from aurweb import db
from aurweb.db import create, query, rollback
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.tu_voteinfo import TUVoteInfo from aurweb.models.tu_voteinfo import TUVoteInfo
from aurweb.models.user import User from aurweb.models.user import User
@ -21,19 +22,21 @@ def setup():
tu_type = query(AccountType, tu_type = query(AccountType,
AccountType.AccountType == "Trusted User").first() AccountType.AccountType == "Trusted User").first()
user = create(User, Username="test", Email="test@example.org", with db.begin():
RealName="Test User", Passwd="testPassword", user = create(User, Username="test", Email="test@example.org",
AccountType=tu_type) RealName="Test User", Passwd="testPassword",
AccountType=tu_type)
def test_tu_voteinfo_creation(): def test_tu_voteinfo_creation():
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
tu_voteinfo = create(TUVoteInfo, with db.begin():
Agenda="Blah blah.", tu_voteinfo = create(TUVoteInfo,
User=user.Username, Agenda="Blah blah.",
Submitted=ts, End=ts + 5, User=user.Username,
Quorum=0.5, Submitted=ts, End=ts + 5,
Submitter=user) Quorum=0.5,
Submitter=user)
assert bool(tu_voteinfo.ID) assert bool(tu_voteinfo.ID)
assert tu_voteinfo.Agenda == "Blah blah." assert tu_voteinfo.Agenda == "Blah blah."
assert tu_voteinfo.User == user.Username assert tu_voteinfo.User == user.Username
@ -51,32 +54,33 @@ def test_tu_voteinfo_creation():
def test_tu_voteinfo_is_running(): def test_tu_voteinfo_is_running():
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
tu_voteinfo = create(TUVoteInfo, with db.begin():
Agenda="Blah blah.", tu_voteinfo = create(TUVoteInfo,
User=user.Username, Agenda="Blah blah.",
Submitted=ts, End=ts + 1000, User=user.Username,
Quorum=0.5, Submitted=ts, End=ts + 1000,
Submitter=user) Quorum=0.5,
Submitter=user)
assert tu_voteinfo.is_running() is True assert tu_voteinfo.is_running() is True
tu_voteinfo.End = ts - 5 with db.begin():
commit() tu_voteinfo.End = ts - 5
assert tu_voteinfo.is_running() is False assert tu_voteinfo.is_running() is False
def test_tu_voteinfo_total_votes(): def test_tu_voteinfo_total_votes():
ts = int(datetime.utcnow().timestamp()) ts = int(datetime.utcnow().timestamp())
tu_voteinfo = create(TUVoteInfo, with db.begin():
Agenda="Blah blah.", tu_voteinfo = create(TUVoteInfo,
User=user.Username, Agenda="Blah blah.",
Submitted=ts, End=ts + 1000, User=user.Username,
Quorum=0.5, Submitted=ts, End=ts + 1000,
Submitter=user) Quorum=0.5,
Submitter=user)
tu_voteinfo.Yes = 1 tu_voteinfo.Yes = 1
tu_voteinfo.No = 3 tu_voteinfo.No = 3
tu_voteinfo.Abstain = 5 tu_voteinfo.Abstain = 5
commit()
# total_votes() should be the sum of Yes, No and Abstain: 1 + 3 + 5 = 9. # total_votes() should be the sum of Yes, No and Abstain: 1 + 3 + 5 = 9.
assert tu_voteinfo.total_votes() == 9 assert tu_voteinfo.total_votes() == 9
@ -84,61 +88,67 @@ def test_tu_voteinfo_total_votes():
def test_tu_voteinfo_null_submitter_raises_exception(): def test_tu_voteinfo_null_submitter_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(TUVoteInfo, with db.begin():
Agenda="Blah blah.", create(TUVoteInfo,
User=user.Username, Agenda="Blah blah.",
Submitted=0, End=0, User=user.Username,
Quorum=0.50) Submitted=0, End=0,
Quorum=0.50)
rollback() rollback()
def test_tu_voteinfo_null_agenda_raises_exception(): def test_tu_voteinfo_null_agenda_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(TUVoteInfo, with db.begin():
User=user.Username, create(TUVoteInfo,
Submitted=0, End=0, User=user.Username,
Quorum=0.50, Submitted=0, End=0,
Submitter=user) Quorum=0.50,
Submitter=user)
rollback() rollback()
def test_tu_voteinfo_null_user_raises_exception(): def test_tu_voteinfo_null_user_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(TUVoteInfo, with db.begin():
Agenda="Blah blah.", create(TUVoteInfo,
Submitted=0, End=0, Agenda="Blah blah.",
Quorum=0.50, Submitted=0, End=0,
Submitter=user) Quorum=0.50,
Submitter=user)
rollback() rollback()
def test_tu_voteinfo_null_submitted_raises_exception(): def test_tu_voteinfo_null_submitted_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(TUVoteInfo, with db.begin():
Agenda="Blah blah.", create(TUVoteInfo,
User=user.Username, Agenda="Blah blah.",
End=0, User=user.Username,
Quorum=0.50, End=0,
Submitter=user) Quorum=0.50,
Submitter=user)
rollback() rollback()
def test_tu_voteinfo_null_end_raises_exception(): def test_tu_voteinfo_null_end_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(TUVoteInfo, with db.begin():
Agenda="Blah blah.", create(TUVoteInfo,
User=user.Username, Agenda="Blah blah.",
Submitted=0, User=user.Username,
Quorum=0.50, Submitted=0,
Submitter=user) Quorum=0.50,
Submitter=user)
rollback() rollback()
def test_tu_voteinfo_null_quorum_raises_exception(): def test_tu_voteinfo_null_quorum_raises_exception():
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
create(TUVoteInfo, with db.begin():
Agenda="Blah blah.", create(TUVoteInfo,
User=user.Username, Agenda="Blah blah.",
Submitted=0, End=0, User=user.Username,
Submitter=user) Submitted=0, End=0,
Submitter=user)
rollback() rollback()

View file

@ -9,7 +9,7 @@ import pytest
import aurweb.auth import aurweb.auth
import aurweb.config import aurweb.config
from aurweb.db import commit, create, query from aurweb import db
from aurweb.models.account_type import AccountType from aurweb.models.account_type import AccountType
from aurweb.models.ban import Ban from aurweb.models.ban import Ban
from aurweb.models.package import Package from aurweb.models.package import Package
@ -40,12 +40,13 @@ def setup():
PackageNotification.__tablename__ PackageNotification.__tablename__
) )
account_type = query(AccountType, account_type = db.query(AccountType,
AccountType.AccountType == "User").first() AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org", with db.begin():
RealName="Test User", Passwd="testPassword", user = db.create(User, Username="test", Email="test@example.org",
AccountType=account_type) RealName="Test User", Passwd="testPassword",
AccountType=account_type)
def test_user_login_logout(): def test_user_login_logout():
@ -70,14 +71,14 @@ def test_user_login_logout():
assert "AURSID" in request.cookies assert "AURSID" in request.cookies
# Expect that User session relationships work right. # Expect that User session relationships work right.
user_session = query(Session, user_session = db.query(Session,
Session.UsersID == user.ID).first() Session.UsersID == user.ID).first()
assert user_session == user.session assert user_session == user.session
assert user.session.SessionID == sid assert user.session.SessionID == sid
assert user.session.User == user assert user.session.User == user
# Search for the user via query API. # Search for the user via query API.
result = query(User, User.ID == user.ID).first() result = db.query(User, User.ID == user.ID).first()
# Compare the result and our original user. # Compare the result and our original user.
assert result == user assert result == user
@ -114,7 +115,8 @@ def test_user_login_twice():
def test_user_login_banned(): def test_user_login_banned():
# Add ban for the next 30 seconds. # Add ban for the next 30 seconds.
banned_timestamp = datetime.utcnow() + timedelta(seconds=30) banned_timestamp = datetime.utcnow() + timedelta(seconds=30)
create(Ban, IPAddress="127.0.0.1", BanTS=banned_timestamp) with db.begin():
db.create(Ban, IPAddress="127.0.0.1", BanTS=banned_timestamp)
request = Request() request = Request()
request.client.host = "127.0.0.1" request.client.host = "127.0.0.1"
@ -122,18 +124,17 @@ def test_user_login_banned():
def test_user_login_suspended(): def test_user_login_suspended():
from aurweb.db import session with db.begin():
user.Suspended = True user.Suspended = True
session.commit()
assert not user.login(Request(), "testPassword") assert not user.login(Request(), "testPassword")
def test_legacy_user_authentication(): def test_legacy_user_authentication():
from aurweb.db import session with db.begin():
user.Salt = bcrypt.gensalt().decode()
user.Salt = bcrypt.gensalt().decode() user.Passwd = hashlib.md5(
user.Passwd = hashlib.md5(f"{user.Salt}testPassword".encode()).hexdigest() f"{user.Salt}testPassword".encode()
session.commit() ).hexdigest()
assert not user.valid_password("badPassword") assert not user.valid_password("badPassword")
assert user.valid_password("testPassword") assert user.valid_password("testPassword")
@ -145,8 +146,9 @@ def test_legacy_user_authentication():
def test_user_login_with_outdated_sid(): def test_user_login_with_outdated_sid():
# Make a session with a LastUpdateTS 5 seconds ago, causing # Make a session with a LastUpdateTS 5 seconds ago, causing
# user.login to update it with a new sid. # user.login to update it with a new sid.
create(Session, UsersID=user.ID, SessionID="stub", with db.begin():
LastUpdateTS=datetime.utcnow().timestamp() - 5) db.create(Session, UsersID=user.ID, SessionID="stub",
LastUpdateTS=datetime.utcnow().timestamp() - 5)
sid = user.login(Request(), "testPassword") sid = user.login(Request(), "testPassword")
assert sid and user.is_authenticated() assert sid and user.is_authenticated()
assert sid != "stub" assert sid != "stub"
@ -171,43 +173,42 @@ def test_user_has_credential():
def test_user_ssh_pub_key(): def test_user_ssh_pub_key():
assert user.ssh_pub_key is None assert user.ssh_pub_key is None
ssh_pub_key = create(SSHPubKey, UserID=user.ID, with db.begin():
Fingerprint="testFingerprint", ssh_pub_key = db.create(SSHPubKey, UserID=user.ID,
PubKey="testPubKey") Fingerprint="testFingerprint",
PubKey="testPubKey")
assert user.ssh_pub_key == ssh_pub_key assert user.ssh_pub_key == ssh_pub_key
def test_user_credential_types(): def test_user_credential_types():
from aurweb.db import session
assert aurweb.auth.user_developer_or_trusted_user(user) assert aurweb.auth.user_developer_or_trusted_user(user)
assert not aurweb.auth.trusted_user(user) assert not aurweb.auth.trusted_user(user)
assert not aurweb.auth.developer(user) assert not aurweb.auth.developer(user)
assert not aurweb.auth.trusted_user_or_dev(user) assert not aurweb.auth.trusted_user_or_dev(user)
trusted_user_type = query(AccountType, trusted_user_type = db.query(AccountType).filter(
AccountType.AccountType == "Trusted User")\ AccountType.AccountType == "Trusted User"
.first() ).first()
user.AccountType = trusted_user_type with db.begin():
session.commit() user.AccountType = trusted_user_type
assert aurweb.auth.trusted_user(user) assert aurweb.auth.trusted_user(user)
assert aurweb.auth.trusted_user_or_dev(user) assert aurweb.auth.trusted_user_or_dev(user)
developer_type = query(AccountType, developer_type = db.query(AccountType,
AccountType.AccountType == "Developer").first() AccountType.AccountType == "Developer").first()
user.AccountType = developer_type with db.begin():
session.commit() user.AccountType = developer_type
assert aurweb.auth.developer(user) assert aurweb.auth.developer(user)
assert aurweb.auth.trusted_user_or_dev(user) assert aurweb.auth.trusted_user_or_dev(user)
type_str = "Trusted User & Developer" type_str = "Trusted User & Developer"
elevated_type = query(AccountType, elevated_type = db.query(AccountType,
AccountType.AccountType == type_str).first() AccountType.AccountType == type_str).first()
user.AccountType = elevated_type with db.begin():
session.commit() user.AccountType = elevated_type
assert aurweb.auth.trusted_user(user) assert aurweb.auth.trusted_user(user)
assert aurweb.auth.developer(user) assert aurweb.auth.developer(user)
@ -233,53 +234,56 @@ def test_user_as_dict():
def test_user_is_trusted_user(): def test_user_is_trusted_user():
tu_type = query(AccountType, tu_type = db.query(AccountType,
AccountType.AccountType == "Trusted User").first() AccountType.AccountType == "Trusted User").first()
user.AccountType = tu_type with db.begin():
commit() user.AccountType = tu_type
assert user.is_trusted_user() is True assert user.is_trusted_user() is True
# Do it again with the combined role. # Do it again with the combined role.
tu_type = query( tu_type = db.query(
AccountType, AccountType,
AccountType.AccountType == "Trusted User & Developer").first() AccountType.AccountType == "Trusted User & Developer").first()
user.AccountType = tu_type with db.begin():
commit() user.AccountType = tu_type
assert user.is_trusted_user() is True assert user.is_trusted_user() is True
def test_user_is_developer(): def test_user_is_developer():
dev_type = query(AccountType, dev_type = db.query(AccountType,
AccountType.AccountType == "Developer").first() AccountType.AccountType == "Developer").first()
user.AccountType = dev_type with db.begin():
commit() user.AccountType = dev_type
assert user.is_developer() is True assert user.is_developer() is True
# Do it again with the combined role. # Do it again with the combined role.
dev_type = query( dev_type = db.query(
AccountType, AccountType,
AccountType.AccountType == "Trusted User & Developer").first() AccountType.AccountType == "Trusted User & Developer").first()
user.AccountType = dev_type with db.begin():
commit() user.AccountType = dev_type
assert user.is_developer() is True assert user.is_developer() is True
def test_user_voted_for(): def test_user_voted_for():
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user) with db.begin():
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name) pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
create(PackageVote, PackageBase=pkgbase, User=user, VoteTS=now) pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
db.create(PackageVote, PackageBase=pkgbase, User=user, VoteTS=now)
assert user.voted_for(pkg) assert user.voted_for(pkg)
def test_user_notified(): def test_user_notified():
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user) with db.begin():
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name) pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
create(PackageNotification, PackageBase=pkgbase, User=user) pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
db.create(PackageNotification, PackageBase=pkgbase, User=user)
assert user.notified(pkg) assert user.notified(pkg)
def test_user_packages(): def test_user_packages():
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user) with db.begin():
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name) pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
assert pkg in user.packages() assert pkg in user.packages()