144 lines
5.2 KiB
Python
144 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
import logging
|
|
from typing import Any, Iterable
|
|
|
|
import psycopg
|
|
from psycopg_pool import ConnectionPool
|
|
|
|
from .schema import ARIADNE_ACCESS_REQUEST_ALTER_SQL, ARIADNE_TABLES_SQL
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DatabaseConfig:
|
|
"""Connection-pool and timeout settings for a database client.
|
|
|
|
Inputs: pool sizing and timeout values supplied by application settings.
|
|
Outputs: a single immutable config object so database construction remains
|
|
explicit and easy to test.
|
|
"""
|
|
|
|
pool_min: int = 0
|
|
pool_max: int = 5
|
|
connect_timeout_sec: int = 5
|
|
lock_timeout_sec: int = 5
|
|
statement_timeout_sec: int = 30
|
|
idle_in_tx_timeout_sec: int = 10
|
|
application_name: str = "ariadne"
|
|
|
|
|
|
class Database:
|
|
"""Thin wrapper around a psycopg connection pool for Ariadne storage.
|
|
|
|
Inputs: a Postgres DSN plus optional pool and timeout configuration.
|
|
Outputs: helper methods for migrations and common query patterns while
|
|
centralizing timeout, locking, and row-format behavior.
|
|
"""
|
|
|
|
def __init__(self, dsn: str, config: DatabaseConfig | None = None) -> None:
|
|
if not dsn:
|
|
raise RuntimeError("database URL is required")
|
|
config = config or DatabaseConfig()
|
|
options = (
|
|
f"-c lock_timeout={config.lock_timeout_sec}s "
|
|
f"-c statement_timeout={config.statement_timeout_sec}s "
|
|
f"-c idle_in_transaction_session_timeout={config.idle_in_tx_timeout_sec}s"
|
|
)
|
|
self._pool = ConnectionPool(
|
|
conninfo=dsn,
|
|
min_size=config.pool_min,
|
|
max_size=config.pool_max,
|
|
kwargs={
|
|
"connect_timeout": config.connect_timeout_sec,
|
|
"application_name": config.application_name,
|
|
"options": options,
|
|
},
|
|
)
|
|
self._lock_timeout_sec = config.lock_timeout_sec
|
|
self._statement_timeout_sec = config.statement_timeout_sec
|
|
|
|
@contextmanager
|
|
def connection(self):
|
|
with self._pool.connection() as conn:
|
|
conn.row_factory = psycopg.rows.dict_row
|
|
yield conn
|
|
|
|
def _configure_timeouts(self, conn: psycopg.Connection) -> None:
|
|
try:
|
|
conn.execute(f"SET lock_timeout = '{self._lock_timeout_sec}s'")
|
|
conn.execute(f"SET statement_timeout = '{self._statement_timeout_sec}s'")
|
|
except Exception:
|
|
pass
|
|
|
|
def _try_advisory_lock(self, conn: psycopg.Connection, lock_id: int) -> bool:
|
|
row = conn.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,)).fetchone()
|
|
if isinstance(row, dict):
|
|
return bool(row.get("pg_try_advisory_lock"))
|
|
return bool(row and row[0])
|
|
|
|
def _apply_ariadne_tables(self, conn: psycopg.Connection) -> bool:
|
|
for stmt in ARIADNE_TABLES_SQL:
|
|
try:
|
|
conn.execute(stmt)
|
|
except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc:
|
|
logger.warning("schema ensure skipped due to lock timeout: %s", exc)
|
|
return False
|
|
return True
|
|
|
|
def _apply_access_requests(self, conn: psycopg.Connection) -> bool:
|
|
try:
|
|
conn.execute(ARIADNE_ACCESS_REQUEST_ALTER_SQL)
|
|
except psycopg.errors.UndefinedTable:
|
|
logger.info("access_requests table missing; skipping alter")
|
|
return True
|
|
except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc:
|
|
logger.warning("schema ensure skipped due to lock timeout: %s", exc)
|
|
return False
|
|
return True
|
|
|
|
def _unlock(self, conn: psycopg.Connection, lock_id: int) -> None:
|
|
try:
|
|
conn.execute("SELECT pg_advisory_unlock(%s)", (lock_id,))
|
|
except Exception:
|
|
pass
|
|
|
|
def migrate(
|
|
self,
|
|
lock_id: int,
|
|
*,
|
|
include_ariadne_tables: bool = True,
|
|
include_access_requests: bool = True,
|
|
) -> None:
|
|
with self.connection() as conn:
|
|
self._configure_timeouts(conn)
|
|
if not self._try_advisory_lock(conn, lock_id):
|
|
return
|
|
try:
|
|
if include_ariadne_tables and not self._apply_ariadne_tables(conn):
|
|
return
|
|
if include_access_requests and not self._apply_access_requests(conn):
|
|
return
|
|
finally:
|
|
self._unlock(conn, lock_id)
|
|
|
|
def fetchone(self, query: str, params: Iterable[Any] | None = None) -> dict[str, Any] | None:
|
|
with self.connection() as conn:
|
|
row = conn.execute(query, params or ()).fetchone()
|
|
return dict(row) if row else None
|
|
|
|
def fetchall(self, query: str, params: Iterable[Any] | None = None) -> list[dict[str, Any]]:
|
|
with self.connection() as conn:
|
|
rows = conn.execute(query, params or ()).fetchall()
|
|
return [dict(row) for row in rows]
|
|
|
|
def execute(self, query: str, params: Iterable[Any] | None = None) -> None:
|
|
with self.connection() as conn:
|
|
conn.execute(query, params or ())
|
|
|
|
def close(self) -> None:
|
|
self._pool.close()
|