from __future__ import annotations """Task-row helpers for access request provisioning.""" import httpx REQUIRED_PROVISION_TASKS: tuple[str, ...] = ( "keycloak_user", "keycloak_password", "keycloak_groups", "mailu_app_password", "mailu_sync", "nextcloud_mail_sync", "wger_account", "firefly_account", "vaultwarden_invite", ) def upsert_task(conn, request_code: str, task: str, status: str, detail: str | None = None) -> None: """Persist the latest status for one provisioning task. WHY: provisioning is retried across requests, so task rows need to be idempotent and update in place rather than accumulating duplicates. """ conn.execute( """ INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at) VALUES (%s, %s, %s, %s, NOW()) ON CONFLICT (request_code, task) DO UPDATE SET status = EXCLUDED.status, detail = EXCLUDED.detail, updated_at = NOW() """, (request_code, task, status, detail), ) def ensure_task_rows(conn, request_code: str, tasks: list[str]) -> None: """Create pending task rows for any provisioning work not yet tracked. Args: conn: Database connection with an ``execute`` method. request_code: Access request identifier. tasks: Task names that must exist before provisioning continues. Returns: None. """ if not tasks: return conn.execute( """ INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at) SELECT %s, task, 'pending', NULL, NOW() FROM UNNEST(%s::text[]) AS task ON CONFLICT (request_code, task) DO NOTHING """, (request_code, tasks), ) def safe_error_detail(exc: Exception, fallback: str) -> str: """Return a bounded, operator-useful detail string for task failures. WHY: task detail is shown back through the portal UI, so upstream errors need to be specific enough to act on without dumping unbounded responses. """ if isinstance(exc, RuntimeError): msg = str(exc).strip() if msg: return msg if isinstance(exc, httpx.HTTPStatusError): detail = f"http {exc.response.status_code}" try: payload = exc.response.json() msg: str | None = None if isinstance(payload, dict): raw = payload.get("errorMessage") or payload.get("error") or payload.get("message") if isinstance(raw, str) and raw.strip(): msg = raw.strip() elif isinstance(payload, str) and payload.strip(): msg = payload.strip() if msg: msg = " ".join(msg.split()) detail = f"{detail}: {msg[:200]}" except Exception: text = (exc.response.text or "").strip() if text: text = " ".join(text.split()) detail = f"{detail}: {text[:200]}" return detail if isinstance(exc, httpx.TimeoutException): return "timeout" return fallback def task_statuses(conn, request_code: str) -> dict[str, str]: """Load current task statuses keyed by task name.""" rows = conn.execute( "SELECT task, status FROM access_request_tasks WHERE request_code = %s", (request_code,), ).fetchall() output: dict[str, str] = {} for row in rows: task = row.get("task") if isinstance(row, dict) else None status = row.get("status") if isinstance(row, dict) else None if isinstance(task, str) and isinstance(status, str): output[task] = status return output def all_tasks_ok(conn, request_code: str, tasks: list[str]) -> bool: """Return whether every required task is currently marked ``ok``.""" statuses = task_statuses(conn, request_code) for task in tasks: if statuses.get(task) != "ok": return False return True