From a5943bf2add0231925d7836e2e0b587a4f5c7f05 Mon Sep 17 00:00:00 2001 From: Kevin Morris Date: Thu, 2 Sep 2021 16:26:48 -0700 Subject: [PATCH] [FastAPI] Refactor db modifications For SQLAlchemy to automatically understand updates from the external world, it must use an `autocommit=True` in its session. This change breaks how we were using commit previously, as `autocommit=True` causes SQLAlchemy to commit when a SessionTransaction context hits __exit__. So, a refactoring was required of our tests: All usage of any `db.{create,delete}` must be called **within** a SessionTransaction context, created via new `db.begin()`. From this point forward, we're going to require: ``` with db.begin(): db.create(...) db.delete(...) db.session.delete(object) ``` With this, we now get external DB modifications automatically without reloading or restarting the FastAPI server, which we absolutely need for production. Signed-off-by: Kevin Morris --- aurweb/db.py | 44 +++++--- aurweb/models/user.py | 42 ++++---- aurweb/routers/accounts.py | 145 +++++++++++++-------------- aurweb/routers/html.py | 6 +- aurweb/routers/trusted_user.py | 20 ++-- test/test_account_type.py | 18 ++-- test/test_accounts_routes.py | 166 ++++++++++++++++--------------- test/test_api_rate_limit.py | 19 ++-- test/test_auth.py | 22 ++-- test/test_auth_routes.py | 9 +- test/test_ban.py | 10 +- test/test_db.py | 22 ++-- test/test_dependency_type.py | 14 ++- test/test_group.py | 11 +- test/test_homepage.py | 49 ++++----- test/test_license.py | 11 +- test/test_official_provider.py | 59 +++++------ test/test_package.py | 68 ++++++------- test/test_package_base.py | 61 ++++++------ test/test_package_blacklist.py | 16 +-- test/test_package_comment.py | 47 +++++---- test/test_package_dependency.py | 84 ++++++++-------- test/test_package_relation.py | 75 +++++++------- test/test_package_request.py | 99 ++++++++++-------- test/test_package_source.py | 23 +++-- test/test_packages_routes.py | 160 +++++++++++++++-------------- test/test_packages_util.py | 27 +++-- test/test_relation_type.py | 15 +-- test/test_request_type.py | 24 +++-- test/test_routes.py | 14 +-- test/test_rss.py | 15 ++- test/test_session.py | 34 ++++--- test/test_ssh_pub_key.py | 32 +++--- test/test_term.py | 19 ++-- test/test_trusted_user_routes.py | 166 ++++++++++++++++--------------- test/test_tu_voteinfo.py | 130 +++++++++++++----------- test/test_user.py | 124 ++++++++++++----------- 37 files changed, 998 insertions(+), 902 deletions(-) diff --git a/aurweb/db.py b/aurweb/db.py index c0147720..ea6b6918 100644 --- a/aurweb/db.py +++ b/aurweb/db.py @@ -59,20 +59,15 @@ def query(model, *args, **kwargs): return session.query(model).filter(*args, **kwargs) -def create(model, autocommit: bool = True, *args, **kwargs): +def create(model, *args, **kwargs): instance = model(*args, **kwargs) - add(instance) - if autocommit is True: - commit() - return instance + return add(instance) -def delete(model, *args, autocommit: bool = True, **kwargs): +def delete(model, *args, **kwargs): instance = session.query(model).filter(*args, **kwargs) for record in instance: session.delete(record) - if autocommit is True: - commit() def rollback(): @@ -84,8 +79,25 @@ def add(model): return model -def commit(): - session.commit() +def begin(): + """ Begin an SQLAlchemy SessionTransaction. + + This context is **required** to perform an modifications to the + database. + + Example: + + with db.begin(): + object = db.create(...) + # On __exit__, db.commit() is run. + + with db.begin(): + object = db.delete(...) + # On __exit__, db.commit() is run. + + :return: A new SessionTransaction based on session + """ + return session.begin() def get_sqlalchemy_url(): @@ -155,23 +167,23 @@ def get_engine(echo: bool = False): connect_args=connect_args, echo=echo) + Session = sessionmaker(autocommit=True, autoflush=False, bind=engine) + session = Session() + if db_backend == "sqlite": # For SQLite, we need to add some custom functions as # they are used in the reference graph method. def regexp(regex, item): return bool(re.search(regex, str(item))) - @event.listens_for(engine, "begin") - def do_begin(conn): + @event.listens_for(engine, "connect") + def do_begin(conn, record): create_deterministic_function = functools.partial( - conn.connection.create_function, + conn.create_function, deterministic=True ) create_deterministic_function("REGEXP", 2, regexp) - Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) - session = Session() - return engine diff --git a/aurweb/models/user.py b/aurweb/models/user.py index 0ccf7329..70d15f88 100644 --- a/aurweb/models/user.py +++ b/aurweb/models/user.py @@ -102,7 +102,7 @@ class User(Base): def login(self, request: Request, password: str, session_time=0): """ Login and authenticate a request. """ - from aurweb.db import session + from aurweb import db from aurweb.models.session import Session, generate_unique_sid if not self._login_approved(request): @@ -112,10 +112,7 @@ class User(Base): if not self.authenticated: return None - self.LastLogin = now_ts = datetime.utcnow().timestamp() - self.LastLoginIPAddress = request.client.host - session.commit() - + now_ts = datetime.utcnow().timestamp() session_ts = now_ts + ( session_time if session_time else aurweb.config.getint("options", "login_timeout") @@ -123,22 +120,23 @@ class User(Base): sid = None - if not self.session: - sid = generate_unique_sid() - self.session = Session(UsersID=self.ID, SessionID=sid, - LastUpdateTS=session_ts) - session.add(self.session) - else: - last_updated = self.session.LastUpdateTS - if last_updated and last_updated < now_ts: - self.session.SessionID = sid = generate_unique_sid() + with db.begin(): + self.LastLogin = now_ts + self.LastLoginIPAddress = request.client.host + if not self.session: + sid = generate_unique_sid() + self.session = Session(UsersID=self.ID, SessionID=sid, + LastUpdateTS=session_ts) + db.add(self.session) else: - # Session is still valid; retrieve the current SID. - sid = self.session.SessionID + last_updated = self.session.LastUpdateTS + if last_updated and last_updated < now_ts: + self.session.SessionID = sid = generate_unique_sid() + else: + # Session is still valid; retrieve the current SID. + sid = self.session.SessionID - self.session.LastUpdateTS = session_ts - - session.commit() + self.session.LastUpdateTS = session_ts request.cookies["AURSID"] = self.session.SessionID return self.session.SessionID @@ -149,13 +147,11 @@ class User(Base): return aurweb.auth.has_credential(self, cred, approved) def logout(self, request): - from aurweb.db import session - del request.cookies["AURSID"] self.authenticated = False if self.session: - session.delete(self.session) - session.commit() + with db.begin(): + db.session.delete(self.session) def is_trusted_user(self): return self.AccountType.ID in { diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py index 466d129d..ef4b99af 100644 --- a/aurweb/routers/accounts.py +++ b/aurweb/routers/accounts.py @@ -43,8 +43,6 @@ async def passreset_post(request: Request, resetkey: str = Form(default=None), password: str = Form(default=None), confirm: str = Form(default=None)): - from aurweb.db import session - context = await make_variable_context(request, "Password Reset") # The user parameter being required, we can match against @@ -86,12 +84,11 @@ async def passreset_post(request: Request, # We got to this point; everything matched up. Update the password # and remove the ResetKey. - user.ResetKey = str() - user.update_password(password) - - if user.session: - session.delete(user.session) - session.commit() + with db.begin(): + user.ResetKey = str() + if user.session: + db.session.delete(user.session) + user.update_password(password) # Render ?step=complete. return RedirectResponse(url="/passreset?step=complete", @@ -99,8 +96,8 @@ async def passreset_post(request: Request, # If we got here, we continue with issuing a resetkey for the user. resetkey = db.make_random_value(User, User.ResetKey) - user.ResetKey = resetkey - session.commit() + with db.begin(): + user.ResetKey = resetkey executor = db.ConnectionExecutor(db.get_engine().raw_connection()) ResetKeyNotification(executor, user.ID).send() @@ -364,8 +361,6 @@ async def account_register_post(request: Request, ON: bool = Form(default=False), captcha: str = Form(default=None), captcha_salt: str = Form(...)): - from aurweb.db import session - context = await make_variable_context(request, "Register") args = dict(await request.form()) @@ -394,11 +389,13 @@ async def account_register_post(request: Request, AccountType.AccountType == "User").first() # Create a user given all parameters available. - user = db.create(User, Username=U, Email=E, HideEmail=H, BackupEmail=BE, - RealName=R, Homepage=HP, IRCNick=I, PGPKey=K, - LangPreference=L, Timezone=TZ, CommentNotify=CN, - UpdateNotify=UN, OwnershipNotify=ON, ResetKey=resetkey, - AccountType=account_type) + with db.begin(): + user = db.create(User, Username=U, + Email=E, HideEmail=H, BackupEmail=BE, + RealName=R, Homepage=HP, IRCNick=I, PGPKey=K, + LangPreference=L, Timezone=TZ, CommentNotify=CN, + UpdateNotify=UN, OwnershipNotify=ON, + ResetKey=resetkey, AccountType=account_type) # If a PK was given and either one does not exist or the given # PK mismatches the existing user's SSHPubKey.PubKey. @@ -410,10 +407,10 @@ async def account_register_post(request: Request, # Remove the host part. pubkey = parts[0] + " " + parts[1] fingerprint = get_fingerprint(pubkey) - user.ssh_pub_key = SSHPubKey(UserID=user.ID, - PubKey=pubkey, - Fingerprint=fingerprint) - session.commit() + with db.begin(): + user.ssh_pub_key = SSHPubKey(UserID=user.ID, + PubKey=pubkey, + Fingerprint=fingerprint) # Send a reset key notification to the new user. executor = db.ConnectionExecutor(db.get_engine().raw_connection()) @@ -499,63 +496,67 @@ async def account_edit_post(request: Request, status_code=int(HTTPStatus.BAD_REQUEST)) # Set all updated fields as needed. - user.Username = U or user.Username - user.Email = E or user.Email - user.HideEmail = bool(H) - user.BackupEmail = BE or user.BackupEmail - user.RealName = R or user.RealName - user.Homepage = HP or user.Homepage - user.IRCNick = I or user.IRCNick - user.PGPKey = K or user.PGPKey - user.InactivityTS = datetime.utcnow().timestamp() if J else 0 + with db.begin(): + user.Username = U or user.Username + user.Email = E or user.Email + user.HideEmail = bool(H) + user.BackupEmail = BE or user.BackupEmail + user.RealName = R or user.RealName + user.Homepage = HP or user.Homepage + user.IRCNick = I or user.IRCNick + user.PGPKey = K or user.PGPKey + user.InactivityTS = datetime.utcnow().timestamp() if J else 0 # If we update the language, update the cookie as well. if L and L != user.LangPreference: request.cookies["AURLANG"] = L - user.LangPreference = L + with db.begin(): + user.LangPreference = L context["language"] = L # If we update the timezone, also update the cookie. if TZ and TZ != user.Timezone: - user.Timezone = TZ + with db.begin(): + user.Timezone = TZ request.cookies["AURTZ"] = TZ context["timezone"] = TZ - user.CommentNotify = bool(CN) - user.UpdateNotify = bool(UN) - user.OwnershipNotify = bool(ON) + with db.begin(): + user.CommentNotify = bool(CN) + user.UpdateNotify = bool(UN) + user.OwnershipNotify = bool(ON) # If a PK is given, compare it against the target user's PK. - if PK: - # Get the second token in the public key, which is the actual key. - pubkey = PK.strip().rstrip() - parts = pubkey.split(" ") - if len(parts) == 3: - # Remove the host part. - pubkey = parts[0] + " " + parts[1] - fingerprint = get_fingerprint(pubkey) - if not user.ssh_pub_key: - # No public key exists, create one. - user.ssh_pub_key = SSHPubKey(UserID=user.ID, - PubKey=pubkey, - Fingerprint=fingerprint) - elif user.ssh_pub_key.PubKey != pubkey: - # A public key already exists, update it. - user.ssh_pub_key.PubKey = pubkey - user.ssh_pub_key.Fingerprint = fingerprint - elif user.ssh_pub_key: - # Else, if the user has a public key already, delete it. - session.delete(user.ssh_pub_key) - - # Commit changes, if any. - session.commit() + with db.begin(): + if PK: + # Get the second token in the public key, which is the actual key. + pubkey = PK.strip().rstrip() + parts = pubkey.split(" ") + if len(parts) == 3: + # Remove the host part. + pubkey = parts[0] + " " + parts[1] + fingerprint = get_fingerprint(pubkey) + if not user.ssh_pub_key: + # No public key exists, create one. + user.ssh_pub_key = SSHPubKey(UserID=user.ID, + PubKey=pubkey, + Fingerprint=fingerprint) + elif user.ssh_pub_key.PubKey != pubkey: + # A public key already exists, update it. + user.ssh_pub_key.PubKey = pubkey + user.ssh_pub_key.Fingerprint = fingerprint + elif user.ssh_pub_key: + # Else, if the user has a public key already, delete it. + session.delete(user.ssh_pub_key) if P and not user.valid_password(P): # Remove the fields we consumed for passwords. context["P"] = context["C"] = str() # If a password was given and it doesn't match the user's, update it. - user.update_password(P) + with db.begin(): + user.update_password(P) + if user == request.user: # If the target user is the request user, login with # the updated password and update AURSID. @@ -731,21 +732,17 @@ async def terms_of_service_post(request: Request, accept_needed = sorted(unaccepted + diffs) return render_terms_of_service(request, context, accept_needed) - # For each term we found, query for the matching accepted term - # and update its Revision to the term's current Revision. - for term in diffs: - accepted_term = request.user.accepted_terms.filter( - AcceptedTerm.TermsID == term.ID).first() - accepted_term.Revision = term.Revision + with db.begin(): + # For each term we found, query for the matching accepted term + # and update its Revision to the term's current Revision. + for term in diffs: + accepted_term = request.user.accepted_terms.filter( + AcceptedTerm.TermsID == term.ID).first() + accepted_term.Revision = term.Revision - # For each term that was never accepted, accept it! - for term in unaccepted: - db.create(AcceptedTerm, User=request.user, - Term=term, Revision=term.Revision, - autocommit=False) - - if diffs or unaccepted: - # If we had any terms to update, commit the changes. - db.commit() + # For each term that was never accepted, accept it! + for term in unaccepted: + db.create(AcceptedTerm, User=request.user, + Term=term, Revision=term.Revision) return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER)) diff --git a/aurweb/routers/html.py b/aurweb/routers/html.py index c2375f69..c3fd3db1 100644 --- a/aurweb/routers/html.py +++ b/aurweb/routers/html.py @@ -44,8 +44,6 @@ async def language(request: Request, setting the language on any page, we want to preserve query parameters across the redirect. """ - from aurweb.db import session - if next[0] != '/': return HTMLResponse(b"Invalid 'next' parameter.", status_code=400) @@ -53,8 +51,8 @@ async def language(request: Request, # If the user is authenticated, update the user's LangPreference. if request.user.is_authenticated(): - request.user.LangPreference = set_lang - session.commit() + with db.begin(): + request.user.LangPreference = set_lang # In any case, set the response's AURLANG cookie that never expires. response = RedirectResponse(url=f"{next}{query_string}", diff --git a/aurweb/routers/trusted_user.py b/aurweb/routers/trusted_user.py index 61cfec6c..a977b31a 100644 --- a/aurweb/routers/trusted_user.py +++ b/aurweb/routers/trusted_user.py @@ -214,10 +214,9 @@ async def trusted_user_proposal_post(request: Request, return Response("Invalid 'decision' value.", status_code=int(HTTPStatus.BAD_REQUEST)) - vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo, - autocommit=False) - voteinfo.ActiveTUs += 1 - db.commit() + with db.begin(): + vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo) + voteinfo.ActiveTUs += 1 context["error"] = "You've already voted for this proposal." return render_proposal(request, context, proposal, voteinfo, voters, vote) @@ -294,12 +293,13 @@ async def trusted_user_addvote_post(request: Request, agenda = re.sub(r'<[/]?style.*>', '', agenda) # Create a new TUVoteInfo (proposal)! - voteinfo = db.create(TUVoteInfo, - User=user, - Agenda=agenda, - Submitted=timestamp, End=timestamp + duration, - Quorum=quorum, - Submitter=request.user) + with db.begin(): + voteinfo = db.create(TUVoteInfo, + User=user, + Agenda=agenda, + Submitted=timestamp, End=timestamp + duration, + Quorum=quorum, + Submitter=request.user) # Redirect to the new proposal. return RedirectResponse(f"/tu/{voteinfo.ID}", diff --git a/test/test_account_type.py b/test/test_account_type.py index fa4bc5ad..86e68253 100644 --- a/test/test_account_type.py +++ b/test/test_account_type.py @@ -1,6 +1,6 @@ import pytest -from aurweb.db import create, delete, query +from aurweb.db import begin, create, delete, query from aurweb.models.account_type import AccountType from aurweb.models.user import User from aurweb.testing import setup_test_db @@ -14,11 +14,13 @@ def setup(): global account_type - account_type = create(AccountType, AccountType="TestUser") + with begin(): + account_type = create(AccountType, AccountType="TestUser") yield account_type - delete(AccountType, AccountType.ID == account_type.ID) + with begin(): + delete(AccountType, AccountType.ID == account_type.ID) def test_account_type(): @@ -38,12 +40,14 @@ def test_account_type(): def test_user_account_type_relationship(): - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=account_type) + with begin(): + user = create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=account_type) assert user.AccountType == account_type # This must be deleted here to avoid foreign key issues when # deleting the temporary AccountType in the fixture. - delete(User, User.ID == user.ID) + with begin(): + delete(User, User.ID == user.ID) diff --git a/test/test_accounts_routes.py b/test/test_accounts_routes.py index 567b3426..9120f23f 100644 --- a/test/test_accounts_routes.py +++ b/test/test_accounts_routes.py @@ -11,9 +11,9 @@ import pytest from fastapi.testclient import TestClient -from aurweb import captcha +from aurweb import captcha, db from aurweb.asgi import app -from aurweb.db import commit, create, query +from aurweb.db import create, query from aurweb.models.accepted_term import AcceptedTerm from aurweb.models.account_type import DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, AccountType from aurweb.models.ban import Ban @@ -57,9 +57,11 @@ def setup(): account_type = query(AccountType, AccountType.AccountType == "User").first() - user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL, - RealName="Test UserZ", Passwd="testPassword", - IRCNick="testZ", AccountType=account_type) + + with db.begin(): + user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL, + RealName="Test UserZ", Passwd="testPassword", + IRCNick="testZ", AccountType=account_type) yield user @@ -70,9 +72,10 @@ def setup(): @pytest.fixture def tu_user(): - user.AccountType = query(AccountType, - AccountType.ID == TRUSTED_USER_AND_DEV_ID).first() - commit() + with db.begin(): + user.AccountType = query(AccountType).filter( + AccountType.ID == TRUSTED_USER_AND_DEV_ID + ).first() yield user @@ -149,11 +152,9 @@ def test_post_passreset_user(): def test_post_passreset_resetkey(): - from aurweb.db import session - - user.session = Session(UsersID=user.ID, SessionID="blah", - LastUpdateTS=datetime.utcnow().timestamp()) - session.commit() + with db.begin(): + user.session = Session(UsersID=user.ID, SessionID="blah", + LastUpdateTS=datetime.utcnow().timestamp()) # Prepare a password reset. with client as request: @@ -357,7 +358,8 @@ def test_post_register_error_invalid_captcha(): def test_post_register_error_ip_banned(): # 'testclient' is used as request.client.host via FastAPI TestClient. - create(Ban, IPAddress="testclient", BanTS=datetime.utcnow()) + with db.begin(): + create(Ban, IPAddress="testclient", BanTS=datetime.utcnow()) with client as request: response = post_register(request) @@ -576,7 +578,8 @@ def test_post_register_error_ssh_pubkey_taken(): # Take the sha256 fingerprint of the ssh public key, create it. fp = get_fingerprint(pk) - create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp) + with db.begin(): + create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp) with client as request: response = post_register(request, PK=pk) @@ -660,13 +663,11 @@ def test_post_account_edit(): def test_post_account_edit_dev(): - from aurweb.db import session - # Modify our user to be a "Trusted User & Developer" name = "Trusted User & Developer" tu_or_dev = query(AccountType, AccountType.AccountType == name).first() - user.AccountType = tu_or_dev - session.commit() + with db.begin(): + user.AccountType = tu_or_dev request = Request() sid = user.login(request, "testPassword") @@ -1001,21 +1002,19 @@ def get_rows(html): def test_post_accounts(tu_user): # Set a PGPKey. - user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30" + with db.begin(): + user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30" # Create a few more users. users = [user] - for i in range(10): - _user = create(User, Username=f"test_{i}", - Email=f"test_{i}@example.org", - RealName=f"Test #{i}", - Passwd="testPassword", - IRCNick=f"test_#{i}", - autocommit=False) - users.append(_user) - - # Commit everything to the database. - commit() + with db.begin(): + for i in range(10): + _user = create(User, Username=f"test_{i}", + Email=f"test_{i}@example.org", + RealName=f"Test #{i}", + Passwd="testPassword", + IRCNick=f"test_#{i}") + users.append(_user) sid = user.login(Request(), "testPassword") cookies = {"AURSID": sid} @@ -1085,11 +1084,12 @@ def test_post_accounts_account_type(tu_user): # test the `u` parameter. account_type = query(AccountType, AccountType.AccountType == "User").first() - create(User, Username="test_2", - Email="test_2@example.org", - RealName="Test User 2", - Passwd="testPassword", - AccountType=account_type) + with db.begin(): + create(User, Username="test_2", + Email="test_2@example.org", + RealName="Test User 2", + Passwd="testPassword", + AccountType=account_type) # Expect no entries; we marked our only user as a User type. with client as request: @@ -1113,9 +1113,10 @@ def test_post_accounts_account_type(tu_user): assert type.text.strip() == "User" # Set our only user to a Trusted User. - user.AccountType = query(AccountType, - AccountType.ID == TRUSTED_USER_ID).first() - commit() + with db.begin(): + user.AccountType = query(AccountType).filter( + AccountType.ID == TRUSTED_USER_ID + ).first() with client as request: response = request.post("/accounts/", cookies=cookies, @@ -1130,9 +1131,10 @@ def test_post_accounts_account_type(tu_user): assert type.text.strip() == "Trusted User" - user.AccountType = query(AccountType, - AccountType.ID == DEVELOPER_ID).first() - commit() + with db.begin(): + user.AccountType = query(AccountType).filter( + AccountType.ID == DEVELOPER_ID + ).first() with client as request: response = request.post("/accounts/", cookies=cookies, @@ -1147,10 +1149,10 @@ def test_post_accounts_account_type(tu_user): assert type.text.strip() == "Developer" - user.AccountType = query(AccountType, - AccountType.ID == TRUSTED_USER_AND_DEV_ID - ).first() - commit() + with db.begin(): + user.AccountType = query(AccountType).filter( + AccountType.ID == TRUSTED_USER_AND_DEV_ID + ).first() with client as request: response = request.post("/accounts/", cookies=cookies, @@ -1182,8 +1184,8 @@ def test_post_accounts_status(tu_user): username, type, status, realname, irc, pgp_key, edit = row assert status.text.strip() == "Active" - user.Suspended = True - commit() + with db.begin(): + user.Suspended = True with client as request: response = request.post("/accounts/", cookies=cookies, @@ -1244,12 +1246,13 @@ def test_post_accounts_sortby(tu_user): # Create a second user so we can compare sorts. account_type = query(AccountType, AccountType.ID == DEVELOPER_ID).first() - create(User, Username="test2", - Email="test2@example.org", - RealName="Test User 2", - Passwd="testPassword", - IRCNick="test2", - AccountType=account_type) + with db.begin(): + create(User, Username="test2", + Email="test2@example.org", + RealName="Test User 2", + Passwd="testPassword", + IRCNick="test2", + AccountType=account_type) sid = user.login(Request(), "testPassword") cookies = {"AURSID": sid} @@ -1297,9 +1300,10 @@ def test_post_accounts_sortby(tu_user): # Test the rows are reversed when ordering by RealName. assert compare_text_values(4, first_rows, reversed(rows)) is True - user.AccountType = query(AccountType, - AccountType.ID == TRUSTED_USER_AND_DEV_ID).first() - commit() + with db.begin(): + user.AccountType = query(AccountType).filter( + AccountType.ID == TRUSTED_USER_AND_DEV_ID + ).first() # Fetch first_rows again with our new AccountType ordering. with client as request: @@ -1322,8 +1326,8 @@ def test_post_accounts_sortby(tu_user): def test_post_accounts_pgp_key(tu_user): - user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30" - commit() + with db.begin(): + user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30" sid = user.login(Request(), "testPassword") cookies = {"AURSID": sid} @@ -1343,15 +1347,14 @@ def test_post_accounts_paged(tu_user): users = [user] account_type = query(AccountType, AccountType.AccountType == "User").first() - for i in range(150): - _user = create(User, Username=f"test_#{i}", - Email=f"test_#{i}@example.org", - RealName=f"Test User #{i}", - Passwd="testPassword", - AccountType=account_type, - autocommit=False) - users.append(_user) - commit() + with db.begin(): + for i in range(150): + _user = create(User, Username=f"test_#{i}", + Email=f"test_#{i}@example.org", + RealName=f"Test User #{i}", + Passwd="testPassword", + AccountType=account_type) + users.append(_user) sid = user.login(Request(), "testPassword") cookies = {"AURSID": sid} @@ -1414,8 +1417,9 @@ def test_post_accounts_paged(tu_user): def test_get_terms_of_service(): - term = create(Term, Description="Test term.", - URL="http://localhost", Revision=1) + with db.begin(): + term = create(Term, Description="Test term.", + URL="http://localhost", Revision=1) with client as request: response = request.get("/tos", allow_redirects=False) @@ -1436,8 +1440,9 @@ def test_get_terms_of_service(): response = request.get("/tos", cookies=cookies, allow_redirects=False) assert response.status_code == int(HTTPStatus.OK) - accepted_term = create(AcceptedTerm, User=user, - Term=term, Revision=term.Revision) + with db.begin(): + accepted_term = create(AcceptedTerm, User=user, + Term=term, Revision=term.Revision) with client as request: response = request.get("/tos", cookies=cookies, allow_redirects=False) @@ -1445,8 +1450,8 @@ def test_get_terms_of_service(): assert response.status_code == int(HTTPStatus.SEE_OTHER) # Bump the term's revision. - term.Revision = 2 - commit() + with db.begin(): + term.Revision = 2 with client as request: response = request.get("/tos", cookies=cookies, allow_redirects=False) @@ -1454,8 +1459,8 @@ def test_get_terms_of_service(): # yet been agreed to via AcceptedTerm update. assert response.status_code == int(HTTPStatus.OK) - accepted_term.Revision = term.Revision - commit() + with db.begin(): + accepted_term.Revision = term.Revision with client as request: response = request.get("/tos", cookies=cookies, allow_redirects=False) @@ -1471,8 +1476,9 @@ def test_post_terms_of_service(): cookies = {"AURSID": sid} # Auth cookie. # Create a fresh Term. - term = create(Term, Description="Test term.", - URL="http://localhost", Revision=1) + with db.begin(): + term = create(Term, Description="Test term.", + URL="http://localhost", Revision=1) # Test that the term we just created is listed. with client as request: @@ -1497,8 +1503,8 @@ def test_post_terms_of_service(): assert accepted_term.Term == term # Update the term to revision 2. - term.Revision = 2 - commit() + with db.begin(): + term.Revision = 2 # A GET request gives us the new revision to accept. with client as request: diff --git a/test/test_api_rate_limit.py b/test/test_api_rate_limit.py index 536e3841..25cb3e0f 100644 --- a/test/test_api_rate_limit.py +++ b/test/test_api_rate_limit.py @@ -2,6 +2,7 @@ import pytest from sqlalchemy.exc import IntegrityError +from aurweb import db from aurweb.db import create from aurweb.models.api_rate_limit import ApiRateLimit from aurweb.testing import setup_test_db @@ -13,26 +14,28 @@ def setup(): def test_api_rate_key_creation(): - rate = create(ApiRateLimit, IP="127.0.0.1", Requests=10, WindowStart=1) + with db.begin(): + rate = create(ApiRateLimit, IP="127.0.0.1", Requests=10, WindowStart=1) assert rate.IP == "127.0.0.1" assert rate.Requests == 10 assert rate.WindowStart == 1 def test_api_rate_key_ip_default(): - api_rate_limit = create(ApiRateLimit, Requests=10, WindowStart=1) + with db.begin(): + api_rate_limit = create(ApiRateLimit, Requests=10, WindowStart=1) assert api_rate_limit.IP == str() def test_api_rate_key_null_requests_raises_exception(): - from aurweb.db import session with pytest.raises(IntegrityError): - create(ApiRateLimit, IP="127.0.0.1", WindowStart=1) - session.rollback() + with db.begin(): + create(ApiRateLimit, IP="127.0.0.1", WindowStart=1) + db.rollback() def test_api_rate_key_null_window_start_raises_exception(): - from aurweb.db import session with pytest.raises(IntegrityError): - create(ApiRateLimit, IP="127.0.0.1", Requests=1) - session.rollback() + with db.begin(): + create(ApiRateLimit, IP="127.0.0.1", Requests=1) + db.rollback() diff --git a/test/test_auth.py b/test/test_auth.py index b386bea1..caa39468 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -4,6 +4,7 @@ import pytest from sqlalchemy.exc import IntegrityError +from aurweb import db from aurweb.auth import BasicAuthBackend, account_type_required, has_credential from aurweb.db import create, query from aurweb.models.account_type import USER, USER_ID, AccountType @@ -23,9 +24,10 @@ def setup(): account_type = query(AccountType, AccountType.AccountType == "User").first() - user = create(User, Username="test", Email="test@example.com", - RealName="Test User", Passwd="testPassword", - AccountType=account_type) + with db.begin(): + user = create(User, Username="test", Email="test@example.com", + RealName="Test User", Passwd="testPassword", + AccountType=account_type) backend = BasicAuthBackend() request = Request() @@ -51,14 +53,13 @@ async def test_auth_backend_invalid_sid(): @pytest.mark.asyncio async def test_auth_backend_invalid_user_id(): - from aurweb.db import session - # Create a new session with a fake user id. now_ts = datetime.utcnow().timestamp() with pytest.raises(IntegrityError): - create(Session, UsersID=666, SessionID="realSession", - LastUpdateTS=now_ts + 5) - session.rollback() + with db.begin(): + create(Session, UsersID=666, SessionID="realSession", + LastUpdateTS=now_ts + 5) + db.rollback() @pytest.mark.asyncio @@ -66,8 +67,9 @@ async def test_basic_auth_backend(): # This time, everything matches up. We expect the user to # equal the real_user. now_ts = datetime.utcnow().timestamp() - create(Session, UsersID=user.ID, SessionID="realSession", - LastUpdateTS=now_ts + 5) + with db.begin(): + create(Session, UsersID=user.ID, SessionID="realSession", + LastUpdateTS=now_ts + 5) request.cookies["AURSID"] = "realSession" _, result = await backend.authenticate(request) assert result == user diff --git a/test/test_auth_routes.py b/test/test_auth_routes.py index b0dd5648..1d8f9cbe 100644 --- a/test/test_auth_routes.py +++ b/test/test_auth_routes.py @@ -9,7 +9,7 @@ from fastapi.testclient import TestClient import aurweb.config from aurweb.asgi import app -from aurweb.db import create, query +from aurweb.db import begin, create, query from aurweb.models.account_type import AccountType from aurweb.models.session import Session from aurweb.models.user import User @@ -32,9 +32,10 @@ def setup(): account_type = query(AccountType, AccountType.AccountType == "User").first() - user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL, - RealName="Test User", Passwd="testPassword", - AccountType=account_type) + with begin(): + user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL, + RealName="Test User", Passwd="testPassword", + AccountType=account_type) client = TestClient(app) diff --git a/test/test_ban.py b/test/test_ban.py index b728644b..f96e9d14 100644 --- a/test/test_ban.py +++ b/test/test_ban.py @@ -6,6 +6,7 @@ import pytest from sqlalchemy import exc as sa_exc +from aurweb import db from aurweb.db import create from aurweb.models.ban import Ban, is_banned from aurweb.testing import setup_test_db @@ -21,7 +22,8 @@ def setup(): setup_test_db("Bans") ts = datetime.utcnow() + timedelta(seconds=30) - ban = create(Ban, IPAddress="127.0.0.1", BanTS=ts) + with db.begin(): + ban = create(Ban, IPAddress="127.0.0.1", BanTS=ts) request = Request() @@ -35,17 +37,17 @@ def test_invalid_ban(): with pytest.raises(sa_exc.IntegrityError): bad_ban = Ban(BanTS=datetime.utcnow()) - session.add(bad_ban) # We're adding a ban with no primary key; this causes an # SQLAlchemy warnings when committing to the DB. # Ignore them. with warnings.catch_warnings(): warnings.simplefilter("ignore", sa_exc.SAWarning) - session.commit() + with db.begin(): + session.add(bad_ban) # Since we got a transaction failure, we need to rollback. - session.rollback() + db.rollback() def test_banned(): diff --git a/test/test_db.py b/test/test_db.py index 9ece25ea..7798d2f6 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -278,18 +278,15 @@ def test_connection_execute_paramstyle_unsupported(): def test_create_delete(): - db.create(AccountType, AccountType="test") + with db.begin(): + db.create(AccountType, AccountType="test") + record = db.query(AccountType, AccountType.AccountType == "test").first() assert record is not None - db.delete(AccountType, AccountType.AccountType == "test") - record = db.query(AccountType, AccountType.AccountType == "test").first() - assert record is None - # Create and delete a record with autocommit=False. - db.create(AccountType, AccountType="test", autocommit=False) - db.commit() - db.delete(AccountType, AccountType.AccountType == "test", autocommit=False) - db.commit() + with db.begin(): + db.delete(AccountType, AccountType.AccountType == "test") + record = db.query(AccountType, AccountType.AccountType == "test").first() assert record is None @@ -297,8 +294,8 @@ def test_create_delete(): def test_add_commit(): # Use db.add and db.commit to add a temporary record. account_type = AccountType(AccountType="test") - db.add(account_type) - db.commit() + with db.begin(): + db.add(account_type) # Assert it got created in the DB. assert bool(account_type.ID) @@ -308,7 +305,8 @@ def test_add_commit(): assert record == account_type # Remove the record. - db.delete(AccountType, AccountType.ID == account_type.ID) + with db.begin(): + db.delete(AccountType, AccountType.ID == account_type.ID) def test_connection_executor_mysql_paramstyle(): diff --git a/test/test_dependency_type.py b/test/test_dependency_type.py index 6c37cc58..4d555123 100644 --- a/test/test_dependency_type.py +++ b/test/test_dependency_type.py @@ -1,6 +1,6 @@ import pytest -from aurweb.db import create, delete, query +from aurweb.db import begin, create, delete, query from aurweb.models.dependency_type import DependencyType from aurweb.testing import setup_test_db @@ -19,13 +19,17 @@ def test_dependency_types(): def test_dependency_type_creation(): - dependency_type = create(DependencyType, Name="Test Type") + with begin(): + dependency_type = create(DependencyType, Name="Test Type") assert bool(dependency_type.ID) assert dependency_type.Name == "Test Type" - delete(DependencyType, DependencyType.ID == dependency_type.ID) + with begin(): + delete(DependencyType, DependencyType.ID == dependency_type.ID) def test_dependency_type_null_name_uses_default(): - dependency_type = create(DependencyType) + with begin(): + dependency_type = create(DependencyType) assert dependency_type.Name == str() - delete(DependencyType, DependencyType.ID == dependency_type.ID) + with begin(): + delete(DependencyType, DependencyType.ID == dependency_type.ID) diff --git a/test/test_group.py b/test/test_group.py index da017a96..cea69b68 100644 --- a/test/test_group.py +++ b/test/test_group.py @@ -2,7 +2,7 @@ import pytest from sqlalchemy.exc import IntegrityError -from aurweb.db import create +from aurweb import db from aurweb.models.group import Group from aurweb.testing import setup_test_db @@ -13,13 +13,14 @@ def setup(): def test_group_creation(): - group = create(Group, Name="Test Group") + with db.begin(): + group = db.create(Group, Name="Test Group") assert bool(group.ID) assert group.Name == "Test Group" def test_group_null_name_raises_exception(): - from aurweb.db import session with pytest.raises(IntegrityError): - create(Group) - session.rollback() + with db.begin(): + db.create(Group) + db.rollback() diff --git a/test/test_homepage.py b/test/test_homepage.py index 2cd6682f..fef3532d 100644 --- a/test/test_homepage.py +++ b/test/test_homepage.py @@ -38,8 +38,10 @@ def setup(): @pytest.fixture def user(): - yield db.create(User, Username="test", Email="test@example.org", - Passwd="testPassword", AccountTypeID=USER_ID) + with db.begin(): + user = db.create(User, Username="test", Email="test@example.org", + Passwd="testPassword", AccountTypeID=USER_ID) + yield user @pytest.fixture @@ -68,17 +70,14 @@ def packages(user): # For i..num_packages, create a package named pkg_{i}. pkgs = [] now = int(datetime.utcnow().timestamp()) - for i in range(num_packages): - pkgbase = db.create(PackageBase, Name=f"pkg_{i}", - Maintainer=user, Packager=user, - autocommit=False, SubmittedTS=now, - ModifiedTS=now) - pkg = db.create(Package, PackageBase=pkgbase, - Name=pkgbase.Name, autocommit=False) - pkgs.append(pkg) - now += 1 - - db.commit() + with db.begin(): + for i in range(num_packages): + pkgbase = db.create(PackageBase, Name=f"pkg_{i}", + Maintainer=user, Packager=user, + SubmittedTS=now, ModifiedTS=now) + pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name) + pkgs.append(pkg) + now += 1 yield pkgs @@ -159,10 +158,11 @@ def test_homepage_updates(redis, packages): def test_homepage_dashboard(redis, packages, user): # Create Comaintainer records for all of the packages. - for pkg in packages: - db.create(PackageComaintainer, PackageBase=pkg.PackageBase, - User=user, Priority=1, autocommit=False) - db.commit() + with db.begin(): + for pkg in packages: + db.create(PackageComaintainer, + PackageBase=pkg.PackageBase, + User=user, Priority=1) cookies = {"AURSID": user.login(Request(), "testPassword")} with client as request: @@ -193,11 +193,12 @@ def test_homepage_dashboard_requests(redis, packages, user): pkg = packages[0] reqtype = db.query(RequestType, RequestType.ID == DELETION_ID).first() - pkgreq = db.create(PackageRequest, PackageBase=pkg.PackageBase, - PackageBaseName=pkg.PackageBase.Name, - User=user, Comments=str(), - ClosureComment=str(), RequestTS=now, - RequestType=reqtype) + with db.begin(): + pkgreq = db.create(PackageRequest, PackageBase=pkg.PackageBase, + PackageBaseName=pkg.PackageBase.Name, + User=user, Comments=str(), + ClosureComment=str(), RequestTS=now, + RequestType=reqtype) cookies = {"AURSID": user.login(Request(), "testPassword")} with client as request: @@ -213,8 +214,8 @@ def test_homepage_dashboard_requests(redis, packages, user): def test_homepage_dashboard_flagged_packages(redis, packages, user): # Set the first Package flagged by setting its OutOfDateTS column. pkg = packages[0] - pkg.PackageBase.OutOfDateTS = int(datetime.utcnow().timestamp()) - db.commit() + with db.begin(): + pkg.PackageBase.OutOfDateTS = int(datetime.utcnow().timestamp()) cookies = {"AURSID": user.login(Request(), "testPassword")} with client as request: diff --git a/test/test_license.py b/test/test_license.py index feb7a396..2c52f058 100644 --- a/test/test_license.py +++ b/test/test_license.py @@ -2,7 +2,7 @@ import pytest from sqlalchemy.exc import IntegrityError -from aurweb.db import create +from aurweb import db from aurweb.models.license import License from aurweb.testing import setup_test_db @@ -13,13 +13,14 @@ def setup(): def test_license_creation(): - license = create(License, Name="Test License") + with db.begin(): + license = db.create(License, Name="Test License") assert bool(license.ID) assert license.Name == "Test License" def test_license_null_name_raises_exception(): - from aurweb.db import session with pytest.raises(IntegrityError): - create(License) - session.rollback() + with db.begin(): + db.create(License) + db.rollback() diff --git a/test/test_official_provider.py b/test/test_official_provider.py index a1d3d54a..0aa4f1d1 100644 --- a/test/test_official_provider.py +++ b/test/test_official_provider.py @@ -2,7 +2,7 @@ import pytest from sqlalchemy.exc import IntegrityError -from aurweb.db import create +from aurweb import db from aurweb.models.official_provider import OfficialProvider from aurweb.testing import setup_test_db @@ -13,10 +13,11 @@ def setup(): def test_official_provider_creation(): - oprovider = create(OfficialProvider, - Name="some-name", - Repo="some-repo", - Provides="some-provides") + with db.begin(): + oprovider = db.create(OfficialProvider, + Name="some-name", + Repo="some-repo", + Provides="some-provides") assert bool(oprovider.ID) assert oprovider.Name == "some-name" assert oprovider.Repo == "some-repo" @@ -25,16 +26,18 @@ def test_official_provider_creation(): def test_official_provider_cs(): """ Test case sensitivity of the database table. """ - oprovider = create(OfficialProvider, - Name="some-name", - Repo="some-repo", - Provides="some-provides") + with db.begin(): + oprovider = db.create(OfficialProvider, + Name="some-name", + Repo="some-repo", + Provides="some-provides") assert bool(oprovider.ID) - oprovider_cs = create(OfficialProvider, - Name="SOME-NAME", - Repo="SOME-REPO", - Provides="SOME-PROVIDES") + with db.begin(): + oprovider_cs = db.create(OfficialProvider, + Name="SOME-NAME", + Repo="SOME-REPO", + Provides="SOME-PROVIDES") assert bool(oprovider_cs.ID) assert oprovider.ID != oprovider_cs.ID @@ -49,27 +52,27 @@ def test_official_provider_cs(): def test_official_provider_null_name_raises_exception(): - from aurweb.db import session with pytest.raises(IntegrityError): - create(OfficialProvider, - Repo="some-repo", - Provides="some-provides") - session.rollback() + with db.begin(): + db.create(OfficialProvider, + Repo="some-repo", + Provides="some-provides") + db.rollback() def test_official_provider_null_repo_raises_exception(): - from aurweb.db import session with pytest.raises(IntegrityError): - create(OfficialProvider, - Name="some-name", - Provides="some-provides") - session.rollback() + with db.begin(): + db.create(OfficialProvider, + Name="some-name", + Provides="some-provides") + db.rollback() def test_official_provider_null_provides_raises_exception(): - from aurweb.db import session with pytest.raises(IntegrityError): - create(OfficialProvider, - Name="some-name", - Repo="some-repo") - session.rollback() + with db.begin(): + db.create(OfficialProvider, + Name="some-name", + Repo="some-repo") + db.rollback() diff --git a/test/test_package.py b/test/test_package.py index 1e940164..112ca9b4 100644 --- a/test/test_package.py +++ b/test/test_package.py @@ -3,7 +3,7 @@ import pytest from sqlalchemy import and_ from sqlalchemy.exc import IntegrityError -from aurweb.db import create, query +from aurweb import db from aurweb.models.account_type import AccountType from aurweb.models.package import Package from aurweb.models.package_base import PackageBase @@ -19,25 +19,25 @@ def setup(): setup_test_db("Packages", "PackageBases", "Users") - account_type = query(AccountType, - AccountType.AccountType == "User").first() - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=account_type) + account_type = db.query(AccountType, + AccountType.AccountType == "User").first() - pkgbase = create(PackageBase, - Name="beautiful-package", - Maintainer=user) - package = create(Package, - PackageBase=pkgbase, - Name=pkgbase.Name, - Description="Test description.", - URL="https://test.package") + with db.begin(): + user = db.create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=account_type) + + pkgbase = db.create(PackageBase, + Name="beautiful-package", + Maintainer=user) + package = db.create(Package, + PackageBase=pkgbase, + Name=pkgbase.Name, + Description="Test description.", + URL="https://test.package") def test_package(): - from aurweb.db import session - assert pkgbase == package.PackageBase assert package.Name == "beautiful-package" assert package.Description == "Test description." @@ -45,33 +45,31 @@ def test_package(): assert package.URL == "https://test.package" # Update package Version. - package.Version = "1.2.3" - session.commit() + with db.begin(): + package.Version = "1.2.3" # Make sure it got updated in the database. - record = query(Package, - and_(Package.ID == package.ID, - Package.Version == "1.2.3")).first() + record = db.query(Package, + and_(Package.ID == package.ID, + Package.Version == "1.2.3")).first() assert record is not None def test_package_null_pkgbase_raises_exception(): - from aurweb.db import session - with pytest.raises(IntegrityError): - create(Package, - Name="some-package", - Description="Some description.", - URL="https://some.package") - session.rollback() + with db.begin(): + db.create(Package, + Name="some-package", + Description="Some description.", + URL="https://some.package") + db.rollback() def test_package_null_name_raises_exception(): - from aurweb.db import session - with pytest.raises(IntegrityError): - create(Package, - PackageBase=pkgbase, - Description="Some description.", - URL="https://some.package") - session.rollback() + with db.begin(): + db.create(Package, + PackageBase=pkgbase, + Description="Some description.", + URL="https://some.package") + db.rollback() diff --git a/test/test_package_base.py b/test/test_package_base.py index 0c0d0526..2bc6278f 100644 --- a/test/test_package_base.py +++ b/test/test_package_base.py @@ -4,7 +4,7 @@ from sqlalchemy.exc import IntegrityError import aurweb.config -from aurweb.db import create, query +from aurweb import db from aurweb.models.account_type import AccountType from aurweb.models.package_base import PackageBase from aurweb.models.user import User @@ -19,17 +19,19 @@ def setup(): setup_test_db("Users", "PackageBases") - account_type = query(AccountType, - AccountType.AccountType == "User").first() - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=account_type) + account_type = db.query(AccountType, + AccountType.AccountType == "User").first() + with db.begin(): + user = db.create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=account_type) def test_package_base(): - pkgbase = create(PackageBase, - Name="beautiful-package", - Maintainer=user) + with db.begin(): + pkgbase = db.create(PackageBase, + Name="beautiful-package", + Maintainer=user) assert pkgbase in user.maintained_bases assert not pkgbase.OutOfDateTS @@ -38,7 +40,8 @@ def test_package_base(): # Set Popularity to a string, then get it by attribute to # exercise the string -> float conversion path. - pkgbase.Popularity = "0.0" + with db.begin(): + pkgbase.Popularity = "0.0" assert pkgbase.Popularity == 0.0 @@ -47,27 +50,28 @@ def test_package_base_ci(): if aurweb.config.get("database", "backend") == "sqlite": return None # SQLite doesn't seem handle this. - from aurweb.db import session - - pkgbase = create(PackageBase, - Name="beautiful-package", - Maintainer=user) + with db.begin(): + pkgbase = db.create(PackageBase, + Name="beautiful-package", + Maintainer=user) assert bool(pkgbase.ID) with pytest.raises(IntegrityError): - create(PackageBase, - Name="Beautiful-Package", - Maintainer=user) - session.rollback() + with db.begin(): + db.create(PackageBase, + Name="Beautiful-Package", + Maintainer=user) + db.rollback() def test_package_base_relationships(): - pkgbase = create(PackageBase, - Name="beautiful-package", - Flagger=user, - Maintainer=user, - Submitter=user, - Packager=user) + with db.begin(): + pkgbase = db.create(PackageBase, + Name="beautiful-package", + Flagger=user, + Maintainer=user, + Submitter=user, + Packager=user) assert pkgbase in user.flagged_bases assert pkgbase in user.maintained_bases assert pkgbase in user.submitted_bases @@ -75,8 +79,7 @@ def test_package_base_relationships(): def test_package_base_null_name_raises_exception(): - from aurweb.db import session - with pytest.raises(IntegrityError): - create(PackageBase) - session.rollback() + with db.begin(): + db.create(PackageBase) + db.rollback() diff --git a/test/test_package_blacklist.py b/test/test_package_blacklist.py index 3c64cc21..93f15de7 100644 --- a/test/test_package_blacklist.py +++ b/test/test_package_blacklist.py @@ -2,7 +2,7 @@ import pytest from sqlalchemy.exc import IntegrityError -from aurweb.db import create, rollback +from aurweb import db from aurweb.models.package_base import PackageBase from aurweb.models.package_blacklist import PackageBlacklist from aurweb.models.user import User @@ -17,18 +17,20 @@ def setup(): setup_test_db("PackageBlacklist", "PackageBases", "Users") - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword") - pkgbase = create(PackageBase, Name="test-package", Maintainer=user) + user = db.create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword") + pkgbase = db.create(PackageBase, Name="test-package", Maintainer=user) def test_package_blacklist_creation(): - package_blacklist = create(PackageBlacklist, Name="evil-package") + with db.begin(): + package_blacklist = db.create(PackageBlacklist, Name="evil-package") assert bool(package_blacklist.ID) assert package_blacklist.Name == "evil-package" def test_package_blacklist_null_name_raises_exception(): with pytest.raises(IntegrityError): - create(PackageBlacklist) - rollback() + with db.begin(): + db.create(PackageBlacklist) + db.rollback() diff --git a/test/test_package_comment.py b/test/test_package_comment.py index ca77b511..60f0333d 100644 --- a/test/test_package_comment.py +++ b/test/test_package_comment.py @@ -2,7 +2,7 @@ import pytest from sqlalchemy.exc import IntegrityError -from aurweb.db import create, query, rollback +from aurweb.db import begin, create, query, rollback from aurweb.models.account_type import AccountType from aurweb.models.package_base import PackageBase from aurweb.models.package_comment import PackageComment @@ -20,45 +20,52 @@ def setup(): account_type = query(AccountType, AccountType.AccountType == "User").first() - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=account_type) - pkgbase = create(PackageBase, Name="test-package", Maintainer=user) + with begin(): + user = create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=account_type) + pkgbase = create(PackageBase, Name="test-package", Maintainer=user) def test_package_comment_creation(): - package_comment = create(PackageComment, - PackageBase=pkgbase, - User=user, - Comments="Test comment.", - RenderedComment="Test rendered comment.") + with begin(): + package_comment = create(PackageComment, + PackageBase=pkgbase, + User=user, + Comments="Test comment.", + RenderedComment="Test rendered comment.") assert bool(package_comment.ID) def test_package_comment_null_package_base_raises_exception(): with pytest.raises(IntegrityError): - create(PackageComment, User=user, Comments="Test comment.", - RenderedComment="Test rendered comment.") + with begin(): + create(PackageComment, User=user, Comments="Test comment.", + RenderedComment="Test rendered comment.") rollback() def test_package_comment_null_user_raises_exception(): with pytest.raises(IntegrityError): - create(PackageComment, PackageBase=pkgbase, Comments="Test comment.", - RenderedComment="Test rendered comment.") + with begin(): + create(PackageComment, PackageBase=pkgbase, + Comments="Test comment.", + RenderedComment="Test rendered comment.") rollback() def test_package_comment_null_comments_raises_exception(): with pytest.raises(IntegrityError): - create(PackageComment, PackageBase=pkgbase, User=user, - RenderedComment="Test rendered comment.") + with begin(): + create(PackageComment, PackageBase=pkgbase, User=user, + RenderedComment="Test rendered comment.") rollback() def test_package_comment_null_renderedcomment_defaults(): - record = create(PackageComment, - PackageBase=pkgbase, - User=user, - Comments="Test comment.") + with begin(): + record = create(PackageComment, + PackageBase=pkgbase, + User=user, + Comments="Test comment.") assert record.RenderedComment == str() diff --git a/test/test_package_dependency.py b/test/test_package_dependency.py index e28f1781..2ddef68e 100644 --- a/test/test_package_dependency.py +++ b/test/test_package_dependency.py @@ -2,7 +2,8 @@ import pytest from sqlalchemy.exc import IntegrityError -from aurweb.db import commit, create, query +from aurweb import db +from aurweb.db import create, query from aurweb.models.account_type import AccountType from aurweb.models.dependency_type import DependencyType from aurweb.models.package import Package @@ -22,25 +23,28 @@ def setup(): account_type = query(AccountType, AccountType.AccountType == "User").first() - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=account_type) - pkgbase = create(PackageBase, - Name="test-package", - Maintainer=user) - package = create(Package, - PackageBase=pkgbase, - Name=pkgbase.Name, - Description="Test description.", - URL="https://test.package") + with db.begin(): + user = create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=account_type) + pkgbase = create(PackageBase, + Name="test-package", + Maintainer=user) + package = create(Package, + PackageBase=pkgbase, + Name=pkgbase.Name, + Description="Test description.", + URL="https://test.package") def test_package_dependencies(): depends = query(DependencyType, DependencyType.Name == "depends").first() - pkgdep = create(PackageDependency, Package=package, - DependencyType=depends, - DepName="test-dep") + + with db.begin(): + pkgdep = create(PackageDependency, Package=package, + DependencyType=depends, + DepName="test-dep") assert pkgdep.DepName == "test-dep" assert pkgdep.Package == package assert pkgdep.DependencyType == depends @@ -49,8 +53,8 @@ def test_package_dependencies(): makedepends = query(DependencyType, DependencyType.Name == "makedepends").first() - pkgdep.DependencyType = makedepends - commit() + with db.begin(): + pkgdep.DependencyType = makedepends assert pkgdep.DepName == "test-dep" assert pkgdep.Package == package assert pkgdep.DependencyType == makedepends @@ -59,8 +63,8 @@ def test_package_dependencies(): checkdepends = query(DependencyType, DependencyType.Name == "checkdepends").first() - pkgdep.DependencyType = checkdepends - commit() + with db.begin(): + pkgdep.DependencyType = checkdepends assert pkgdep.DepName == "test-dep" assert pkgdep.Package == package assert pkgdep.DependencyType == checkdepends @@ -69,8 +73,8 @@ def test_package_dependencies(): optdepends = query(DependencyType, DependencyType.Name == "optdepends").first() - pkgdep.DependencyType = optdepends - commit() + with db.begin(): + pkgdep.DependencyType = optdepends assert pkgdep.DepName == "test-dep" assert pkgdep.Package == package assert pkgdep.DependencyType == optdepends @@ -79,39 +83,37 @@ def test_package_dependencies(): assert not pkgdep.is_package() - base = create(PackageBase, Name=pkgdep.DepName, Maintainer=user) - create(Package, PackageBase=base, Name=pkgdep.DepName) + with db.begin(): + base = create(PackageBase, Name=pkgdep.DepName, Maintainer=user) + create(Package, PackageBase=base, Name=pkgdep.DepName) assert pkgdep.is_package() def test_package_dependencies_null_package_raises_exception(): - from aurweb.db import session - depends = query(DependencyType, DependencyType.Name == "depends").first() with pytest.raises(IntegrityError): - create(PackageDependency, - DependencyType=depends, - DepName="test-dep") - session.rollback() + with db.begin(): + create(PackageDependency, + DependencyType=depends, + DepName="test-dep") + db.rollback() def test_package_dependencies_null_dependency_type_raises_exception(): - from aurweb.db import session - with pytest.raises(IntegrityError): - create(PackageDependency, - Package=package, - DepName="test-dep") - session.rollback() + with db.begin(): + create(PackageDependency, + Package=package, + DepName="test-dep") + db.rollback() def test_package_dependencies_null_depname_raises_exception(): - from aurweb.db import session - depends = query(DependencyType, DependencyType.Name == "depends").first() with pytest.raises(IntegrityError): - create(PackageDependency, - Package=package, - DependencyType=depends) - session.rollback() + with db.begin(): + create(PackageDependency, + Package=package, + DependencyType=depends) + db.rollback() diff --git a/test/test_package_relation.py b/test/test_package_relation.py index 766d0017..edb67078 100644 --- a/test/test_package_relation.py +++ b/test/test_package_relation.py @@ -2,7 +2,8 @@ import pytest from sqlalchemy.exc import IntegrityError, OperationalError -from aurweb.db import commit, create, query +from aurweb import db +from aurweb.db import create, query from aurweb.models.account_type import AccountType from aurweb.models.package import Package from aurweb.models.package_base import PackageBase @@ -22,25 +23,28 @@ def setup(): account_type = query(AccountType, AccountType.AccountType == "User").first() - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=account_type) - pkgbase = create(PackageBase, - Name="test-package", - Maintainer=user) - package = create(Package, - PackageBase=pkgbase, - Name=pkgbase.Name, - Description="Test description.", - URL="https://test.package") + with db.begin(): + user = create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=account_type) + pkgbase = create(PackageBase, + Name="test-package", + Maintainer=user) + package = create(Package, + PackageBase=pkgbase, + Name=pkgbase.Name, + Description="Test description.", + URL="https://test.package") def test_package_relation(): conflicts = query(RelationType, RelationType.Name == "conflicts").first() - pkgrel = create(PackageRelation, Package=package, - RelationType=conflicts, - RelName="test-relation") + + with db.begin(): + pkgrel = create(PackageRelation, Package=package, + RelationType=conflicts, + RelName="test-relation") assert pkgrel.RelName == "test-relation" assert pkgrel.Package == package assert pkgrel.RelationType == conflicts @@ -48,8 +52,8 @@ def test_package_relation(): assert pkgrel in package.package_relations provides = query(RelationType, RelationType.Name == "provides").first() - pkgrel.RelationType = provides - commit() + with db.begin(): + pkgrel.RelationType = provides assert pkgrel.RelName == "test-relation" assert pkgrel.Package == package assert pkgrel.RelationType == provides @@ -57,8 +61,8 @@ def test_package_relation(): assert pkgrel in package.package_relations replaces = query(RelationType, RelationType.Name == "replaces").first() - pkgrel.RelationType = replaces - commit() + with db.begin(): + pkgrel.RelationType = replaces assert pkgrel.RelName == "test-relation" assert pkgrel.Package == package assert pkgrel.RelationType == replaces @@ -67,36 +71,33 @@ def test_package_relation(): def test_package_relation_null_package_raises_exception(): - from aurweb.db import session - conflicts = query(RelationType, RelationType.Name == "conflicts").first() assert conflicts is not None with pytest.raises(IntegrityError): - create(PackageRelation, - RelationType=conflicts, - RelName="test-relation") - session.rollback() + with db.begin(): + create(PackageRelation, + RelationType=conflicts, + RelName="test-relation") + db.rollback() def test_package_relation_null_relation_type_raises_exception(): - from aurweb.db import session - with pytest.raises(IntegrityError): - create(PackageRelation, - Package=package, - RelName="test-relation") - session.rollback() + with db.begin(): + create(PackageRelation, + Package=package, + RelName="test-relation") + db.rollback() def test_package_relation_null_relname_raises_exception(): - from aurweb.db import session - depends = query(RelationType, RelationType.Name == "conflicts").first() assert depends is not None with pytest.raises((OperationalError, IntegrityError)): - create(PackageRelation, - Package=package, - RelationType=depends) - session.rollback() + with db.begin(): + create(PackageRelation, + Package=package, + RelationType=depends) + db.rollback() diff --git a/test/test_package_request.py b/test/test_package_request.py index c28af6bd..1589ffc2 100644 --- a/test/test_package_request.py +++ b/test/test_package_request.py @@ -4,7 +4,8 @@ import pytest from sqlalchemy.exc import IntegrityError -from aurweb.db import commit, create, query, rollback +from aurweb import db +from aurweb.db import create, query, rollback from aurweb.models.package_base import PackageBase from aurweb.models.package_request import (ACCEPTED, ACCEPTED_ID, CLOSED, CLOSED_ID, PENDING, PENDING_ID, REJECTED, REJECTED_ID, PackageRequest) @@ -21,19 +22,21 @@ def setup(): setup_test_db("PackageRequests", "PackageBases", "Users") - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword") - pkgbase = create(PackageBase, Name="test-package", Maintainer=user) + with db.begin(): + user = create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword") + pkgbase = create(PackageBase, Name="test-package", Maintainer=user) def test_package_request_creation(): request_type = query(RequestType, RequestType.Name == "merge").first() assert request_type.Name == "merge" - package_request = create(PackageRequest, RequestType=request_type, - User=user, PackageBase=pkgbase, - PackageBaseName=pkgbase.Name, - Comments=str(), ClosureComment=str()) + with db.begin(): + package_request = create(PackageRequest, RequestType=request_type, + User=user, PackageBase=pkgbase, + PackageBaseName=pkgbase.Name, + Comments=str(), ClosureComment=str()) assert bool(package_request.ID) assert package_request.RequestType == request_type @@ -54,11 +57,12 @@ def test_package_request_closed(): assert request_type.Name == "merge" ts = int(datetime.utcnow().timestamp()) - package_request = create(PackageRequest, RequestType=request_type, - User=user, PackageBase=pkgbase, - PackageBaseName=pkgbase.Name, - Closer=user, ClosedTS=ts, - Comments=str(), ClosureComment=str()) + with db.begin(): + package_request = create(PackageRequest, RequestType=request_type, + User=user, PackageBase=pkgbase, + PackageBaseName=pkgbase.Name, + Closer=user, ClosedTS=ts, + Comments=str(), ClosureComment=str()) assert package_request.Closer == user assert package_request.ClosedTS == ts @@ -69,54 +73,60 @@ def test_package_request_closed(): def test_package_request_null_request_type_raises_exception(): with pytest.raises(IntegrityError): - create(PackageRequest, User=user, PackageBase=pkgbase, - PackageBaseName=pkgbase.Name, - Comments=str(), ClosureComment=str()) + with db.begin(): + create(PackageRequest, User=user, PackageBase=pkgbase, + PackageBaseName=pkgbase.Name, + Comments=str(), ClosureComment=str()) rollback() def test_package_request_null_user_raises_exception(): request_type = query(RequestType, RequestType.Name == "merge").first() with pytest.raises(IntegrityError): - create(PackageRequest, RequestType=request_type, PackageBase=pkgbase, - PackageBaseName=pkgbase.Name, - Comments=str(), ClosureComment=str()) + with db.begin(): + create(PackageRequest, RequestType=request_type, + PackageBase=pkgbase, PackageBaseName=pkgbase.Name, + Comments=str(), ClosureComment=str()) rollback() def test_package_request_null_package_base_raises_exception(): request_type = query(RequestType, RequestType.Name == "merge").first() with pytest.raises(IntegrityError): - create(PackageRequest, RequestType=request_type, - User=user, PackageBaseName=pkgbase.Name, - Comments=str(), ClosureComment=str()) + with db.begin(): + create(PackageRequest, RequestType=request_type, + User=user, PackageBaseName=pkgbase.Name, + Comments=str(), ClosureComment=str()) rollback() def test_package_request_null_package_base_name_raises_exception(): request_type = query(RequestType, RequestType.Name == "merge").first() with pytest.raises(IntegrityError): - create(PackageRequest, RequestType=request_type, - User=user, PackageBase=pkgbase, - Comments=str(), ClosureComment=str()) + with db.begin(): + create(PackageRequest, RequestType=request_type, + User=user, PackageBase=pkgbase, + Comments=str(), ClosureComment=str()) rollback() def test_package_request_null_comments_raises_exception(): request_type = query(RequestType, RequestType.Name == "merge").first() with pytest.raises(IntegrityError): - create(PackageRequest, RequestType=request_type, - User=user, PackageBase=pkgbase, PackageBaseName=pkgbase.Name, - ClosureComment=str()) + with db.begin(): + create(PackageRequest, RequestType=request_type, User=user, + PackageBase=pkgbase, PackageBaseName=pkgbase.Name, + ClosureComment=str()) rollback() def test_package_request_null_closure_comment_raises_exception(): request_type = query(RequestType, RequestType.Name == "merge").first() with pytest.raises(IntegrityError): - create(PackageRequest, RequestType=request_type, - User=user, PackageBase=pkgbase, PackageBaseName=pkgbase.Name, - Comments=str()) + with db.begin(): + create(PackageRequest, RequestType=request_type, User=user, + PackageBase=pkgbase, PackageBaseName=pkgbase.Name, + Comments=str()) rollback() @@ -124,26 +134,27 @@ def test_package_request_status_display(): """ Test status_display() based on the Status column value. """ request_type = query(RequestType, RequestType.Name == "merge").first() - pkgreq = create(PackageRequest, RequestType=request_type, - User=user, PackageBase=pkgbase, - PackageBaseName=pkgbase.Name, - Comments=str(), ClosureComment=str(), - Status=PENDING_ID) + with db.begin(): + pkgreq = create(PackageRequest, RequestType=request_type, + User=user, PackageBase=pkgbase, + PackageBaseName=pkgbase.Name, + Comments=str(), ClosureComment=str(), + Status=PENDING_ID) assert pkgreq.status_display() == PENDING - pkgreq.Status = CLOSED_ID - commit() + with db.begin(): + pkgreq.Status = CLOSED_ID assert pkgreq.status_display() == CLOSED - pkgreq.Status = ACCEPTED_ID - commit() + with db.begin(): + pkgreq.Status = ACCEPTED_ID assert pkgreq.status_display() == ACCEPTED - pkgreq.Status = REJECTED_ID - commit() + with db.begin(): + pkgreq.Status = REJECTED_ID assert pkgreq.status_display() == REJECTED - pkgreq.Status = 124 - commit() + with db.begin(): + pkgreq.Status = 124 with pytest.raises(KeyError): pkgreq.status_display() diff --git a/test/test_package_source.py b/test/test_package_source.py index 7453f756..d1adcf9c 100644 --- a/test/test_package_source.py +++ b/test/test_package_source.py @@ -2,7 +2,7 @@ import pytest from sqlalchemy.exc import IntegrityError -from aurweb.db import create, query, rollback +from aurweb.db import begin, create, query, rollback from aurweb.models.account_type import AccountType from aurweb.models.package import Package from aurweb.models.package_base import PackageBase @@ -21,17 +21,19 @@ def setup(): account_type = query(AccountType, AccountType.AccountType == "User").first() - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=account_type) - pkgbase = create(PackageBase, - Name="test-package", - Maintainer=user) - package = create(Package, PackageBase=pkgbase, Name="test-package") + with begin(): + user = create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=account_type) + pkgbase = create(PackageBase, + Name="test-package", + Maintainer=user) + package = create(Package, PackageBase=pkgbase, Name="test-package") def test_package_source(): - pkgsource = create(PackageSource, Package=package) + with begin(): + pkgsource = create(PackageSource, Package=package) assert pkgsource.Package == package # By default, PackageSources.Source assigns the string '/dev/null'. assert pkgsource.Source == "/dev/null" @@ -40,5 +42,6 @@ def test_package_source(): def test_package_source_null_package_raises_exception(): with pytest.raises(IntegrityError): - create(PackageSource) + with begin(): + create(PackageSource) rollback() diff --git a/test/test_packages_routes.py b/test/test_packages_routes.py index ad07ec17..8a468c15 100644 --- a/test/test_packages_routes.py +++ b/test/test_packages_routes.py @@ -28,31 +28,25 @@ def package_endpoint(package: Package) -> str: return f"/packages/{package.Name}" -def create_package(pkgname: str, maintainer: User, - autocommit: bool = True) -> Package: +def create_package(pkgname: str, maintainer: User) -> Package: pkgbase = db.create(PackageBase, Name=pkgname, - Maintainer=maintainer, - autocommit=False) - return db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase, - autocommit=autocommit) + Maintainer=maintainer) + return db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase) def create_package_dep(package: Package, depname: str, - dep_type_name: str = "depends", - autocommit: bool = True) -> PackageDependency: + dep_type_name: str = "depends") -> PackageDependency: dep_type = db.query(DependencyType, DependencyType.Name == dep_type_name).first() return db.create(PackageDependency, DependencyType=dep_type, Package=package, - DepName=depname, - autocommit=autocommit) + DepName=depname) def create_package_rel(package: Package, - relname: str, - autocommit: bool = True) -> PackageRelation: + relname: str) -> PackageRelation: rel_type = db.query(RelationType, RelationType.ID == PROVIDES_ID).first() return db.create(PackageRelation, @@ -84,31 +78,37 @@ def client() -> TestClient: def user() -> User: """ Yield a user. """ account_type = db.query(AccountType, AccountType.ID == USER_ID).first() - yield db.create(User, Username="test", - Email="test@example.org", - Passwd="testPassword", - AccountType=account_type) + with db.begin(): + user = db.create(User, Username="test", + Email="test@example.org", + Passwd="testPassword", + AccountType=account_type) + yield user @pytest.fixture def maintainer() -> User: """ Yield a specific User used to maintain packages. """ account_type = db.query(AccountType, AccountType.ID == USER_ID).first() - yield db.create(User, Username="test_maintainer", - Email="test_maintainer@example.org", - Passwd="testPassword", - AccountType=account_type) + with db.begin(): + maintainer = db.create(User, Username="test_maintainer", + Email="test_maintainer@example.org", + Passwd="testPassword", + AccountType=account_type) + yield maintainer @pytest.fixture def package(maintainer: User) -> Package: """ Yield a Package created by user. """ - pkgbase = db.create(PackageBase, - Name="test-package", - Maintainer=maintainer) - yield db.create(Package, - PackageBase=pkgbase, - Name=pkgbase.Name) + with db.begin(): + pkgbase = db.create(PackageBase, + Name="test-package", + Maintainer=maintainer) + package = db.create(Package, + PackageBase=pkgbase, + Name=pkgbase.Name) + yield package def test_package_not_found(client: TestClient): @@ -121,10 +121,11 @@ def test_package_official_not_found(client: TestClient, package: Package): """ When a Package has a matching OfficialProvider record, it is not hosted on AUR, but in the official repositories. Getting a package with this kind of record should return a status code 404. """ - db.create(OfficialProvider, - Name=package.Name, - Repo="core", - Provides=package.Name) + with db.begin(): + db.create(OfficialProvider, + Name=package.Name, + Repo="core", + Provides=package.Name) with client as request: resp = request.get(package_endpoint(package)) @@ -157,8 +158,9 @@ def test_package(client: TestClient, package: Package): def test_package_comments(client: TestClient, user: User, package: Package): now = (datetime.utcnow().timestamp()) - comment = db.create(PackageComment, PackageBase=package.PackageBase, - User=user, Comments="Test comment", CommentTS=now) + with db.begin(): + comment = db.create(PackageComment, PackageBase=package.PackageBase, + User=user, Comments="Test comment", CommentTS=now) cookies = {"AURSID": user.login(Request(), "testPassword")} with client as request: @@ -178,11 +180,12 @@ def test_package_comments(client: TestClient, user: User, package: Package): def test_package_requests_display(client: TestClient, user: User, package: Package): type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first() - db.create(PackageRequest, PackageBase=package.PackageBase, - PackageBaseName=package.PackageBase.Name, - User=user, RequestType=type_, - Comments="Test comment.", - ClosureComment=str()) + with db.begin(): + db.create(PackageRequest, PackageBase=package.PackageBase, + PackageBaseName=package.PackageBase.Name, + User=user, RequestType=type_, + Comments="Test comment.", + ClosureComment=str()) # Test that a single request displays "1 pending request". with client as request: @@ -195,11 +198,12 @@ def test_package_requests_display(client: TestClient, user: User, assert target.text.strip() == "1 pending request" type_ = db.query(RequestType, RequestType.ID == DELETION_ID).first() - db.create(PackageRequest, PackageBase=package.PackageBase, - PackageBaseName=package.PackageBase.Name, - User=user, RequestType=type_, - Comments="Test comment2.", - ClosureComment=str()) + with db.begin(): + db.create(PackageRequest, PackageBase=package.PackageBase, + PackageBaseName=package.PackageBase.Name, + User=user, RequestType=type_, + Comments="Test comment2.", + ClosureComment=str()) # Test that a two requests display "2 pending requests". with client as request: @@ -271,50 +275,43 @@ def test_package_authenticated_maintainer(client: TestClient, def test_package_dependencies(client: TestClient, maintainer: User, package: Package): # Create a normal dependency of type depends. - dep_pkg = create_package("test-dep-1", maintainer, autocommit=False) - dep = create_package_dep(package, dep_pkg.Name, autocommit=False) - dep.DepArch = "x86_64" + with db.begin(): + dep_pkg = create_package("test-dep-1", maintainer) + dep = create_package_dep(package, dep_pkg.Name) + dep.DepArch = "x86_64" - # Also, create a makedepends. - make_dep_pkg = create_package("test-dep-2", maintainer, autocommit=False) - make_dep = create_package_dep(package, make_dep_pkg.Name, - dep_type_name="makedepends", - autocommit=False) + # Also, create a makedepends. + make_dep_pkg = create_package("test-dep-2", maintainer) + make_dep = create_package_dep(package, make_dep_pkg.Name, + dep_type_name="makedepends") - # And... a checkdepends! - check_dep_pkg = create_package("test-dep-3", maintainer, autocommit=False) - check_dep = create_package_dep(package, check_dep_pkg.Name, - dep_type_name="checkdepends", - autocommit=False) + # And... a checkdepends! + check_dep_pkg = create_package("test-dep-3", maintainer) + check_dep = create_package_dep(package, check_dep_pkg.Name, + dep_type_name="checkdepends") - # Geez. Just stop. This is optdepends. - opt_dep_pkg = create_package("test-dep-4", maintainer, autocommit=False) - opt_dep = create_package_dep(package, opt_dep_pkg.Name, - dep_type_name="optdepends", - autocommit=False) + # Geez. Just stop. This is optdepends. + opt_dep_pkg = create_package("test-dep-4", maintainer) + opt_dep = create_package_dep(package, opt_dep_pkg.Name, + dep_type_name="optdepends") - # Heh. Another optdepends to test one with a description. - opt_desc_dep_pkg = create_package("test-dep-5", maintainer, - autocommit=False) - opt_desc_dep = create_package_dep(package, opt_desc_dep_pkg.Name, - dep_type_name="optdepends", - autocommit=False) - opt_desc_dep.DepDesc = "Test description." + # Heh. Another optdepends to test one with a description. + opt_desc_dep_pkg = create_package("test-dep-5", maintainer) + opt_desc_dep = create_package_dep(package, opt_desc_dep_pkg.Name, + dep_type_name="optdepends") + opt_desc_dep.DepDesc = "Test description." - broken_dep = create_package_dep(package, "test-dep-6", - dep_type_name="depends", - autocommit=False) + broken_dep = create_package_dep(package, "test-dep-6", + dep_type_name="depends") - # Create an official provider record. - db.create(OfficialProvider, Name="test-dep-99", - Repo="core", Provides="test-dep-99", - autocommit=False) - official_dep = create_package_dep(package, "test-dep-99", - autocommit=False) + # Create an official provider record. + db.create(OfficialProvider, Name="test-dep-99", + Repo="core", Provides="test-dep-99") + official_dep = create_package_dep(package, "test-dep-99") - # Also, create a provider who provides our test-dep-99. - provider = create_package("test-provider", maintainer, autocommit=False) - create_package_rel(provider, dep.DepName) + # Also, create a provider who provides our test-dep-99. + provider = create_package("test-provider", maintainer) + create_package_rel(provider, dep.DepName) with client as request: resp = request.get(package_endpoint(package)) @@ -358,8 +355,9 @@ def test_pkgbase_redirect(client: TestClient, package: Package): def test_pkgbase(client: TestClient, package: Package): - second = db.create(Package, Name="second-pkg", - PackageBase=package.PackageBase) + with db.begin(): + second = db.create(Package, Name="second-pkg", + PackageBase=package.PackageBase) expected = [package.Name, second.Name] with client as request: diff --git a/test/test_packages_util.py b/test/test_packages_util.py index bc6a941c..754e3b8d 100644 --- a/test/test_packages_util.py +++ b/test/test_packages_util.py @@ -26,17 +26,21 @@ def setup(): @pytest.fixture def maintainer() -> User: account_type = db.query(AccountType, AccountType.ID == USER_ID).first() - yield db.create(User, Username="test_maintainer", - Email="test_maintainer@examepl.org", - Passwd="testPassword", - AccountType=account_type) + with db.begin(): + maintainer = db.create(User, Username="test_maintainer", + Email="test_maintainer@examepl.org", + Passwd="testPassword", + AccountType=account_type) + yield maintainer @pytest.fixture def package(maintainer: User) -> Package: - pkgbase = db.create(PackageBase, Name="test-pkg", - Packager=maintainer, Maintainer=maintainer) - yield db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase) + with db.begin(): + pkgbase = db.create(PackageBase, Name="test-pkg", + Packager=maintainer, Maintainer=maintainer) + package = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase) + yield package @pytest.fixture @@ -45,10 +49,11 @@ def client() -> TestClient: def test_package_link(client: TestClient, maintainer: User, package: Package): - db.create(OfficialProvider, - Name=package.Name, - Repo="core", - Provides=package.Name) + with db.begin(): + db.create(OfficialProvider, + Name=package.Name, + Repo="core", + Provides=package.Name) expected = f"{OFFICIAL_BASE}/packages/?q={package.Name}" assert util.package_link(package) == expected diff --git a/test/test_relation_type.py b/test/test_relation_type.py index bf23505c..fbc22c71 100644 --- a/test/test_relation_type.py +++ b/test/test_relation_type.py @@ -1,6 +1,6 @@ import pytest -from aurweb.db import create, delete, query +from aurweb import db from aurweb.models.relation_type import RelationType from aurweb.testing import setup_test_db @@ -11,22 +11,25 @@ def setup(): def test_relation_type_creation(): - relation_type = create(RelationType, Name="test-relation") + with db.begin(): + relation_type = db.create(RelationType, Name="test-relation") + assert bool(relation_type.ID) assert relation_type.Name == "test-relation" - delete(RelationType, RelationType.ID == relation_type.ID) + with db.begin(): + db.delete(RelationType, RelationType.ID == relation_type.ID) def test_relation_types(): - conflicts = query(RelationType, RelationType.Name == "conflicts").first() + conflicts = db.query(RelationType, RelationType.Name == "conflicts").first() assert conflicts is not None assert conflicts.Name == "conflicts" - provides = query(RelationType, RelationType.Name == "provides").first() + provides = db.query(RelationType, RelationType.Name == "provides").first() assert provides is not None assert provides.Name == "provides" - replaces = query(RelationType, RelationType.Name == "replaces").first() + replaces = db.query(RelationType, RelationType.Name == "replaces").first() assert replaces is not None assert replaces.Name == "replaces" diff --git a/test/test_request_type.py b/test/test_request_type.py index a3b3ccb8..8d21c2d9 100644 --- a/test/test_request_type.py +++ b/test/test_request_type.py @@ -1,6 +1,6 @@ import pytest -from aurweb.db import create, delete, query +from aurweb import db from aurweb.models.request_type import DELETION_ID, MERGE_ID, ORPHAN_ID, RequestType from aurweb.testing import setup_test_db @@ -11,25 +11,33 @@ def setup(): def test_request_type_creation(): - request_type = create(RequestType, Name="Test Request") + with db.begin(): + request_type = db.create(RequestType, Name="Test Request") + assert bool(request_type.ID) assert request_type.Name == "Test Request" - delete(RequestType, RequestType.ID == request_type.ID) + + with db.begin(): + db.delete(RequestType, RequestType.ID == request_type.ID) def test_request_type_null_name_returns_empty_string(): - request_type = create(RequestType) + with db.begin(): + request_type = db.create(RequestType) + assert bool(request_type.ID) assert request_type.Name == str() - delete(RequestType, RequestType.ID == request_type.ID) + + with db.begin(): + db.delete(RequestType, RequestType.ID == request_type.ID) def test_request_type_name_display(): - deletion = query(RequestType, RequestType.ID == DELETION_ID).first() + deletion = db.query(RequestType, RequestType.ID == DELETION_ID).first() assert deletion.name_display() == "Deletion" - orphan = query(RequestType, RequestType.ID == ORPHAN_ID).first() + orphan = db.query(RequestType, RequestType.ID == ORPHAN_ID).first() assert orphan.name_display() == "Orphan" - merge = query(RequestType, RequestType.ID == MERGE_ID).first() + merge = db.query(RequestType, RequestType.ID == MERGE_ID).first() assert merge.name_display() == "Merge" diff --git a/test/test_routes.py b/test/test_routes.py index a2d1786e..e3f69d7a 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -8,8 +8,8 @@ import pytest from fastapi.testclient import TestClient +from aurweb import db from aurweb.asgi import app -from aurweb.db import create, query from aurweb.models.account_type import AccountType from aurweb.models.user import User from aurweb.testing import setup_test_db @@ -24,11 +24,13 @@ def setup(): setup_test_db("Users", "Sessions") - account_type = query(AccountType, - AccountType.AccountType == "User").first() - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=account_type) + account_type = db.query(AccountType, + AccountType.AccountType == "User").first() + + with db.begin(): + user = db.create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=account_type) client = TestClient(app) diff --git a/test/test_rss.py b/test/test_rss.py index 7dd5bb47..ce3bc71f 100644 --- a/test/test_rss.py +++ b/test/test_rss.py @@ -49,14 +49,13 @@ def packages(user): now = int(datetime.utcnow().timestamp()) # Create 101 packages; we limit 100 on RSS feeds. - for i in range(101): - pkgbase = db.create( - PackageBase, Maintainer=user, Name=f"test-package-{i}", - SubmittedTS=(now + i), ModifiedTS=(now + i), autocommit=False) - pkg = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase, - autocommit=False) - pkgs.append(pkg) - db.commit() + with db.begin(): + for i in range(101): + pkgbase = db.create( + PackageBase, Maintainer=user, Name=f"test-package-{i}", + SubmittedTS=(now + i), ModifiedTS=(now + i)) + pkg = db.create(Package, Name=pkgbase.Name, PackageBase=pkgbase) + pkgs.append(pkg) yield pkgs diff --git a/test/test_session.py b/test/test_session.py index 1ba11556..4e6f4db4 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -4,7 +4,7 @@ from unittest import mock import pytest -from aurweb.db import create, query +from aurweb import db from aurweb.models.account_type import AccountType from aurweb.models.session import Session, generate_unique_sid from aurweb.models.user import User @@ -19,13 +19,16 @@ def setup(): setup_test_db("Users", "Sessions") - account_type = query(AccountType, - AccountType.AccountType == "User").first() - user = create(User, Username="test", Email="test@example.org", - ResetKey="testReset", Passwd="testPassword", - AccountType=account_type) - session = create(Session, UsersID=user.ID, SessionID="testSession", - LastUpdateTS=datetime.utcnow().timestamp()) + account_type = db.query(AccountType, + AccountType.AccountType == "User").first() + with db.begin(): + user = db.create(User, Username="test", Email="test@example.org", + ResetKey="testReset", Passwd="testPassword", + AccountType=account_type) + + with db.begin(): + session = db.create(Session, UsersID=user.ID, SessionID="testSession", + LastUpdateTS=datetime.utcnow().timestamp()) def test_session(): @@ -35,12 +38,15 @@ def test_session(): def test_session_cs(): """ Test case sensitivity of the database table. """ - user2 = create(User, Username="test2", Email="test2@example.org", - ResetKey="testReset2", Passwd="testPassword", - AccountType=account_type) - session_cs = create(Session, UsersID=user2.ID, - SessionID="TESTSESSION", - LastUpdateTS=datetime.utcnow().timestamp()) + with db.begin(): + user2 = db.create(User, Username="test2", Email="test2@example.org", + ResetKey="testReset2", Passwd="testPassword", + AccountType=account_type) + + with db.begin(): + session_cs = db.create(Session, UsersID=user2.ID, + SessionID="TESTSESSION", + LastUpdateTS=datetime.utcnow().timestamp()) assert session_cs.SessionID == "TESTSESSION" assert session.SessionID == "testSession" diff --git a/test/test_ssh_pub_key.py b/test/test_ssh_pub_key.py index 0793199a..12a3e1ce 100644 --- a/test/test_ssh_pub_key.py +++ b/test/test_ssh_pub_key.py @@ -1,6 +1,6 @@ import pytest -from aurweb.db import create, query +from aurweb import db from aurweb.models.account_type import AccountType from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint from aurweb.models.user import User @@ -19,19 +19,18 @@ def setup(): setup_test_db("Users", "SSHPubKeys") - account_type = query(AccountType, - AccountType.AccountType == "User").first() - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=account_type) + account_type = db.query(AccountType, + AccountType.AccountType == "User").first() + with db.begin(): + user = db.create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=account_type) - assert account_type == user.AccountType - assert account_type.ID == user.AccountTypeID - - ssh_pub_key = create(SSHPubKey, - UserID=user.ID, - Fingerprint="testFingerprint", - PubKey="testPubKey") + with db.begin(): + ssh_pub_key = db.create(SSHPubKey, + UserID=user.ID, + Fingerprint="testFingerprint", + PubKey="testPubKey") def test_ssh_pub_key(): @@ -43,9 +42,10 @@ def test_ssh_pub_key(): def test_ssh_pub_key_cs(): """ Test case sensitivity of the database table. """ - ssh_pub_key_cs = create(SSHPubKey, UserID=user.ID, - Fingerprint="TESTFINGERPRINT", - PubKey="TESTPUBKEY") + with db.begin(): + ssh_pub_key_cs = db.create(SSHPubKey, UserID=user.ID, + Fingerprint="TESTFINGERPRINT", + PubKey="TESTPUBKEY") assert ssh_pub_key_cs.Fingerprint == "TESTFINGERPRINT" assert ssh_pub_key_cs.PubKey == "TESTPUBKEY" diff --git a/test/test_term.py b/test/test_term.py index 25108419..3f28311f 100644 --- a/test/test_term.py +++ b/test/test_term.py @@ -2,7 +2,7 @@ import pytest from sqlalchemy.exc import IntegrityError -from aurweb.db import create +from aurweb import db from aurweb.models.term import Term from aurweb.testing import setup_test_db @@ -18,8 +18,9 @@ def setup(): def test_term_creation(): - term = create(Term, Description="Term description", - URL="https://fake_url.io") + with db.begin(): + term = db.create(Term, Description="Term description", + URL="https://fake_url.io") assert bool(term.ID) assert term.Description == "Term description" assert term.URL == "https://fake_url.io" @@ -27,14 +28,14 @@ def test_term_creation(): def test_term_null_description_raises_exception(): - from aurweb.db import session with pytest.raises(IntegrityError): - create(Term, URL="https://fake_url.io") - session.rollback() + with db.begin(): + db.create(Term, URL="https://fake_url.io") + db.rollback() def test_term_null_url_raises_exception(): - from aurweb.db import session with pytest.raises(IntegrityError): - create(Term, Description="Term description") - session.rollback() + with db.begin(): + db.create(Term, Description="Term description") + db.rollback() diff --git a/test/test_trusted_user_routes.py b/test/test_trusted_user_routes.py index 0c33f958..67181db3 100644 --- a/test/test_trusted_user_routes.py +++ b/test/test_trusted_user_routes.py @@ -90,37 +90,37 @@ def client(): def tu_user(): tu_type = db.query(AccountType, AccountType.AccountType == "Trusted User").first() - yield db.create(User, Username="test_tu", Email="test_tu@example.org", - RealName="Test TU", Passwd="testPassword", - AccountType=tu_type) + with db.begin(): + tu_user = db.create(User, Username="test_tu", + Email="test_tu@example.org", + RealName="Test TU", Passwd="testPassword", + AccountType=tu_type) + yield tu_user @pytest.fixture def user(): user_type = db.query(AccountType, AccountType.AccountType == "User").first() - yield db.create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=user_type) + with db.begin(): + user = db.create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=user_type) + yield user @pytest.fixture -def proposal(tu_user): +def proposal(user, tu_user): ts = int(datetime.utcnow().timestamp()) agenda = "Test proposal." start = ts - 5 end = ts + 1000 - user_type = db.query(AccountType, - AccountType.AccountType == "User").first() - user = db.create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=user_type) - - voteinfo = db.create(TUVoteInfo, - Agenda=agenda, Quorum=0.0, - User=user.Username, Submitter=tu_user, - Submitted=start, End=end) + with db.begin(): + voteinfo = db.create(TUVoteInfo, + Agenda=agenda, Quorum=0.0, + User=user.Username, Submitter=tu_user, + Submitted=start, End=end) yield (tu_user, user, voteinfo) @@ -170,20 +170,22 @@ def test_tu_index(client, tu_user): ("Test agenda 2", ts - 1000, ts - 5) # Not running anymore. ] vote_records = [] - for vote in votes: - agenda, start, end = vote - vote_records.append( - db.create(TUVoteInfo, Agenda=agenda, - User=tu_user.Username, - Submitted=start, End=end, - Quorum=0.0, - Submitter=tu_user)) + with db.begin(): + for vote in votes: + agenda, start, end = vote + vote_records.append( + db.create(TUVoteInfo, Agenda=agenda, + User=tu_user.Username, + Submitted=start, End=end, + Quorum=0.0, + Submitter=tu_user)) - # Vote on an ended proposal. - vote_record = vote_records[1] - vote_record.Yes += 1 - vote_record.ActiveTUs += 1 - db.create(TUVote, VoteInfo=vote_record, User=tu_user) + with db.begin(): + # Vote on an ended proposal. + vote_record = vote_records[1] + vote_record.Yes += 1 + vote_record.ActiveTUs += 1 + db.create(TUVote, VoteInfo=vote_record, User=tu_user) cookies = {"AURSID": tu_user.login(Request(), "testPassword")} with client as request: @@ -255,22 +257,22 @@ def test_tu_index(client, tu_user): def test_tu_index_table_paging(client, tu_user): ts = int(datetime.utcnow().timestamp()) - for i in range(25): - # Create 25 current votes. - db.create(TUVoteInfo, Agenda=f"Agenda #{i}", - User=tu_user.Username, - Submitted=(ts - 5), End=(ts + 1000), - Quorum=0.0, - Submitter=tu_user, autocommit=False) + with db.begin(): + for i in range(25): + # Create 25 current votes. + db.create(TUVoteInfo, Agenda=f"Agenda #{i}", + User=tu_user.Username, + Submitted=(ts - 5), End=(ts + 1000), + Quorum=0.0, + Submitter=tu_user) - for i in range(25): - # Create 25 past votes. - db.create(TUVoteInfo, Agenda=f"Agenda #{25 + i}", - User=tu_user.Username, - Submitted=(ts - 1000), End=(ts - 5), - Quorum=0.0, - Submitter=tu_user, autocommit=False) - db.commit() + for i in range(25): + # Create 25 past votes. + db.create(TUVoteInfo, Agenda=f"Agenda #{25 + i}", + User=tu_user.Username, + Submitted=(ts - 1000), End=(ts - 5), + Quorum=0.0, + Submitter=tu_user) cookies = {"AURSID": tu_user.login(Request(), "testPassword")} with client as request: @@ -363,18 +365,19 @@ def test_tu_index_table_paging(client, tu_user): def test_tu_index_sorting(client, tu_user): ts = int(datetime.utcnow().timestamp()) - for i in range(2): - # Create 'Agenda #1' and 'Agenda #2'. - db.create(TUVoteInfo, Agenda=f"Agenda #{i + 1}", - User=tu_user.Username, - Submitted=(ts + 5), End=(ts + 1000), - Quorum=0.0, - Submitter=tu_user, autocommit=False) + with db.begin(): + for i in range(2): + # Create 'Agenda #1' and 'Agenda #2'. + db.create(TUVoteInfo, Agenda=f"Agenda #{i + 1}", + User=tu_user.Username, + Submitted=(ts + 5), End=(ts + 1000), + Quorum=0.0, + Submitter=tu_user) - # Let's order each vote one day after the other. - # This will allow us to test the sorting nature - # of the tables. - ts += 86405 + # Let's order each vote one day after the other. + # This will allow us to test the sorting nature + # of the tables. + ts += 86405 # Make a default request to /tu. cookies = {"AURSID": tu_user.login(Request(), "testPassword")} @@ -432,18 +435,19 @@ def test_tu_index_sorting(client, tu_user): def test_tu_index_last_votes(client, tu_user, user): ts = int(datetime.utcnow().timestamp()) - # Create a proposal which has ended. - voteinfo = db.create(TUVoteInfo, Agenda="Test agenda", - User=user.Username, - Submitted=(ts - 1000), - End=(ts - 5), - Yes=1, - ActiveTUs=1, - Quorum=0.0, - Submitter=tu_user) + with db.begin(): + # Create a proposal which has ended. + voteinfo = db.create(TUVoteInfo, Agenda="Test agenda", + User=user.Username, + Submitted=(ts - 1000), + End=(ts - 5), + Yes=1, + ActiveTUs=1, + Quorum=0.0, + Submitter=tu_user) - # Create a vote on it from tu_user. - db.create(TUVote, VoteInfo=voteinfo, User=tu_user) + # Create a vote on it from tu_user. + db.create(TUVote, VoteInfo=voteinfo, User=tu_user) # Now, check that tu_user got populated in the .last-votes table. cookies = {"AURSID": tu_user.login(Request(), "testPassword")} @@ -529,10 +533,10 @@ def test_tu_running_proposal(client, proposal): assert abstain.attrib["value"] == "Abstain" # Create a vote. - db.create(TUVote, VoteInfo=voteinfo, User=tu_user) - voteinfo.ActiveTUs += 1 - voteinfo.Yes += 1 - db.commit() + with db.begin(): + db.create(TUVote, VoteInfo=voteinfo, User=tu_user) + voteinfo.ActiveTUs += 1 + voteinfo.Yes += 1 # Make another request now that we've voted. with client as request: @@ -556,8 +560,8 @@ def test_tu_ended_proposal(client, proposal): tu_user, user, voteinfo = proposal ts = int(datetime.utcnow().timestamp()) - voteinfo.End = ts - 5 # 5 seconds ago. - db.commit() + with db.begin(): + voteinfo.End = ts - 5 # 5 seconds ago. # Initiate an authenticated GET request to /tu/{proposal_id}. proposal_id = voteinfo.ID @@ -635,8 +639,8 @@ def test_tu_proposal_vote_unauthorized(client, proposal): dev_type = db.query(AccountType, AccountType.AccountType == "Developer").first() - tu_user.AccountType = dev_type - db.commit() + with db.begin(): + tu_user.AccountType = dev_type cookies = {"AURSID": tu_user.login(Request(), "testPassword")} with client as request: @@ -664,8 +668,8 @@ def test_tu_proposal_vote_cant_self_vote(client, proposal): tu_user, user, voteinfo = proposal # Update voteinfo.User. - voteinfo.User = tu_user.Username - db.commit() + with db.begin(): + voteinfo.User = tu_user.Username cookies = {"AURSID": tu_user.login(Request(), "testPassword")} with client as request: @@ -692,10 +696,10 @@ def test_tu_proposal_vote_cant_self_vote(client, proposal): def test_tu_proposal_vote_already_voted(client, proposal): tu_user, user, voteinfo = proposal - db.create(TUVote, VoteInfo=voteinfo, User=tu_user) - voteinfo.Yes += 1 - voteinfo.ActiveTUs += 1 - db.commit() + with db.begin(): + db.create(TUVote, VoteInfo=voteinfo, User=tu_user) + voteinfo.Yes += 1 + voteinfo.ActiveTUs += 1 cookies = {"AURSID": tu_user.login(Request(), "testPassword")} with client as request: diff --git a/test/test_tu_voteinfo.py b/test/test_tu_voteinfo.py index 494300c5..b60e2e6a 100644 --- a/test/test_tu_voteinfo.py +++ b/test/test_tu_voteinfo.py @@ -4,7 +4,8 @@ import pytest from sqlalchemy.exc import IntegrityError -from aurweb.db import commit, create, query, rollback +from aurweb import db +from aurweb.db import create, query, rollback from aurweb.models.account_type import AccountType from aurweb.models.tu_voteinfo import TUVoteInfo from aurweb.models.user import User @@ -21,19 +22,21 @@ def setup(): tu_type = query(AccountType, AccountType.AccountType == "Trusted User").first() - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=tu_type) + with db.begin(): + user = create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=tu_type) def test_tu_voteinfo_creation(): ts = int(datetime.utcnow().timestamp()) - tu_voteinfo = create(TUVoteInfo, - Agenda="Blah blah.", - User=user.Username, - Submitted=ts, End=ts + 5, - Quorum=0.5, - Submitter=user) + with db.begin(): + tu_voteinfo = create(TUVoteInfo, + Agenda="Blah blah.", + User=user.Username, + Submitted=ts, End=ts + 5, + Quorum=0.5, + Submitter=user) assert bool(tu_voteinfo.ID) assert tu_voteinfo.Agenda == "Blah blah." assert tu_voteinfo.User == user.Username @@ -51,32 +54,33 @@ def test_tu_voteinfo_creation(): def test_tu_voteinfo_is_running(): ts = int(datetime.utcnow().timestamp()) - tu_voteinfo = create(TUVoteInfo, - Agenda="Blah blah.", - User=user.Username, - Submitted=ts, End=ts + 1000, - Quorum=0.5, - Submitter=user) + with db.begin(): + tu_voteinfo = create(TUVoteInfo, + Agenda="Blah blah.", + User=user.Username, + Submitted=ts, End=ts + 1000, + Quorum=0.5, + Submitter=user) assert tu_voteinfo.is_running() is True - tu_voteinfo.End = ts - 5 - commit() + with db.begin(): + tu_voteinfo.End = ts - 5 assert tu_voteinfo.is_running() is False def test_tu_voteinfo_total_votes(): ts = int(datetime.utcnow().timestamp()) - tu_voteinfo = create(TUVoteInfo, - Agenda="Blah blah.", - User=user.Username, - Submitted=ts, End=ts + 1000, - Quorum=0.5, - Submitter=user) + with db.begin(): + tu_voteinfo = create(TUVoteInfo, + Agenda="Blah blah.", + User=user.Username, + Submitted=ts, End=ts + 1000, + Quorum=0.5, + Submitter=user) - tu_voteinfo.Yes = 1 - tu_voteinfo.No = 3 - tu_voteinfo.Abstain = 5 - commit() + tu_voteinfo.Yes = 1 + tu_voteinfo.No = 3 + tu_voteinfo.Abstain = 5 # total_votes() should be the sum of Yes, No and Abstain: 1 + 3 + 5 = 9. assert tu_voteinfo.total_votes() == 9 @@ -84,61 +88,67 @@ def test_tu_voteinfo_total_votes(): def test_tu_voteinfo_null_submitter_raises_exception(): with pytest.raises(IntegrityError): - create(TUVoteInfo, - Agenda="Blah blah.", - User=user.Username, - Submitted=0, End=0, - Quorum=0.50) + with db.begin(): + create(TUVoteInfo, + Agenda="Blah blah.", + User=user.Username, + Submitted=0, End=0, + Quorum=0.50) rollback() def test_tu_voteinfo_null_agenda_raises_exception(): with pytest.raises(IntegrityError): - create(TUVoteInfo, - User=user.Username, - Submitted=0, End=0, - Quorum=0.50, - Submitter=user) + with db.begin(): + create(TUVoteInfo, + User=user.Username, + Submitted=0, End=0, + Quorum=0.50, + Submitter=user) rollback() def test_tu_voteinfo_null_user_raises_exception(): with pytest.raises(IntegrityError): - create(TUVoteInfo, - Agenda="Blah blah.", - Submitted=0, End=0, - Quorum=0.50, - Submitter=user) + with db.begin(): + create(TUVoteInfo, + Agenda="Blah blah.", + Submitted=0, End=0, + Quorum=0.50, + Submitter=user) rollback() def test_tu_voteinfo_null_submitted_raises_exception(): with pytest.raises(IntegrityError): - create(TUVoteInfo, - Agenda="Blah blah.", - User=user.Username, - End=0, - Quorum=0.50, - Submitter=user) + with db.begin(): + create(TUVoteInfo, + Agenda="Blah blah.", + User=user.Username, + End=0, + Quorum=0.50, + Submitter=user) rollback() def test_tu_voteinfo_null_end_raises_exception(): with pytest.raises(IntegrityError): - create(TUVoteInfo, - Agenda="Blah blah.", - User=user.Username, - Submitted=0, - Quorum=0.50, - Submitter=user) + with db.begin(): + create(TUVoteInfo, + Agenda="Blah blah.", + User=user.Username, + Submitted=0, + Quorum=0.50, + Submitter=user) rollback() def test_tu_voteinfo_null_quorum_raises_exception(): with pytest.raises(IntegrityError): - create(TUVoteInfo, - Agenda="Blah blah.", - User=user.Username, - Submitted=0, End=0, - Submitter=user) + with db.begin(): + create(TUVoteInfo, + Agenda="Blah blah.", + User=user.Username, + Submitted=0, End=0, + Submitter=user) rollback() diff --git a/test/test_user.py b/test/test_user.py index 7756cff3..70eac079 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -9,7 +9,7 @@ import pytest import aurweb.auth import aurweb.config -from aurweb.db import commit, create, query +from aurweb import db from aurweb.models.account_type import AccountType from aurweb.models.ban import Ban from aurweb.models.package import Package @@ -40,12 +40,13 @@ def setup(): PackageNotification.__tablename__ ) - account_type = query(AccountType, - AccountType.AccountType == "User").first() + account_type = db.query(AccountType, + AccountType.AccountType == "User").first() - user = create(User, Username="test", Email="test@example.org", - RealName="Test User", Passwd="testPassword", - AccountType=account_type) + with db.begin(): + user = db.create(User, Username="test", Email="test@example.org", + RealName="Test User", Passwd="testPassword", + AccountType=account_type) def test_user_login_logout(): @@ -70,14 +71,14 @@ def test_user_login_logout(): assert "AURSID" in request.cookies # Expect that User session relationships work right. - user_session = query(Session, - Session.UsersID == user.ID).first() + user_session = db.query(Session, + Session.UsersID == user.ID).first() assert user_session == user.session assert user.session.SessionID == sid assert user.session.User == user # Search for the user via query API. - result = query(User, User.ID == user.ID).first() + result = db.query(User, User.ID == user.ID).first() # Compare the result and our original user. assert result == user @@ -114,7 +115,8 @@ def test_user_login_twice(): def test_user_login_banned(): # Add ban for the next 30 seconds. banned_timestamp = datetime.utcnow() + timedelta(seconds=30) - create(Ban, IPAddress="127.0.0.1", BanTS=banned_timestamp) + with db.begin(): + db.create(Ban, IPAddress="127.0.0.1", BanTS=banned_timestamp) request = Request() request.client.host = "127.0.0.1" @@ -122,18 +124,17 @@ def test_user_login_banned(): def test_user_login_suspended(): - from aurweb.db import session - user.Suspended = True - session.commit() + with db.begin(): + user.Suspended = True assert not user.login(Request(), "testPassword") def test_legacy_user_authentication(): - from aurweb.db import session - - user.Salt = bcrypt.gensalt().decode() - user.Passwd = hashlib.md5(f"{user.Salt}testPassword".encode()).hexdigest() - session.commit() + with db.begin(): + user.Salt = bcrypt.gensalt().decode() + user.Passwd = hashlib.md5( + f"{user.Salt}testPassword".encode() + ).hexdigest() assert not user.valid_password("badPassword") assert user.valid_password("testPassword") @@ -145,8 +146,9 @@ def test_legacy_user_authentication(): def test_user_login_with_outdated_sid(): # Make a session with a LastUpdateTS 5 seconds ago, causing # user.login to update it with a new sid. - create(Session, UsersID=user.ID, SessionID="stub", - LastUpdateTS=datetime.utcnow().timestamp() - 5) + with db.begin(): + db.create(Session, UsersID=user.ID, SessionID="stub", + LastUpdateTS=datetime.utcnow().timestamp() - 5) sid = user.login(Request(), "testPassword") assert sid and user.is_authenticated() assert sid != "stub" @@ -171,43 +173,42 @@ def test_user_has_credential(): def test_user_ssh_pub_key(): assert user.ssh_pub_key is None - ssh_pub_key = create(SSHPubKey, UserID=user.ID, - Fingerprint="testFingerprint", - PubKey="testPubKey") + with db.begin(): + ssh_pub_key = db.create(SSHPubKey, UserID=user.ID, + Fingerprint="testFingerprint", + PubKey="testPubKey") assert user.ssh_pub_key == ssh_pub_key def test_user_credential_types(): - from aurweb.db import session - assert aurweb.auth.user_developer_or_trusted_user(user) assert not aurweb.auth.trusted_user(user) assert not aurweb.auth.developer(user) assert not aurweb.auth.trusted_user_or_dev(user) - trusted_user_type = query(AccountType, - AccountType.AccountType == "Trusted User")\ - .first() - user.AccountType = trusted_user_type - session.commit() + trusted_user_type = db.query(AccountType).filter( + AccountType.AccountType == "Trusted User" + ).first() + with db.begin(): + user.AccountType = trusted_user_type assert aurweb.auth.trusted_user(user) assert aurweb.auth.trusted_user_or_dev(user) - developer_type = query(AccountType, - AccountType.AccountType == "Developer").first() - user.AccountType = developer_type - session.commit() + developer_type = db.query(AccountType, + AccountType.AccountType == "Developer").first() + with db.begin(): + user.AccountType = developer_type assert aurweb.auth.developer(user) assert aurweb.auth.trusted_user_or_dev(user) type_str = "Trusted User & Developer" - elevated_type = query(AccountType, - AccountType.AccountType == type_str).first() - user.AccountType = elevated_type - session.commit() + elevated_type = db.query(AccountType, + AccountType.AccountType == type_str).first() + with db.begin(): + user.AccountType = elevated_type assert aurweb.auth.trusted_user(user) assert aurweb.auth.developer(user) @@ -233,53 +234,56 @@ def test_user_as_dict(): def test_user_is_trusted_user(): - tu_type = query(AccountType, - AccountType.AccountType == "Trusted User").first() - user.AccountType = tu_type - commit() + tu_type = db.query(AccountType, + AccountType.AccountType == "Trusted User").first() + with db.begin(): + user.AccountType = tu_type assert user.is_trusted_user() is True # Do it again with the combined role. - tu_type = query( + tu_type = db.query( AccountType, AccountType.AccountType == "Trusted User & Developer").first() - user.AccountType = tu_type - commit() + with db.begin(): + user.AccountType = tu_type assert user.is_trusted_user() is True def test_user_is_developer(): - dev_type = query(AccountType, - AccountType.AccountType == "Developer").first() - user.AccountType = dev_type - commit() + dev_type = db.query(AccountType, + AccountType.AccountType == "Developer").first() + with db.begin(): + user.AccountType = dev_type assert user.is_developer() is True # Do it again with the combined role. - dev_type = query( + dev_type = db.query( AccountType, AccountType.AccountType == "Trusted User & Developer").first() - user.AccountType = dev_type - commit() + with db.begin(): + user.AccountType = dev_type assert user.is_developer() is True def test_user_voted_for(): now = int(datetime.utcnow().timestamp()) - pkgbase = create(PackageBase, Name="pkg1", Maintainer=user) - pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name) - create(PackageVote, PackageBase=pkgbase, User=user, VoteTS=now) + with db.begin(): + pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user) + pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name) + db.create(PackageVote, PackageBase=pkgbase, User=user, VoteTS=now) assert user.voted_for(pkg) def test_user_notified(): - pkgbase = create(PackageBase, Name="pkg1", Maintainer=user) - pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name) - create(PackageNotification, PackageBase=pkgbase, User=user) + with db.begin(): + pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user) + pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name) + db.create(PackageNotification, PackageBase=pkgbase, User=user) assert user.notified(pkg) def test_user_packages(): - pkgbase = create(PackageBase, Name="pkg1", Maintainer=user) - pkg = create(Package, PackageBase=pkgbase, Name=pkgbase.Name) + with db.begin(): + pkgbase = db.create(PackageBase, Name="pkg1", Maintainer=user) + pkg = db.create(Package, PackageBase=pkgbase, Name=pkgbase.Name) assert pkg in user.packages()