diff --git a/aurweb/db.py b/aurweb/db.py index 04c8653a..c0147720 100644 --- a/aurweb/db.py +++ b/aurweb/db.py @@ -59,24 +59,35 @@ def query(model, *args, **kwargs): return session.query(model).filter(*args, **kwargs) -def create(model, *args, **kwargs): +def create(model, autocommit: bool = True, *args, **kwargs): instance = model(*args, **kwargs) - session.add(instance) - session.commit() + add(instance) + if autocommit is True: + commit() return instance -def delete(model, *args, **kwargs): +def delete(model, *args, autocommit: bool = True, **kwargs): instance = session.query(model).filter(*args, **kwargs) for record in instance: session.delete(record) - session.commit() + if autocommit is True: + commit() def rollback(): session.rollback() +def add(model): + session.add(model) + return model + + +def commit(): + session.commit() + + def get_sqlalchemy_url(): """ Build an SQLAlchemy for use with create_engine based on the aurweb configuration. diff --git a/test/test_db.py b/test/test_db.py index 9298c53d..d7a91813 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -273,6 +273,31 @@ def test_create_delete(): record = db.query(AccountType, AccountType.AccountType == "test").first() assert record is None + # Create and delete a record with autocommit=False. + db.create(AccountType, AccountType="test", autocommit=False) + db.commit() + db.delete(AccountType, AccountType.AccountType == "test", autocommit=False) + db.commit() + record = db.query(AccountType, AccountType.AccountType == "test").first() + assert record is None + + +def test_add_commit(): + # Use db.add and db.commit to add a temporary record. + account_type = AccountType(AccountType="test") + db.add(account_type) + db.commit() + + # Assert it got created in the DB. + assert bool(account_type.ID) + + # Query the DB for it and compare the record with our object. + record = db.query(AccountType, AccountType.AccountType == "test").first() + assert record == account_type + + # Remove the record. + db.delete(AccountType, AccountType.ID == account_type.ID) + def test_connection_executor_mysql_paramstyle(): executor = db.ConnectionExecutor(None, backend="mysql")