db: move migrations to cli and cap pools
This commit is contained in:
parent
46ff29ae1c
commit
9fc1b41f73
@ -60,8 +60,26 @@ class PasswordResetRequest:
|
||||
updated_attr: str
|
||||
error_hint: str
|
||||
|
||||
portal_db = Database(settings.portal_database_url)
|
||||
ariadne_db = Database(settings.ariadne_database_url)
|
||||
portal_db = Database(
|
||||
settings.portal_database_url,
|
||||
pool_min=settings.ariadne_db_pool_min,
|
||||
pool_max=settings.ariadne_db_pool_max,
|
||||
connect_timeout_sec=settings.ariadne_db_connect_timeout_sec,
|
||||
lock_timeout_sec=settings.ariadne_db_lock_timeout_sec,
|
||||
statement_timeout_sec=settings.ariadne_db_statement_timeout_sec,
|
||||
idle_in_tx_timeout_sec=settings.ariadne_db_idle_in_tx_timeout_sec,
|
||||
application_name="ariadne_portal",
|
||||
)
|
||||
ariadne_db = Database(
|
||||
settings.ariadne_database_url,
|
||||
pool_min=settings.ariadne_db_pool_min,
|
||||
pool_max=settings.ariadne_db_pool_max,
|
||||
connect_timeout_sec=settings.ariadne_db_connect_timeout_sec,
|
||||
lock_timeout_sec=settings.ariadne_db_lock_timeout_sec,
|
||||
statement_timeout_sec=settings.ariadne_db_statement_timeout_sec,
|
||||
idle_in_tx_timeout_sec=settings.ariadne_db_idle_in_tx_timeout_sec,
|
||||
application_name="ariadne",
|
||||
)
|
||||
storage = Storage(ariadne_db, portal_db)
|
||||
provisioning = ProvisioningManager(portal_db, storage)
|
||||
scheduler = CronScheduler(storage, settings.schedule_tick_sec)
|
||||
@ -229,8 +247,6 @@ def _run_password_reset(request: PasswordResetRequest) -> JSONResponse:
|
||||
|
||||
@app.on_event("startup")
|
||||
def _startup() -> None:
|
||||
ariadne_db.ensure_schema(include_access_requests=False)
|
||||
portal_db.ensure_schema(include_ariadne_tables=False)
|
||||
provisioning.start()
|
||||
|
||||
scheduler.add_task("schedule.mailu_sync", settings.mailu_sync_cron, lambda: mailu.sync("ariadne_schedule"))
|
||||
|
||||
@ -7,16 +7,43 @@ from typing import Any, Iterable
|
||||
import psycopg
|
||||
from psycopg_pool import ConnectionPool
|
||||
|
||||
from .schema import ARIADNE_ACCESS_REQUEST_ALTER, ARIADNE_TABLES_SQL
|
||||
from .schema import ARIADNE_ACCESS_REQUEST_ALTER_SQL, ARIADNE_TABLES_SQL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, dsn: str, pool_size: int = 5) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
dsn: str,
|
||||
*,
|
||||
pool_min: int = 0,
|
||||
pool_max: int = 5,
|
||||
connect_timeout_sec: int = 5,
|
||||
lock_timeout_sec: int = 5,
|
||||
statement_timeout_sec: int = 30,
|
||||
idle_in_tx_timeout_sec: int = 10,
|
||||
application_name: str = "ariadne",
|
||||
) -> None:
|
||||
if not dsn:
|
||||
raise RuntimeError("database URL is required")
|
||||
self._pool = ConnectionPool(conninfo=dsn, max_size=pool_size)
|
||||
options = (
|
||||
f"-c lock_timeout={lock_timeout_sec}s "
|
||||
f"-c statement_timeout={statement_timeout_sec}s "
|
||||
f"-c idle_in_transaction_session_timeout={idle_in_tx_timeout_sec}s"
|
||||
)
|
||||
self._pool = ConnectionPool(
|
||||
conninfo=dsn,
|
||||
min_size=pool_min,
|
||||
max_size=pool_max,
|
||||
kwargs={
|
||||
"connect_timeout": connect_timeout_sec,
|
||||
"application_name": application_name,
|
||||
"options": options,
|
||||
},
|
||||
)
|
||||
self._lock_timeout_sec = lock_timeout_sec
|
||||
self._statement_timeout_sec = statement_timeout_sec
|
||||
|
||||
@contextmanager
|
||||
def connection(self):
|
||||
@ -24,37 +51,44 @@ class Database:
|
||||
conn.row_factory = psycopg.rows.dict_row
|
||||
yield conn
|
||||
|
||||
def ensure_schema(
|
||||
def migrate(
|
||||
self,
|
||||
lock_timeout_sec: int = 5,
|
||||
statement_timeout_sec: int = 30,
|
||||
lock_id: int,
|
||||
*,
|
||||
include_ariadne_tables: bool = True,
|
||||
include_access_requests: bool = True,
|
||||
) -> None:
|
||||
with self.connection() as conn:
|
||||
try:
|
||||
conn.execute(f"SET lock_timeout = '{lock_timeout_sec}s'")
|
||||
conn.execute(f"SET statement_timeout = '{statement_timeout_sec}s'")
|
||||
conn.execute(f"SET lock_timeout = '{self._lock_timeout_sec}s'")
|
||||
conn.execute(f"SET statement_timeout = '{self._statement_timeout_sec}s'")
|
||||
except Exception:
|
||||
pass
|
||||
if include_ariadne_tables:
|
||||
for stmt in ARIADNE_TABLES_SQL:
|
||||
row = conn.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,)).fetchone()
|
||||
locked = bool(row and row[0])
|
||||
if not locked:
|
||||
return
|
||||
try:
|
||||
if include_ariadne_tables:
|
||||
for stmt in ARIADNE_TABLES_SQL:
|
||||
try:
|
||||
conn.execute(stmt)
|
||||
except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc:
|
||||
logger.warning("schema ensure skipped due to lock timeout: %s", exc)
|
||||
return
|
||||
if include_access_requests:
|
||||
try:
|
||||
conn.execute(stmt)
|
||||
except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc:
|
||||
logger.warning("schema ensure skipped due to lock timeout: %s", exc)
|
||||
return
|
||||
if include_access_requests:
|
||||
for stmt in ARIADNE_ACCESS_REQUEST_ALTER:
|
||||
try:
|
||||
conn.execute(stmt)
|
||||
conn.execute(ARIADNE_ACCESS_REQUEST_ALTER_SQL)
|
||||
except psycopg.errors.UndefinedTable:
|
||||
logger.info("access_requests table missing; skipping alter")
|
||||
continue
|
||||
except (psycopg.errors.LockNotAvailable, psycopg.errors.QueryCanceled) as exc:
|
||||
logger.warning("schema ensure skipped due to lock timeout: %s", exc)
|
||||
return
|
||||
finally:
|
||||
try:
|
||||
conn.execute("SELECT pg_advisory_unlock(%s)", (lock_id,))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def fetchone(self, query: str, params: Iterable[Any] | None = None) -> dict[str, Any] | None:
|
||||
with self.connection() as conn:
|
||||
|
||||
@ -44,11 +44,12 @@ ARIADNE_TABLES_SQL = [
|
||||
""",
|
||||
]
|
||||
|
||||
ARIADNE_ACCESS_REQUEST_ALTER = [
|
||||
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS welcome_email_sent_at TIMESTAMPTZ",
|
||||
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS approval_flags TEXT[]",
|
||||
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS approval_note TEXT",
|
||||
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS denial_note TEXT",
|
||||
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS first_name TEXT",
|
||||
"ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS last_name TEXT",
|
||||
]
|
||||
ARIADNE_ACCESS_REQUEST_ALTER_SQL = """
|
||||
ALTER TABLE access_requests
|
||||
ADD COLUMN IF NOT EXISTS welcome_email_sent_at TIMESTAMPTZ,
|
||||
ADD COLUMN IF NOT EXISTS approval_flags TEXT[],
|
||||
ADD COLUMN IF NOT EXISTS approval_note TEXT,
|
||||
ADD COLUMN IF NOT EXISTS denial_note TEXT,
|
||||
ADD COLUMN IF NOT EXISTS first_name TEXT,
|
||||
ADD COLUMN IF NOT EXISTS last_name TEXT
|
||||
"""
|
||||
|
||||
51
ariadne/migrate.py
Normal file
51
ariadne/migrate.py
Normal file
@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .db.database import Database
|
||||
from .settings import settings
|
||||
|
||||
|
||||
PORTAL_MIGRATION_LOCK_ID = 982731
|
||||
ARIADNE_MIGRATION_LOCK_ID = 982732
|
||||
|
||||
|
||||
def _build_db(dsn: str, application_name: str) -> Database:
|
||||
return Database(
|
||||
dsn,
|
||||
pool_min=settings.ariadne_db_pool_min,
|
||||
pool_max=settings.ariadne_db_pool_max,
|
||||
connect_timeout_sec=settings.ariadne_db_connect_timeout_sec,
|
||||
lock_timeout_sec=settings.ariadne_db_lock_timeout_sec,
|
||||
statement_timeout_sec=settings.ariadne_db_statement_timeout_sec,
|
||||
idle_in_tx_timeout_sec=settings.ariadne_db_idle_in_tx_timeout_sec,
|
||||
application_name=application_name,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if not settings.ariadne_run_migrations:
|
||||
return
|
||||
|
||||
ariadne_db = _build_db(settings.ariadne_database_url, "ariadne_migrate")
|
||||
try:
|
||||
ariadne_db.migrate(
|
||||
ARIADNE_MIGRATION_LOCK_ID,
|
||||
include_ariadne_tables=True,
|
||||
include_access_requests=False,
|
||||
)
|
||||
finally:
|
||||
ariadne_db.close()
|
||||
|
||||
if settings.portal_database_url:
|
||||
portal_db = _build_db(settings.portal_database_url, "ariadne_portal_migrate")
|
||||
try:
|
||||
portal_db.migrate(
|
||||
PORTAL_MIGRATION_LOCK_ID,
|
||||
include_ariadne_tables=False,
|
||||
include_access_requests=True,
|
||||
)
|
||||
finally:
|
||||
portal_db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -38,6 +38,13 @@ class Settings:
|
||||
portal_database_url: str
|
||||
portal_public_base_url: str
|
||||
log_level: str
|
||||
ariadne_db_pool_min: int
|
||||
ariadne_db_pool_max: int
|
||||
ariadne_db_connect_timeout_sec: int
|
||||
ariadne_db_lock_timeout_sec: int
|
||||
ariadne_db_statement_timeout_sec: int
|
||||
ariadne_db_idle_in_tx_timeout_sec: int
|
||||
ariadne_run_migrations: bool
|
||||
|
||||
keycloak_url: str
|
||||
keycloak_realm: str
|
||||
@ -487,6 +494,13 @@ class Settings:
|
||||
portal_database_url=portal_db,
|
||||
portal_public_base_url=_env("PORTAL_PUBLIC_BASE_URL", "https://bstein.dev").rstrip("/"),
|
||||
log_level=_env("ARIADNE_LOG_LEVEL", "INFO"),
|
||||
ariadne_db_pool_min=_env_int("ARIADNE_DB_POOL_MIN", 0),
|
||||
ariadne_db_pool_max=_env_int("ARIADNE_DB_POOL_MAX", 5),
|
||||
ariadne_db_connect_timeout_sec=_env_int("ARIADNE_DB_CONNECT_TIMEOUT_SEC", 5),
|
||||
ariadne_db_lock_timeout_sec=_env_int("ARIADNE_DB_LOCK_TIMEOUT_SEC", 5),
|
||||
ariadne_db_statement_timeout_sec=_env_int("ARIADNE_DB_STATEMENT_TIMEOUT_SEC", 30),
|
||||
ariadne_db_idle_in_tx_timeout_sec=_env_int("ARIADNE_DB_IDLE_IN_TX_TIMEOUT_SEC", 10),
|
||||
ariadne_run_migrations=_env_bool("ARIADNE_RUN_MIGRATIONS", "false"),
|
||||
provision_poll_interval_sec=_env_float("ARIADNE_PROVISION_POLL_INTERVAL_SEC", 5.0),
|
||||
provision_retry_cooldown_sec=_env_float("ARIADNE_PROVISION_RETRY_COOLDOWN_SEC", 30.0),
|
||||
schedule_tick_sec=_env_float("ARIADNE_SCHEDULE_TICK_SEC", 5.0),
|
||||
|
||||
@ -15,8 +15,6 @@ import ariadne.app as app_module
|
||||
|
||||
def _client(monkeypatch, ctx: AuthContext) -> TestClient:
|
||||
monkeypatch.setattr(app_module.authenticator, "authenticate", lambda token: ctx)
|
||||
monkeypatch.setattr(app_module.portal_db, "ensure_schema", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(app_module.ariadne_db, "ensure_schema", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(app_module.provisioning, "start", lambda: None)
|
||||
monkeypatch.setattr(app_module.scheduler, "start", lambda: None)
|
||||
monkeypatch.setattr(app_module.provisioning, "stop", lambda: None)
|
||||
@ -37,8 +35,6 @@ def test_health_ok(monkeypatch) -> None:
|
||||
|
||||
|
||||
def test_startup_and_shutdown(monkeypatch) -> None:
|
||||
monkeypatch.setattr(app_module.portal_db, "ensure_schema", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(app_module.ariadne_db, "ensure_schema", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(app_module.provisioning, "start", lambda: None)
|
||||
monkeypatch.setattr(app_module.scheduler, "add_task", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(app_module.scheduler, "start", lambda: None)
|
||||
|
||||
@ -27,11 +27,13 @@ class DummyConn:
|
||||
|
||||
def execute(self, query, params=None):
|
||||
self.executed.append((query, params))
|
||||
if "pg_try_advisory_lock" in query:
|
||||
return DummyResult(row=(True,))
|
||||
return DummyResult()
|
||||
|
||||
|
||||
class DummyPool:
|
||||
def __init__(self, conninfo=None, max_size=None):
|
||||
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
||||
self.conn = DummyConn()
|
||||
|
||||
@contextmanager
|
||||
@ -42,10 +44,10 @@ class DummyPool:
|
||||
return None
|
||||
|
||||
|
||||
def test_ensure_schema_runs(monkeypatch) -> None:
|
||||
def test_migrate_runs(monkeypatch) -> None:
|
||||
monkeypatch.setattr(db_module, "ConnectionPool", DummyPool)
|
||||
db = Database("postgresql://user:pass@localhost/db")
|
||||
db.ensure_schema()
|
||||
db.migrate(lock_id=123)
|
||||
assert db._pool.conn.executed
|
||||
|
||||
|
||||
@ -64,7 +66,7 @@ def test_database_requires_dsn() -> None:
|
||||
Database("")
|
||||
|
||||
|
||||
def test_ensure_schema_handles_lock(monkeypatch) -> None:
|
||||
def test_migrate_handles_lock(monkeypatch) -> None:
|
||||
class LockConn(DummyConn):
|
||||
def execute(self, query, params=None):
|
||||
if "CREATE TABLE" in query:
|
||||
@ -72,15 +74,15 @@ def test_ensure_schema_handles_lock(monkeypatch) -> None:
|
||||
return super().execute(query, params)
|
||||
|
||||
class LockPool(DummyPool):
|
||||
def __init__(self, conninfo=None, max_size=None):
|
||||
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
||||
self.conn = LockConn()
|
||||
|
||||
monkeypatch.setattr(db_module, "ConnectionPool", LockPool)
|
||||
db = Database("postgresql://user:pass@localhost/db")
|
||||
db.ensure_schema()
|
||||
db.migrate(lock_id=123)
|
||||
|
||||
|
||||
def test_ensure_schema_ignores_timeout_errors(monkeypatch) -> None:
|
||||
def test_migrate_ignores_timeout_errors(monkeypatch) -> None:
|
||||
class TimeoutConn(DummyConn):
|
||||
def execute(self, query, params=None):
|
||||
if query.startswith("SET lock_timeout") or query.startswith("SET statement_timeout"):
|
||||
@ -88,28 +90,28 @@ def test_ensure_schema_ignores_timeout_errors(monkeypatch) -> None:
|
||||
return super().execute(query, params)
|
||||
|
||||
class TimeoutPool(DummyPool):
|
||||
def __init__(self, conninfo=None, max_size=None):
|
||||
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
||||
self.conn = TimeoutConn()
|
||||
|
||||
monkeypatch.setattr(db_module, "ConnectionPool", TimeoutPool)
|
||||
db = Database("postgresql://user:pass@localhost/db")
|
||||
db.ensure_schema()
|
||||
db.migrate(lock_id=123)
|
||||
|
||||
|
||||
def test_ensure_schema_handles_lock_on_alter(monkeypatch) -> None:
|
||||
def test_migrate_handles_lock_on_alter(monkeypatch) -> None:
|
||||
class LockConn(DummyConn):
|
||||
def execute(self, query, params=None):
|
||||
if query.startswith("ALTER TABLE"):
|
||||
if "ALTER TABLE access_requests" in query:
|
||||
raise db_module.psycopg.errors.QueryCanceled()
|
||||
return super().execute(query, params)
|
||||
|
||||
class LockPool(DummyPool):
|
||||
def __init__(self, conninfo=None, max_size=None):
|
||||
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
||||
self.conn = LockConn()
|
||||
|
||||
monkeypatch.setattr(db_module, "ConnectionPool", LockPool)
|
||||
db = Database("postgresql://user:pass@localhost/db")
|
||||
db.ensure_schema()
|
||||
db.migrate(lock_id=123)
|
||||
|
||||
|
||||
def test_fetchone_and_fetchall_return_dicts(monkeypatch) -> None:
|
||||
@ -120,7 +122,7 @@ def test_fetchone_and_fetchall_return_dicts(monkeypatch) -> None:
|
||||
return DummyResult(row=None, rows=[{"id": 1}, {"id": 2}])
|
||||
|
||||
class RowPool(DummyPool):
|
||||
def __init__(self, conninfo=None, max_size=None):
|
||||
def __init__(self, conninfo=None, min_size=None, max_size=None, kwargs=None):
|
||||
self.conn = RowConn()
|
||||
|
||||
monkeypatch.setattr(db_module, "ConnectionPool", RowPool)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user