From 03bf6f7d9b5c02dec17461978838adf4f776496f Mon Sep 17 00:00:00 2001 From: Brad Stein Date: Sat, 24 Jan 2026 07:12:11 -0300 Subject: [PATCH] feat: add retryable provisioning retries --- ariadne/app.py | 56 ++++++++++++++++++++++++++ ariadne/manager/provisioning.py | 71 ++++++++++++++++++++++++++++----- tests/test_app.py | 39 ++++++++++++++++++ tests/test_provisioning.py | 9 +++++ 4 files changed, 164 insertions(+), 11 deletions(-) diff --git a/ariadne/app.py b/ariadne/app.py index 04f1df1..94c36a6 100644 --- a/ariadne/app.py +++ b/ariadne/app.py @@ -628,6 +628,62 @@ async def deny_access_request( return JSONResponse({"ok": True, "request_code": row.get("request_code")}) +@app.post("/api/access/requests/{request_code}/retry") +def retry_access_request(request_code: str) -> JSONResponse: + code = (request_code or "").strip() + if not code: + raise HTTPException(status_code=400, detail="request_code is required") + if not keycloak_admin.ready(): + raise HTTPException(status_code=503, detail="server not configured") + + try: + row = portal_db.fetchone( + "SELECT status FROM access_requests WHERE request_code = %s", + (code,), + ) + except Exception: + raise HTTPException(status_code=502, detail="failed to load request") + + if not row: + raise HTTPException(status_code=404, detail="not found") + + status = (row.get("status") or "").strip() + if status not in {"accounts_building", "approved"}: + raise HTTPException(status_code=409, detail="request not retryable") + + try: + portal_db.execute( + "UPDATE access_requests SET provision_attempted_at = NULL WHERE request_code = %s", + (code,), + ) + portal_db.execute( + """ + UPDATE access_request_tasks + SET status = 'pending', + detail = 'retry requested', + updated_at = NOW() + WHERE request_code = %s AND status = 'error' + """, + (code,), + ) + except Exception: + raise HTTPException(status_code=502, detail="failed to update retry state") + + threading.Thread( + target=provisioning.provision_access_request, + args=(code,), + daemon=True, + ).start() + _record_event( + "access_request_retry", + { + "request_code": code, + "status": "ok", + }, + ) + return JSONResponse({"ok": True, "request_code": code}) + + @app.post("/api/account/mailu/rotate") def rotate_mailu_password(ctx: AuthContext = Depends(_require_auth)) -> JSONResponse: _require_account_access(ctx) diff --git a/ariadne/manager/provisioning.py b/ariadne/manager/provisioning.py index 3ab35c7..38ead9d 100644 --- a/ariadne/manager/provisioning.py +++ b/ariadne/manager/provisioning.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timedelta, timezone import hashlib +import re import threading import time from typing import Any @@ -31,6 +32,21 @@ WGER_PASSWORD_UPDATED_ATTR = "wger_password_updated_at" FIREFLY_PASSWORD_ATTR = "firefly_password" FIREFLY_PASSWORD_UPDATED_ATTR = "firefly_password_updated_at" VAULTWARDEN_GRANDFATHERED_FLAG = "vaultwarden_grandfathered" +_RETRYABLE_HTTP_CODES = {429, 500, 502, 503, 504} +_RETRYABLE_TOKENS = ( + "timeout", + "temporar", + "rate limited", + "mailbox not ready", + "connection refused", + "connection reset", + "network is unreachable", + "dns", + "name resolution", + "service unavailable", + "bad gateway", + "gateway timeout", +) logger = get_logger(__name__) @@ -457,6 +473,39 @@ class ProvisioningManager: self._upsert_task(conn, request_code, task, "pending", detail) self._record_task(request_code, task, "pending", detail, started) + def _is_retryable_detail(self, detail: str) -> bool: + if not detail: + return False + detail_lower = detail.lower() + match = re.match(r"^http\s+(\d{3})", detail_lower) + if match: + try: + code = int(match.group(1)) + except ValueError: + code = 0 + if code in _RETRYABLE_HTTP_CODES: + return True + return any(token in detail_lower for token in _RETRYABLE_TOKENS) + + def _retryable_detail(self, detail: str) -> str: + cleaned = detail.strip() if isinstance(detail, str) else "" + if not cleaned: + return "retryable: temporary failure" + return f"retryable: {cleaned}" + + def _task_fail( + self, + conn, + request_code: str, + task: str, + detail: str, + started: datetime, + ) -> None: + if self._is_retryable_detail(detail): + self._task_pending(conn, request_code, task, self._retryable_detail(detail), started) + return + self._task_error(conn, request_code, task, detail, started) + def _vaultwarden_rate_limit_detail(self) -> tuple[str, datetime]: retry_at = datetime.now(timezone.utc) + timedelta( seconds=float(settings.vaultwarden_admin_rate_limit_backoff_sec) @@ -643,7 +692,7 @@ class ProvisioningManager: return True except Exception as exc: detail = safe_error_detail(exc, "failed to ensure user") - self._task_error(conn, ctx.request_code, "keycloak_user", detail, start) + self._task_fail(conn, ctx.request_code, "keycloak_user", detail, start) return False def _ensure_keycloak_password(self, conn, ctx: RequestContext) -> None: @@ -679,7 +728,7 @@ class ProvisioningManager: raise RuntimeError("initial password missing") except Exception as exc: detail = safe_error_detail(exc, "failed to set password") - self._task_error(conn, ctx.request_code, "keycloak_password", detail, start) + self._task_fail(conn, ctx.request_code, "keycloak_password", detail, start) def _ensure_keycloak_groups(self, conn, ctx: RequestContext) -> None: start = datetime.now(timezone.utc) @@ -694,7 +743,7 @@ class ProvisioningManager: self._task_ok(conn, ctx.request_code, "keycloak_groups", None, start) except Exception as exc: detail = safe_error_detail(exc, "failed to add groups") - self._task_error(conn, ctx.request_code, "keycloak_groups", detail, start) + self._task_fail(conn, ctx.request_code, "keycloak_groups", detail, start) def _ensure_mailu_app_password(self, conn, ctx: RequestContext) -> None: start = datetime.now(timezone.utc) @@ -707,7 +756,7 @@ class ProvisioningManager: self._task_ok(conn, ctx.request_code, "mailu_app_password", None, start) except Exception as exc: detail = safe_error_detail(exc, "failed to set mail password") - self._task_error(conn, ctx.request_code, "mailu_app_password", detail, start) + self._task_fail(conn, ctx.request_code, "mailu_app_password", detail, start) def _sync_mailu(self, conn, ctx: RequestContext) -> bool: start = datetime.now(timezone.utc) @@ -727,7 +776,7 @@ class ProvisioningManager: return True except Exception as exc: detail = safe_error_detail(exc, "failed to sync mailu") - self._task_error(conn, ctx.request_code, "mailu_sync", detail, start) + self._task_fail(conn, ctx.request_code, "mailu_sync", detail, start) return False def _sync_nextcloud_mail(self, conn, ctx: RequestContext) -> None: @@ -749,10 +798,10 @@ class ProvisioningManager: if not detail and isinstance(result, dict): detail = str(result.get("detail") or "") detail = detail or str(status_val) - self._task_error(conn, ctx.request_code, "nextcloud_mail_sync", detail, start) + self._task_fail(conn, ctx.request_code, "nextcloud_mail_sync", detail, start) except Exception as exc: detail = safe_error_detail(exc, "failed to sync nextcloud") - self._task_error(conn, ctx.request_code, "nextcloud_mail_sync", detail, start) + self._task_fail(conn, ctx.request_code, "nextcloud_mail_sync", detail, start) def _ensure_wger_account(self, conn, ctx: RequestContext) -> None: start = datetime.now(timezone.utc) @@ -779,7 +828,7 @@ class ProvisioningManager: self._task_ok(conn, ctx.request_code, "wger_account", None, start) except Exception as exc: detail = safe_error_detail(exc, "failed to provision wger") - self._task_error(conn, ctx.request_code, "wger_account", detail, start) + self._task_fail(conn, ctx.request_code, "wger_account", detail, start) def _ensure_firefly_account(self, conn, ctx: RequestContext) -> None: start = datetime.now(timezone.utc) @@ -804,7 +853,7 @@ class ProvisioningManager: self._task_ok(conn, ctx.request_code, "firefly_account", None, start) except Exception as exc: detail = safe_error_detail(exc, "failed to provision firefly") - self._task_error(conn, ctx.request_code, "firefly_account", detail, start) + self._task_fail(conn, ctx.request_code, "firefly_account", detail, start) def _handle_vaultwarden_grandfathered(self, conn, ctx: RequestContext, start: datetime) -> None: lookup = vaultwarden.find_user_by_email(ctx.contact_email) @@ -827,7 +876,7 @@ class ProvisioningManager: ) return detail = lookup.detail or lookup.status - self._task_error(conn, ctx.request_code, "vaultwarden_invite", detail, start) + self._task_fail(conn, ctx.request_code, "vaultwarden_invite", detail, start) def _ensure_vaultwarden_invite(self, conn, ctx: RequestContext) -> None: start = datetime.now(timezone.utc) @@ -859,7 +908,7 @@ class ProvisioningManager: self._set_vaultwarden_attrs(ctx.username, ctx.mailu_email, status) except Exception as exc: detail = safe_error_detail(exc, "failed to provision vaultwarden") - self._task_error(conn, ctx.request_code, "vaultwarden_invite", detail, start) + self._task_fail(conn, ctx.request_code, "vaultwarden_invite", detail, start) def _send_welcome_email(self, request_code: str, username: str, contact_email: str) -> None: if not settings.welcome_email_enabled: diff --git a/tests/test_app.py b/tests/test_app.py index b13f6c8..598e048 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -110,6 +110,45 @@ def test_account_access_allows_missing_groups(monkeypatch) -> None: assert resp.status_code != 403 +def test_retry_access_request_ok(monkeypatch) -> None: + ctx = AuthContext(username="", email="", groups=[], claims={}) + client = _client(monkeypatch, ctx) + executed = [] + invoked = {} + + monkeypatch.setattr(app_module.keycloak_admin, "ready", lambda: True) + monkeypatch.setattr(app_module.portal_db, "fetchone", lambda *_args, **_kwargs: {"status": "accounts_building"}) + monkeypatch.setattr(app_module.portal_db, "execute", lambda query, params=None: executed.append((query, params))) + monkeypatch.setattr(app_module.provisioning, "provision_access_request", lambda code: invoked.setdefault("code", code)) + monkeypatch.setattr(app_module, "_record_event", lambda *args, **kwargs: None) + + resp = client.post("/api/access/requests/REQ123/retry") + assert resp.status_code == 200 + assert resp.json()["request_code"] == "REQ123" + assert invoked["code"] == "REQ123" + assert any("provision_attempted_at" in query for query, _params in executed) + + +def test_retry_access_request_not_found(monkeypatch) -> None: + ctx = AuthContext(username="", email="", groups=[], claims={}) + client = _client(monkeypatch, ctx) + monkeypatch.setattr(app_module.keycloak_admin, "ready", lambda: True) + monkeypatch.setattr(app_module.portal_db, "fetchone", lambda *_args, **_kwargs: None) + + resp = client.post("/api/access/requests/REQ123/retry") + assert resp.status_code == 404 + + +def test_retry_access_request_not_retryable(monkeypatch) -> None: + ctx = AuthContext(username="", email="", groups=[], claims={}) + client = _client(monkeypatch, ctx) + monkeypatch.setattr(app_module.keycloak_admin, "ready", lambda: True) + monkeypatch.setattr(app_module.portal_db, "fetchone", lambda *_args, **_kwargs: {"status": "ready"}) + + resp = client.post("/api/access/requests/REQ123/retry") + assert resp.status_code == 409 + + def test_metrics_endpoint(monkeypatch) -> None: ctx = AuthContext(username="", email="", groups=[], claims={}) client = _client(monkeypatch, ctx) diff --git a/tests/test_provisioning.py b/tests/test_provisioning.py index 9bdf34d..7879f47 100644 --- a/tests/test_provisioning.py +++ b/tests/test_provisioning.py @@ -809,6 +809,15 @@ def test_provisioning_task_helpers() -> None: assert manager._all_tasks_ok(Conn(), "REQ", ["b"]) is False +def test_provisioning_retryable_detail_detection() -> None: + manager = prov.ProvisioningManager(DummyDB({}, locked=True), DummyStorage()) + assert manager._is_retryable_detail("timeout") is True + assert manager._is_retryable_detail("http 503: service unavailable") is True + assert manager._is_retryable_detail("mailbox not ready") is True + assert manager._is_retryable_detail("invalid credentials") is False + assert manager._retryable_detail("timeout").startswith("retryable:") + + def test_provisioning_ensure_task_rows_empty() -> None: manager = prov.ProvisioningManager(DummyDB({}), DummyStorage()) manager._ensure_task_rows(DummyConn({}, locked=True), "REQ", [])