bstein-dev-home/backend/atlas_portal/provisioning_tasks.py

123 lines
3.9 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
"""Task-row helpers for access request provisioning."""
import httpx
REQUIRED_PROVISION_TASKS: tuple[str, ...] = (
"keycloak_user",
"keycloak_password",
"keycloak_groups",
"mailu_app_password",
"mailu_sync",
"nextcloud_mail_sync",
"wger_account",
"firefly_account",
"vaultwarden_invite",
)
def upsert_task(conn, request_code: str, task: str, status: str, detail: str | None = None) -> None:
"""Persist the latest status for one provisioning task.
WHY: provisioning is retried across requests, so task rows need to be
idempotent and update in place rather than accumulating duplicates.
"""
conn.execute(
"""
INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at)
VALUES (%s, %s, %s, %s, NOW())
ON CONFLICT (request_code, task)
DO UPDATE SET status = EXCLUDED.status, detail = EXCLUDED.detail, updated_at = NOW()
""",
(request_code, task, status, detail),
)
def ensure_task_rows(conn, request_code: str, tasks: list[str]) -> None:
"""Create pending task rows for any provisioning work not yet tracked.
Args:
conn: Database connection with an ``execute`` method.
request_code: Access request identifier.
tasks: Task names that must exist before provisioning continues.
Returns:
None.
"""
if not tasks:
return
conn.execute(
"""
INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at)
SELECT %s, task, 'pending', NULL, NOW()
FROM UNNEST(%s::text[]) AS task
ON CONFLICT (request_code, task) DO NOTHING
""",
(request_code, tasks),
)
def safe_error_detail(exc: Exception, fallback: str) -> str:
"""Return a bounded, operator-useful detail string for task failures.
WHY: task detail is shown back through the portal UI, so upstream errors
need to be specific enough to act on without dumping unbounded responses.
"""
if isinstance(exc, RuntimeError):
msg = str(exc).strip()
if msg:
return msg
if isinstance(exc, httpx.HTTPStatusError):
detail = f"http {exc.response.status_code}"
try:
payload = exc.response.json()
msg: str | None = None
if isinstance(payload, dict):
raw = payload.get("errorMessage") or payload.get("error") or payload.get("message")
if isinstance(raw, str) and raw.strip():
msg = raw.strip()
elif isinstance(payload, str) and payload.strip():
msg = payload.strip()
if msg:
msg = " ".join(msg.split())
detail = f"{detail}: {msg[:200]}"
except Exception:
text = (exc.response.text or "").strip()
if text:
text = " ".join(text.split())
detail = f"{detail}: {text[:200]}"
return detail
if isinstance(exc, httpx.TimeoutException):
return "timeout"
return fallback
def task_statuses(conn, request_code: str) -> dict[str, str]:
"""Load current task statuses keyed by task name."""
rows = conn.execute(
"SELECT task, status FROM access_request_tasks WHERE request_code = %s",
(request_code,),
).fetchall()
output: dict[str, str] = {}
for row in rows:
task = row.get("task") if isinstance(row, dict) else None
status = row.get("status") if isinstance(row, dict) else None
if isinstance(task, str) and isinstance(status, str):
output[task] = status
return output
def all_tasks_ok(conn, request_code: str, tasks: list[str]) -> bool:
"""Return whether every required task is currently marked ``ok``."""
statuses = task_statuses(conn, request_code)
for task in tasks:
if statuses.get(task) != "ok":
return False
return True