224 lines
7.6 KiB
Python
224 lines
7.6 KiB
Python
from __future__ import annotations
|
|
|
|
from contextlib import contextmanager
|
|
|
|
import pytest
|
|
|
|
import ariadne.db.database as db_module
|
|
from ariadne.db.database import Database
|
|
|
|
|
|
class DummyResult:
|
|
def __init__(self, row=None, rows=None):
|
|
self._row = row
|
|
self._rows = rows or []
|
|
|
|
def fetchone(self):
|
|
return self._row
|
|
|
|
def fetchall(self):
|
|
return self._rows
|
|
|
|
|
|
class DummyConn:
|
|
def __init__(self):
|
|
self.row_factory = None
|
|
self.executed = []
|
|
|
|
def execute(self, query, params=None):
|
|
self.executed.append((query, params))
|
|
if "pg_try_advisory_lock" in query:
|
|
return DummyResult(row=(True,))
|
|
return DummyResult()
|
|
|
|
|
|
class DummyPool:
|
|
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
|
self.conn = DummyConn()
|
|
|
|
@contextmanager
|
|
def connection(self):
|
|
yield self.conn
|
|
|
|
def close(self):
|
|
return None
|
|
|
|
|
|
def test_migrate_runs(monkeypatch) -> None:
|
|
monkeypatch.setattr(db_module, "ConnectionPool", DummyPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.migrate(lock_id=123)
|
|
assert db._pool.conn.executed
|
|
|
|
|
|
def test_fetch_and_execute(monkeypatch) -> None:
|
|
monkeypatch.setattr(db_module, "ConnectionPool", DummyPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.execute("SELECT 1")
|
|
db.fetchone("SELECT 1")
|
|
db.fetchall("SELECT 1")
|
|
db.close()
|
|
assert db._pool.conn.executed
|
|
|
|
|
|
def test_database_requires_dsn() -> None:
|
|
with pytest.raises(RuntimeError):
|
|
Database("")
|
|
|
|
|
|
def test_migrate_handles_lock(monkeypatch) -> None:
|
|
class LockConn(DummyConn):
|
|
def execute(self, query, params=None):
|
|
if "CREATE TABLE" in query:
|
|
raise db_module.psycopg.errors.LockNotAvailable()
|
|
return super().execute(query, params)
|
|
|
|
class LockPool(DummyPool):
|
|
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
|
self.conn = LockConn()
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", LockPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.migrate(lock_id=123)
|
|
|
|
|
|
def test_migrate_ignores_timeout_errors(monkeypatch) -> None:
|
|
class TimeoutConn(DummyConn):
|
|
def execute(self, query, params=None):
|
|
if query.startswith("SET lock_timeout") or query.startswith("SET statement_timeout"):
|
|
raise RuntimeError("boom")
|
|
return super().execute(query, params)
|
|
|
|
class TimeoutPool(DummyPool):
|
|
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
|
self.conn = TimeoutConn()
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", TimeoutPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.migrate(lock_id=123)
|
|
|
|
|
|
def test_migrate_handles_lock_on_alter(monkeypatch) -> None:
|
|
class LockConn(DummyConn):
|
|
def execute(self, query, params=None):
|
|
if "ALTER TABLE access_requests" in query:
|
|
raise db_module.psycopg.errors.QueryCanceled()
|
|
return super().execute(query, params)
|
|
|
|
class LockPool(DummyPool):
|
|
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
|
self.conn = LockConn()
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", LockPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.migrate(lock_id=123)
|
|
|
|
|
|
def test_fetchone_and_fetchall_return_dicts(monkeypatch) -> None:
|
|
class RowConn(DummyConn):
|
|
def execute(self, query, params=None):
|
|
if "fetchone" in query:
|
|
return DummyResult(row={"id": 1})
|
|
return DummyResult(row=None, rows=[{"id": 1}, {"id": 2}])
|
|
|
|
class RowPool(DummyPool):
|
|
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
|
self.conn = RowConn()
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", RowPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
assert db.fetchone("fetchone") == {"id": 1}
|
|
assert db.fetchall("fetchall") == [{"id": 1}, {"id": 2}]
|
|
|
|
|
|
def test_database_passes_config_to_pool(monkeypatch) -> None:
|
|
captured = {}
|
|
|
|
class CapturePool(DummyPool):
|
|
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
|
captured.update(
|
|
{
|
|
"conninfo": conninfo,
|
|
"min_size": min_size,
|
|
"max_size": max_size,
|
|
"kwargs": kwargs,
|
|
}
|
|
)
|
|
super().__init__(conninfo=conninfo, min_size=min_size, max_size=max_size, kwargs=kwargs)
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", CapturePool)
|
|
config = db_module.DatabaseConfig(
|
|
pool_min=1,
|
|
pool_max=7,
|
|
connect_timeout_sec=9,
|
|
lock_timeout_sec=11,
|
|
statement_timeout_sec=13,
|
|
idle_in_tx_timeout_sec=15,
|
|
application_name="ariadne-test",
|
|
)
|
|
Database("postgresql://user:pass@localhost/db", config)
|
|
|
|
assert captured["conninfo"] == "postgresql://user:pass@localhost/db"
|
|
assert captured["min_size"] == 1
|
|
assert captured["max_size"] == 7
|
|
assert captured["kwargs"]["connect_timeout"] == 9
|
|
assert captured["kwargs"]["application_name"] == "ariadne-test"
|
|
assert "lock_timeout=11s" in captured["kwargs"]["options"]
|
|
|
|
|
|
def test_migrate_returns_when_advisory_lock_is_unavailable(monkeypatch) -> None:
|
|
class NoLockConn(DummyConn):
|
|
def execute(self, query, params=None):
|
|
if "pg_try_advisory_lock" in query:
|
|
return DummyResult(row=(False,))
|
|
return super().execute(query, params)
|
|
|
|
class NoLockPool(DummyPool):
|
|
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
|
self.conn = NoLockConn()
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", NoLockPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.migrate(lock_id=123)
|
|
assert not any("CREATE TABLE" in query for query, _ in db._pool.conn.executed)
|
|
|
|
|
|
def test_migrate_handles_missing_access_requests_table(monkeypatch) -> None:
|
|
class UndefinedTableConn(DummyConn):
|
|
def execute(self, query, params=None):
|
|
if "ALTER TABLE access_requests" in query:
|
|
raise db_module.psycopg.errors.UndefinedTable()
|
|
return super().execute(query, params)
|
|
|
|
class UndefinedTablePool(DummyPool):
|
|
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
|
self.conn = UndefinedTableConn()
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", UndefinedTablePool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.migrate(lock_id=123)
|
|
|
|
|
|
def test_migrate_skip_flags(monkeypatch) -> None:
|
|
monkeypatch.setattr(db_module, "ConnectionPool", DummyPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.migrate(lock_id=123, include_ariadne_tables=False, include_access_requests=False)
|
|
assert not any("CREATE TABLE" in query for query, _ in db._pool.conn.executed)
|
|
assert not any("ALTER TABLE access_requests" in query for query, _ in db._pool.conn.executed)
|
|
|
|
|
|
def test_unlock_swallows_errors(monkeypatch) -> None:
|
|
class UnlockConn(DummyConn):
|
|
def execute(self, query, params=None):
|
|
if "pg_advisory_unlock" in query:
|
|
raise RuntimeError("boom")
|
|
return super().execute(query, params)
|
|
|
|
class UnlockPool(DummyPool):
|
|
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
|
self.conn = UnlockConn()
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", UnlockPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.migrate(lock_id=123)
|