diff --git a/tests/unit/test_migrate.py b/tests/unit/test_migrate.py new file mode 100644 index 0000000..9a2881c --- /dev/null +++ b/tests/unit/test_migrate.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import types + +import pytest + +from ariadne import migrate as migrate_module + + +def _settings(**overrides): + values = { + "ariadne_run_migrations": True, + "ariadne_database_url": "postgresql://ariadne", + "portal_database_url": "", + "ariadne_db_pool_min": 1, + "ariadne_db_pool_max": 4, + "ariadne_db_connect_timeout_sec": 5, + "ariadne_db_lock_timeout_sec": 7, + "ariadne_db_statement_timeout_sec": 11, + "ariadne_db_idle_in_tx_timeout_sec": 13, + } + values.update(overrides) + return types.SimpleNamespace(**values) + + +def test_build_db_uses_runtime_settings(monkeypatch) -> None: + created: list[tuple[str, migrate_module.DatabaseConfig]] = [] + + class FakeDatabase: + def __init__(self, dsn, config): + created.append((dsn, config)) + + monkeypatch.setattr(migrate_module, "settings", _settings()) + monkeypatch.setattr(migrate_module, "Database", FakeDatabase) + + db = migrate_module._build_db("postgresql://app", "app_migrate") + + assert isinstance(db, FakeDatabase) + assert created[0][0] == "postgresql://app" + assert created[0][1].pool_min == 1 + assert created[0][1].pool_max == 4 + assert created[0][1].application_name == "app_migrate" + + +def test_main_skips_when_migrations_disabled(monkeypatch) -> None: + monkeypatch.setattr(migrate_module, "settings", _settings(ariadne_run_migrations=False)) + monkeypatch.setattr(migrate_module, "_build_db", lambda *_args: (_ for _ in ()).throw(AssertionError("should not build db"))) + + migrate_module.main() + + +def test_main_runs_ariadne_and_portal_migrations(monkeypatch) -> None: + created: list[FakeDatabase] = [] + + class FakeDatabase: + def __init__(self, dsn: str, application_name: str): + self.dsn = dsn + self.application_name = application_name + self.migrations: list[tuple[int, bool, bool]] = [] + self.closed = False + created.append(self) + + def migrate(self, lock_id: int, *, include_ariadne_tables: bool, include_access_requests: bool) -> None: + self.migrations.append((lock_id, include_ariadne_tables, include_access_requests)) + + def close(self) -> None: + self.closed = True + + monkeypatch.setattr( + migrate_module, + "settings", + _settings(portal_database_url="postgresql://portal"), + ) + monkeypatch.setattr(migrate_module, "_build_db", lambda dsn, application_name: FakeDatabase(dsn, application_name)) + + migrate_module.main() + + assert [(db.dsn, db.application_name, db.closed) for db in created] == [ + ("postgresql://ariadne", "ariadne_migrate", True), + ("postgresql://portal", "ariadne_portal_migrate", True), + ] + assert created[0].migrations == [(migrate_module.ARIADNE_MIGRATION_LOCK_ID, True, False)] + assert created[1].migrations == [(migrate_module.PORTAL_MIGRATION_LOCK_ID, False, True)] + + +def test_main_closes_ariadne_db_when_migration_raises(monkeypatch) -> None: + class FakeDatabase: + closed = False + + def migrate(self, *_args, **_kwargs) -> None: + raise RuntimeError("migration failed") + + def close(self) -> None: + self.closed = True + + db = FakeDatabase() + monkeypatch.setattr(migrate_module, "settings", _settings()) + monkeypatch.setattr(migrate_module, "_build_db", lambda *_args: db) + + with pytest.raises(RuntimeError, match="migration failed"): + migrate_module.main() + + assert db.closed is True