from __future__ import annotations from contextlib import contextmanager import pytest import ariadne.db.database as db_module from ariadne.db.database import Database class DummyResult: def __init__(self, row=None, rows=None): self._row = row self._rows = rows or [] def fetchone(self): return self._row def fetchall(self): return self._rows class DummyConn: def __init__(self): self.row_factory = None self.executed = [] def execute(self, query, params=None): self.executed.append((query, params)) if "pg_try_advisory_lock" in query: return DummyResult(row=(True,)) return DummyResult() class DummyPool: def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = DummyConn() @contextmanager def connection(self): yield self.conn def close(self): return None def test_migrate_runs(monkeypatch) -> None: monkeypatch.setattr(db_module, "ConnectionPool", DummyPool) db = Database("postgresql://user:pass@localhost/db") db.migrate(lock_id=123) assert db._pool.conn.executed def test_fetch_and_execute(monkeypatch) -> None: monkeypatch.setattr(db_module, "ConnectionPool", DummyPool) db = Database("postgresql://user:pass@localhost/db") db.execute("SELECT 1") db.fetchone("SELECT 1") db.fetchall("SELECT 1") db.close() assert db._pool.conn.executed def test_database_requires_dsn() -> None: with pytest.raises(RuntimeError): Database("") def test_migrate_handles_lock(monkeypatch) -> None: class LockConn(DummyConn): def execute(self, query, params=None): if "CREATE TABLE" in query: raise db_module.psycopg.errors.LockNotAvailable() return super().execute(query, params) class LockPool(DummyPool): def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = LockConn() monkeypatch.setattr(db_module, "ConnectionPool", LockPool) db = Database("postgresql://user:pass@localhost/db") db.migrate(lock_id=123) def test_migrate_ignores_timeout_errors(monkeypatch) -> None: class TimeoutConn(DummyConn): def execute(self, query, params=None): if query.startswith("SET lock_timeout") or query.startswith("SET statement_timeout"): raise RuntimeError("boom") return super().execute(query, params) class TimeoutPool(DummyPool): def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = TimeoutConn() monkeypatch.setattr(db_module, "ConnectionPool", TimeoutPool) db = Database("postgresql://user:pass@localhost/db") db.migrate(lock_id=123) def test_migrate_handles_lock_on_alter(monkeypatch) -> None: class LockConn(DummyConn): def execute(self, query, params=None): if "ALTER TABLE access_requests" in query: raise db_module.psycopg.errors.QueryCanceled() return super().execute(query, params) class LockPool(DummyPool): def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = LockConn() monkeypatch.setattr(db_module, "ConnectionPool", LockPool) db = Database("postgresql://user:pass@localhost/db") db.migrate(lock_id=123) def test_fetchone_and_fetchall_return_dicts(monkeypatch) -> None: class RowConn(DummyConn): def execute(self, query, params=None): if "fetchone" in query: return DummyResult(row={"id": 1}) return DummyResult(row=None, rows=[{"id": 1}, {"id": 2}]) class RowPool(DummyPool): def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = RowConn() monkeypatch.setattr(db_module, "ConnectionPool", RowPool) db = Database("postgresql://user:pass@localhost/db") assert db.fetchone("fetchone") == {"id": 1} assert db.fetchall("fetchall") == [{"id": 1}, {"id": 2}] def test_database_passes_config_to_pool(monkeypatch) -> None: captured = {} class CapturePool(DummyPool): def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): captured.update( { "conninfo": conninfo, "min_size": min_size, "max_size": max_size, "kwargs": kwargs, } ) super().__init__(conninfo=conninfo, min_size=min_size, max_size=max_size, kwargs=kwargs) monkeypatch.setattr(db_module, "ConnectionPool", CapturePool) config = db_module.DatabaseConfig( pool_min=1, pool_max=7, connect_timeout_sec=9, lock_timeout_sec=11, statement_timeout_sec=13, idle_in_tx_timeout_sec=15, application_name="ariadne-test", ) Database("postgresql://user:pass@localhost/db", config) assert captured["conninfo"] == "postgresql://user:pass@localhost/db" assert captured["min_size"] == 1 assert captured["max_size"] == 7 assert captured["kwargs"]["connect_timeout"] == 9 assert captured["kwargs"]["application_name"] == "ariadne-test" assert "lock_timeout=11s" in captured["kwargs"]["options"] def test_migrate_returns_when_advisory_lock_is_unavailable(monkeypatch) -> None: class NoLockConn(DummyConn): def execute(self, query, params=None): if "pg_try_advisory_lock" in query: return DummyResult(row=(False,)) return super().execute(query, params) class NoLockPool(DummyPool): def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = NoLockConn() monkeypatch.setattr(db_module, "ConnectionPool", NoLockPool) db = Database("postgresql://user:pass@localhost/db") db.migrate(lock_id=123) assert not any("CREATE TABLE" in query for query, _ in db._pool.conn.executed) def test_migrate_handles_missing_access_requests_table(monkeypatch) -> None: class UndefinedTableConn(DummyConn): def execute(self, query, params=None): if "ALTER TABLE access_requests" in query: raise db_module.psycopg.errors.UndefinedTable() return super().execute(query, params) class UndefinedTablePool(DummyPool): def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = UndefinedTableConn() monkeypatch.setattr(db_module, "ConnectionPool", UndefinedTablePool) db = Database("postgresql://user:pass@localhost/db") db.migrate(lock_id=123) def test_migrate_skip_flags(monkeypatch) -> None: monkeypatch.setattr(db_module, "ConnectionPool", DummyPool) db = Database("postgresql://user:pass@localhost/db") db.migrate(lock_id=123, include_ariadne_tables=False, include_access_requests=False) assert not any("CREATE TABLE" in query for query, _ in db._pool.conn.executed) assert not any("ALTER TABLE access_requests" in query for query, _ in db._pool.conn.executed) def test_unlock_swallows_errors(monkeypatch) -> None: class UnlockConn(DummyConn): def execute(self, query, params=None): if "pg_advisory_unlock" in query: raise RuntimeError("boom") return super().execute(query, params) class UnlockPool(DummyPool): def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = UnlockConn() monkeypatch.setattr(db_module, "ConnectionPool", UnlockPool) db = Database("postgresql://user:pass@localhost/db") db.migrate(lock_id=123)