130 lines
3.8 KiB
Python
130 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
from contextlib import contextmanager
|
|
|
|
import pytest
|
|
|
|
import ariadne.db.database as db_module
|
|
from ariadne.db.database import Database
|
|
|
|
|
|
class DummyResult:
|
|
def __init__(self, row=None, rows=None):
|
|
self._row = row
|
|
self._rows = rows or []
|
|
|
|
def fetchone(self):
|
|
return self._row
|
|
|
|
def fetchall(self):
|
|
return self._rows
|
|
|
|
|
|
class DummyConn:
|
|
def __init__(self):
|
|
self.row_factory = None
|
|
self.executed = []
|
|
|
|
def execute(self, query, params=None):
|
|
self.executed.append((query, params))
|
|
return DummyResult()
|
|
|
|
|
|
class DummyPool:
|
|
def __init__(self, conninfo=None, max_size=None):
|
|
self.conn = DummyConn()
|
|
|
|
@contextmanager
|
|
def connection(self):
|
|
yield self.conn
|
|
|
|
def close(self):
|
|
return None
|
|
|
|
|
|
def test_ensure_schema_runs(monkeypatch) -> None:
|
|
monkeypatch.setattr(db_module, "ConnectionPool", DummyPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.ensure_schema()
|
|
assert db._pool.conn.executed
|
|
|
|
|
|
def test_fetch_and_execute(monkeypatch) -> None:
|
|
monkeypatch.setattr(db_module, "ConnectionPool", DummyPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.execute("SELECT 1")
|
|
db.fetchone("SELECT 1")
|
|
db.fetchall("SELECT 1")
|
|
db.close()
|
|
assert db._pool.conn.executed
|
|
|
|
|
|
def test_database_requires_dsn() -> None:
|
|
with pytest.raises(RuntimeError):
|
|
Database("")
|
|
|
|
|
|
def test_ensure_schema_handles_lock(monkeypatch) -> None:
|
|
class LockConn(DummyConn):
|
|
def execute(self, query, params=None):
|
|
if "CREATE TABLE" in query:
|
|
raise db_module.psycopg.errors.LockNotAvailable()
|
|
return super().execute(query, params)
|
|
|
|
class LockPool(DummyPool):
|
|
def __init__(self, conninfo=None, max_size=None):
|
|
self.conn = LockConn()
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", LockPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.ensure_schema()
|
|
|
|
|
|
def test_ensure_schema_ignores_timeout_errors(monkeypatch) -> None:
|
|
class TimeoutConn(DummyConn):
|
|
def execute(self, query, params=None):
|
|
if query.startswith("SET lock_timeout") or query.startswith("SET statement_timeout"):
|
|
raise RuntimeError("boom")
|
|
return super().execute(query, params)
|
|
|
|
class TimeoutPool(DummyPool):
|
|
def __init__(self, conninfo=None, max_size=None):
|
|
self.conn = TimeoutConn()
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", TimeoutPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.ensure_schema()
|
|
|
|
|
|
def test_ensure_schema_handles_lock_on_alter(monkeypatch) -> None:
|
|
class LockConn(DummyConn):
|
|
def execute(self, query, params=None):
|
|
if query.startswith("ALTER TABLE"):
|
|
raise db_module.psycopg.errors.QueryCanceled()
|
|
return super().execute(query, params)
|
|
|
|
class LockPool(DummyPool):
|
|
def __init__(self, conninfo=None, max_size=None):
|
|
self.conn = LockConn()
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", LockPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
db.ensure_schema()
|
|
|
|
|
|
def test_fetchone_and_fetchall_return_dicts(monkeypatch) -> None:
|
|
class RowConn(DummyConn):
|
|
def execute(self, query, params=None):
|
|
if "fetchone" in query:
|
|
return DummyResult(row={"id": 1})
|
|
return DummyResult(row=None, rows=[{"id": 1}, {"id": 2}])
|
|
|
|
class RowPool(DummyPool):
|
|
def __init__(self, conninfo=None, max_size=None):
|
|
self.conn = RowConn()
|
|
|
|
monkeypatch.setattr(db_module, "ConnectionPool", RowPool)
|
|
db = Database("postgresql://user:pass@localhost/db")
|
|
assert db.fetchone("fetchone") == {"id": 1}
|
|
assert db.fetchall("fetchall") == [{"id": 1}, {"id": 2}]
|