From 89301619260235ba6581bbe00a615d14c281aee9 Mon Sep 17 00:00:00 2001 From: Brad Stein Date: Wed, 21 Jan 2026 03:38:33 -0300 Subject: [PATCH] db: split portal and ariadne connections --- ariadne/app.py | 17 ++++++++++------- ariadne/db/database.py | 43 +++++++++++++++++++++++++----------------- ariadne/settings.py | 7 ++++++- tests/test_app.py | 36 +++++++++++++++++++++++------------ 4 files changed, 66 insertions(+), 37 deletions(-) diff --git a/ariadne/app.py b/ariadne/app.py index eb06450..193b866 100644 --- a/ariadne/app.py +++ b/ariadne/app.py @@ -60,9 +60,10 @@ class PasswordResetRequest: updated_attr: str error_hint: str -db = Database(settings.portal_database_url) -storage = Storage(db) -provisioning = ProvisioningManager(db, storage) +portal_db = Database(settings.portal_database_url) +ariadne_db = Database(settings.ariadne_database_url) +storage = Storage(ariadne_db) +provisioning = ProvisioningManager(portal_db, storage) scheduler = CronScheduler(storage, settings.schedule_tick_sec) @@ -228,7 +229,8 @@ def _run_password_reset(request: PasswordResetRequest) -> JSONResponse: @app.on_event("startup") def _startup() -> None: - db.ensure_schema() + ariadne_db.ensure_schema(include_access_requests=False) + portal_db.ensure_schema(include_ariadne_tables=False) provisioning.start() scheduler.add_task("schedule.mailu_sync", settings.mailu_sync_cron, lambda: mailu.sync("ariadne_schedule")) @@ -346,7 +348,8 @@ def _startup() -> None: def _shutdown() -> None: scheduler.stop() provisioning.stop() - db.close() + portal_db.close() + ariadne_db.close() logger.info("ariadne stopped", extra={"event": "shutdown"}) @@ -474,7 +477,7 @@ async def approve_access_request( decided_by = ctx.username or "" try: - row = db.fetchone( + row = portal_db.fetchone( """ UPDATE access_requests SET status = 'approved', @@ -550,7 +553,7 @@ async def deny_access_request( decided_by = ctx.username or "" try: - row = db.fetchone( + row = portal_db.fetchone( """ UPDATE access_requests SET status = 'denied', diff --git a/ariadne/db/database.py b/ariadne/db/database.py index e601274..f16b3c4 100644 --- a/ariadne/db/database.py +++ b/ariadne/db/database.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) class Database: def __init__(self, dsn: str, pool_size: int = 5) -> None: if not dsn: - raise RuntimeError("PORTAL_DATABASE_URL is required") + raise RuntimeError("database URL is required") self._pool = ConnectionPool(conninfo=dsn, max_size=pool_size) @contextmanager @@ -24,28 +24,37 @@ class Database: conn.row_factory = psycopg.rows.dict_row yield conn - def ensure_schema(self, lock_timeout_sec: int = 5, statement_timeout_sec: int = 30) -> None: + def ensure_schema( + self, + lock_timeout_sec: int = 5, + statement_timeout_sec: int = 30, + *, + include_ariadne_tables: bool = True, + include_access_requests: bool = True, + ) -> None: with self.connection() as conn: try: conn.execute(f"SET lock_timeout = '{lock_timeout_sec}s'") conn.execute(f"SET statement_timeout = '{statement_timeout_sec}s'") except Exception: pass - for stmt in ARIADNE_TABLES_SQL: - try: - conn.execute(stmt) - except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc: - logger.warning("schema ensure skipped due to lock timeout: %s", exc) - return - for stmt in ARIADNE_ACCESS_REQUEST_ALTER: - try: - conn.execute(stmt) - except psycopg.errors.UndefinedTable: - logger.info("access_requests table missing; skipping alter") - continue - except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc: - logger.warning("schema ensure skipped due to lock timeout: %s", exc) - return + if include_ariadne_tables: + for stmt in ARIADNE_TABLES_SQL: + try: + conn.execute(stmt) + except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc: + logger.warning("schema ensure skipped due to lock timeout: %s", exc) + return + if include_access_requests: + for stmt in ARIADNE_ACCESS_REQUEST_ALTER: + try: + conn.execute(stmt) + except psycopg.errors.UndefinedTable: + logger.info("access_requests table missing; skipping alter") + continue + except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc: + logger.warning("schema ensure skipped due to lock timeout: %s", exc) + return def fetchone(self, query: str, params: Iterable[Any] | None = None) -> dict[str, Any] | None: with self.connection() as conn: diff --git a/ariadne/settings.py b/ariadne/settings.py index 814f4a1..f0ee0ae 100644 --- a/ariadne/settings.py +++ b/ariadne/settings.py @@ -34,6 +34,7 @@ class Settings: app_name: str bind_host: str bind_port: int + ariadne_database_url: str portal_database_url: str portal_public_base_url: str log_level: str @@ -473,11 +474,15 @@ class Settings: schedule_cfg = cls._schedule_config() opensearch_cfg = cls._opensearch_config() + portal_db = _env("PORTAL_DATABASE_URL", "") + ariadne_db = _env("ARIADNE_DATABASE_URL", portal_db) + return cls( app_name=_env("ARIADNE_APP_NAME", "ariadne"), bind_host=_env("ARIADNE_BIND_HOST", "0.0.0.0"), bind_port=_env_int("ARIADNE_BIND_PORT", 8080), - portal_database_url=_env("ARIADNE_DATABASE_URL", _env("PORTAL_DATABASE_URL", "")), + ariadne_database_url=ariadne_db, + portal_database_url=portal_db, portal_public_base_url=_env("PORTAL_PUBLIC_BASE_URL", "https://bstein.dev").rstrip("/"), log_level=_env("ARIADNE_LOG_LEVEL", "INFO"), provision_poll_interval_sec=_env_float("ARIADNE_PROVISION_POLL_INTERVAL_SEC", 5.0), diff --git a/tests/test_app.py b/tests/test_app.py index 91174bc..139158b 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -15,12 +15,14 @@ import ariadne.app as app_module def _client(monkeypatch, ctx: AuthContext) -> TestClient: monkeypatch.setattr(app_module.authenticator, "authenticate", lambda token: ctx) - monkeypatch.setattr(app_module.db, "ensure_schema", lambda: None) + monkeypatch.setattr(app_module.portal_db, "ensure_schema", lambda *args, **kwargs: None) + monkeypatch.setattr(app_module.ariadne_db, "ensure_schema", lambda *args, **kwargs: None) monkeypatch.setattr(app_module.provisioning, "start", lambda: None) monkeypatch.setattr(app_module.scheduler, "start", lambda: None) monkeypatch.setattr(app_module.provisioning, "stop", lambda: None) monkeypatch.setattr(app_module.scheduler, "stop", lambda: None) - monkeypatch.setattr(app_module.db, "close", lambda: None) + monkeypatch.setattr(app_module.portal_db, "close", lambda: None) + monkeypatch.setattr(app_module.ariadne_db, "close", lambda: None) monkeypatch.setattr(app_module.storage, "record_event", lambda *args, **kwargs: None) monkeypatch.setattr(app_module.storage, "record_task_run", lambda *args, **kwargs: None) return TestClient(app_module.app) @@ -35,13 +37,15 @@ def test_health_ok(monkeypatch) -> None: def test_startup_and_shutdown(monkeypatch) -> None: - monkeypatch.setattr(app_module.db, "ensure_schema", lambda: None) + monkeypatch.setattr(app_module.portal_db, "ensure_schema", lambda *args, **kwargs: None) + monkeypatch.setattr(app_module.ariadne_db, "ensure_schema", lambda *args, **kwargs: None) monkeypatch.setattr(app_module.provisioning, "start", lambda: None) monkeypatch.setattr(app_module.scheduler, "add_task", lambda *args, **kwargs: None) monkeypatch.setattr(app_module.scheduler, "start", lambda: None) monkeypatch.setattr(app_module.scheduler, "stop", lambda: None) monkeypatch.setattr(app_module.provisioning, "stop", lambda: None) - monkeypatch.setattr(app_module.db, "close", lambda: None) + monkeypatch.setattr(app_module.portal_db, "close", lambda: None) + monkeypatch.setattr(app_module.ariadne_db, "close", lambda: None) app_module._startup() app_module._shutdown() @@ -290,7 +294,7 @@ def test_access_request_approve(monkeypatch) -> None: captured["flags"] = params[1] return {"request_code": "REQ1"} - monkeypatch.setattr(app_module.db, "fetchone", fake_fetchone) + monkeypatch.setattr(app_module.portal_db, "fetchone", fake_fetchone) monkeypatch.setattr(app_module.provisioning, "provision_access_request", lambda code: None) monkeypatch.setattr(app_module.keycloak_admin, "ready", lambda: True) monkeypatch.setattr(app_module.keycloak_admin, "list_group_names", lambda **kwargs: ["demo"]) @@ -308,7 +312,7 @@ def test_access_request_approve(monkeypatch) -> None: def test_access_request_approve_bad_json(monkeypatch) -> None: ctx = AuthContext(username="bstein", email="", groups=["admin"], claims={}) client = _client(monkeypatch, ctx) - monkeypatch.setattr(app_module.db, "fetchone", lambda *args, **kwargs: {"request_code": "REQ1"}) + monkeypatch.setattr(app_module.portal_db, "fetchone", lambda *args, **kwargs: {"request_code": "REQ1"}) resp = client.post( "/api/admin/access/requests/alice/approve", @@ -321,7 +325,11 @@ def test_access_request_approve_bad_json(monkeypatch) -> None: def test_access_request_approve_db_error(monkeypatch) -> None: ctx = AuthContext(username="bstein", email="", groups=["admin"], claims={}) client = _client(monkeypatch, ctx) - monkeypatch.setattr(app_module.db, "fetchone", lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("fail"))) + monkeypatch.setattr( + app_module.portal_db, + "fetchone", + lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("fail")), + ) resp = client.post( "/api/admin/access/requests/alice/approve", @@ -335,7 +343,7 @@ def test_access_request_approve_skipped(monkeypatch) -> None: ctx = AuthContext(username="bstein", email="", groups=["admin"], claims={}) client = _client(monkeypatch, ctx) - monkeypatch.setattr(app_module.db, "fetchone", lambda *args, **kwargs: None) + monkeypatch.setattr(app_module.portal_db, "fetchone", lambda *args, **kwargs: None) resp = client.post( "/api/admin/access/requests/alice/approve", @@ -350,7 +358,7 @@ def test_access_request_deny(monkeypatch) -> None: ctx = AuthContext(username="bstein", email="", groups=["admin"], claims={}) client = _client(monkeypatch, ctx) - monkeypatch.setattr(app_module.db, "fetchone", lambda *args, **kwargs: {"request_code": "REQ2"}) + monkeypatch.setattr(app_module.portal_db, "fetchone", lambda *args, **kwargs: {"request_code": "REQ2"}) resp = client.post( "/api/admin/access/requests/alice/deny", @@ -364,7 +372,11 @@ def test_access_request_deny(monkeypatch) -> None: def test_access_request_deny_db_error(monkeypatch) -> None: ctx = AuthContext(username="bstein", email="", groups=["admin"], claims={}) client = _client(monkeypatch, ctx) - monkeypatch.setattr(app_module.db, "fetchone", lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("fail"))) + monkeypatch.setattr( + app_module.portal_db, + "fetchone", + lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("fail")), + ) resp = client.post( "/api/admin/access/requests/alice/deny", @@ -377,7 +389,7 @@ def test_access_request_deny_db_error(monkeypatch) -> None: def test_access_request_deny_skipped(monkeypatch) -> None: ctx = AuthContext(username="bstein", email="", groups=["admin"], claims={}) client = _client(monkeypatch, ctx) - monkeypatch.setattr(app_module.db, "fetchone", lambda *args, **kwargs: None) + monkeypatch.setattr(app_module.portal_db, "fetchone", lambda *args, **kwargs: None) resp = client.post( "/api/admin/access/requests/alice/deny", @@ -504,7 +516,7 @@ def test_require_account_access_allows_when_disabled(monkeypatch) -> None: def test_access_request_deny_bad_json(monkeypatch) -> None: ctx = AuthContext(username="bstein", email="", groups=["admin"], claims={}) client = _client(monkeypatch, ctx) - monkeypatch.setattr(app_module.db, "fetchone", lambda *args, **kwargs: {"request_code": "REQ2"}) + monkeypatch.setattr(app_module.portal_db, "fetchone", lambda *args, **kwargs: {"request_code": "REQ2"}) resp = client.post( "/api/admin/access/requests/alice/deny",