diff --git a/test/test_db.py b/test/test_db.py index 293fb096..9ece25ea 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -44,38 +44,55 @@ class DBConnection: pass -def make_temp_config(config_file, *replacements): +def make_temp_config(*replacements): """ Generate a temporary config file with a set of replacements. :param *replacements: A variable number of tuple regex replacement pairs :return: A tuple containing (temp directory, temp config file) """ + aurwebdir = aurweb.config.get("options", "aurwebdir") + config_file = os.path.join(aurwebdir, "conf", "config.dev") + config_defaults = os.path.join(aurwebdir, "conf", "config.defaults") + + db_name = aurweb.config.get("database", "name") + db_host = aurweb.config.get_with_fallback("database", "host", "localhost") + db_port = aurweb.config.get_with_fallback("database", "port", "3306") + db_user = aurweb.config.get_with_fallback("database", "user", "root") + db_password = aurweb.config.get_with_fallback("database", "password", None) + + # Replacements to perform before *replacements. + # These serve as generic replacements in config.dev + perform = ( + (r"name = .+", f"name = {db_name}"), + (r"host = .+", f"host = {db_host}"), + (r";port = .+", f";port = {db_port}"), + (r"user = .+", f"user = {db_user}"), + (r"password = .+", f"password = {db_password}"), + ("YOUR_AUR_ROOT", aurwebdir), + ) + tmpdir = tempfile.TemporaryDirectory() tmp = os.path.join(tmpdir.name, "config.tmp") with open(config_file) as f: config = f.read() - for repl in list(replacements): + for repl in tuple(perform + replacements): config = re.sub(repl[0], repl[1], config) with open(tmp, "w") as o: o.write(config) - aurwebdir = aurweb.config.get("options", "aurwebdir") - defaults = os.path.join(aurwebdir, "conf/config.defaults") - with open(defaults) as i: + with open(config_defaults) as i: with open(f"{tmp}.defaults", "w") as o: o.write(i.read()) return tmpdir, tmp -def make_temp_sqlite_config(config_file): - return make_temp_config(config_file, - (r"backend = .*", "backend = sqlite"), +def make_temp_sqlite_config(): + return make_temp_config((r"backend = .*", "backend = sqlite"), (r"name = .*", "name = /tmp/aurweb.sqlite3")) -def make_temp_mysql_config(config_file): - return make_temp_config(config_file, - (r"backend = .*", "backend = mysql"), - (r"name = .*", "name = aurweb")) +def make_temp_mysql_config(): + return make_temp_config((r"backend = .*", "backend = mysql"), + (r"name = .*", "name = aurweb_test")) @pytest.fixture(autouse=True) @@ -91,7 +108,7 @@ def setup_db(): def test_sqlalchemy_sqlite_url(): - tmpctx, tmp = make_temp_sqlite_config("conf/config") + tmpctx, tmp = make_temp_sqlite_config() with tmpctx: with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): aurweb.config.rehash() @@ -100,7 +117,7 @@ def test_sqlalchemy_sqlite_url(): def test_sqlalchemy_mysql_url(): - tmpctx, tmp = make_temp_mysql_config("conf/config") + tmpctx, tmp = make_temp_mysql_config() with tmpctx: with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): aurweb.config.rehash() @@ -109,8 +126,7 @@ def test_sqlalchemy_mysql_url(): def test_sqlalchemy_mysql_port_url(): - tmpctx, tmp = make_temp_config("conf/config", - (r";port = 3306", "port = 3306")) + tmpctx, tmp = make_temp_config((r";port = 3306", "port = 3306")) with tmpctx: with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): @@ -120,8 +136,7 @@ def test_sqlalchemy_mysql_port_url(): def test_sqlalchemy_mysql_socket_url(): - tmpctx, tmp = make_temp_config("conf/config", - (r"[;]?port = 3306", ";port = 3306")) + tmpctx, tmp = make_temp_config() with tmpctx: with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): @@ -131,8 +146,7 @@ def test_sqlalchemy_mysql_socket_url(): def test_sqlalchemy_unknown_backend(): - tmpctx, tmp = make_temp_config("conf/config", - (r"backend = .+", "backend = blah")) + tmpctx, tmp = make_temp_config((r"backend = .+", "backend = blah")) with tmpctx: with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): @@ -149,7 +163,7 @@ def test_db_connects_without_fail(): def test_connection_class_sqlite_without_fail(): - tmpctx, tmp = make_temp_sqlite_config("conf/config") + tmpctx, tmp = make_temp_sqlite_config() with tmpctx: with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): aurweb.config.rehash() @@ -166,8 +180,7 @@ def test_connection_class_sqlite_without_fail(): def test_connection_class_unsupported_backend(): - tmpctx, tmp = make_temp_config("conf/config", - (r"backend = .+", "backend = blah")) + tmpctx, tmp = make_temp_config((r"backend = .+", "backend = blah")) with tmpctx: with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): @@ -179,12 +192,9 @@ def test_connection_class_unsupported_backend(): @mock.patch("MySQLdb.connect", mock.MagicMock(return_value=True)) def test_connection_mysql(): - tmpctx, tmp = make_temp_mysql_config("conf/config") + tmpctx, tmp = make_temp_mysql_config() with tmpctx: - with mock.patch.dict(os.environ, { - "AUR_CONFIG": tmp, - "AUR_CONFIG_DEFAULTS": "conf/config.defaults" - }): + with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): aurweb.config.rehash() db.Connection() aurweb.config.rehash() @@ -199,13 +209,10 @@ def test_connection_sqlite(): @mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection())) @mock.patch.object(sqlite3, "paramstyle", "format") def test_connection_execute_paramstyle_format(): - tmpctx, tmp = make_temp_sqlite_config("conf/config") + tmpctx, tmp = make_temp_sqlite_config() with tmpctx: - with mock.patch.dict(os.environ, { - "AUR_CONFIG": tmp, - "AUR_CONFIG_DEFAULTS": "conf/config.defaults" - }): + with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): aurweb.config.rehash() aurweb.db.kill_engine() @@ -235,13 +242,10 @@ def test_connection_execute_paramstyle_format(): @mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection())) @mock.patch.object(sqlite3, "paramstyle", "qmark") def test_connection_execute_paramstyle_qmark(): - tmpctx, tmp = make_temp_sqlite_config("conf/config") + tmpctx, tmp = make_temp_sqlite_config() with tmpctx: - with mock.patch.dict(os.environ, { - "AUR_CONFIG": tmp, - "AUR_CONFIG_DEFAULTS": "conf/config.defaults" - }): + with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): aurweb.config.rehash() aurweb.db.kill_engine() @@ -260,12 +264,9 @@ def test_connection_execute_paramstyle_qmark(): @mock.patch("sqlite3.connect", mock.MagicMock(return_value=DBConnection())) @mock.patch.object(sqlite3, "paramstyle", "unsupported") def test_connection_execute_paramstyle_unsupported(): - tmpctx, tmp = make_temp_sqlite_config("conf/config") + tmpctx, tmp = make_temp_sqlite_config() with tmpctx: - with mock.patch.dict(os.environ, { - "AUR_CONFIG": tmp, - "AUR_CONFIG_DEFAULTS": "conf/config.defaults" - }): + with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}): aurweb.config.rehash() conn = db.Connection() with pytest.raises(ValueError, match="unsupported paramstyle"):