[FastAPI] Refactor db modifications

For SQLAlchemy to automatically understand updates from the
external world, it must use an `autocommit=True` in its session.

This change breaks how we were using commit previously, as
`autocommit=True` causes SQLAlchemy to commit when a
SessionTransaction context hits __exit__.

So, a refactoring was required of our tests: All usage of
any `db.{create,delete}` must be called **within** a
SessionTransaction context, created via new `db.begin()`.

From this point forward, we're going to require:

```
with db.begin():
    db.create(...)
    db.delete(...)
    db.session.delete(object)
```

With this, we now get external DB modifications automatically
without reloading or restarting the FastAPI server, which we
absolutely need for production.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2021-09-02 16:26:48 -07:00
parent b52059d437
commit a5943bf2ad
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
37 changed files with 998 additions and 902 deletions

View file

@ -59,20 +59,15 @@ def query(model, *args, **kwargs):
return session.query(model).filter(*args, **kwargs)
def create(model, autocommit: bool = True, *args, **kwargs):
def create(model, *args, **kwargs):
instance = model(*args, **kwargs)
add(instance)
if autocommit is True:
commit()
return instance
return add(instance)
def delete(model, *args, autocommit: bool = True, **kwargs):
def delete(model, *args, **kwargs):
instance = session.query(model).filter(*args, **kwargs)
for record in instance:
session.delete(record)
if autocommit is True:
commit()
def rollback():
@ -84,8 +79,25 @@ def add(model):
return model
def commit():
session.commit()
def begin():
""" Begin an SQLAlchemy SessionTransaction.
This context is **required** to perform an modifications to the
database.
Example:
with db.begin():
object = db.create(...)
# On __exit__, db.commit() is run.
with db.begin():
object = db.delete(...)
# On __exit__, db.commit() is run.
:return: A new SessionTransaction based on session
"""
return session.begin()
def get_sqlalchemy_url():
@ -155,23 +167,23 @@ def get_engine(echo: bool = False):
connect_args=connect_args,
echo=echo)
Session = sessionmaker(autocommit=True, autoflush=False, bind=engine)
session = Session()
if db_backend == "sqlite":
# For SQLite, we need to add some custom functions as
# they are used in the reference graph method.
def regexp(regex, item):
return bool(re.search(regex, str(item)))
@event.listens_for(engine, "begin")
def do_begin(conn):
@event.listens_for(engine, "connect")
def do_begin(conn, record):
create_deterministic_function = functools.partial(
conn.connection.create_function,
conn.create_function,
deterministic=True
)
create_deterministic_function("REGEXP", 2, regexp)
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
return engine

View file

@ -102,7 +102,7 @@ class User(Base):
def login(self, request: Request, password: str, session_time=0):
""" Login and authenticate a request. """
from aurweb.db import session
from aurweb import db
from aurweb.models.session import Session, generate_unique_sid
if not self._login_approved(request):
@ -112,10 +112,7 @@ class User(Base):
if not self.authenticated:
return None
self.LastLogin = now_ts = datetime.utcnow().timestamp()
self.LastLoginIPAddress = request.client.host
session.commit()
now_ts = datetime.utcnow().timestamp()
session_ts = now_ts + (
session_time if session_time
else aurweb.config.getint("options", "login_timeout")
@ -123,22 +120,23 @@ class User(Base):
sid = None
if not self.session:
sid = generate_unique_sid()
self.session = Session(UsersID=self.ID, SessionID=sid,
LastUpdateTS=session_ts)
session.add(self.session)
else:
last_updated = self.session.LastUpdateTS
if last_updated and last_updated < now_ts:
self.session.SessionID = sid = generate_unique_sid()
with db.begin():
self.LastLogin = now_ts
self.LastLoginIPAddress = request.client.host
if not self.session:
sid = generate_unique_sid()
self.session = Session(UsersID=self.ID, SessionID=sid,
LastUpdateTS=session_ts)
db.add(self.session)
else:
# Session is still valid; retrieve the current SID.
sid = self.session.SessionID
last_updated = self.session.LastUpdateTS
if last_updated and last_updated < now_ts:
self.session.SessionID = sid = generate_unique_sid()
else:
# Session is still valid; retrieve the current SID.
sid = self.session.SessionID
self.session.LastUpdateTS = session_ts
session.commit()
self.session.LastUpdateTS = session_ts
request.cookies["AURSID"] = self.session.SessionID
return self.session.SessionID
@ -149,13 +147,11 @@ class User(Base):
return aurweb.auth.has_credential(self, cred, approved)
def logout(self, request):
from aurweb.db import session
del request.cookies["AURSID"]
self.authenticated = False
if self.session:
session.delete(self.session)
session.commit()
with db.begin():
db.session.delete(self.session)
def is_trusted_user(self):
return self.AccountType.ID in {

View file

@ -43,8 +43,6 @@ async def passreset_post(request: Request,
resetkey: str = Form(default=None),
password: str = Form(default=None),
confirm: str = Form(default=None)):
from aurweb.db import session
context = await make_variable_context(request, "Password Reset")
# The user parameter being required, we can match against
@ -86,12 +84,11 @@ async def passreset_post(request: Request,
# We got to this point; everything matched up. Update the password
# and remove the ResetKey.
user.ResetKey = str()
user.update_password(password)
if user.session:
session.delete(user.session)
session.commit()
with db.begin():
user.ResetKey = str()
if user.session:
db.session.delete(user.session)
user.update_password(password)
# Render ?step=complete.
return RedirectResponse(url="/passreset?step=complete",
@ -99,8 +96,8 @@ async def passreset_post(request: Request,
# If we got here, we continue with issuing a resetkey for the user.
resetkey = db.make_random_value(User, User.ResetKey)
user.ResetKey = resetkey
session.commit()
with db.begin():
user.ResetKey = resetkey
executor = db.ConnectionExecutor(db.get_engine().raw_connection())
ResetKeyNotification(executor, user.ID).send()
@ -364,8 +361,6 @@ async def account_register_post(request: Request,
ON: bool = Form(default=False),
captcha: str = Form(default=None),
captcha_salt: str = Form(...)):
from aurweb.db import session
context = await make_variable_context(request, "Register")
args = dict(await request.form())
@ -394,11 +389,13 @@ async def account_register_post(request: Request,
AccountType.AccountType == "User").first()
# Create a user given all parameters available.
user = db.create(User, Username=U, Email=E, HideEmail=H, BackupEmail=BE,
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
LangPreference=L, Timezone=TZ, CommentNotify=CN,
UpdateNotify=UN, OwnershipNotify=ON, ResetKey=resetkey,
AccountType=account_type)
with db.begin():
user = db.create(User, Username=U,
Email=E, HideEmail=H, BackupEmail=BE,
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
LangPreference=L, Timezone=TZ, CommentNotify=CN,
UpdateNotify=UN, OwnershipNotify=ON,
ResetKey=resetkey, AccountType=account_type)
# If a PK was given and either one does not exist or the given
# PK mismatches the existing user's SSHPubKey.PubKey.
@ -410,10 +407,10 @@ async def account_register_post(request: Request,
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
session.commit()
with db.begin():
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
# Send a reset key notification to the new user.
executor = db.ConnectionExecutor(db.get_engine().raw_connection())
@ -499,63 +496,67 @@ async def account_edit_post(request: Request,
status_code=int(HTTPStatus.BAD_REQUEST))
# Set all updated fields as needed.
user.Username = U or user.Username
user.Email = E or user.Email
user.HideEmail = bool(H)
user.BackupEmail = BE or user.BackupEmail
user.RealName = R or user.RealName
user.Homepage = HP or user.Homepage
user.IRCNick = I or user.IRCNick
user.PGPKey = K or user.PGPKey
user.InactivityTS = datetime.utcnow().timestamp() if J else 0
with db.begin():
user.Username = U or user.Username
user.Email = E or user.Email
user.HideEmail = bool(H)
user.BackupEmail = BE or user.BackupEmail
user.RealName = R or user.RealName
user.Homepage = HP or user.Homepage
user.IRCNick = I or user.IRCNick
user.PGPKey = K or user.PGPKey
user.InactivityTS = datetime.utcnow().timestamp() if J else 0
# If we update the language, update the cookie as well.
if L and L != user.LangPreference:
request.cookies["AURLANG"] = L
user.LangPreference = L
with db.begin():
user.LangPreference = L
context["language"] = L
# If we update the timezone, also update the cookie.
if TZ and TZ != user.Timezone:
user.Timezone = TZ
with db.begin():
user.Timezone = TZ
request.cookies["AURTZ"] = TZ
context["timezone"] = TZ
user.CommentNotify = bool(CN)
user.UpdateNotify = bool(UN)
user.OwnershipNotify = bool(ON)
with db.begin():
user.CommentNotify = bool(CN)
user.UpdateNotify = bool(UN)
user.OwnershipNotify = bool(ON)
# If a PK is given, compare it against the target user's PK.
if PK:
# Get the second token in the public key, which is the actual key.
pubkey = PK.strip().rstrip()
parts = pubkey.split(" ")
if len(parts) == 3:
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
if not user.ssh_pub_key:
# No public key exists, create one.
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
elif user.ssh_pub_key.PubKey != pubkey:
# A public key already exists, update it.
user.ssh_pub_key.PubKey = pubkey
user.ssh_pub_key.Fingerprint = fingerprint
elif user.ssh_pub_key:
# Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key)
# Commit changes, if any.
session.commit()
with db.begin():
if PK:
# Get the second token in the public key, which is the actual key.
pubkey = PK.strip().rstrip()
parts = pubkey.split(" ")
if len(parts) == 3:
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
if not user.ssh_pub_key:
# No public key exists, create one.
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
elif user.ssh_pub_key.PubKey != pubkey:
# A public key already exists, update it.
user.ssh_pub_key.PubKey = pubkey
user.ssh_pub_key.Fingerprint = fingerprint
elif user.ssh_pub_key:
# Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key)
if P and not user.valid_password(P):
# Remove the fields we consumed for passwords.
context["P"] = context["C"] = str()
# If a password was given and it doesn't match the user's, update it.
user.update_password(P)
with db.begin():
user.update_password(P)
if user == request.user:
# If the target user is the request user, login with
# the updated password and update AURSID.
@ -731,21 +732,17 @@ async def terms_of_service_post(request: Request,
accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(request, context, accept_needed)
# For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision.
for term in diffs:
accepted_term = request.user.accepted_terms.filter(
AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision
with db.begin():
# For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision.
for term in diffs:
accepted_term = request.user.accepted_terms.filter(
AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it!
for term in unaccepted:
db.create(AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision,
autocommit=False)
if diffs or unaccepted:
# If we had any terms to update, commit the changes.
db.commit()
# For each term that was never accepted, accept it!
for term in unaccepted:
db.create(AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision)
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))

View file

@ -44,8 +44,6 @@ async def language(request: Request,
setting the language on any page, we want to preserve query
parameters across the redirect.
"""
from aurweb.db import session
if next[0] != '/':
return HTMLResponse(b"Invalid 'next' parameter.", status_code=400)
@ -53,8 +51,8 @@ async def language(request: Request,
# If the user is authenticated, update the user's LangPreference.
if request.user.is_authenticated():
request.user.LangPreference = set_lang
session.commit()
with db.begin():
request.user.LangPreference = set_lang
# In any case, set the response's AURLANG cookie that never expires.
response = RedirectResponse(url=f"{next}{query_string}",

View file

@ -214,10 +214,9 @@ async def trusted_user_proposal_post(request: Request,
return Response("Invalid 'decision' value.",
status_code=int(HTTPStatus.BAD_REQUEST))
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo,
autocommit=False)
voteinfo.ActiveTUs += 1
db.commit()
with db.begin():
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo)
voteinfo.ActiveTUs += 1
context["error"] = "You've already voted for this proposal."
return render_proposal(request, context, proposal, voteinfo, voters, vote)
@ -294,12 +293,13 @@ async def trusted_user_addvote_post(request: Request,
agenda = re.sub(r'<[/]?style.*>', '', agenda)
# Create a new TUVoteInfo (proposal)!
voteinfo = db.create(TUVoteInfo,
User=user,
Agenda=agenda,
Submitted=timestamp, End=timestamp + duration,
Quorum=quorum,
Submitter=request.user)
with db.begin():
voteinfo = db.create(TUVoteInfo,
User=user,
Agenda=agenda,
Submitted=timestamp, End=timestamp + duration,
Quorum=quorum,
Submitter=request.user)
# Redirect to the new proposal.
return RedirectResponse(f"/tu/{voteinfo.ID}",

View file

@ -1,6 +1,6 @@
import pytest
from aurweb.db import create, delete, query
from aurweb.db import begin, create, delete, query
from aurweb.models.account_type import AccountType
from aurweb.models.user import User
from aurweb.testing import setup_test_db
@ -14,11 +14,13 @@ def setup():
global account_type
account_type = create(AccountType, AccountType="TestUser")
with begin():
account_type = create(AccountType, AccountType="TestUser")
yield account_type
delete(AccountType, AccountType.ID == account_type.ID)
with begin():
delete(AccountType, AccountType.ID == account_type.ID)
def test_account_type():
@ -38,12 +40,14 @@ def test_account_type():
def test_user_account_type_relationship():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
with begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
assert user.AccountType == account_type
# This must be deleted here to avoid foreign key issues when
# deleting the temporary AccountType in the fixture.
delete(User, User.ID == user.ID)
with begin():
delete(User, User.ID == user.ID)

View file

@ -11,9 +11,9 @@ import pytest
from fastapi.testclient import TestClient
from aurweb import captcha
from aurweb import captcha, db
from aurweb.asgi import app
from aurweb.db import commit, create, query
from aurweb.db import create, query
from aurweb.models.accepted_term import AcceptedTerm
from aurweb.models.account_type import DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, AccountType
from aurweb.models.ban import Ban
@ -57,9 +57,11 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test UserZ", Passwd="testPassword",
IRCNick="testZ", AccountType=account_type)
with db.begin():
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test UserZ", Passwd="testPassword",
IRCNick="testZ", AccountType=account_type)
yield user
@ -70,9 +72,10 @@ def setup():
@pytest.fixture
def tu_user():
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_AND_DEV_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
yield user
@ -149,11 +152,9 @@ def test_post_passreset_user():
def test_post_passreset_resetkey():
from aurweb.db import session
user.session = Session(UsersID=user.ID, SessionID="blah",
LastUpdateTS=datetime.utcnow().timestamp())
session.commit()
with db.begin():
user.session = Session(UsersID=user.ID, SessionID="blah",
LastUpdateTS=datetime.utcnow().timestamp())
# Prepare a password reset.
with client as request:
@ -357,7 +358,8 @@ def test_post_register_error_invalid_captcha():
def test_post_register_error_ip_banned():
# 'testclient' is used as request.client.host via FastAPI TestClient.
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
with db.begin():
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
with client as request:
response = post_register(request)
@ -576,7 +578,8 @@ def test_post_register_error_ssh_pubkey_taken():
# Take the sha256 fingerprint of the ssh public key, create it.
fp = get_fingerprint(pk)
create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp)
with db.begin():
create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp)
with client as request:
response = post_register(request, PK=pk)
@ -660,13 +663,11 @@ def test_post_account_edit():
def test_post_account_edit_dev():
from aurweb.db import session
# Modify our user to be a "Trusted User & Developer"
name = "Trusted User & Developer"
tu_or_dev = query(AccountType, AccountType.AccountType == name).first()
user.AccountType = tu_or_dev
session.commit()
with db.begin():
user.AccountType = tu_or_dev
request = Request()
sid = user.login(request, "testPassword")
@ -1001,21 +1002,19 @@ def get_rows(html):
def test_post_accounts(tu_user):
# Set a PGPKey.
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
with db.begin():
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
# Create a few more users.
users = [user]
for i in range(10):
_user = create(User, Username=f"test_{i}",
Email=f"test_{i}@example.org",
RealName=f"Test #{i}",
Passwd="testPassword",
IRCNick=f"test_#{i}",
autocommit=False)
users.append(_user)
# Commit everything to the database.
commit()
with db.begin():
for i in range(10):
_user = create(User, Username=f"test_{i}",
Email=f"test_{i}@example.org",
RealName=f"Test #{i}",
Passwd="testPassword",
IRCNick=f"test_#{i}")
users.append(_user)
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
@ -1085,11 +1084,12 @@ def test_post_accounts_account_type(tu_user):
# test the `u` parameter.
account_type = query(AccountType,
AccountType.AccountType == "User").first()
create(User, Username="test_2",
Email="test_2@example.org",
RealName="Test User 2",
Passwd="testPassword",
AccountType=account_type)
with db.begin():
create(User, Username="test_2",
Email="test_2@example.org",
RealName="Test User 2",
Passwd="testPassword",
AccountType=account_type)
# Expect no entries; we marked our only user as a User type.
with client as request:
@ -1113,9 +1113,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "User"
# Set our only user to a Trusted User.
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_ID
).first()
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1130,9 +1131,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "Trusted User"
user.AccountType = query(AccountType,
AccountType.ID == DEVELOPER_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == DEVELOPER_ID
).first()
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1147,10 +1149,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "Developer"
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1182,8 +1184,8 @@ def test_post_accounts_status(tu_user):
username, type, status, realname, irc, pgp_key, edit = row
assert status.text.strip() == "Active"
user.Suspended = True
commit()
with db.begin():
user.Suspended = True
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1244,12 +1246,13 @@ def test_post_accounts_sortby(tu_user):
# Create a second user so we can compare sorts.
account_type = query(AccountType,
AccountType.ID == DEVELOPER_ID).first()
create(User, Username="test2",
Email="test2@example.org",
RealName="Test User 2",
Passwd="testPassword",
IRCNick="test2",
AccountType=account_type)
with db.begin():
create(User, Username="test2",
Email="test2@example.org",
RealName="Test User 2",
Passwd="testPassword",
IRCNick="test2",
AccountType=account_type)
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
@ -1297,9 +1300,10 @@ def test_post_accounts_sortby(tu_user):
# Test the rows are reversed when ordering by RealName.
assert compare_text_values(4, first_rows, reversed(rows)) is True
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_AND_DEV_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
# Fetch first_rows again with our new AccountType ordering.
with client as request:
@ -1322,8 +1326,8 @@ def test_post_accounts_sortby(tu_user):
def test_post_accounts_pgp_key(tu_user):
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
commit()
with db.begin():
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
@ -1343,15 +1347,14 @@ def test_post_accounts_paged(tu_user):
users = [user]
account_type = query(AccountType,
AccountType.AccountType == "User").first()
for i in range(150):
_user = create(User, Username=f"test_#{i}",
Email=f"test_#{i}@example.org",
RealName=f"Test User #{i}",
Passwd="testPassword",
AccountType=account_type,
autocommit=False)
users.append(_user)
commit()
with db.begin():
for i in range(150):
_user = create(User, Username=f"test_#{i}",
Email=f"test_#{i}@example.org",
RealName=f"Test User #{i}",
Passwd="testPassword",
AccountType=account_type)
users.append(_user)
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
@ -1414,8 +1417,9 @@ def test_post_accounts_paged(tu_user):
def test_get_terms_of_service():
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
with db.begin():
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
with client as request:
response = request.get("/tos", allow_redirects=False)
@ -1436,8 +1440,9 @@ def test_get_terms_of_service():
response = request.get("/tos", cookies=cookies, allow_redirects=False)
assert response.status_code == int(HTTPStatus.OK)
accepted_term = create(AcceptedTerm, User=user,
Term=term, Revision=term.Revision)
with db.begin():
accepted_term = create(AcceptedTerm, User=user,
Term=term, Revision=term.Revision)
with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1445,8 +1450,8 @@ def test_get_terms_of_service():
assert response.status_code == int(HTTPStatus.SEE_OTHER)
# Bump the term's revision.
term.Revision = 2
commit()
with db.begin():
term.Revision = 2
with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1454,8 +1459,8 @@ def test_get_terms_of_service():
# yet been agreed to via AcceptedTerm update.
assert response.status_code == int(HTTPStatus.OK)
accepted_term.Revision = term.Revision
commit()
with db.begin():
accepted_term.Revision = term.Revision
with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1471,8 +1476,9 @@ def test_post_terms_of_service():
cookies = {"AURSID": sid} # Auth cookie.
# Create a fresh Term.
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
with db.begin():
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
# Test that the term we just created is listed.
with client as request:
@ -1497,8 +1503,8 @@ def test_post_terms_of_service():
assert accepted_term.Term == term
# Update the term to revision 2.
term.Revision = 2
commit()
with db.begin():
term.Revision = 2
# A GET request gives us the new revision to accept.
with client as request:

View file

@ -2,6 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb import db
from aurweb.db import create
from aurweb.models.api_rate_limit import ApiRateLimit
from aurweb.testing import setup_test_db
@ -13,26 +14,28 @@ def setup():
def test_api_rate_key_creation():
rate = create(ApiRateLimit, IP="127.0.0.1", Requests=10, WindowStart=1)
with db.begin():
rate = create(ApiRateLimit, IP="127.0.0.1", Requests=10, WindowStart=1)
assert rate.IP == "127.0.0.1"
assert rate.Requests == 10
assert rate.WindowStart == 1
def test_api_rate_key_ip_default():
api_rate_limit = create(ApiRateLimit, Requests=10, WindowStart=1)
with db.begin():
api_rate_limit = create(ApiRateLimit, Requests=10, WindowStart=1)
assert api_rate_limit.IP == str()
def test_api_rate_key_null_requests_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(ApiRateLimit, IP="127.0.0.1", WindowStart=1)
session.rollback()
with db.begin():
create(ApiRateLimit, IP="127.0.0.1", WindowStart=1)
db.rollback()
def test_api_rate_key_null_window_start_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(ApiRateLimit, IP="127.0.0.1", Requests=1)
session.rollback()
with db.begin():
create(ApiRateLimit, IP="127.0.0.1", Requests=1)
db.rollback()

View file

@ -4,6 +4,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb import db
from aurweb.auth import BasicAuthBackend, account_type_required, has_credential
from aurweb.db import create, query
from aurweb.models.account_type import USER, USER_ID, AccountType
@ -23,9 +24,10 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.com",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
with db.begin():
user = create(User, Username="test", Email="test@example.com",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
backend = BasicAuthBackend()
request = Request()
@ -51,14 +53,13 @@ async def test_auth_backend_invalid_sid():
@pytest.mark.asyncio
async def test_auth_backend_invalid_user_id():
from aurweb.db import session
# Create a new session with a fake user id.
now_ts = datetime.utcnow().timestamp()
with pytest.raises(IntegrityError):
create(Session, UsersID=666, SessionID="realSession",
LastUpdateTS=now_ts + 5)
session.rollback()
with db.begin():
create(Session, UsersID=666, SessionID="realSession",
LastUpdateTS=now_ts + 5)
db.rollback()
@pytest.mark.asyncio
@ -66,8 +67,9 @@ async def test_basic_auth_backend():
# This time, everything matches up. We expect the user to
# equal the real_user.
now_ts = datetime.utcnow().timestamp()
create(Session, UsersID=user.ID, SessionID="realSession",
LastUpdateTS=now_ts + 5)
with db.begin():
create(Session, UsersID=user.ID, SessionID="realSession",
LastUpdateTS=now_ts + 5)
request.cookies["AURSID"] = "realSession"
_, result = await backend.authenticate(request)
assert result == user

View file

@ -9,7 +9,7 @@ from fastapi.testclient import TestClient
import aurweb.config
from aurweb.asgi import app
from aurweb.db import create, query
from aurweb.db import begin, create, query
from aurweb.models.account_type import AccountType
from aurweb.models.session import Session
from aurweb.models.user import User
@ -32,9 +32,10 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
with begin():
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
client = TestClient(app)

View file

@ -6,6 +6,7 @@ import pytest
from sqlalchemy import exc as sa_exc
from aurweb import db
from aurweb.db import create
from aurweb.models.ban import Ban, is_banned
from aurweb.testing import setup_test_db
@ -21,7 +22,8 @@ def setup():
setup_test_db("Bans")
ts = datetime.utcnow() + timedelta(seconds=30)
ban = create(Ban, IPAddress="127.0.0.1", BanTS=ts)
with db.begin():
ban = create(Ban, IPAddress="127.0.0.1", BanTS=ts)
request = Request()
@ -35,17 +37,17 @@ def test_invalid_ban():
with pytest.raises(sa_exc.IntegrityError):
bad_ban = Ban(BanTS=datetime.utcnow())
session.add(bad_ban)
# We're adding a ban with no primary key; this causes an
# SQLAlchemy warnings when committing to the DB.
# Ignore them.
with warnings.catch_warnings():
warnings.simplefilter("ignore", sa_exc.SAWarning)
session.commit()
with db.begin():
session.add(bad_ban)
# Since we got a transaction failure, we need to rollback.
session.rollback()
db.rollback()
def test_banned():

View file

@ -278,18 +278,15 @@ def test_connection_execute_paramstyle_unsupported():
def test_create_delete():
db.create(AccountType, AccountType="test")
with db.begin():
db.create(AccountType, AccountType="test")
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is not None
db.delete(AccountType, AccountType.AccountType == "test")
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is None
# Create and delete a record with autocommit=False.
db.create(AccountType, AccountType="test", autocommit=False)
db.commit()
db.delete(AccountType, AccountType.AccountType == "test", autocommit=False)
db.commit()
with db.begin():
db.delete(AccountType, AccountType.AccountType == "test")
record = db.query(AccountType, AccountType.AccountType == "test").first()
assert record is None
@ -297,8 +294,8 @@ def test_create_delete():
def test_add_commit():
# Use db.add and db.commit to add a temporary record.
account_type = AccountType(AccountType="test")
db.add(account_type)
db.commit()
with db.begin():
db.add(account_type)
# Assert it got created in the DB.
assert bool(account_type.ID)
@ -308,7 +305,8 @@ def test_add_commit():
assert record == account_type
# Remove the record.
db.delete(AccountType, AccountType.ID == account_type.ID)
with db.begin():
db.delete(AccountType, AccountType.ID == account_type.ID)
def test_connection_executor_mysql_paramstyle():

View file

@ -1,6 +1,6 @@
import pytest
from aurweb.db import create, delete, query
from aurweb.db import begin, create, delete, query
from aurweb.models.dependency_type import DependencyType
from aurweb.testing import setup_test_db
@ -19,13 +19,17 @@ def test_dependency_types():
def test_dependency_type_creation():
dependency_type = create(DependencyType, Name="Test Type")
with begin():
dependency_type = create(DependencyType, Name="Test Type")
assert bool(dependency_type.ID)
assert dependency_type.Name == "Test Type"
delete(DependencyType, DependencyType.ID == dependency_type.ID)
with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID)
def test_dependency_type_null_name_uses_default():
dependency_type = create(DependencyType)
with begin():
dependency_type = create(DependencyType)
assert dependency_type.Name == str()
delete(DependencyType, DependencyType.ID == dependency_type.ID)
with begin():
delete(DependencyType, DependencyType.ID == dependency_type.ID)

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create
from aurweb import db
from aurweb.models.group import Group
from aurweb.testing import setup_test_db
@ -13,13 +13,14 @@ def setup():
def test_group_creation():
group = create(Group, Name="Test Group")
with db.begin():
group = db.create(Group, Name="Test Group")
assert bool(group.ID)
assert group.Name == "Test Group"
def test_group_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(Group)
session.rollback()
with db.begin():
db.create(Group)
db.rollback()

View file

@ -38,8 +38,10 @@ def setup():
@pytest.fixture
def user():
yield db.create(User, Username="test", Email="test@example.org",
Passwd="testPassword", AccountTypeID=USER_ID)
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
Passwd="testPassword", AccountTypeID=USER_ID)
yield user
@pytest.fixture
@ -68,17 +70,14 @@ def packages(user):
# For i..num_packages, create a package named pkg_{i}.
pkgs = []
now = int(datetime.utcnow().timestamp())
for i in range(num_packages):
pkgbase = db.create(PackageBase, Name=f"pkg_{i}",
Maintainer=user, Packager=user,
autocommit=False, SubmittedTS=now,
ModifiedTS=now)
pkg = db.create(Package, PackageBase=pkgbase,
Name=pkgbase.Name, autocommit=False)
pkgs.append(pkg)
now += 1
db.commit()
with db.begin():
for i in range(num_packages):
pkgbase = db.create(PackageBase, Name=f"pkg_{i}",
Maintainer=user, Packager=user,
SubmittedTS=now, ModifiedTS=now)
pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
pkgs.append(pkg)
now += 1
yield pkgs
@ -159,10 +158,11 @@ def test_homepage_updates(redis, packages):
def test_homepage_dashboard(redis, packages, user):
# Create Comaintainer records for all of the packages.
for pkg in packages:
db.create(PackageComaintainer, PackageBase=pkg.PackageBase,
User=user, Priority=1, autocommit=False)
db.commit()
with db.begin():
for pkg in packages:
db.create(PackageComaintainer,
PackageBase=pkg.PackageBase,
User=user, Priority=1)
cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request:
@ -193,11 +193,12 @@ def test_homepage_dashboard_requests(redis, packages, user):
pkg = packages[0]
reqtype = db.query(RequestType, RequestType.ID == DELETION_ID).first()
pkgreq = db.create(PackageRequest, PackageBase=pkg.PackageBase,
PackageBaseName=pkg.PackageBase.Name,
User=user, Comments=str(),
ClosureComment=str(), RequestTS=now,
RequestType=reqtype)
with db.begin():
pkgreq = db.create(PackageRequest, PackageBase=pkg.PackageBase,
PackageBaseName=pkg.PackageBase.Name,
User=user, Comments=str(),
ClosureComment=str(), RequestTS=now,
RequestType=reqtype)
cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request:
@ -213,8 +214,8 @@ def test_homepage_dashboard_requests(redis, packages, user):
def test_homepage_dashboard_flagged_packages(redis, packages, user):
# Set the first Package flagged by setting its OutOfDateTS column.
pkg = packages[0]
pkg.PackageBase.OutOfDateTS = int(datetime.utcnow().timestamp())
db.commit()
with db.begin():
pkg.PackageBase.OutOfDateTS = int(datetime.utcnow().timestamp())
cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request:

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create
from aurweb import db
from aurweb.models.license import License
from aurweb.testing import setup_test_db
@ -13,13 +13,14 @@ def setup():
def test_license_creation():
license = create(License, Name="Test License")
with db.begin():
license = db.create(License, Name="Test License")
assert bool(license.ID)
assert license.Name == "Test License"
def test_license_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(License)
session.rollback()
with db.begin():
db.create(License)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create
from aurweb import db
from aurweb.models.official_provider import OfficialProvider
from aurweb.testing import setup_test_db
@ -13,10 +13,11 @@ def setup():
def test_official_provider_creation():
oprovider = create(OfficialProvider,
Name="some-name",
Repo="some-repo",
Provides="some-provides")
with db.begin():
oprovider = db.create(OfficialProvider,
Name="some-name",
Repo="some-repo",
Provides="some-provides")
assert bool(oprovider.ID)
assert oprovider.Name == "some-name"
assert oprovider.Repo == "some-repo"
@ -25,16 +26,18 @@ def test_official_provider_creation():
def test_official_provider_cs():
""" Test case sensitivity of the database table. """
oprovider = create(OfficialProvider,
Name="some-name",
Repo="some-repo",
Provides="some-provides")
with db.begin():
oprovider = db.create(OfficialProvider,
Name="some-name",
Repo="some-repo",
Provides="some-provides")
assert bool(oprovider.ID)
oprovider_cs = create(OfficialProvider,
Name="SOME-NAME",
Repo="SOME-REPO",
Provides="SOME-PROVIDES")
with db.begin():
oprovider_cs = db.create(OfficialProvider,
Name="SOME-NAME",
Repo="SOME-REPO",
Provides="SOME-PROVIDES")
assert bool(oprovider_cs.ID)
assert oprovider.ID != oprovider_cs.ID
@ -49,27 +52,27 @@ def test_official_provider_cs():
def test_official_provider_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(OfficialProvider,
Repo="some-repo",
Provides="some-provides")
session.rollback()
with db.begin():
db.create(OfficialProvider,
Repo="some-repo",
Provides="some-provides")
db.rollback()
def test_official_provider_null_repo_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(OfficialProvider,
Name="some-name",
Provides="some-provides")
session.rollback()
with db.begin():
db.create(OfficialProvider,
Name="some-name",
Provides="some-provides")
db.rollback()
def test_official_provider_null_provides_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(OfficialProvider,
Name="some-name",
Repo="some-repo")
session.rollback()
with db.begin():
db.create(OfficialProvider,
Name="some-name",
Repo="some-repo")
db.rollback()

View file

@ -3,7 +3,7 @@ import pytest
from sqlalchemy import and_
from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query
from aurweb import db
from aurweb.models.account_type import AccountType
from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase
@ -19,25 +19,25 @@ def setup():
setup_test_db("Packages", "PackageBases", "Users")
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
pkgbase = create(PackageBase,
Name="beautiful-package",
Maintainer=user)
package = create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name,
Description="Test description.",
URL="https://test.package")
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = db.create(PackageBase,
Name="beautiful-package",
Maintainer=user)
package = db.create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name,
Description="Test description.",
URL="https://test.package")
def test_package():
from aurweb.db import session
assert pkgbase == package.PackageBase
assert package.Name == "beautiful-package"
assert package.Description == "Test description."
@ -45,33 +45,31 @@ def test_package():
assert package.URL == "https://test.package"
# Update package Version.
package.Version = "1.2.3"
session.commit()
with db.begin():
package.Version = "1.2.3"
# Make sure it got updated in the database.
record = query(Package,
and_(Package.ID == package.ID,
Package.Version == "1.2.3")).first()
record = db.query(Package,
and_(Package.ID == package.ID,
Package.Version == "1.2.3")).first()
assert record is not None
def test_package_null_pkgbase_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(Package,
Name="some-package",
Description="Some description.",
URL="https://some.package")
session.rollback()
with db.begin():
db.create(Package,
Name="some-package",
Description="Some description.",
URL="https://some.package")
db.rollback()
def test_package_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(Package,
PackageBase=pkgbase,
Description="Some description.",
URL="https://some.package")
session.rollback()
with db.begin():
db.create(Package,
PackageBase=pkgbase,
Description="Some description.",
URL="https://some.package")
db.rollback()

View file

@ -4,7 +4,7 @@ from sqlalchemy.exc import IntegrityError
import aurweb.config
from aurweb.db import create, query
from aurweb import db
from aurweb.models.account_type import AccountType
from aurweb.models.package_base import PackageBase
from aurweb.models.user import User
@ -19,17 +19,19 @@ def setup():
setup_test_db("Users", "PackageBases")
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
def test_package_base():
pkgbase = create(PackageBase,
Name="beautiful-package",
Maintainer=user)
with db.begin():
pkgbase = db.create(PackageBase,
Name="beautiful-package",
Maintainer=user)
assert pkgbase in user.maintained_bases
assert not pkgbase.OutOfDateTS
@ -38,7 +40,8 @@ def test_package_base():
# Set Popularity to a string, then get it by attribute to
# exercise the string -> float conversion path.
pkgbase.Popularity = "0.0"
with db.begin():
pkgbase.Popularity = "0.0"
assert pkgbase.Popularity == 0.0
@ -47,27 +50,28 @@ def test_package_base_ci():
if aurweb.config.get("database", "backend") == "sqlite":
return None # SQLite doesn't seem handle this.
from aurweb.db import session
pkgbase = create(PackageBase,
Name="beautiful-package",
Maintainer=user)
with db.begin():
pkgbase = db.create(PackageBase,
Name="beautiful-package",
Maintainer=user)
assert bool(pkgbase.ID)
with pytest.raises(IntegrityError):
create(PackageBase,
Name="Beautiful-Package",
Maintainer=user)
session.rollback()
with db.begin():
db.create(PackageBase,
Name="Beautiful-Package",
Maintainer=user)
db.rollback()
def test_package_base_relationships():
pkgbase = create(PackageBase,
Name="beautiful-package",
Flagger=user,
Maintainer=user,
Submitter=user,
Packager=user)
with db.begin():
pkgbase = db.create(PackageBase,
Name="beautiful-package",
Flagger=user,
Maintainer=user,
Submitter=user,
Packager=user)
assert pkgbase in user.flagged_bases
assert pkgbase in user.maintained_bases
assert pkgbase in user.submitted_bases
@ -75,8 +79,7 @@ def test_package_base_relationships():
def test_package_base_null_name_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(PackageBase)
session.rollback()
with db.begin():
db.create(PackageBase)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create, rollback
from aurweb import db
from aurweb.models.package_base import PackageBase
from aurweb.models.package_blacklist import PackageBlacklist
from aurweb.models.user import User
@ -17,18 +17,20 @@ def setup():
setup_test_db("PackageBlacklist", "PackageBases", "Users")
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword")
pkgbase = create(PackageBase, Name="test-package", Maintainer=user)
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword")
pkgbase = db.create(PackageBase, Name="test-package", Maintainer=user)
def test_package_blacklist_creation():
package_blacklist = create(PackageBlacklist, Name="evil-package")
with db.begin():
package_blacklist = db.create(PackageBlacklist, Name="evil-package")
assert bool(package_blacklist.ID)
assert package_blacklist.Name == "evil-package"
def test_package_blacklist_null_name_raises_exception():
with pytest.raises(IntegrityError):
create(PackageBlacklist)
rollback()
with db.begin():
db.create(PackageBlacklist)
db.rollback()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query, rollback
from aurweb.db import begin, create, query, rollback
from aurweb.models.account_type import AccountType
from aurweb.models.package_base import PackageBase
from aurweb.models.package_comment import PackageComment
@ -20,45 +20,52 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase, Name="test-package", Maintainer=user)
with begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase, Name="test-package", Maintainer=user)
def test_package_comment_creation():
package_comment = create(PackageComment,
PackageBase=pkgbase,
User=user,
Comments="Test comment.",
RenderedComment="Test rendered comment.")
with begin():
package_comment = create(PackageComment,
PackageBase=pkgbase,
User=user,
Comments="Test comment.",
RenderedComment="Test rendered comment.")
assert bool(package_comment.ID)
def test_package_comment_null_package_base_raises_exception():
with pytest.raises(IntegrityError):
create(PackageComment, User=user, Comments="Test comment.",
RenderedComment="Test rendered comment.")
with begin():
create(PackageComment, User=user, Comments="Test comment.",
RenderedComment="Test rendered comment.")
rollback()
def test_package_comment_null_user_raises_exception():
with pytest.raises(IntegrityError):
create(PackageComment, PackageBase=pkgbase, Comments="Test comment.",
RenderedComment="Test rendered comment.")
with begin():
create(PackageComment, PackageBase=pkgbase,
Comments="Test comment.",
RenderedComment="Test rendered comment.")
rollback()
def test_package_comment_null_comments_raises_exception():
with pytest.raises(IntegrityError):
create(PackageComment, PackageBase=pkgbase, User=user,
RenderedComment="Test rendered comment.")
with begin():
create(PackageComment, PackageBase=pkgbase, User=user,
RenderedComment="Test rendered comment.")
rollback()
def test_package_comment_null_renderedcomment_defaults():
record = create(PackageComment,
PackageBase=pkgbase,
User=user,
Comments="Test comment.")
with begin():
record = create(PackageComment,
PackageBase=pkgbase,
User=user,
Comments="Test comment.")
assert record.RenderedComment == str()

View file

@ -2,7 +2,8 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query
from aurweb import db
from aurweb.db import create, query
from aurweb.models.account_type import AccountType
from aurweb.models.dependency_type import DependencyType
from aurweb.models.package import Package
@ -22,25 +23,28 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase,
Name="test-package",
Maintainer=user)
package = create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name,
Description="Test description.",
URL="https://test.package")
with db.begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase,
Name="test-package",
Maintainer=user)
package = create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name,
Description="Test description.",
URL="https://test.package")
def test_package_dependencies():
depends = query(DependencyType, DependencyType.Name == "depends").first()
pkgdep = create(PackageDependency, Package=package,
DependencyType=depends,
DepName="test-dep")
with db.begin():
pkgdep = create(PackageDependency, Package=package,
DependencyType=depends,
DepName="test-dep")
assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package
assert pkgdep.DependencyType == depends
@ -49,8 +53,8 @@ def test_package_dependencies():
makedepends = query(DependencyType,
DependencyType.Name == "makedepends").first()
pkgdep.DependencyType = makedepends
commit()
with db.begin():
pkgdep.DependencyType = makedepends
assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package
assert pkgdep.DependencyType == makedepends
@ -59,8 +63,8 @@ def test_package_dependencies():
checkdepends = query(DependencyType,
DependencyType.Name == "checkdepends").first()
pkgdep.DependencyType = checkdepends
commit()
with db.begin():
pkgdep.DependencyType = checkdepends
assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package
assert pkgdep.DependencyType == checkdepends
@ -69,8 +73,8 @@ def test_package_dependencies():
optdepends = query(DependencyType,
DependencyType.Name == "optdepends").first()
pkgdep.DependencyType = optdepends
commit()
with db.begin():
pkgdep.DependencyType = optdepends
assert pkgdep.DepName == "test-dep"
assert pkgdep.Package == package
assert pkgdep.DependencyType == optdepends
@ -79,39 +83,37 @@ def test_package_dependencies():
assert not pkgdep.is_package()
base = create(PackageBase, Name=pkgdep.DepName, Maintainer=user)
create(Package, PackageBase=base, Name=pkgdep.DepName)
with db.begin():
base = create(PackageBase, Name=pkgdep.DepName, Maintainer=user)
create(Package, PackageBase=base, Name=pkgdep.DepName)
assert pkgdep.is_package()
def test_package_dependencies_null_package_raises_exception():
from aurweb.db import session
depends = query(DependencyType, DependencyType.Name == "depends").first()
with pytest.raises(IntegrityError):
create(PackageDependency,
DependencyType=depends,
DepName="test-dep")
session.rollback()
with db.begin():
create(PackageDependency,
DependencyType=depends,
DepName="test-dep")
db.rollback()
def test_package_dependencies_null_dependency_type_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(PackageDependency,
Package=package,
DepName="test-dep")
session.rollback()
with db.begin():
create(PackageDependency,
Package=package,
DepName="test-dep")
db.rollback()
def test_package_dependencies_null_depname_raises_exception():
from aurweb.db import session
depends = query(DependencyType, DependencyType.Name == "depends").first()
with pytest.raises(IntegrityError):
create(PackageDependency,
Package=package,
DependencyType=depends)
session.rollback()
with db.begin():
create(PackageDependency,
Package=package,
DependencyType=depends)
db.rollback()

View file

@ -2,7 +2,8 @@ import pytest
from sqlalchemy.exc import IntegrityError, OperationalError
from aurweb.db import commit, create, query
from aurweb import db
from aurweb.db import create, query
from aurweb.models.account_type import AccountType
from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase
@ -22,25 +23,28 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase,
Name="test-package",
Maintainer=user)
package = create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name,
Description="Test description.",
URL="https://test.package")
with db.begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase,
Name="test-package",
Maintainer=user)
package = create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name,
Description="Test description.",
URL="https://test.package")
def test_package_relation():
conflicts = query(RelationType, RelationType.Name == "conflicts").first()
pkgrel = create(PackageRelation, Package=package,
RelationType=conflicts,
RelName="test-relation")
with db.begin():
pkgrel = create(PackageRelation, Package=package,
RelationType=conflicts,
RelName="test-relation")
assert pkgrel.RelName == "test-relation"
assert pkgrel.Package == package
assert pkgrel.RelationType == conflicts
@ -48,8 +52,8 @@ def test_package_relation():
assert pkgrel in package.package_relations
provides = query(RelationType, RelationType.Name == "provides").first()
pkgrel.RelationType = provides
commit()
with db.begin():
pkgrel.RelationType = provides
assert pkgrel.RelName == "test-relation"
assert pkgrel.Package == package
assert pkgrel.RelationType == provides
@ -57,8 +61,8 @@ def test_package_relation():
assert pkgrel in package.package_relations
replaces = query(RelationType, RelationType.Name == "replaces").first()
pkgrel.RelationType = replaces
commit()
with db.begin():
pkgrel.RelationType = replaces
assert pkgrel.RelName == "test-relation"
assert pkgrel.Package == package
assert pkgrel.RelationType == replaces
@ -67,36 +71,33 @@ def test_package_relation():
def test_package_relation_null_package_raises_exception():
from aurweb.db import session
conflicts = query(RelationType, RelationType.Name == "conflicts").first()
assert conflicts is not None
with pytest.raises(IntegrityError):
create(PackageRelation,
RelationType=conflicts,
RelName="test-relation")
session.rollback()
with db.begin():
create(PackageRelation,
RelationType=conflicts,
RelName="test-relation")
db.rollback()
def test_package_relation_null_relation_type_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(PackageRelation,
Package=package,
RelName="test-relation")
session.rollback()
with db.begin():
create(PackageRelation,
Package=package,
RelName="test-relation")
db.rollback()
def test_package_relation_null_relname_raises_exception():
from aurweb.db import session
depends = query(RelationType, RelationType.Name == "conflicts").first()
assert depends is not None
with pytest.raises((OperationalError, IntegrityError)):
create(PackageRelation,
Package=package,
RelationType=depends)
session.rollback()
with db.begin():
create(PackageRelation,
Package=package,
RelationType=depends)
db.rollback()

View file

@ -4,7 +4,8 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query, rollback
from aurweb import db
from aurweb.db import create, query, rollback
from aurweb.models.package_base import PackageBase
from aurweb.models.package_request import (ACCEPTED, ACCEPTED_ID, CLOSED, CLOSED_ID, PENDING, PENDING_ID, REJECTED,
REJECTED_ID, PackageRequest)
@ -21,19 +22,21 @@ def setup():
setup_test_db("PackageRequests", "PackageBases", "Users")
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword")
pkgbase = create(PackageBase, Name="test-package", Maintainer=user)
with db.begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword")
pkgbase = create(PackageBase, Name="test-package", Maintainer=user)
def test_package_request_creation():
request_type = query(RequestType, RequestType.Name == "merge").first()
assert request_type.Name == "merge"
package_request = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
with db.begin():
package_request = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
assert bool(package_request.ID)
assert package_request.RequestType == request_type
@ -54,11 +57,12 @@ def test_package_request_closed():
assert request_type.Name == "merge"
ts = int(datetime.utcnow().timestamp())
package_request = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
Closer=user, ClosedTS=ts,
Comments=str(), ClosureComment=str())
with db.begin():
package_request = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
Closer=user, ClosedTS=ts,
Comments=str(), ClosureComment=str())
assert package_request.Closer == user
assert package_request.ClosedTS == ts
@ -69,54 +73,60 @@ def test_package_request_closed():
def test_package_request_null_request_type_raises_exception():
with pytest.raises(IntegrityError):
create(PackageRequest, User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
with db.begin():
create(PackageRequest, User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
rollback()
def test_package_request_null_user_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
with db.begin():
create(PackageRequest, RequestType=request_type,
PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
rollback()
def test_package_request_null_package_base_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type,
User=user, PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
with db.begin():
create(PackageRequest, RequestType=request_type,
User=user, PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str())
rollback()
def test_package_request_null_package_base_name_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
Comments=str(), ClosureComment=str())
with db.begin():
create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
Comments=str(), ClosureComment=str())
rollback()
def test_package_request_null_comments_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
ClosureComment=str())
with db.begin():
create(PackageRequest, RequestType=request_type, User=user,
PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
ClosureComment=str())
rollback()
def test_package_request_null_closure_comment_raises_exception():
request_type = query(RequestType, RequestType.Name == "merge").first()
with pytest.raises(IntegrityError):
create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
Comments=str())
with db.begin():
create(PackageRequest, RequestType=request_type, User=user,
PackageBase=pkgbase, PackageBaseName=pkgbase.Name,
Comments=str())
rollback()
@ -124,26 +134,27 @@ def test_package_request_status_display():
""" Test status_display() based on the Status column value. """
request_type = query(RequestType, RequestType.Name == "merge").first()
pkgreq = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str(),
Status=PENDING_ID)
with db.begin():
pkgreq = create(PackageRequest, RequestType=request_type,
User=user, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
Comments=str(), ClosureComment=str(),
Status=PENDING_ID)
assert pkgreq.status_display() == PENDING
pkgreq.Status = CLOSED_ID
commit()
with db.begin():
pkgreq.Status = CLOSED_ID
assert pkgreq.status_display() == CLOSED
pkgreq.Status = ACCEPTED_ID
commit()
with db.begin():
pkgreq.Status = ACCEPTED_ID
assert pkgreq.status_display() == ACCEPTED
pkgreq.Status = REJECTED_ID
commit()
with db.begin():
pkgreq.Status = REJECTED_ID
assert pkgreq.status_display() == REJECTED
pkgreq.Status = 124
commit()
with db.begin():
pkgreq.Status = 124
with pytest.raises(KeyError):
pkgreq.status_display()

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create, query, rollback
from aurweb.db import begin, create, query, rollback
from aurweb.models.account_type import AccountType
from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase
@ -21,17 +21,19 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase,
Name="test-package",
Maintainer=user)
package = create(Package, PackageBase=pkgbase, Name="test-package")
with begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
pkgbase = create(PackageBase,
Name="test-package",
Maintainer=user)
package = create(Package, PackageBase=pkgbase, Name="test-package")
def test_package_source():
pkgsource = create(PackageSource, Package=package)
with begin():
pkgsource = create(PackageSource, Package=package)
assert pkgsource.Package == package
# By default, PackageSources.Source assigns the string '/dev/null'.
assert pkgsource.Source == "/dev/null"
@ -40,5 +42,6 @@ def test_package_source():
def test_package_source_null_package_raises_exception():
with pytest.raises(IntegrityError):
create(PackageSource)
with begin():
create(PackageSource)
rollback()

View file

@ -28,31 +28,25 @@ def package_endpoint(package: Package) -> str:
return f"/packages/{package.Name}"
def create_package(pkgname: str, maintainer: User,
autocommit: bool = True) -> Package:
def create_package(pkgname: str, maintainer: User) -> Package:
pkgbase = db.create(PackageBase,
Name=pkgname,
Maintainer=maintainer,
autocommit=False)
return db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase,
autocommit=autocommit)
Maintainer=maintainer)
return db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
def create_package_dep(package: Package, depname: str,
dep_type_name: str = "depends",
autocommit: bool = True) -> PackageDependency:
dep_type_name: str = "depends") -> PackageDependency:
dep_type = db.query(DependencyType,
DependencyType.Name == dep_type_name).first()
return db.create(PackageDependency,
DependencyType=dep_type,
Package=package,
DepName=depname,
autocommit=autocommit)
DepName=depname)
def create_package_rel(package: Package,
relname: str,
autocommit: bool = True) -> PackageRelation:
relname: str) -> PackageRelation:
rel_type = db.query(RelationType,
RelationType.ID == PROVIDES_ID).first()
return db.create(PackageRelation,
@ -84,31 +78,37 @@ def client() -> TestClient:
def user() -> User:
""" Yield a user. """
account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test",
Email="test@example.org",
Passwd="testPassword",
AccountType=account_type)
with db.begin():
user = db.create(User, Username="test",
Email="test@example.org",
Passwd="testPassword",
AccountType=account_type)
yield user
@pytest.fixture
def maintainer() -> User:
""" Yield a specific User used to maintain packages. """
account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test_maintainer",
Email="test_maintainer@example.org",
Passwd="testPassword",
AccountType=account_type)
with db.begin():
maintainer = db.create(User, Username="test_maintainer",
Email="test_maintainer@example.org",
Passwd="testPassword",
AccountType=account_type)
yield maintainer
@pytest.fixture
def package(maintainer: User) -> Package:
""" Yield a Package created by user. """
pkgbase = db.create(PackageBase,
Name="test-package",
Maintainer=maintainer)
yield db.create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name)
with db.begin():
pkgbase = db.create(PackageBase,
Name="test-package",
Maintainer=maintainer)
package = db.create(Package,
PackageBase=pkgbase,
Name=pkgbase.Name)
yield package
def test_package_not_found(client: TestClient):
@ -121,10 +121,11 @@ def test_package_official_not_found(client: TestClient, package: Package):
""" When a Package has a matching OfficialProvider record, it is not
hosted on AUR, but in the official repositories. Getting a package
with this kind of record should return a status code 404. """
db.create(OfficialProvider,
Name=package.Name,
Repo="core",
Provides=package.Name)
with db.begin():
db.create(OfficialProvider,
Name=package.Name,
Repo="core",
Provides=package.Name)
with client as request:
resp = request.get(package_endpoint(package))
@ -157,8 +158,9 @@ def test_package(client: TestClient, package: Package):
def test_package_comments(client: TestClient, user: User, package: Package):
now = (datetime.utcnow().timestamp())
comment = db.create(PackageComment, PackageBase=package.PackageBase,
User=user, Comments="Test comment", CommentTS=now)
with db.begin():
comment = db.create(PackageComment, PackageBase=package.PackageBase,
User=user, Comments="Test comment", CommentTS=now)
cookies = {"AURSID": user.login(Request(), "testPassword")}
with client as request:
@ -178,11 +180,12 @@ def test_package_comments(client: TestClient, user: User, package: Package):
def test_package_requests_display(client: TestClient, user: User,
package: Package):
type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first()
db.create(PackageRequest, PackageBase=package.PackageBase,
PackageBaseName=package.PackageBase.Name,
User=user, RequestType=type_,
Comments="Test comment.",
ClosureComment=str())
with db.begin():
db.create(PackageRequest, PackageBase=package.PackageBase,
PackageBaseName=package.PackageBase.Name,
User=user, RequestType=type_,
Comments="Test comment.",
ClosureComment=str())
# Test that a single request displays "1 pending request".
with client as request:
@ -195,11 +198,12 @@ def test_package_requests_display(client: TestClient, user: User,
assert target.text.strip() == "1 pending request"
type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first()
db.create(PackageRequest, PackageBase=package.PackageBase,
PackageBaseName=package.PackageBase.Name,
User=user, RequestType=type_,
Comments="Test comment2.",
ClosureComment=str())
with db.begin():
db.create(PackageRequest, PackageBase=package.PackageBase,
PackageBaseName=package.PackageBase.Name,
User=user, RequestType=type_,
Comments="Test comment2.",
ClosureComment=str())
# Test that a two requests display "2 pending requests".
with client as request:
@ -271,50 +275,43 @@ def test_package_authenticated_maintainer(client: TestClient,
def test_package_dependencies(client: TestClient, maintainer: User,
package: Package):
# Create a normal dependency of type depends.
dep_pkg = create_package("test-dep-1", maintainer, autocommit=False)
dep = create_package_dep(package, dep_pkg.Name, autocommit=False)
dep.DepArch = "x86_64"
with db.begin():
dep_pkg = create_package("test-dep-1", maintainer)
dep = create_package_dep(package, dep_pkg.Name)
dep.DepArch = "x86_64"
# Also, create a makedepends.
make_dep_pkg = create_package("test-dep-2", maintainer, autocommit=False)
make_dep = create_package_dep(package, make_dep_pkg.Name,
dep_type_name="makedepends",
autocommit=False)
# Also, create a makedepends.
make_dep_pkg = create_package("test-dep-2", maintainer)
make_dep = create_package_dep(package, make_dep_pkg.Name,
dep_type_name="makedepends")
# And... a checkdepends!
check_dep_pkg = create_package("test-dep-3", maintainer, autocommit=False)
check_dep = create_package_dep(package, check_dep_pkg.Name,
dep_type_name="checkdepends",
autocommit=False)
# And... a checkdepends!
check_dep_pkg = create_package("test-dep-3", maintainer)
check_dep = create_package_dep(package, check_dep_pkg.Name,
dep_type_name="checkdepends")
# Geez. Just stop. This is optdepends.
opt_dep_pkg = create_package("test-dep-4", maintainer, autocommit=False)
opt_dep = create_package_dep(package, opt_dep_pkg.Name,
dep_type_name="optdepends",
autocommit=False)
# Geez. Just stop. This is optdepends.
opt_dep_pkg = create_package("test-dep-4", maintainer)
opt_dep = create_package_dep(package, opt_dep_pkg.Name,
dep_type_name="optdepends")
# Heh. Another optdepends to test one with a description.
opt_desc_dep_pkg = create_package("test-dep-5", maintainer,
autocommit=False)
opt_desc_dep = create_package_dep(package, opt_desc_dep_pkg.Name,
dep_type_name="optdepends",
autocommit=False)
opt_desc_dep.DepDesc = "Test description."
# Heh. Another optdepends to test one with a description.
opt_desc_dep_pkg = create_package("test-dep-5", maintainer)
opt_desc_dep = create_package_dep(package, opt_desc_dep_pkg.Name,
dep_type_name="optdepends")
opt_desc_dep.DepDesc = "Test description."
broken_dep = create_package_dep(package, "test-dep-6",
dep_type_name="depends",
autocommit=False)
broken_dep = create_package_dep(package, "test-dep-6",
dep_type_name="depends")
# Create an official provider record.
db.create(OfficialProvider, Name="test-dep-99",
Repo="core", Provides="test-dep-99",
autocommit=False)
official_dep = create_package_dep(package, "test-dep-99",
autocommit=False)
# Create an official provider record.
db.create(OfficialProvider, Name="test-dep-99",
Repo="core", Provides="test-dep-99")
official_dep = create_package_dep(package, "test-dep-99")
# Also, create a provider who provides our test-dep-99.
provider = create_package("test-provider", maintainer, autocommit=False)
create_package_rel(provider, dep.DepName)
# Also, create a provider who provides our test-dep-99.
provider = create_package("test-provider", maintainer)
create_package_rel(provider, dep.DepName)
with client as request:
resp = request.get(package_endpoint(package))
@ -358,8 +355,9 @@ def test_pkgbase_redirect(client: TestClient, package: Package):
def test_pkgbase(client: TestClient, package: Package):
second = db.create(Package, Name="second-pkg",
PackageBase=package.PackageBase)
with db.begin():
second = db.create(Package, Name="second-pkg",
PackageBase=package.PackageBase)
expected = [package.Name, second.Name]
with client as request:

View file

@ -26,17 +26,21 @@ def setup():
@pytest.fixture
def maintainer() -> User:
account_type = db.query(AccountType, AccountType.ID == USER_ID).first()
yield db.create(User, Username="test_maintainer",
Email="test_maintainer@examepl.org",
Passwd="testPassword",
AccountType=account_type)
with db.begin():
maintainer = db.create(User, Username="test_maintainer",
Email="test_maintainer@examepl.org",
Passwd="testPassword",
AccountType=account_type)
yield maintainer
@pytest.fixture
def package(maintainer: User) -> Package:
pkgbase = db.create(PackageBase, Name="test-pkg",
Packager=maintainer, Maintainer=maintainer)
yield db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
with db.begin():
pkgbase = db.create(PackageBase, Name="test-pkg",
Packager=maintainer, Maintainer=maintainer)
package = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
yield package
@pytest.fixture
@ -45,10 +49,11 @@ def client() -> TestClient:
def test_package_link(client: TestClient, maintainer: User, package: Package):
db.create(OfficialProvider,
Name=package.Name,
Repo="core",
Provides=package.Name)
with db.begin():
db.create(OfficialProvider,
Name=package.Name,
Repo="core",
Provides=package.Name)
expected = f"{OFFICIAL_BASE}/packages/?q={package.Name}"
assert util.package_link(package) == expected

View file

@ -1,6 +1,6 @@
import pytest
from aurweb.db import create, delete, query
from aurweb import db
from aurweb.models.relation_type import RelationType
from aurweb.testing import setup_test_db
@ -11,22 +11,25 @@ def setup():
def test_relation_type_creation():
relation_type = create(RelationType, Name="test-relation")
with db.begin():
relation_type = db.create(RelationType, Name="test-relation")
assert bool(relation_type.ID)
assert relation_type.Name == "test-relation"
delete(RelationType, RelationType.ID == relation_type.ID)
with db.begin():
db.delete(RelationType, RelationType.ID == relation_type.ID)
def test_relation_types():
conflicts = query(RelationType, RelationType.Name == "conflicts").first()
conflicts = db.query(RelationType, RelationType.Name == "conflicts").first()
assert conflicts is not None
assert conflicts.Name == "conflicts"
provides = query(RelationType, RelationType.Name == "provides").first()
provides = db.query(RelationType, RelationType.Name == "provides").first()
assert provides is not None
assert provides.Name == "provides"
replaces = query(RelationType, RelationType.Name == "replaces").first()
replaces = db.query(RelationType, RelationType.Name == "replaces").first()
assert replaces is not None
assert replaces.Name == "replaces"

View file

@ -1,6 +1,6 @@
import pytest
from aurweb.db import create, delete, query
from aurweb import db
from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID, RequestType
from aurweb.testing import setup_test_db
@ -11,25 +11,33 @@ def setup():
def test_request_type_creation():
request_type = create(RequestType, Name="Test Request")
with db.begin():
request_type = db.create(RequestType, Name="Test Request")
assert bool(request_type.ID)
assert request_type.Name == "Test Request"
delete(RequestType, RequestType.ID == request_type.ID)
with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID)
def test_request_type_null_name_returns_empty_string():
request_type = create(RequestType)
with db.begin():
request_type = db.create(RequestType)
assert bool(request_type.ID)
assert request_type.Name == str()
delete(RequestType, RequestType.ID == request_type.ID)
with db.begin():
db.delete(RequestType, RequestType.ID == request_type.ID)
def test_request_type_name_display():
deletion = query(RequestType, RequestType.ID == DELETION_ID).first()
deletion = db.query(RequestType, RequestType.ID == DELETION_ID).first()
assert deletion.name_display() == "Deletion"
orphan = query(RequestType, RequestType.ID == ORPHAN_ID).first()
orphan = db.query(RequestType, RequestType.ID == ORPHAN_ID).first()
assert orphan.name_display() == "Orphan"
merge = query(RequestType, RequestType.ID == MERGE_ID).first()
merge = db.query(RequestType, RequestType.ID == MERGE_ID).first()
assert merge.name_display() == "Merge"

View file

@ -8,8 +8,8 @@ import pytest
from fastapi.testclient import TestClient
from aurweb import db
from aurweb.asgi import app
from aurweb.db import create, query
from aurweb.models.account_type import AccountType
from aurweb.models.user import User
from aurweb.testing import setup_test_db
@ -24,11 +24,13 @@ def setup():
setup_test_db("Users", "Sessions")
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
client = TestClient(app)

View file

@ -49,14 +49,13 @@ def packages(user):
now = int(datetime.utcnow().timestamp())
# Create 101 packages; we limit 100 on RSS feeds.
for i in range(101):
pkgbase = db.create(
PackageBase, Maintainer=user, Name=f"test-package-{i}",
SubmittedTS=(now + i), ModifiedTS=(now + i), autocommit=False)
pkg = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase,
autocommit=False)
pkgs.append(pkg)
db.commit()
with db.begin():
for i in range(101):
pkgbase = db.create(
PackageBase, Maintainer=user, Name=f"test-package-{i}",
SubmittedTS=(now + i), ModifiedTS=(now + i))
pkg = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase)
pkgs.append(pkg)
yield pkgs

View file

@ -4,7 +4,7 @@ from unittest import mock
import pytest
from aurweb.db import create, query
from aurweb import db
from aurweb.models.account_type import AccountType
from aurweb.models.session import Session, generate_unique_sid
from aurweb.models.user import User
@ -19,13 +19,16 @@ def setup():
setup_test_db("Users", "Sessions")
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
ResetKey="testReset", Passwd="testPassword",
AccountType=account_type)
session = create(Session, UsersID=user.ID, SessionID="testSession",
LastUpdateTS=datetime.utcnow().timestamp())
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
ResetKey="testReset", Passwd="testPassword",
AccountType=account_type)
with db.begin():
session = db.create(Session, UsersID=user.ID, SessionID="testSession",
LastUpdateTS=datetime.utcnow().timestamp())
def test_session():
@ -35,12 +38,15 @@ def test_session():
def test_session_cs():
""" Test case sensitivity of the database table. """
user2 = create(User, Username="test2", Email="test2@example.org",
ResetKey="testReset2", Passwd="testPassword",
AccountType=account_type)
session_cs = create(Session, UsersID=user2.ID,
SessionID="TESTSESSION",
LastUpdateTS=datetime.utcnow().timestamp())
with db.begin():
user2 = db.create(User, Username="test2", Email="test2@example.org",
ResetKey="testReset2", Passwd="testPassword",
AccountType=account_type)
with db.begin():
session_cs = db.create(Session, UsersID=user2.ID,
SessionID="TESTSESSION",
LastUpdateTS=datetime.utcnow().timestamp())
assert session_cs.SessionID == "TESTSESSION"
assert session.SessionID == "testSession"

View file

@ -1,6 +1,6 @@
import pytest
from aurweb.db import create, query
from aurweb import db
from aurweb.models.account_type import AccountType
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
from aurweb.models.user import User
@ -19,19 +19,18 @@ def setup():
setup_test_db("Users", "SSHPubKeys")
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
assert account_type == user.AccountType
assert account_type.ID == user.AccountTypeID
ssh_pub_key = create(SSHPubKey,
UserID=user.ID,
Fingerprint="testFingerprint",
PubKey="testPubKey")
with db.begin():
ssh_pub_key = db.create(SSHPubKey,
UserID=user.ID,
Fingerprint="testFingerprint",
PubKey="testPubKey")
def test_ssh_pub_key():
@ -43,9 +42,10 @@ def test_ssh_pub_key():
def test_ssh_pub_key_cs():
""" Test case sensitivity of the database table. """
ssh_pub_key_cs = create(SSHPubKey, UserID=user.ID,
Fingerprint="TESTFINGERPRINT",
PubKey="TESTPUBKEY")
with db.begin():
ssh_pub_key_cs = db.create(SSHPubKey, UserID=user.ID,
Fingerprint="TESTFINGERPRINT",
PubKey="TESTPUBKEY")
assert ssh_pub_key_cs.Fingerprint == "TESTFINGERPRINT"
assert ssh_pub_key_cs.PubKey == "TESTPUBKEY"

View file

@ -2,7 +2,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import create
from aurweb import db
from aurweb.models.term import Term
from aurweb.testing import setup_test_db
@ -18,8 +18,9 @@ def setup():
def test_term_creation():
term = create(Term, Description="Term description",
URL="https://fake_url.io")
with db.begin():
term = db.create(Term, Description="Term description",
URL="https://fake_url.io")
assert bool(term.ID)
assert term.Description == "Term description"
assert term.URL == "https://fake_url.io"
@ -27,14 +28,14 @@ def test_term_creation():
def test_term_null_description_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(Term, URL="https://fake_url.io")
session.rollback()
with db.begin():
db.create(Term, URL="https://fake_url.io")
db.rollback()
def test_term_null_url_raises_exception():
from aurweb.db import session
with pytest.raises(IntegrityError):
create(Term, Description="Term description")
session.rollback()
with db.begin():
db.create(Term, Description="Term description")
db.rollback()

View file

@ -90,37 +90,37 @@ def client():
def tu_user():
tu_type = db.query(AccountType,
AccountType.AccountType == "Trusted User").first()
yield db.create(User, Username="test_tu", Email="test_tu@example.org",
RealName="Test TU", Passwd="testPassword",
AccountType=tu_type)
with db.begin():
tu_user = db.create(User, Username="test_tu",
Email="test_tu@example.org",
RealName="Test TU", Passwd="testPassword",
AccountType=tu_type)
yield tu_user
@pytest.fixture
def user():
user_type = db.query(AccountType,
AccountType.AccountType == "User").first()
yield db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=user_type)
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=user_type)
yield user
@pytest.fixture
def proposal(tu_user):
def proposal(user, tu_user):
ts = int(datetime.utcnow().timestamp())
agenda = "Test proposal."
start = ts - 5
end = ts + 1000
user_type = db.query(AccountType,
AccountType.AccountType == "User").first()
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=user_type)
voteinfo = db.create(TUVoteInfo,
Agenda=agenda, Quorum=0.0,
User=user.Username, Submitter=tu_user,
Submitted=start, End=end)
with db.begin():
voteinfo = db.create(TUVoteInfo,
Agenda=agenda, Quorum=0.0,
User=user.Username, Submitter=tu_user,
Submitted=start, End=end)
yield (tu_user, user, voteinfo)
@ -170,20 +170,22 @@ def test_tu_index(client, tu_user):
("Test agenda 2", ts - 1000, ts - 5) # Not running anymore.
]
vote_records = []
for vote in votes:
agenda, start, end = vote
vote_records.append(
db.create(TUVoteInfo, Agenda=agenda,
User=tu_user.Username,
Submitted=start, End=end,
Quorum=0.0,
Submitter=tu_user))
with db.begin():
for vote in votes:
agenda, start, end = vote
vote_records.append(
db.create(TUVoteInfo, Agenda=agenda,
User=tu_user.Username,
Submitted=start, End=end,
Quorum=0.0,
Submitter=tu_user))
# Vote on an ended proposal.
vote_record = vote_records[1]
vote_record.Yes += 1
vote_record.ActiveTUs += 1
db.create(TUVote, VoteInfo=vote_record, User=tu_user)
with db.begin():
# Vote on an ended proposal.
vote_record = vote_records[1]
vote_record.Yes += 1
vote_record.ActiveTUs += 1
db.create(TUVote, VoteInfo=vote_record, User=tu_user)
cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request:
@ -255,22 +257,22 @@ def test_tu_index(client, tu_user):
def test_tu_index_table_paging(client, tu_user):
ts = int(datetime.utcnow().timestamp())
for i in range(25):
# Create 25 current votes.
db.create(TUVoteInfo, Agenda=f"Agenda #{i}",
User=tu_user.Username,
Submitted=(ts - 5), End=(ts + 1000),
Quorum=0.0,
Submitter=tu_user, autocommit=False)
with db.begin():
for i in range(25):
# Create 25 current votes.
db.create(TUVoteInfo, Agenda=f"Agenda #{i}",
User=tu_user.Username,
Submitted=(ts - 5), End=(ts + 1000),
Quorum=0.0,
Submitter=tu_user)
for i in range(25):
# Create 25 past votes.
db.create(TUVoteInfo, Agenda=f"Agenda #{25 + i}",
User=tu_user.Username,
Submitted=(ts - 1000), End=(ts - 5),
Quorum=0.0,
Submitter=tu_user, autocommit=False)
db.commit()
for i in range(25):
# Create 25 past votes.
db.create(TUVoteInfo, Agenda=f"Agenda #{25 + i}",
User=tu_user.Username,
Submitted=(ts - 1000), End=(ts - 5),
Quorum=0.0,
Submitter=tu_user)
cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request:
@ -363,18 +365,19 @@ def test_tu_index_table_paging(client, tu_user):
def test_tu_index_sorting(client, tu_user):
ts = int(datetime.utcnow().timestamp())
for i in range(2):
# Create 'Agenda #1' and 'Agenda #2'.
db.create(TUVoteInfo, Agenda=f"Agenda #{i + 1}",
User=tu_user.Username,
Submitted=(ts + 5), End=(ts + 1000),
Quorum=0.0,
Submitter=tu_user, autocommit=False)
with db.begin():
for i in range(2):
# Create 'Agenda #1' and 'Agenda #2'.
db.create(TUVoteInfo, Agenda=f"Agenda #{i + 1}",
User=tu_user.Username,
Submitted=(ts + 5), End=(ts + 1000),
Quorum=0.0,
Submitter=tu_user)
# Let's order each vote one day after the other.
# This will allow us to test the sorting nature
# of the tables.
ts += 86405
# Let's order each vote one day after the other.
# This will allow us to test the sorting nature
# of the tables.
ts += 86405
# Make a default request to /tu.
cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
@ -432,18 +435,19 @@ def test_tu_index_sorting(client, tu_user):
def test_tu_index_last_votes(client, tu_user, user):
ts = int(datetime.utcnow().timestamp())
# Create a proposal which has ended.
voteinfo = db.create(TUVoteInfo, Agenda="Test agenda",
User=user.Username,
Submitted=(ts - 1000),
End=(ts - 5),
Yes=1,
ActiveTUs=1,
Quorum=0.0,
Submitter=tu_user)
with db.begin():
# Create a proposal which has ended.
voteinfo = db.create(TUVoteInfo, Agenda="Test agenda",
User=user.Username,
Submitted=(ts - 1000),
End=(ts - 5),
Yes=1,
ActiveTUs=1,
Quorum=0.0,
Submitter=tu_user)
# Create a vote on it from tu_user.
db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
# Create a vote on it from tu_user.
db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
# Now, check that tu_user got populated in the .last-votes table.
cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
@ -529,10 +533,10 @@ def test_tu_running_proposal(client, proposal):
assert abstain.attrib["value"] == "Abstain"
# Create a vote.
db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
voteinfo.ActiveTUs += 1
voteinfo.Yes += 1
db.commit()
with db.begin():
db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
voteinfo.ActiveTUs += 1
voteinfo.Yes += 1
# Make another request now that we've voted.
with client as request:
@ -556,8 +560,8 @@ def test_tu_ended_proposal(client, proposal):
tu_user, user, voteinfo = proposal
ts = int(datetime.utcnow().timestamp())
voteinfo.End = ts - 5 # 5 seconds ago.
db.commit()
with db.begin():
voteinfo.End = ts - 5 # 5 seconds ago.
# Initiate an authenticated GET request to /tu/{proposal_id}.
proposal_id = voteinfo.ID
@ -635,8 +639,8 @@ def test_tu_proposal_vote_unauthorized(client, proposal):
dev_type = db.query(AccountType,
AccountType.AccountType == "Developer").first()
tu_user.AccountType = dev_type
db.commit()
with db.begin():
tu_user.AccountType = dev_type
cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request:
@ -664,8 +668,8 @@ def test_tu_proposal_vote_cant_self_vote(client, proposal):
tu_user, user, voteinfo = proposal
# Update voteinfo.User.
voteinfo.User = tu_user.Username
db.commit()
with db.begin():
voteinfo.User = tu_user.Username
cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request:
@ -692,10 +696,10 @@ def test_tu_proposal_vote_cant_self_vote(client, proposal):
def test_tu_proposal_vote_already_voted(client, proposal):
tu_user, user, voteinfo = proposal
db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
voteinfo.Yes += 1
voteinfo.ActiveTUs += 1
db.commit()
with db.begin():
db.create(TUVote, VoteInfo=voteinfo, User=tu_user)
voteinfo.Yes += 1
voteinfo.ActiveTUs += 1
cookies = {"AURSID": tu_user.login(Request(), "testPassword")}
with client as request:

View file

@ -4,7 +4,8 @@ import pytest
from sqlalchemy.exc import IntegrityError
from aurweb.db import commit, create, query, rollback
from aurweb import db
from aurweb.db import create, query, rollback
from aurweb.models.account_type import AccountType
from aurweb.models.tu_voteinfo import TUVoteInfo
from aurweb.models.user import User
@ -21,19 +22,21 @@ def setup():
tu_type = query(AccountType,
AccountType.AccountType == "Trusted User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=tu_type)
with db.begin():
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=tu_type)
def test_tu_voteinfo_creation():
ts = int(datetime.utcnow().timestamp())
tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=ts, End=ts + 5,
Quorum=0.5,
Submitter=user)
with db.begin():
tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=ts, End=ts + 5,
Quorum=0.5,
Submitter=user)
assert bool(tu_voteinfo.ID)
assert tu_voteinfo.Agenda == "Blah blah."
assert tu_voteinfo.User == user.Username
@ -51,32 +54,33 @@ def test_tu_voteinfo_creation():
def test_tu_voteinfo_is_running():
ts = int(datetime.utcnow().timestamp())
tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=ts, End=ts + 1000,
Quorum=0.5,
Submitter=user)
with db.begin():
tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=ts, End=ts + 1000,
Quorum=0.5,
Submitter=user)
assert tu_voteinfo.is_running() is True
tu_voteinfo.End = ts - 5
commit()
with db.begin():
tu_voteinfo.End = ts - 5
assert tu_voteinfo.is_running() is False
def test_tu_voteinfo_total_votes():
ts = int(datetime.utcnow().timestamp())
tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=ts, End=ts + 1000,
Quorum=0.5,
Submitter=user)
with db.begin():
tu_voteinfo = create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=ts, End=ts + 1000,
Quorum=0.5,
Submitter=user)
tu_voteinfo.Yes = 1
tu_voteinfo.No = 3
tu_voteinfo.Abstain = 5
commit()
tu_voteinfo.Yes = 1
tu_voteinfo.No = 3
tu_voteinfo.Abstain = 5
# total_votes() should be the sum of Yes, No and Abstain: 1 + 3 + 5 = 9.
assert tu_voteinfo.total_votes() == 9
@ -84,61 +88,67 @@ def test_tu_voteinfo_total_votes():
def test_tu_voteinfo_null_submitter_raises_exception():
with pytest.raises(IntegrityError):
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=0, End=0,
Quorum=0.50)
with db.begin():
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=0, End=0,
Quorum=0.50)
rollback()
def test_tu_voteinfo_null_agenda_raises_exception():
with pytest.raises(IntegrityError):
create(TUVoteInfo,
User=user.Username,
Submitted=0, End=0,
Quorum=0.50,
Submitter=user)
with db.begin():
create(TUVoteInfo,
User=user.Username,
Submitted=0, End=0,
Quorum=0.50,
Submitter=user)
rollback()
def test_tu_voteinfo_null_user_raises_exception():
with pytest.raises(IntegrityError):
create(TUVoteInfo,
Agenda="Blah blah.",
Submitted=0, End=0,
Quorum=0.50,
Submitter=user)
with db.begin():
create(TUVoteInfo,
Agenda="Blah blah.",
Submitted=0, End=0,
Quorum=0.50,
Submitter=user)
rollback()
def test_tu_voteinfo_null_submitted_raises_exception():
with pytest.raises(IntegrityError):
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
End=0,
Quorum=0.50,
Submitter=user)
with db.begin():
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
End=0,
Quorum=0.50,
Submitter=user)
rollback()
def test_tu_voteinfo_null_end_raises_exception():
with pytest.raises(IntegrityError):
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=0,
Quorum=0.50,
Submitter=user)
with db.begin():
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=0,
Quorum=0.50,
Submitter=user)
rollback()
def test_tu_voteinfo_null_quorum_raises_exception():
with pytest.raises(IntegrityError):
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=0, End=0,
Submitter=user)
with db.begin():
create(TUVoteInfo,
Agenda="Blah blah.",
User=user.Username,
Submitted=0, End=0,
Submitter=user)
rollback()

View file

@ -9,7 +9,7 @@ import pytest
import aurweb.auth
import aurweb.config
from aurweb.db import commit, create, query
from aurweb import db
from aurweb.models.account_type import AccountType
from aurweb.models.ban import Ban
from aurweb.models.package import Package
@ -40,12 +40,13 @@ def setup():
PackageNotification.__tablename__
)
account_type = query(AccountType,
AccountType.AccountType == "User").first()
account_type = db.query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
with db.begin():
user = db.create(User, Username="test", Email="test@example.org",
RealName="Test User", Passwd="testPassword",
AccountType=account_type)
def test_user_login_logout():
@ -70,14 +71,14 @@ def test_user_login_logout():
assert "AURSID" in request.cookies
# Expect that User session relationships work right.
user_session = query(Session,
Session.UsersID == user.ID).first()
user_session = db.query(Session,
Session.UsersID == user.ID).first()
assert user_session == user.session
assert user.session.SessionID == sid
assert user.session.User == user
# Search for the user via query API.
result = query(User, User.ID == user.ID).first()
result = db.query(User, User.ID == user.ID).first()
# Compare the result and our original user.
assert result == user
@ -114,7 +115,8 @@ def test_user_login_twice():
def test_user_login_banned():
# Add ban for the next 30 seconds.
banned_timestamp = datetime.utcnow() + timedelta(seconds=30)
create(Ban, IPAddress="127.0.0.1", BanTS=banned_timestamp)
with db.begin():
db.create(Ban, IPAddress="127.0.0.1", BanTS=banned_timestamp)
request = Request()
request.client.host = "127.0.0.1"
@ -122,18 +124,17 @@ def test_user_login_banned():
def test_user_login_suspended():
from aurweb.db import session
user.Suspended = True
session.commit()
with db.begin():
user.Suspended = True
assert not user.login(Request(), "testPassword")
def test_legacy_user_authentication():
from aurweb.db import session
user.Salt = bcrypt.gensalt().decode()
user.Passwd = hashlib.md5(f"{user.Salt}testPassword".encode()).hexdigest()
session.commit()
with db.begin():
user.Salt = bcrypt.gensalt().decode()
user.Passwd = hashlib.md5(
f"{user.Salt}testPassword".encode()
).hexdigest()
assert not user.valid_password("badPassword")
assert user.valid_password("testPassword")
@ -145,8 +146,9 @@ def test_legacy_user_authentication():
def test_user_login_with_outdated_sid():
# Make a session with a LastUpdateTS 5 seconds ago, causing
# user.login to update it with a new sid.
create(Session, UsersID=user.ID, SessionID="stub",
LastUpdateTS=datetime.utcnow().timestamp() - 5)
with db.begin():
db.create(Session, UsersID=user.ID, SessionID="stub",
LastUpdateTS=datetime.utcnow().timestamp() - 5)
sid = user.login(Request(), "testPassword")
assert sid and user.is_authenticated()
assert sid != "stub"
@ -171,43 +173,42 @@ def test_user_has_credential():
def test_user_ssh_pub_key():
assert user.ssh_pub_key is None
ssh_pub_key = create(SSHPubKey, UserID=user.ID,
Fingerprint="testFingerprint",
PubKey="testPubKey")
with db.begin():
ssh_pub_key = db.create(SSHPubKey, UserID=user.ID,
Fingerprint="testFingerprint",
PubKey="testPubKey")
assert user.ssh_pub_key == ssh_pub_key
def test_user_credential_types():
from aurweb.db import session
assert aurweb.auth.user_developer_or_trusted_user(user)
assert not aurweb.auth.trusted_user(user)
assert not aurweb.auth.developer(user)
assert not aurweb.auth.trusted_user_or_dev(user)
trusted_user_type = query(AccountType,
AccountType.AccountType == "Trusted User")\
.first()
user.AccountType = trusted_user_type
session.commit()
trusted_user_type = db.query(AccountType).filter(
AccountType.AccountType == "Trusted User"
).first()
with db.begin():
user.AccountType = trusted_user_type
assert aurweb.auth.trusted_user(user)
assert aurweb.auth.trusted_user_or_dev(user)
developer_type = query(AccountType,
AccountType.AccountType == "Developer").first()
user.AccountType = developer_type
session.commit()
developer_type = db.query(AccountType,
AccountType.AccountType == "Developer").first()
with db.begin():
user.AccountType = developer_type
assert aurweb.auth.developer(user)
assert aurweb.auth.trusted_user_or_dev(user)
type_str = "Trusted User & Developer"
elevated_type = query(AccountType,
AccountType.AccountType == type_str).first()
user.AccountType = elevated_type
session.commit()
elevated_type = db.query(AccountType,
AccountType.AccountType == type_str).first()
with db.begin():
user.AccountType = elevated_type
assert aurweb.auth.trusted_user(user)
assert aurweb.auth.developer(user)
@ -233,53 +234,56 @@ def test_user_as_dict():
def test_user_is_trusted_user():
tu_type = query(AccountType,
AccountType.AccountType == "Trusted User").first()
user.AccountType = tu_type
commit()
tu_type = db.query(AccountType,
AccountType.AccountType == "Trusted User").first()
with db.begin():
user.AccountType = tu_type
assert user.is_trusted_user() is True
# Do it again with the combined role.
tu_type = query(
tu_type = db.query(
AccountType,
AccountType.AccountType == "Trusted User & Developer").first()
user.AccountType = tu_type
commit()
with db.begin():
user.AccountType = tu_type
assert user.is_trusted_user() is True
def test_user_is_developer():
dev_type = query(AccountType,
AccountType.AccountType == "Developer").first()
user.AccountType = dev_type
commit()
dev_type = db.query(AccountType,
AccountType.AccountType == "Developer").first()
with db.begin():
user.AccountType = dev_type
assert user.is_developer() is True
# Do it again with the combined role.
dev_type = query(
dev_type = db.query(
AccountType,
AccountType.AccountType == "Trusted User & Developer").first()
user.AccountType = dev_type
commit()
with db.begin():
user.AccountType = dev_type
assert user.is_developer() is True
def test_user_voted_for():
now = int(datetime.utcnow().timestamp())
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user)
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
create(PackageVote, PackageBase=pkgbase, User=user, VoteTS=now)
with db.begin():
pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
db.create(PackageVote, PackageBase=pkgbase, User=user, VoteTS=now)
assert user.voted_for(pkg)
def test_user_notified():
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user)
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
create(PackageNotification, PackageBase=pkgbase, User=user)
with db.begin():
pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
db.create(PackageNotification, PackageBase=pkgbase, User=user)
assert user.notified(pkg)
def test_user_packages():
pkgbase = create(PackageBase, Name="pkg1", Maintainer=user)
pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
with db.begin():
pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user)
pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name)
assert pkg in user.packages()