[FastAPI] Refactor db modifications

For SQLAlchemy to automatically understand updates from the
external world, it must use an `autocommit=True` in its session.

This change breaks how we were using commit previously, as
`autocommit=True` causes SQLAlchemy to commit when a
SessionTransaction context hits __exit__.

So, a refactoring was required of our tests: All usage of
any `db.{create,delete}` must be called **within** a
SessionTransaction context, created via new `db.begin()`.

From this point forward, we're going to require:

```
with db.begin():
    db.create(...)
    db.delete(...)
    db.session.delete(object)
```

With this, we now get external DB modifications automatically
without reloading or restarting the FastAPI server, which we
absolutely need for production.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2021-09-02 16:26:48 -07:00
parent b52059d437
commit a5943bf2ad
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
37 changed files with 998 additions and 902 deletions

View file

@ -11,9 +11,9 @@ import pytest
from fastapi.testclient import TestClient
from aurweb import captcha
from aurweb import captcha, db
from aurweb.asgi import app
from aurweb.db import commit, create, query
from aurweb.db import create, query
from aurweb.models.accepted_term import AcceptedTerm
from aurweb.models.account_type import DEVELOPER_ID, TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID, AccountType
from aurweb.models.ban import Ban
@ -57,9 +57,11 @@ def setup():
account_type = query(AccountType,
AccountType.AccountType == "User").first()
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test UserZ", Passwd="testPassword",
IRCNick="testZ", AccountType=account_type)
with db.begin():
user = create(User, Username=TEST_USERNAME, Email=TEST_EMAIL,
RealName="Test UserZ", Passwd="testPassword",
IRCNick="testZ", AccountType=account_type)
yield user
@ -70,9 +72,10 @@ def setup():
@pytest.fixture
def tu_user():
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_AND_DEV_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
yield user
@ -149,11 +152,9 @@ def test_post_passreset_user():
def test_post_passreset_resetkey():
from aurweb.db import session
user.session = Session(UsersID=user.ID, SessionID="blah",
LastUpdateTS=datetime.utcnow().timestamp())
session.commit()
with db.begin():
user.session = Session(UsersID=user.ID, SessionID="blah",
LastUpdateTS=datetime.utcnow().timestamp())
# Prepare a password reset.
with client as request:
@ -357,7 +358,8 @@ def test_post_register_error_invalid_captcha():
def test_post_register_error_ip_banned():
# 'testclient' is used as request.client.host via FastAPI TestClient.
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
with db.begin():
create(Ban, IPAddress="testclient", BanTS=datetime.utcnow())
with client as request:
response = post_register(request)
@ -576,7 +578,8 @@ def test_post_register_error_ssh_pubkey_taken():
# Take the sha256 fingerprint of the ssh public key, create it.
fp = get_fingerprint(pk)
create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp)
with db.begin():
create(SSHPubKey, UserID=user.ID, PubKey=pk, Fingerprint=fp)
with client as request:
response = post_register(request, PK=pk)
@ -660,13 +663,11 @@ def test_post_account_edit():
def test_post_account_edit_dev():
from aurweb.db import session
# Modify our user to be a "Trusted User & Developer"
name = "Trusted User & Developer"
tu_or_dev = query(AccountType, AccountType.AccountType == name).first()
user.AccountType = tu_or_dev
session.commit()
with db.begin():
user.AccountType = tu_or_dev
request = Request()
sid = user.login(request, "testPassword")
@ -1001,21 +1002,19 @@ def get_rows(html):
def test_post_accounts(tu_user):
# Set a PGPKey.
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
with db.begin():
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
# Create a few more users.
users = [user]
for i in range(10):
_user = create(User, Username=f"test_{i}",
Email=f"test_{i}@example.org",
RealName=f"Test #{i}",
Passwd="testPassword",
IRCNick=f"test_#{i}",
autocommit=False)
users.append(_user)
# Commit everything to the database.
commit()
with db.begin():
for i in range(10):
_user = create(User, Username=f"test_{i}",
Email=f"test_{i}@example.org",
RealName=f"Test #{i}",
Passwd="testPassword",
IRCNick=f"test_#{i}")
users.append(_user)
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
@ -1085,11 +1084,12 @@ def test_post_accounts_account_type(tu_user):
# test the `u` parameter.
account_type = query(AccountType,
AccountType.AccountType == "User").first()
create(User, Username="test_2",
Email="test_2@example.org",
RealName="Test User 2",
Passwd="testPassword",
AccountType=account_type)
with db.begin():
create(User, Username="test_2",
Email="test_2@example.org",
RealName="Test User 2",
Passwd="testPassword",
AccountType=account_type)
# Expect no entries; we marked our only user as a User type.
with client as request:
@ -1113,9 +1113,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "User"
# Set our only user to a Trusted User.
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_ID
).first()
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1130,9 +1131,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "Trusted User"
user.AccountType = query(AccountType,
AccountType.ID == DEVELOPER_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == DEVELOPER_ID
).first()
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1147,10 +1149,10 @@ def test_post_accounts_account_type(tu_user):
assert type.text.strip() == "Developer"
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1182,8 +1184,8 @@ def test_post_accounts_status(tu_user):
username, type, status, realname, irc, pgp_key, edit = row
assert status.text.strip() == "Active"
user.Suspended = True
commit()
with db.begin():
user.Suspended = True
with client as request:
response = request.post("/accounts/", cookies=cookies,
@ -1244,12 +1246,13 @@ def test_post_accounts_sortby(tu_user):
# Create a second user so we can compare sorts.
account_type = query(AccountType,
AccountType.ID == DEVELOPER_ID).first()
create(User, Username="test2",
Email="test2@example.org",
RealName="Test User 2",
Passwd="testPassword",
IRCNick="test2",
AccountType=account_type)
with db.begin():
create(User, Username="test2",
Email="test2@example.org",
RealName="Test User 2",
Passwd="testPassword",
IRCNick="test2",
AccountType=account_type)
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
@ -1297,9 +1300,10 @@ def test_post_accounts_sortby(tu_user):
# Test the rows are reversed when ordering by RealName.
assert compare_text_values(4, first_rows, reversed(rows)) is True
user.AccountType = query(AccountType,
AccountType.ID == TRUSTED_USER_AND_DEV_ID).first()
commit()
with db.begin():
user.AccountType = query(AccountType).filter(
AccountType.ID == TRUSTED_USER_AND_DEV_ID
).first()
# Fetch first_rows again with our new AccountType ordering.
with client as request:
@ -1322,8 +1326,8 @@ def test_post_accounts_sortby(tu_user):
def test_post_accounts_pgp_key(tu_user):
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
commit()
with db.begin():
user.PGPKey = "5F18B20346188419750745D7335F2CB41F253D30"
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
@ -1343,15 +1347,14 @@ def test_post_accounts_paged(tu_user):
users = [user]
account_type = query(AccountType,
AccountType.AccountType == "User").first()
for i in range(150):
_user = create(User, Username=f"test_#{i}",
Email=f"test_#{i}@example.org",
RealName=f"Test User #{i}",
Passwd="testPassword",
AccountType=account_type,
autocommit=False)
users.append(_user)
commit()
with db.begin():
for i in range(150):
_user = create(User, Username=f"test_#{i}",
Email=f"test_#{i}@example.org",
RealName=f"Test User #{i}",
Passwd="testPassword",
AccountType=account_type)
users.append(_user)
sid = user.login(Request(), "testPassword")
cookies = {"AURSID": sid}
@ -1414,8 +1417,9 @@ def test_post_accounts_paged(tu_user):
def test_get_terms_of_service():
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
with db.begin():
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
with client as request:
response = request.get("/tos", allow_redirects=False)
@ -1436,8 +1440,9 @@ def test_get_terms_of_service():
response = request.get("/tos", cookies=cookies, allow_redirects=False)
assert response.status_code == int(HTTPStatus.OK)
accepted_term = create(AcceptedTerm, User=user,
Term=term, Revision=term.Revision)
with db.begin():
accepted_term = create(AcceptedTerm, User=user,
Term=term, Revision=term.Revision)
with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1445,8 +1450,8 @@ def test_get_terms_of_service():
assert response.status_code == int(HTTPStatus.SEE_OTHER)
# Bump the term's revision.
term.Revision = 2
commit()
with db.begin():
term.Revision = 2
with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1454,8 +1459,8 @@ def test_get_terms_of_service():
# yet been agreed to via AcceptedTerm update.
assert response.status_code == int(HTTPStatus.OK)
accepted_term.Revision = term.Revision
commit()
with db.begin():
accepted_term.Revision = term.Revision
with client as request:
response = request.get("/tos", cookies=cookies, allow_redirects=False)
@ -1471,8 +1476,9 @@ def test_post_terms_of_service():
cookies = {"AURSID": sid} # Auth cookie.
# Create a fresh Term.
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
with db.begin():
term = create(Term, Description="Test term.",
URL="http://localhost", Revision=1)
# Test that the term we just created is listed.
with client as request:
@ -1497,8 +1503,8 @@ def test_post_terms_of_service():
assert accepted_term.Term == term
# Update the term to revision 2.
term.Revision = 2
commit()
with db.begin():
term.Revision = 2
# A GET request gives us the new revision to accept.
with client as request: