ariadne/ariadne/db/database.py

144 lines
5.2 KiB
Python

from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
import logging
from typing import Any, Iterable
import psycopg
from psycopg_pool import ConnectionPool
from .schema import ARIADNE_ACCESS_REQUEST_ALTER_SQL, ARIADNE_TABLES_SQL
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class DatabaseConfig:
"""Connection-pool and timeout settings for a database client.
Inputs: pool sizing and timeout values supplied by application settings.
Outputs: a single immutable config object so database construction remains
explicit and easy to test.
"""
pool_min: int = 0
pool_max: int = 5
connect_timeout_sec: int = 5
lock_timeout_sec: int = 5
statement_timeout_sec: int = 30
idle_in_tx_timeout_sec: int = 10
application_name: str = "ariadne"
class Database:
"""Thin wrapper around a psycopg connection pool for Ariadne storage.
Inputs: a Postgres DSN plus optional pool and timeout configuration.
Outputs: helper methods for migrations and common query patterns while
centralizing timeout, locking, and row-format behavior.
"""
def __init__(self, dsn: str, config: DatabaseConfig | None = None) -> None:
if not dsn:
raise RuntimeError("database URL is required")
config = config or DatabaseConfig()
options = (
f"-c lock_timeout={config.lock_timeout_sec}s "
f"-c statement_timeout={config.statement_timeout_sec}s "
f"-c idle_in_transaction_session_timeout={config.idle_in_tx_timeout_sec}s"
)
self._pool = ConnectionPool(
conninfo=dsn,
min_size=config.pool_min,
max_size=config.pool_max,
kwargs={
"connect_timeout": config.connect_timeout_sec,
"application_name": config.application_name,
"options": options,
},
)
self._lock_timeout_sec = config.lock_timeout_sec
self._statement_timeout_sec = config.statement_timeout_sec
@contextmanager
def connection(self):
with self._pool.connection() as conn:
conn.row_factory = psycopg.rows.dict_row
yield conn
def _configure_timeouts(self, conn: psycopg.Connection) -> None:
try:
conn.execute(f"SET lock_timeout = '{self._lock_timeout_sec}s'")
conn.execute(f"SET statement_timeout = '{self._statement_timeout_sec}s'")
except Exception:
pass
def _try_advisory_lock(self, conn: psycopg.Connection, lock_id: int) -> bool:
row = conn.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,)).fetchone()
if isinstance(row, dict):
return bool(row.get("pg_try_advisory_lock"))
return bool(row and row[0])
def _apply_ariadne_tables(self, conn: psycopg.Connection) -> bool:
for stmt in ARIADNE_TABLES_SQL:
try:
conn.execute(stmt)
except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc:
logger.warning("schema ensure skipped due to lock timeout: %s", exc)
return False
return True
def _apply_access_requests(self, conn: psycopg.Connection) -> bool:
try:
conn.execute(ARIADNE_ACCESS_REQUEST_ALTER_SQL)
except psycopg.errors.UndefinedTable:
logger.info("access_requests table missing; skipping alter")
return True
except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc:
logger.warning("schema ensure skipped due to lock timeout: %s", exc)
return False
return True
def _unlock(self, conn: psycopg.Connection, lock_id: int) -> None:
try:
conn.execute("SELECT pg_advisory_unlock(%s)", (lock_id,))
except Exception:
pass
def migrate(
self,
lock_id: int,
*,
include_ariadne_tables: bool = True,
include_access_requests: bool = True,
) -> None:
with self.connection() as conn:
self._configure_timeouts(conn)
if not self._try_advisory_lock(conn, lock_id):
return
try:
if include_ariadne_tables and not self._apply_ariadne_tables(conn):
return
if include_access_requests and not self._apply_access_requests(conn):
return
finally:
self._unlock(conn, lock_id)
def fetchone(self, query: str, params: Iterable[Any] | None = None) -> dict[str, Any] | None:
with self.connection() as conn:
row = conn.execute(query, params or ()).fetchone()
return dict(row) if row else None
def fetchall(self, query: str, params: Iterable[Any] | None = None) -> list[dict[str, Any]]:
with self.connection() as conn:
rows = conn.execute(query, params or ()).fetchall()
return [dict(row) for row in rows]
def execute(self, query: str, params: Iterable[Any] | None = None) -> None:
with self.connection() as conn:
conn.execute(query, params or ())
def close(self) -> None:
self._pool.close()