From d18cfad63eeaab54fd21d386a769d85067153f8f Mon Sep 17 00:00:00 2001 From: Kevin Morris Date: Thu, 10 Jun 2021 14:18:39 -0700 Subject: [PATCH] use djangos method of wiping sqlite3 tables Django uses a reference graph to determine the order in table deletions that occur. Do the same here. This commit also adds in the `REGEXP` sqlite function, exactly how Django uses it in its reference graphing. Signed-off-by: Kevin Morris --- aurweb/db.py | 24 +++++++++++++++++++++++- aurweb/testing/__init__.py | 36 ++++++++++++++++++++++++++++++++++++ test/test_db.py | 3 +++ 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/aurweb/db.py b/aurweb/db.py index 9837c746..04c8653a 100644 --- a/aurweb/db.py +++ b/aurweb/db.py @@ -1,4 +1,8 @@ +import functools import math +import re + +from sqlalchemy import event import aurweb.config import aurweb.util @@ -129,13 +133,31 @@ def get_engine(echo: bool = False): if engine is None: connect_args = dict() - if aurweb.config.get("database", "backend") == "sqlite": + + db_backend = aurweb.config.get("database", "backend") + if db_backend == "sqlite": # check_same_thread is for a SQLite technicality # https://fastapi.tiangolo.com/tutorial/sql-databases/#note connect_args["check_same_thread"] = False + engine = create_engine(get_sqlalchemy_url(), connect_args=connect_args, echo=echo) + + if db_backend == "sqlite": + # For SQLite, we need to add some custom functions as + # they are used in the reference graph method. + def regexp(regex, item): + return bool(re.search(regex, str(item))) + + @event.listens_for(engine, "begin") + def do_begin(conn): + create_deterministic_function = functools.partial( + conn.connection.create_function, + deterministic=True + ) + create_deterministic_function("REGEXP", 2, regexp) + Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) session = Session() diff --git a/aurweb/testing/__init__.py b/aurweb/testing/__init__.py index 02c21a4c..90d46720 100644 --- a/aurweb/testing/__init__.py +++ b/aurweb/testing/__init__.py @@ -1,6 +1,28 @@ +from itertools import chain + import aurweb.db +def references_graph(table): + """ Taken from Django's sqlite3/operations.py. """ + query = """ + WITH tables AS ( + SELECT :table name + UNION + SELECT sqlite_master.name + FROM sqlite_master + JOIN tables ON (sql REGEXP :regexp_1 || tables.name || :regexp_2) + ) SELECT name FROM tables; + """ + params = { + "table": table, + "regexp_1": r'(?i)\s+references\s+("|\')?', + "regexp_2": r'("|\')?\s*\(', + } + cursor = aurweb.db.session.execute(query, params=params) + return [row[0] for row in cursor.fetchall()] + + def setup_test_db(*args): """ This function is to be used to setup a test database before using it. It takes a variable number of table strings, and for @@ -25,8 +47,22 @@ def setup_test_db(*args): aurweb.db.get_engine() tables = list(args) + + db_backend = aurweb.config.get("database", "backend") + + if db_backend != "sqlite": + aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 0") + else: + # We're using sqlite, setup tables to be deleted without violating + # foreign key constraints by graphing references. + tables = set(chain.from_iterable( + references_graph(table) for table in tables)) + for table in tables: aurweb.db.session.execute(f"DELETE FROM {table}") + if db_backend != "sqlite": + aurweb.db.session.execute("SET FOREIGN_KEY_CHECKS = 1") + # Expunge all objects from SQLAlchemy's IdentityMap. aurweb.db.session.expunge_all() diff --git a/test/test_db.py b/test/test_db.py index 3911134f..9298c53d 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -200,6 +200,9 @@ def test_connection_execute_paramstyle_format(): aurweb.db.kill_engine() aurweb.initdb.run(Args()) + # Test SQLite route of clearing tables. + setup_test_db("Users", "Bans") + conn = db.Connection() # First, test ? to %s format replacement.