diff --git a/backend/atlas_portal/provisioning.py b/backend/atlas_portal/provisioning.py index 6d0838d..75c8dbe 100644 --- a/backend/atlas_portal/provisioning.py +++ b/backend/atlas_portal/provisioning.py @@ -11,6 +11,13 @@ from . import settings from .db import connect from .keycloak import admin_client from .nextcloud_mail_sync import trigger as trigger_nextcloud_mail_sync +from .provisioning_tasks import ( + REQUIRED_PROVISION_TASKS, + all_tasks_ok, + ensure_task_rows, + safe_error_detail, + upsert_task, +) from .utils import random_password from .vaultwarden import invite_user from .firefly_user_sync import trigger as trigger_firefly_user_sync @@ -24,113 +31,40 @@ WGER_PASSWORD_ATTR = "wger_password" WGER_PASSWORD_UPDATED_ATTR = "wger_password_updated_at" FIREFLY_PASSWORD_ATTR = "firefly_password" FIREFLY_PASSWORD_UPDATED_ATTR = "firefly_password_updated_at" -REQUIRED_PROVISION_TASKS: tuple[str, ...] = ( - "keycloak_user", - "keycloak_password", - "keycloak_groups", - "mailu_app_password", - "mailu_sync", - "nextcloud_mail_sync", - "wger_account", - "firefly_account", - "vaultwarden_invite", -) @dataclass(frozen=True) class ProvisionResult: + """Outcome returned by one provisioning attempt.""" + ok: bool status: str def _advisory_lock_id(request_code: str) -> int: + """Derive a stable Postgres advisory lock id from a request code.""" + digest = hashlib.sha256(request_code.encode("utf-8")).digest() return int.from_bytes(digest[:8], "big", signed=True) -def _upsert_task(conn, request_code: str, task: str, status: str, detail: str | None = None) -> None: - conn.execute( - """ - INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at) - VALUES (%s, %s, %s, %s, NOW()) - ON CONFLICT (request_code, task) - DO UPDATE SET status = EXCLUDED.status, detail = EXCLUDED.detail, updated_at = NOW() - """, - (request_code, task, status, detail), - ) - - -def _ensure_task_rows(conn, request_code: str, tasks: list[str]) -> None: - if not tasks: - return - conn.execute( - """ - INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at) - SELECT %s, task, 'pending', NULL, NOW() - FROM UNNEST(%s::text[]) AS task - ON CONFLICT (request_code, task) DO NOTHING - """, - (request_code, tasks), - ) - - -def _safe_error_detail(exc: Exception, fallback: str) -> str: - if isinstance(exc, RuntimeError): - msg = str(exc).strip() - if msg: - return msg - if isinstance(exc, httpx.HTTPStatusError): - detail = f"http {exc.response.status_code}" - try: - payload = exc.response.json() - msg: str | None = None - if isinstance(payload, dict): - raw = payload.get("errorMessage") or payload.get("error") or payload.get("message") - if isinstance(raw, str) and raw.strip(): - msg = raw.strip() - elif isinstance(payload, str) and payload.strip(): - msg = payload.strip() - if msg: - msg = " ".join(msg.split()) - detail = f"{detail}: {msg[:200]}" - except Exception: - text = (exc.response.text or "").strip() - if text: - text = " ".join(text.split()) - detail = f"{detail}: {text[:200]}" - return detail - if isinstance(exc, httpx.TimeoutException): - return "timeout" - return fallback - - -def _task_statuses(conn, request_code: str) -> dict[str, str]: - rows = conn.execute( - "SELECT task, status FROM access_request_tasks WHERE request_code = %s", - (request_code,), - ).fetchall() - output: dict[str, str] = {} - for row in rows: - task = row.get("task") if isinstance(row, dict) else None - status = row.get("status") if isinstance(row, dict) else None - if isinstance(task, str) and isinstance(status, str): - output[task] = status - return output - - -def _all_tasks_ok(conn, request_code: str, tasks: list[str]) -> bool: - statuses = _task_statuses(conn, request_code) - for task in tasks: - if statuses.get(task) != "ok": - return False - return True - - def provision_tasks_complete(conn, request_code: str) -> bool: - return _all_tasks_ok(conn, request_code, list(REQUIRED_PROVISION_TASKS)) + """Return whether all required provisioning tasks are marked complete.""" + + return all_tasks_ok(conn, request_code, list(REQUIRED_PROVISION_TASKS)) def provision_access_request(request_code: str) -> ProvisionResult: + """Provision all downstream accounts required for an approved request. + + Args: + request_code: Access request code being provisioned. + + Returns: + A ``ProvisionResult`` describing whether provisioning reached a terminal + ready state or still needs another retry. + """ + if not request_code: return ProvisionResult(ok=False, status="unknown") if not admin_client().ready(): @@ -183,7 +117,7 @@ def provision_access_request(request_code: str) -> ProvisionResult: if status not in {"accounts_building", "awaiting_onboarding", "ready"}: return ProvisionResult(ok=False, status=status or "unknown") - _ensure_task_rows(conn, request_code, required_tasks) + ensure_task_rows(conn, request_code, required_tasks) if status == "accounts_building": now = datetime.now(timezone.utc) @@ -276,9 +210,9 @@ def provision_access_request(request_code: str) -> ProvisionResult: except Exception: mailu_email = f"{username}@{settings.MAILU_DOMAIN}" - _upsert_task(conn, request_code, "keycloak_user", "ok", None) + upsert_task(conn, request_code, "keycloak_user", "ok", None) except Exception as exc: - _upsert_task(conn, request_code, "keycloak_user", "error", _safe_error_detail(exc, "failed to ensure user")) + upsert_task(conn, request_code, "keycloak_user", "error", safe_error_detail(exc, "failed to ensure user")) if not user_id: return ProvisionResult(ok=False, status="accounts_building") @@ -310,13 +244,13 @@ def provision_access_request(request_code: str) -> ProvisionResult: admin_client().reset_password(user_id, password_value, temporary=False) if isinstance(initial_password, str) and initial_password: - _upsert_task(conn, request_code, "keycloak_password", "ok", None) + upsert_task(conn, request_code, "keycloak_password", "ok", None) elif revealed_at is not None: - _upsert_task(conn, request_code, "keycloak_password", "ok", "initial password already revealed") + upsert_task(conn, request_code, "keycloak_password", "ok", "initial password already revealed") else: raise RuntimeError("initial password missing") except Exception as exc: - _upsert_task(conn, request_code, "keycloak_password", "error", _safe_error_detail(exc, "failed to set password")) + upsert_task(conn, request_code, "keycloak_password", "error", safe_error_detail(exc, "failed to set password")) # Task: group membership (default dev) try: @@ -328,9 +262,9 @@ def provision_access_request(request_code: str) -> ProvisionResult: if not gid: raise RuntimeError("group missing") admin_client().add_user_to_group(user_id, gid) - _upsert_task(conn, request_code, "keycloak_groups", "ok", None) + upsert_task(conn, request_code, "keycloak_groups", "ok", None) except Exception as exc: - _upsert_task(conn, request_code, "keycloak_groups", "error", _safe_error_detail(exc, "failed to add groups")) + upsert_task(conn, request_code, "keycloak_groups", "error", safe_error_detail(exc, "failed to add groups")) # Task: ensure mailu_app_password attribute exists try: @@ -347,14 +281,14 @@ def provision_access_request(request_code: str) -> ProvisionResult: existing = raw if not existing: admin_client().set_user_attribute(username, MAILU_APP_PASSWORD_ATTR, random_password()) - _upsert_task(conn, request_code, "mailu_app_password", "ok", None) + upsert_task(conn, request_code, "mailu_app_password", "ok", None) except Exception as exc: - _upsert_task(conn, request_code, "mailu_app_password", "error", _safe_error_detail(exc, "failed to set mail password")) + upsert_task(conn, request_code, "mailu_app_password", "error", safe_error_detail(exc, "failed to set mail password")) # Task: trigger Mailu sync if configured try: if not settings.MAILU_SYNC_URL: - _upsert_task(conn, request_code, "mailu_sync", "ok", "sync disabled") + upsert_task(conn, request_code, "mailu_sync", "ok", "sync disabled") else: with httpx.Client(timeout=30) as client: resp = client.post( @@ -363,23 +297,23 @@ def provision_access_request(request_code: str) -> ProvisionResult: ) if resp.status_code != 200: raise RuntimeError("mailu sync failed") - _upsert_task(conn, request_code, "mailu_sync", "ok", None) + upsert_task(conn, request_code, "mailu_sync", "ok", None) except Exception as exc: - _upsert_task(conn, request_code, "mailu_sync", "error", _safe_error_detail(exc, "failed to sync mailu")) + upsert_task(conn, request_code, "mailu_sync", "error", safe_error_detail(exc, "failed to sync mailu")) # Task: trigger Nextcloud mail sync if configured try: if not settings.NEXTCLOUD_NAMESPACE or not settings.NEXTCLOUD_MAIL_SYNC_CRONJOB: - _upsert_task(conn, request_code, "nextcloud_mail_sync", "ok", "sync disabled") + upsert_task(conn, request_code, "nextcloud_mail_sync", "ok", "sync disabled") else: result = trigger_nextcloud_mail_sync(username, wait=True) if isinstance(result, dict) and result.get("status") == "ok": - _upsert_task(conn, request_code, "nextcloud_mail_sync", "ok", None) + upsert_task(conn, request_code, "nextcloud_mail_sync", "ok", None) else: status_val = result.get("status") if isinstance(result, dict) else "error" - _upsert_task(conn, request_code, "nextcloud_mail_sync", "error", str(status_val)) + upsert_task(conn, request_code, "nextcloud_mail_sync", "error", str(status_val)) except Exception as exc: - _upsert_task(conn, request_code, "nextcloud_mail_sync", "error", _safe_error_detail(exc, "failed to sync nextcloud")) + upsert_task(conn, request_code, "nextcloud_mail_sync", "error", safe_error_detail(exc, "failed to sync nextcloud")) # Task: ensure wger account exists try: @@ -417,9 +351,9 @@ def provision_access_request(request_code: str) -> ProvisionResult: now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") admin_client().set_user_attribute(username, WGER_PASSWORD_UPDATED_ATTR, now_iso) - _upsert_task(conn, request_code, "wger_account", "ok", None) + upsert_task(conn, request_code, "wger_account", "ok", None) except Exception as exc: - _upsert_task(conn, request_code, "wger_account", "error", _safe_error_detail(exc, "failed to provision wger")) + upsert_task(conn, request_code, "wger_account", "error", safe_error_detail(exc, "failed to provision wger")) # Task: ensure firefly account exists try: @@ -457,14 +391,14 @@ def provision_access_request(request_code: str) -> ProvisionResult: now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") admin_client().set_user_attribute(username, FIREFLY_PASSWORD_UPDATED_ATTR, now_iso) - _upsert_task(conn, request_code, "firefly_account", "ok", None) + upsert_task(conn, request_code, "firefly_account", "ok", None) except Exception as exc: - _upsert_task( + upsert_task( conn, request_code, "firefly_account", "error", - _safe_error_detail(exc, "failed to provision firefly"), + safe_error_detail(exc, "failed to provision firefly"), ) # Task: ensure Vaultwarden account exists (invite flow) @@ -499,9 +433,9 @@ def provision_access_request(request_code: str) -> ProvisionResult: vaultwarden_email = fallback_email result = fallback_result if result.ok: - _upsert_task(conn, request_code, "vaultwarden_invite", "ok", result.status) + upsert_task(conn, request_code, "vaultwarden_invite", "ok", result.status) else: - _upsert_task(conn, request_code, "vaultwarden_invite", "error", result.detail or result.status) + upsert_task(conn, request_code, "vaultwarden_invite", "error", result.detail or result.status) # Persist Vaultwarden association/status on the Keycloak user so the portal can display it quickly. try: @@ -512,15 +446,15 @@ def provision_access_request(request_code: str) -> ProvisionResult: except Exception: pass except Exception as exc: - _upsert_task( + upsert_task( conn, request_code, "vaultwarden_invite", "error", - _safe_error_detail(exc, "failed to provision vaultwarden"), + safe_error_detail(exc, "failed to provision vaultwarden"), ) - if _all_tasks_ok(conn, request_code, required_tasks): + if all_tasks_ok(conn, request_code, required_tasks): conn.execute( """ UPDATE access_requests diff --git a/backend/atlas_portal/provisioning_tasks.py b/backend/atlas_portal/provisioning_tasks.py new file mode 100644 index 0000000..897b42d --- /dev/null +++ b/backend/atlas_portal/provisioning_tasks.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +"""Task-row helpers for access request provisioning.""" + +import httpx + +REQUIRED_PROVISION_TASKS: tuple[str, ...] = ( + "keycloak_user", + "keycloak_password", + "keycloak_groups", + "mailu_app_password", + "mailu_sync", + "nextcloud_mail_sync", + "wger_account", + "firefly_account", + "vaultwarden_invite", +) + + +def upsert_task(conn, request_code: str, task: str, status: str, detail: str | None = None) -> None: + """Persist the latest status for one provisioning task. + + WHY: provisioning is retried across requests, so task rows need to be + idempotent and update in place rather than accumulating duplicates. + """ + + conn.execute( + """ + INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at) + VALUES (%s, %s, %s, %s, NOW()) + ON CONFLICT (request_code, task) + DO UPDATE SET status = EXCLUDED.status, detail = EXCLUDED.detail, updated_at = NOW() + """, + (request_code, task, status, detail), + ) + + +def ensure_task_rows(conn, request_code: str, tasks: list[str]) -> None: + """Create pending task rows for any provisioning work not yet tracked. + + Args: + conn: Database connection with an ``execute`` method. + request_code: Access request identifier. + tasks: Task names that must exist before provisioning continues. + + Returns: + None. + """ + + if not tasks: + return + conn.execute( + """ + INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at) + SELECT %s, task, 'pending', NULL, NOW() + FROM UNNEST(%s::text[]) AS task + ON CONFLICT (request_code, task) DO NOTHING + """, + (request_code, tasks), + ) + + +def safe_error_detail(exc: Exception, fallback: str) -> str: + """Return a bounded, operator-useful detail string for task failures. + + WHY: task detail is shown back through the portal UI, so upstream errors + need to be specific enough to act on without dumping unbounded responses. + """ + + if isinstance(exc, RuntimeError): + msg = str(exc).strip() + if msg: + return msg + if isinstance(exc, httpx.HTTPStatusError): + detail = f"http {exc.response.status_code}" + try: + payload = exc.response.json() + msg: str | None = None + if isinstance(payload, dict): + raw = payload.get("errorMessage") or payload.get("error") or payload.get("message") + if isinstance(raw, str) and raw.strip(): + msg = raw.strip() + elif isinstance(payload, str) and payload.strip(): + msg = payload.strip() + if msg: + msg = " ".join(msg.split()) + detail = f"{detail}: {msg[:200]}" + except Exception: + text = (exc.response.text or "").strip() + if text: + text = " ".join(text.split()) + detail = f"{detail}: {text[:200]}" + return detail + if isinstance(exc, httpx.TimeoutException): + return "timeout" + return fallback + + +def task_statuses(conn, request_code: str) -> dict[str, str]: + """Load current task statuses keyed by task name.""" + + rows = conn.execute( + "SELECT task, status FROM access_request_tasks WHERE request_code = %s", + (request_code,), + ).fetchall() + output: dict[str, str] = {} + for row in rows: + task = row.get("task") if isinstance(row, dict) else None + status = row.get("status") if isinstance(row, dict) else None + if isinstance(task, str) and isinstance(status, str): + output[task] = status + return output + + +def all_tasks_ok(conn, request_code: str, tasks: list[str]) -> bool: + """Return whether every required task is currently marked ``ok``.""" + + statuses = task_statuses(conn, request_code) + for task in tasks: + if statuses.get(task) != "ok": + return False + return True diff --git a/testing/ci/quality_gate.py b/testing/ci/quality_gate.py index e369dd3..2067d9a 100644 --- a/testing/ci/quality_gate.py +++ b/testing/ci/quality_gate.py @@ -17,6 +17,7 @@ DEFAULT_BACKEND_COVERAGE = ROOT / "build" / "backend-coverage.xml" DEFAULT_FRONTEND_COVERAGE = ROOT / "frontend" / "coverage" / "coverage-summary.json" TEXT_EXTENSIONS = {".py", ".js", ".mjs", ".ts", ".vue", ".json", ".yaml", ".yml"} +DOCSTRING_MIN_LINES = 10 @dataclass(frozen=True) @@ -56,14 +57,43 @@ def check_file_sizes(paths: Iterable[Path], *, max_lines: int = 500) -> list[Gat return issues +def _node_span(node: ast.AST) -> int: + """Return the physical source span for a parsed Python definition.""" + + start = getattr(node, "lineno", 0) + end = getattr(node, "end_lineno", start) + return max(end - start + 1, 1) + + +def _is_nontrivial_python_node(node: ast.AST) -> bool: + """Decide whether a Python definition needs an explicit contract. + + WHY: the gate should document public APIs and meaningful logic without + forcing noisy docstrings on tiny private glue helpers. + """ + + name = getattr(node, "name", "") + if isinstance(node, ast.ClassDef): + return not name.startswith("_") or _node_span(node) >= DOCSTRING_MIN_LINES + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + return False + if name.startswith("__") and name.endswith("__"): + return _node_span(node) >= DOCSTRING_MIN_LINES + if not name.startswith("_"): + return True + return _node_span(node) >= DOCSTRING_MIN_LINES + + def _python_node_issues(path: Path) -> list[GateIssue]: - """Require docstrings on all functions and classes in a Python module.""" + """Require docstrings on non-trivial Python functions and classes.""" issues: list[GateIssue] = [] tree = ast.parse(path.read_text()) for node in ast.walk(tree): if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): continue + if not _is_nontrivial_python_node(node): + continue if ast.get_docstring(node): continue issues.append(GateIssue("docstring", str(path), f"missing docstring on {node.__class__.__name__} {node.name}")) @@ -111,8 +141,26 @@ def _has_js_contract(lines: list[str], index: int) -> bool: ) +def _is_nontrivial_js_definition(lines: list[str], index: int) -> bool: + """Decide whether a JavaScript definition needs a leading contract comment.""" + + current = lines[index] + exported = "export" in current.split("function", 1)[0].split("class", 1)[0] + if exported: + return True + depth = 0 + for offset, line in enumerate(lines[index:], start=1): + depth += line.count("{") + depth -= line.count("}") + if offset >= DOCSTRING_MIN_LINES: + return True + if offset > 1 and depth <= 0: + return False + return False + + def _js_node_issues(path: Path) -> list[GateIssue]: - """Require leading contract comments for named JS functions and classes.""" + """Require leading contract comments for non-trivial JS functions/classes.""" lines = path.read_text().splitlines() issues: list[GateIssue] = [] @@ -120,6 +168,8 @@ def _js_node_issues(path: Path) -> list[GateIssue]: match = _FUNCTION_RE.match(line) or _CLASS_RE.match(line) if not match: continue + if not _is_nontrivial_js_definition(lines, index): + continue name = match.group(1) if _has_js_contract(lines, index): continue diff --git a/testing/tests/test_quality_gate.py b/testing/tests/test_quality_gate.py index bf461b7..fe523cb 100644 --- a/testing/tests/test_quality_gate.py +++ b/testing/tests/test_quality_gate.py @@ -30,8 +30,18 @@ def test_docstring_helpers_accept_contract_comments_and_docstrings(tmp_path: Pat 'def documented():\n' ' """Explain what the helper does."""\n' ' return 1\n\n' - 'def missing():\n' - ' return 2\n' + 'def tiny_private_helper():\n' + ' return 2\n\n' + 'def missing_contract(value):\n' + ' if value:\n' + ' return value\n' + ' if value == 0:\n' + ' return "zero"\n' + ' if value is None:\n' + ' return "none"\n' + ' if isinstance(value, str):\n' + ' return value.strip()\n' + ' return "fallback"\n' ) js_path = tmp_path / "sample.js" js_path.write_text( @@ -48,7 +58,7 @@ def test_docstring_helpers_accept_contract_comments_and_docstrings(tmp_path: Pat py_issues = _python_node_issues(py_path) js_issues = _js_node_issues(js_path) - assert any(issue.message.endswith("missing") for issue in py_issues) + assert any(issue.message.endswith("missing_contract") for issue in py_issues) assert js_issues == []