db: move migrations to cli and cap pools

This commit is contained in:
Brad Stein 2026-01-22 14:11:21 -03:00
parent 46ff29ae1c
commit 9fc1b41f73
7 changed files with 163 additions and 49 deletions

View File

@ -60,8 +60,26 @@ class PasswordResetRequest:
updated_attr: str
error_hint: str
portal_db = Database(settings.portal_database_url)
ariadne_db = Database(settings.ariadne_database_url)
portal_db = Database(
settings.portal_database_url,
pool_min=settings.ariadne_db_pool_min,
pool_max=settings.ariadne_db_pool_max,
connect_timeout_sec=settings.ariadne_db_connect_timeout_sec,
lock_timeout_sec=settings.ariadne_db_lock_timeout_sec,
statement_timeout_sec=settings.ariadne_db_statement_timeout_sec,
idle_in_tx_timeout_sec=settings.ariadne_db_idle_in_tx_timeout_sec,
application_name="ariadne_portal",
)
ariadne_db = Database(
settings.ariadne_database_url,
pool_min=settings.ariadne_db_pool_min,
pool_max=settings.ariadne_db_pool_max,
connect_timeout_sec=settings.ariadne_db_connect_timeout_sec,
lock_timeout_sec=settings.ariadne_db_lock_timeout_sec,
statement_timeout_sec=settings.ariadne_db_statement_timeout_sec,
idle_in_tx_timeout_sec=settings.ariadne_db_idle_in_tx_timeout_sec,
application_name="ariadne",
)
storage = Storage(ariadne_db, portal_db)
provisioning = ProvisioningManager(portal_db, storage)
scheduler = CronScheduler(storage, settings.schedule_tick_sec)
@ -229,8 +247,6 @@ def _run_password_reset(request: PasswordResetRequest) -> JSONResponse:
@app.on_event("startup")
def _startup() -> None:
ariadne_db.ensure_schema(include_access_requests=False)
portal_db.ensure_schema(include_ariadne_tables=False)
provisioning.start()
scheduler.add_task("schedule.mailu_sync", settings.mailu_sync_cron, lambda: mailu.sync("ariadne_schedule"))

View File

@ -7,16 +7,43 @@ from typing import Any, Iterable
import psycopg
from psycopg_pool import ConnectionPool
from .schema import ARIADNE_ACCESS_REQUEST_ALTER, ARIADNE_TABLES_SQL
from .schema import ARIADNE_ACCESS_REQUEST_ALTER_SQL, ARIADNE_TABLES_SQL
logger = logging.getLogger(__name__)
class Database:
def __init__(self, dsn: str, pool_size: int = 5) -> None:
def __init__(
self,
dsn: str,
*,
pool_min: int = 0,
pool_max: int = 5,
connect_timeout_sec: int = 5,
lock_timeout_sec: int = 5,
statement_timeout_sec: int = 30,
idle_in_tx_timeout_sec: int = 10,
application_name: str = "ariadne",
) -> None:
if not dsn:
raise RuntimeError("database URL is required")
self._pool = ConnectionPool(conninfo=dsn, max_size=pool_size)
options = (
f"-c lock_timeout={lock_timeout_sec}s "
f"-c statement_timeout={statement_timeout_sec}s "
f"-c idle_in_transaction_session_timeout={idle_in_tx_timeout_sec}s"
)
self._pool = ConnectionPool(
conninfo=dsn,
min_size=pool_min,
max_size=pool_max,
kwargs={
"connect_timeout": connect_timeout_sec,
"application_name": application_name,
"options": options,
},
)
self._lock_timeout_sec = lock_timeout_sec
self._statement_timeout_sec = statement_timeout_sec
@contextmanager
def connection(self):
@ -24,37 +51,44 @@ class Database:
conn.row_factory = psycopg.rows.dict_row
yield conn
def ensure_schema(
def migrate(
self,
lock_timeout_sec: int = 5,
statement_timeout_sec: int = 30,
lock_id: int,
*,
include_ariadne_tables: bool = True,
include_access_requests: bool = True,
) -> None:
with self.connection() as conn:
try:
conn.execute(f"SET lock_timeout = '{lock_timeout_sec}s'")
conn.execute(f"SET statement_timeout = '{statement_timeout_sec}s'")
conn.execute(f"SET lock_timeout = '{self._lock_timeout_sec}s'")
conn.execute(f"SET statement_timeout = '{self._statement_timeout_sec}s'")
except Exception:
pass
if include_ariadne_tables:
for stmt in ARIADNE_TABLES_SQL:
row = conn.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,)).fetchone()
locked = bool(row and row[0])
if not locked:
return
try:
if include_ariadne_tables:
for stmt in ARIADNE_TABLES_SQL:
try:
conn.execute(stmt)
except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc:
logger.warning("schema ensure skipped due to lock timeout: %s", exc)
return
if include_access_requests:
try:
conn.execute(stmt)
except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc:
logger.warning("schema ensure skipped due to lock timeout: %s", exc)
return
if include_access_requests:
for stmt in ARIADNE_ACCESS_REQUEST_ALTER:
try:
conn.execute(stmt)
conn.execute(ARIADNE_ACCESS_REQUEST_ALTER_SQL)
except psycopg.errors.UndefinedTable:
logger.info("access_requests table missing; skipping alter")
continue
except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc:
logger.warning("schema ensure skipped due to lock timeout: %s", exc)
return
finally:
try:
conn.execute("SELECT pg_advisory_unlock(%s)", (lock_id,))
except Exception:
pass
def fetchone(self, query: str, params: Iterable[Any] | None = None) -> dict[str, Any] | None:
with self.connection() as conn:

View File

@ -44,11 +44,12 @@ ARIADNE_TABLES_SQL = [
""",
]
ARIADNE_ACCESS_REQUEST_ALTER = [
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS welcome_email_sent_at TIMESTAMPTZ",
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS approval_flags TEXT[]",
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS approval_note TEXT",
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS denial_note TEXT",
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS first_name TEXT",
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS last_name TEXT",
]
ARIADNE_ACCESS_REQUEST_ALTER_SQL = """
ALTER TABLE access_requests
ADD COLUMN IF NOT EXISTS welcome_email_sent_at TIMESTAMPTZ,
ADD COLUMN IF NOT EXISTS approval_flags TEXT[],
ADD COLUMN IF NOT EXISTS approval_note TEXT,
ADD COLUMN IF NOT EXISTS denial_note TEXT,
ADD COLUMN IF NOT EXISTS first_name TEXT,
ADD COLUMN IF NOT EXISTS last_name TEXT
"""

51
ariadne/migrate.py Normal file
View File

@ -0,0 +1,51 @@
from __future__ import annotations
from .db.database import Database
from .settings import settings
PORTAL_MIGRATION_LOCK_ID = 982731
ARIADNE_MIGRATION_LOCK_ID = 982732
def _build_db(dsn: str, application_name: str) -> Database:
return Database(
dsn,
pool_min=settings.ariadne_db_pool_min,
pool_max=settings.ariadne_db_pool_max,
connect_timeout_sec=settings.ariadne_db_connect_timeout_sec,
lock_timeout_sec=settings.ariadne_db_lock_timeout_sec,
statement_timeout_sec=settings.ariadne_db_statement_timeout_sec,
idle_in_tx_timeout_sec=settings.ariadne_db_idle_in_tx_timeout_sec,
application_name=application_name,
)
def main() -> None:
if not settings.ariadne_run_migrations:
return
ariadne_db = _build_db(settings.ariadne_database_url, "ariadne_migrate")
try:
ariadne_db.migrate(
ARIADNE_MIGRATION_LOCK_ID,
include_ariadne_tables=True,
include_access_requests=False,
)
finally:
ariadne_db.close()
if settings.portal_database_url:
portal_db = _build_db(settings.portal_database_url, "ariadne_portal_migrate")
try:
portal_db.migrate(
PORTAL_MIGRATION_LOCK_ID,
include_ariadne_tables=False,
include_access_requests=True,
)
finally:
portal_db.close()
if __name__ == "__main__":
main()

View File

@ -38,6 +38,13 @@ class Settings:
portal_database_url: str
portal_public_base_url: str
log_level: str
ariadne_db_pool_min: int
ariadne_db_pool_max: int
ariadne_db_connect_timeout_sec: int
ariadne_db_lock_timeout_sec: int
ariadne_db_statement_timeout_sec: int
ariadne_db_idle_in_tx_timeout_sec: int
ariadne_run_migrations: bool
keycloak_url: str
keycloak_realm: str
@ -487,6 +494,13 @@ class Settings:
portal_database_url=portal_db,
portal_public_base_url=_env("PORTAL_PUBLIC_BASE_URL", "https://bstein.dev").rstrip("/"),
log_level=_env("ARIADNE_LOG_LEVEL", "INFO"),
ariadne_db_pool_min=_env_int("ARIADNE_DB_POOL_MIN", 0),
ariadne_db_pool_max=_env_int("ARIADNE_DB_POOL_MAX", 5),
ariadne_db_connect_timeout_sec=_env_int("ARIADNE_DB_CONNECT_TIMEOUT_SEC", 5),
ariadne_db_lock_timeout_sec=_env_int("ARIADNE_DB_LOCK_TIMEOUT_SEC", 5),
ariadne_db_statement_timeout_sec=_env_int("ARIADNE_DB_STATEMENT_TIMEOUT_SEC", 30),
ariadne_db_idle_in_tx_timeout_sec=_env_int("ARIADNE_DB_IDLE_IN_TX_TIMEOUT_SEC", 10),
ariadne_run_migrations=_env_bool("ARIADNE_RUN_MIGRATIONS", "false"),
provision_poll_interval_sec=_env_float("ARIADNE_PROVISION_POLL_INTERVAL_SEC", 5.0),
provision_retry_cooldown_sec=_env_float("ARIADNE_PROVISION_RETRY_COOLDOWN_SEC", 30.0),
schedule_tick_sec=_env_float("ARIADNE_SCHEDULE_TICK_SEC", 5.0),

View File

@ -15,8 +15,6 @@ import ariadne.app as app_module
def _client(monkeypatch, ctx: AuthContext) -> TestClient:
monkeypatch.setattr(app_module.authenticator, "authenticate", lambda token: ctx)
monkeypatch.setattr(app_module.portal_db, "ensure_schema", lambda *args, **kwargs: None)
monkeypatch.setattr(app_module.ariadne_db, "ensure_schema", lambda *args, **kwargs: None)
monkeypatch.setattr(app_module.provisioning, "start", lambda: None)
monkeypatch.setattr(app_module.scheduler, "start", lambda: None)
monkeypatch.setattr(app_module.provisioning, "stop", lambda: None)
@ -37,8 +35,6 @@ def test_health_ok(monkeypatch) -> None:
def test_startup_and_shutdown(monkeypatch) -> None:
monkeypatch.setattr(app_module.portal_db, "ensure_schema", lambda *args, **kwargs: None)
monkeypatch.setattr(app_module.ariadne_db, "ensure_schema", lambda *args, **kwargs: None)
monkeypatch.setattr(app_module.provisioning, "start", lambda: None)
monkeypatch.setattr(app_module.scheduler, "add_task", lambda *args, **kwargs: None)
monkeypatch.setattr(app_module.scheduler, "start", lambda: None)

View File

@ -27,11 +27,13 @@ class DummyConn:
def execute(self, query, params=None):
self.executed.append((query, params))
if "pg_try_advisory_lock" in query:
return DummyResult(row=(True,))
return DummyResult()
class DummyPool:
def __init__(self, conninfo=None, max_size=None):
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
self.conn = DummyConn()
@contextmanager
@ -42,10 +44,10 @@ class DummyPool:
return None
def test_ensure_schema_runs(monkeypatch) -> None:
def test_migrate_runs(monkeypatch) -> None:
monkeypatch.setattr(db_module, "ConnectionPool", DummyPool)
db = Database("postgresql://user:pass@localhost/db")
db.ensure_schema()
db.migrate(lock_id=123)
assert db._pool.conn.executed
@ -64,7 +66,7 @@ def test_database_requires_dsn() -> None:
Database("")
def test_ensure_schema_handles_lock(monkeypatch) -> None:
def test_migrate_handles_lock(monkeypatch) -> None:
class LockConn(DummyConn):
def execute(self, query, params=None):
if "CREATE TABLE" in query:
@ -72,15 +74,15 @@ def test_ensure_schema_handles_lock(monkeypatch) -> None:
return super().execute(query, params)
class LockPool(DummyPool):
def __init__(self, conninfo=None, max_size=None):
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
self.conn = LockConn()
monkeypatch.setattr(db_module, "ConnectionPool", LockPool)
db = Database("postgresql://user:pass@localhost/db")
db.ensure_schema()
db.migrate(lock_id=123)
def test_ensure_schema_ignores_timeout_errors(monkeypatch) -> None:
def test_migrate_ignores_timeout_errors(monkeypatch) -> None:
class TimeoutConn(DummyConn):
def execute(self, query, params=None):
if query.startswith("SET lock_timeout") or query.startswith("SET statement_timeout"):
@ -88,28 +90,28 @@ def test_ensure_schema_ignores_timeout_errors(monkeypatch) -> None:
return super().execute(query, params)
class TimeoutPool(DummyPool):
def __init__(self, conninfo=None, max_size=None):
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
self.conn = TimeoutConn()
monkeypatch.setattr(db_module, "ConnectionPool", TimeoutPool)
db = Database("postgresql://user:pass@localhost/db")
db.ensure_schema()
db.migrate(lock_id=123)
def test_ensure_schema_handles_lock_on_alter(monkeypatch) -> None:
def test_migrate_handles_lock_on_alter(monkeypatch) -> None:
class LockConn(DummyConn):
def execute(self, query, params=None):
if query.startswith("ALTER TABLE"):
if "ALTER TABLE access_requests" in query:
raise db_module.psycopg.errors.QueryCanceled()
return super().execute(query, params)
class LockPool(DummyPool):
def __init__(self, conninfo=None, max_size=None):
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
self.conn = LockConn()
monkeypatch.setattr(db_module, "ConnectionPool", LockPool)
db = Database("postgresql://user:pass@localhost/db")
db.ensure_schema()
db.migrate(lock_id=123)
def test_fetchone_and_fetchall_return_dicts(monkeypatch) -> None:
@ -120,7 +122,7 @@ def test_fetchone_and_fetchall_return_dicts(monkeypatch) -> None:
return DummyResult(row=None, rows=[{"id": 1}, {"id": 2}])
class RowPool(DummyPool):
def __init__(self, conninfo=None, max_size=None):
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
self.conn = RowConn()
monkeypatch.setattr(db_module, "ConnectionPool", RowPool)