123 lines
3.9 KiB
Python
123 lines
3.9 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
"""Task-row helpers for access request provisioning."""
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
|
||
|
|
REQUIRED_PROVISION_TASKS: tuple[str, ...] = (
|
||
|
|
"keycloak_user",
|
||
|
|
"keycloak_password",
|
||
|
|
"keycloak_groups",
|
||
|
|
"mailu_app_password",
|
||
|
|
"mailu_sync",
|
||
|
|
"nextcloud_mail_sync",
|
||
|
|
"wger_account",
|
||
|
|
"firefly_account",
|
||
|
|
"vaultwarden_invite",
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def upsert_task(conn, request_code: str, task: str, status: str, detail: str | None = None) -> None:
|
||
|
|
"""Persist the latest status for one provisioning task.
|
||
|
|
|
||
|
|
WHY: provisioning is retried across requests, so task rows need to be
|
||
|
|
idempotent and update in place rather than accumulating duplicates.
|
||
|
|
"""
|
||
|
|
|
||
|
|
conn.execute(
|
||
|
|
"""
|
||
|
|
INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at)
|
||
|
|
VALUES (%s, %s, %s, %s, NOW())
|
||
|
|
ON CONFLICT (request_code, task)
|
||
|
|
DO UPDATE SET status = EXCLUDED.status, detail = EXCLUDED.detail, updated_at = NOW()
|
||
|
|
""",
|
||
|
|
(request_code, task, status, detail),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def ensure_task_rows(conn, request_code: str, tasks: list[str]) -> None:
|
||
|
|
"""Create pending task rows for any provisioning work not yet tracked.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
conn: Database connection with an ``execute`` method.
|
||
|
|
request_code: Access request identifier.
|
||
|
|
tasks: Task names that must exist before provisioning continues.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
None.
|
||
|
|
"""
|
||
|
|
|
||
|
|
if not tasks:
|
||
|
|
return
|
||
|
|
conn.execute(
|
||
|
|
"""
|
||
|
|
INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at)
|
||
|
|
SELECT %s, task, 'pending', NULL, NOW()
|
||
|
|
FROM UNNEST(%s::text[]) AS task
|
||
|
|
ON CONFLICT (request_code, task) DO NOTHING
|
||
|
|
""",
|
||
|
|
(request_code, tasks),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def safe_error_detail(exc: Exception, fallback: str) -> str:
|
||
|
|
"""Return a bounded, operator-useful detail string for task failures.
|
||
|
|
|
||
|
|
WHY: task detail is shown back through the portal UI, so upstream errors
|
||
|
|
need to be specific enough to act on without dumping unbounded responses.
|
||
|
|
"""
|
||
|
|
|
||
|
|
if isinstance(exc, RuntimeError):
|
||
|
|
msg = str(exc).strip()
|
||
|
|
if msg:
|
||
|
|
return msg
|
||
|
|
if isinstance(exc, httpx.HTTPStatusError):
|
||
|
|
detail = f"http {exc.response.status_code}"
|
||
|
|
try:
|
||
|
|
payload = exc.response.json()
|
||
|
|
msg: str | None = None
|
||
|
|
if isinstance(payload, dict):
|
||
|
|
raw = payload.get("errorMessage") or payload.get("error") or payload.get("message")
|
||
|
|
if isinstance(raw, str) and raw.strip():
|
||
|
|
msg = raw.strip()
|
||
|
|
elif isinstance(payload, str) and payload.strip():
|
||
|
|
msg = payload.strip()
|
||
|
|
if msg:
|
||
|
|
msg = " ".join(msg.split())
|
||
|
|
detail = f"{detail}: {msg[:200]}"
|
||
|
|
except Exception:
|
||
|
|
text = (exc.response.text or "").strip()
|
||
|
|
if text:
|
||
|
|
text = " ".join(text.split())
|
||
|
|
detail = f"{detail}: {text[:200]}"
|
||
|
|
return detail
|
||
|
|
if isinstance(exc, httpx.TimeoutException):
|
||
|
|
return "timeout"
|
||
|
|
return fallback
|
||
|
|
|
||
|
|
|
||
|
|
def task_statuses(conn, request_code: str) -> dict[str, str]:
|
||
|
|
"""Load current task statuses keyed by task name."""
|
||
|
|
|
||
|
|
rows = conn.execute(
|
||
|
|
"SELECT task, status FROM access_request_tasks WHERE request_code = %s",
|
||
|
|
(request_code,),
|
||
|
|
).fetchall()
|
||
|
|
output: dict[str, str] = {}
|
||
|
|
for row in rows:
|
||
|
|
task = row.get("task") if isinstance(row, dict) else None
|
||
|
|
status = row.get("status") if isinstance(row, dict) else None
|
||
|
|
if isinstance(task, str) and isinstance(status, str):
|
||
|
|
output[task] = status
|
||
|
|
return output
|
||
|
|
|
||
|
|
|
||
|
|
def all_tasks_ok(conn, request_code: str, tasks: list[str]) -> bool:
|
||
|
|
"""Return whether every required task is currently marked ``ok``."""
|
||
|
|
|
||
|
|
statuses = task_statuses(conn, request_code)
|
||
|
|
for task in tasks:
|
||
|
|
if statuses.get(task) != "ok":
|
||
|
|
return False
|
||
|
|
return True
|