diff --git a/ariadne/app.py b/ariadne/app.py index 5176e40..2413544 100644 --- a/ariadne/app.py +++ b/ariadne/app.py @@ -60,8 +60,26 @@ class PasswordResetRequest: updated_attr: str error_hint: str -portal_db = Database(settings.portal_database_url) -ariadne_db = Database(settings.ariadne_database_url) +portal_db = Database( + settings.portal_database_url, + pool_min=settings.ariadne_db_pool_min, + pool_max=settings.ariadne_db_pool_max, + connect_timeout_sec=settings.ariadne_db_connect_timeout_sec, + lock_timeout_sec=settings.ariadne_db_lock_timeout_sec, + statement_timeout_sec=settings.ariadne_db_statement_timeout_sec, + idle_in_tx_timeout_sec=settings.ariadne_db_idle_in_tx_timeout_sec, + application_name="ariadne_portal", +) +ariadne_db = Database( + settings.ariadne_database_url, + pool_min=settings.ariadne_db_pool_min, + pool_max=settings.ariadne_db_pool_max, + connect_timeout_sec=settings.ariadne_db_connect_timeout_sec, + lock_timeout_sec=settings.ariadne_db_lock_timeout_sec, + statement_timeout_sec=settings.ariadne_db_statement_timeout_sec, + idle_in_tx_timeout_sec=settings.ariadne_db_idle_in_tx_timeout_sec, + application_name="ariadne", +) storage = Storage(ariadne_db, portal_db) provisioning = ProvisioningManager(portal_db, storage) scheduler = CronScheduler(storage, settings.schedule_tick_sec) @@ -229,8 +247,6 @@ def _run_password_reset(request: PasswordResetRequest) -> JSONResponse: @app.on_event("startup") def _startup() -> None: - ariadne_db.ensure_schema(include_access_requests=False) - portal_db.ensure_schema(include_ariadne_tables=False) provisioning.start() scheduler.add_task("schedule.mailu_sync", settings.mailu_sync_cron, lambda: mailu.sync("ariadne_schedule")) diff --git a/ariadne/db/database.py b/ariadne/db/database.py index f16b3c4..581693b 100644 --- a/ariadne/db/database.py +++ b/ariadne/db/database.py @@ -7,16 +7,43 @@ from typing import Any, Iterable import psycopg from psycopg_pool import ConnectionPool -from .schema import ARIADNE_ACCESS_REQUEST_ALTER, ARIADNE_TABLES_SQL +from .schema import ARIADNE_ACCESS_REQUEST_ALTER_SQL, ARIADNE_TABLES_SQL logger = logging.getLogger(__name__) class Database: - def __init__(self, dsn: str, pool_size: int = 5) -> None: + def __init__( + self, + dsn: str, + *, + pool_min: int = 0, + pool_max: int = 5, + connect_timeout_sec: int = 5, + lock_timeout_sec: int = 5, + statement_timeout_sec: int = 30, + idle_in_tx_timeout_sec: int = 10, + application_name: str = "ariadne", + ) -> None: if not dsn: raise RuntimeError("database URL is required") - self._pool = ConnectionPool(conninfo=dsn, max_size=pool_size) + options = ( + f"-c lock_timeout={lock_timeout_sec}s " + f"-c statement_timeout={statement_timeout_sec}s " + f"-c idle_in_transaction_session_timeout={idle_in_tx_timeout_sec}s" + ) + self._pool = ConnectionPool( + conninfo=dsn, + min_size=pool_min, + max_size=pool_max, + kwargs={ + "connect_timeout": connect_timeout_sec, + "application_name": application_name, + "options": options, + }, + ) + self._lock_timeout_sec = lock_timeout_sec + self._statement_timeout_sec = statement_timeout_sec @contextmanager def connection(self): @@ -24,37 +51,44 @@ class Database: conn.row_factory = psycopg.rows.dict_row yield conn - def ensure_schema( + def migrate( self, - lock_timeout_sec: int = 5, - statement_timeout_sec: int = 30, + lock_id: int, *, include_ariadne_tables: bool = True, include_access_requests: bool = True, ) -> None: with self.connection() as conn: try: - conn.execute(f"SET lock_timeout = '{lock_timeout_sec}s'") - conn.execute(f"SET statement_timeout = '{statement_timeout_sec}s'") + conn.execute(f"SET lock_timeout = '{self._lock_timeout_sec}s'") + conn.execute(f"SET statement_timeout = '{self._statement_timeout_sec}s'") except Exception: pass - if include_ariadne_tables: - for stmt in ARIADNE_TABLES_SQL: + row = conn.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,)).fetchone() + locked = bool(row and row[0]) + if not locked: + return + try: + if include_ariadne_tables: + for stmt in ARIADNE_TABLES_SQL: + try: + conn.execute(stmt) + except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc: + logger.warning("schema ensure skipped due to lock timeout: %s", exc) + return + if include_access_requests: try: - conn.execute(stmt) - except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc: - logger.warning("schema ensure skipped due to lock timeout: %s", exc) - return - if include_access_requests: - for stmt in ARIADNE_ACCESS_REQUEST_ALTER: - try: - conn.execute(stmt) + conn.execute(ARIADNE_ACCESS_REQUEST_ALTER_SQL) except psycopg.errors.UndefinedTable: logger.info("access_requests table missing; skipping alter") - continue except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc: logger.warning("schema ensure skipped due to lock timeout: %s", exc) return + finally: + try: + conn.execute("SELECT pg_advisory_unlock(%s)", (lock_id,)) + except Exception: + pass def fetchone(self, query: str, params: Iterable[Any] | None = None) -> dict[str, Any] | None: with self.connection() as conn: diff --git a/ariadne/db/schema.py b/ariadne/db/schema.py index 14ba6f5..e3d4c15 100644 --- a/ariadne/db/schema.py +++ b/ariadne/db/schema.py @@ -44,11 +44,12 @@ ARIADNE_TABLES_SQL = [ """, ] -ARIADNE_ACCESS_REQUEST_ALTER = [ - "ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS welcome_email_sent_at TIMESTAMPTZ", - "ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS approval_flags TEXT[]", - "ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS approval_note TEXT", - "ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS denial_note TEXT", - "ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS first_name TEXT", - "ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS last_name TEXT", -] +ARIADNE_ACCESS_REQUEST_ALTER_SQL = """ +ALTER TABLE access_requests + ADD COLUMN IF NOT EXISTS welcome_email_sent_at TIMESTAMPTZ, + ADD COLUMN IF NOT EXISTS approval_flags TEXT[], + ADD COLUMN IF NOT EXISTS approval_note TEXT, + ADD COLUMN IF NOT EXISTS denial_note TEXT, + ADD COLUMN IF NOT EXISTS first_name TEXT, + ADD COLUMN IF NOT EXISTS last_name TEXT +""" diff --git a/ariadne/migrate.py b/ariadne/migrate.py new file mode 100644 index 0000000..e3f4435 --- /dev/null +++ b/ariadne/migrate.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from .db.database import Database +from .settings import settings + + +PORTAL_MIGRATION_LOCK_ID = 982731 +ARIADNE_MIGRATION_LOCK_ID = 982732 + + +def _build_db(dsn: str, application_name: str) -> Database: + return Database( + dsn, + pool_min=settings.ariadne_db_pool_min, + pool_max=settings.ariadne_db_pool_max, + connect_timeout_sec=settings.ariadne_db_connect_timeout_sec, + lock_timeout_sec=settings.ariadne_db_lock_timeout_sec, + statement_timeout_sec=settings.ariadne_db_statement_timeout_sec, + idle_in_tx_timeout_sec=settings.ariadne_db_idle_in_tx_timeout_sec, + application_name=application_name, + ) + + +def main() -> None: + if not settings.ariadne_run_migrations: + return + + ariadne_db = _build_db(settings.ariadne_database_url, "ariadne_migrate") + try: + ariadne_db.migrate( + ARIADNE_MIGRATION_LOCK_ID, + include_ariadne_tables=True, + include_access_requests=False, + ) + finally: + ariadne_db.close() + + if settings.portal_database_url: + portal_db = _build_db(settings.portal_database_url, "ariadne_portal_migrate") + try: + portal_db.migrate( + PORTAL_MIGRATION_LOCK_ID, + include_ariadne_tables=False, + include_access_requests=True, + ) + finally: + portal_db.close() + + +if __name__ == "__main__": + main() diff --git a/ariadne/settings.py b/ariadne/settings.py index f65f676..3af91e2 100644 --- a/ariadne/settings.py +++ b/ariadne/settings.py @@ -38,6 +38,13 @@ class Settings: portal_database_url: str portal_public_base_url: str log_level: str + ariadne_db_pool_min: int + ariadne_db_pool_max: int + ariadne_db_connect_timeout_sec: int + ariadne_db_lock_timeout_sec: int + ariadne_db_statement_timeout_sec: int + ariadne_db_idle_in_tx_timeout_sec: int + ariadne_run_migrations: bool keycloak_url: str keycloak_realm: str @@ -487,6 +494,13 @@ class Settings: portal_database_url=portal_db, portal_public_base_url=_env("PORTAL_PUBLIC_BASE_URL", "https://bstein.dev").rstrip("/"), log_level=_env("ARIADNE_LOG_LEVEL", "INFO"), + ariadne_db_pool_min=_env_int("ARIADNE_DB_POOL_MIN", 0), + ariadne_db_pool_max=_env_int("ARIADNE_DB_POOL_MAX", 5), + ariadne_db_connect_timeout_sec=_env_int("ARIADNE_DB_CONNECT_TIMEOUT_SEC", 5), + ariadne_db_lock_timeout_sec=_env_int("ARIADNE_DB_LOCK_TIMEOUT_SEC", 5), + ariadne_db_statement_timeout_sec=_env_int("ARIADNE_DB_STATEMENT_TIMEOUT_SEC", 30), + ariadne_db_idle_in_tx_timeout_sec=_env_int("ARIADNE_DB_IDLE_IN_TX_TIMEOUT_SEC", 10), + ariadne_run_migrations=_env_bool("ARIADNE_RUN_MIGRATIONS", "false"), provision_poll_interval_sec=_env_float("ARIADNE_PROVISION_POLL_INTERVAL_SEC", 5.0), provision_retry_cooldown_sec=_env_float("ARIADNE_PROVISION_RETRY_COOLDOWN_SEC", 30.0), schedule_tick_sec=_env_float("ARIADNE_SCHEDULE_TICK_SEC", 5.0), diff --git a/tests/test_app.py b/tests/test_app.py index 139158b..3af8e4a 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -15,8 +15,6 @@ import ariadne.app as app_module def _client(monkeypatch, ctx: AuthContext) -> TestClient: monkeypatch.setattr(app_module.authenticator, "authenticate", lambda token: ctx) - monkeypatch.setattr(app_module.portal_db, "ensure_schema", lambda *args, **kwargs: None) - monkeypatch.setattr(app_module.ariadne_db, "ensure_schema", lambda *args, **kwargs: None) monkeypatch.setattr(app_module.provisioning, "start", lambda: None) monkeypatch.setattr(app_module.scheduler, "start", lambda: None) monkeypatch.setattr(app_module.provisioning, "stop", lambda: None) @@ -37,8 +35,6 @@ def test_health_ok(monkeypatch) -> None: def test_startup_and_shutdown(monkeypatch) -> None: - monkeypatch.setattr(app_module.portal_db, "ensure_schema", lambda *args, **kwargs: None) - monkeypatch.setattr(app_module.ariadne_db, "ensure_schema", lambda *args, **kwargs: None) monkeypatch.setattr(app_module.provisioning, "start", lambda: None) monkeypatch.setattr(app_module.scheduler, "add_task", lambda *args, **kwargs: None) monkeypatch.setattr(app_module.scheduler, "start", lambda: None) diff --git a/tests/test_database.py b/tests/test_database.py index 39225ac..fd3dc08 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -27,11 +27,13 @@ class DummyConn: def execute(self, query, params=None): self.executed.append((query, params)) + if "pg_try_advisory_lock" in query: + return DummyResult(row=(True,)) return DummyResult() class DummyPool: - def __init__(self, conninfo=None, max_size=None): + def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = DummyConn() @contextmanager @@ -42,10 +44,10 @@ class DummyPool: return None -def test_ensure_schema_runs(monkeypatch) -> None: +def test_migrate_runs(monkeypatch) -> None: monkeypatch.setattr(db_module, "ConnectionPool", DummyPool) db = Database("postgresql://user:pass@localhost/db") - db.ensure_schema() + db.migrate(lock_id=123) assert db._pool.conn.executed @@ -64,7 +66,7 @@ def test_database_requires_dsn() -> None: Database("") -def test_ensure_schema_handles_lock(monkeypatch) -> None: +def test_migrate_handles_lock(monkeypatch) -> None: class LockConn(DummyConn): def execute(self, query, params=None): if "CREATE TABLE" in query: @@ -72,15 +74,15 @@ def test_ensure_schema_handles_lock(monkeypatch) -> None: return super().execute(query, params) class LockPool(DummyPool): - def __init__(self, conninfo=None, max_size=None): + def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = LockConn() monkeypatch.setattr(db_module, "ConnectionPool", LockPool) db = Database("postgresql://user:pass@localhost/db") - db.ensure_schema() + db.migrate(lock_id=123) -def test_ensure_schema_ignores_timeout_errors(monkeypatch) -> None: +def test_migrate_ignores_timeout_errors(monkeypatch) -> None: class TimeoutConn(DummyConn): def execute(self, query, params=None): if query.startswith("SET lock_timeout") or query.startswith("SET statement_timeout"): @@ -88,28 +90,28 @@ def test_ensure_schema_ignores_timeout_errors(monkeypatch) -> None: return super().execute(query, params) class TimeoutPool(DummyPool): - def __init__(self, conninfo=None, max_size=None): + def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = TimeoutConn() monkeypatch.setattr(db_module, "ConnectionPool", TimeoutPool) db = Database("postgresql://user:pass@localhost/db") - db.ensure_schema() + db.migrate(lock_id=123) -def test_ensure_schema_handles_lock_on_alter(monkeypatch) -> None: +def test_migrate_handles_lock_on_alter(monkeypatch) -> None: class LockConn(DummyConn): def execute(self, query, params=None): - if query.startswith("ALTER TABLE"): + if "ALTER TABLE access_requests" in query: raise db_module.psycopg.errors.QueryCanceled() return super().execute(query, params) class LockPool(DummyPool): - def __init__(self, conninfo=None, max_size=None): + def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = LockConn() monkeypatch.setattr(db_module, "ConnectionPool", LockPool) db = Database("postgresql://user:pass@localhost/db") - db.ensure_schema() + db.migrate(lock_id=123) def test_fetchone_and_fetchall_return_dicts(monkeypatch) -> None: @@ -120,7 +122,7 @@ def test_fetchone_and_fetchall_return_dicts(monkeypatch) -> None: return DummyResult(row=None, rows=[{"id": 1}, {"id": 2}]) class RowPool(DummyPool): - def __init__(self, conninfo=None, max_size=None): + def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None): self.conn = RowConn() monkeypatch.setattr(db_module, "ConnectionPool", RowPool)