diff --git a/ariadne/manager/provisioning.py b/ariadne/manager/provisioning.py index c202032..3ab35c7 100644 --- a/ariadne/manager/provisioning.py +++ b/ariadne/manager/provisioning.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone import hashlib import threading import time @@ -30,6 +30,7 @@ WGER_PASSWORD_ATTR = "wger_password" WGER_PASSWORD_UPDATED_ATTR = "wger_password_updated_at" FIREFLY_PASSWORD_ATTR = "firefly_password" FIREFLY_PASSWORD_UPDATED_ATTR = "firefly_password_updated_at" +VAULTWARDEN_GRANDFATHERED_FLAG = "vaultwarden_grandfathered" logger = get_logger(__name__) @@ -445,6 +446,70 @@ class ProvisioningManager: self._upsert_task(conn, request_code, task, "error", detail) self._record_task(request_code, task, "error", detail, started) + def _task_pending( + self, + conn, + request_code: str, + task: str, + detail: str, + started: datetime, + ) -> None: + self._upsert_task(conn, request_code, task, "pending", detail) + self._record_task(request_code, task, "pending", detail, started) + + def _vaultwarden_rate_limit_detail(self) -> tuple[str, datetime]: + retry_at = datetime.now(timezone.utc) + timedelta( + seconds=float(settings.vaultwarden_admin_rate_limit_backoff_sec) + ) + retry_iso = retry_at.strftime("%Y-%m-%dT%H:%M:%SZ") + return f"rate limited until {retry_iso}", retry_at + + @staticmethod + def _parse_retry_at(detail: str) -> datetime | None: + prefix = "rate limited until " + if not isinstance(detail, str) or not detail.startswith(prefix): + return None + ts = detail[len(prefix) :].strip() + for fmt in ("%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M:%S%z"): + try: + parsed = datetime.strptime(ts, fmt) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) + return parsed + except ValueError: + continue + return None + + def _vaultwarden_retry_due(self, conn, request_code: str) -> bool: + row = conn.execute( + """ + SELECT status, detail + FROM access_request_tasks + WHERE request_code = %s AND task = 'vaultwarden_invite' + """, + (request_code,), + ).fetchone() + if not isinstance(row, dict): + return True + if row.get("status") != "pending": + return True + retry_at = self._parse_retry_at(row.get("detail") or "") + if not retry_at: + return True + return datetime.now(timezone.utc) >= retry_at + + @staticmethod + def _set_vaultwarden_attrs(username: str, email: str, status: str) -> None: + if not username or not email or not status: + return + try: + now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + keycloak_admin.set_user_attribute(username, "vaultwarden_email", email) + keycloak_admin.set_user_attribute(username, "vaultwarden_status", status) + keycloak_admin.set_user_attribute(username, "vaultwarden_synced_at", now_iso) + except Exception: + return + def _ready_for_retry(self, ctx: RequestContext) -> bool: if ctx.status != "accounts_building": return True @@ -741,9 +806,37 @@ class ProvisioningManager: detail = safe_error_detail(exc, "failed to provision firefly") self._task_error(conn, ctx.request_code, "firefly_account", detail, start) + def _handle_vaultwarden_grandfathered(self, conn, ctx: RequestContext, start: datetime) -> None: + lookup = vaultwarden.find_user_by_email(ctx.contact_email) + if lookup.status == "rate_limited": + detail, _ = self._vaultwarden_rate_limit_detail() + self._task_pending(conn, ctx.request_code, "vaultwarden_invite", detail, start) + self._set_vaultwarden_attrs(ctx.username, ctx.contact_email, "rate_limited") + return + if lookup.ok and lookup.status == "present": + self._task_ok(conn, ctx.request_code, "vaultwarden_invite", "grandfathered", start) + self._set_vaultwarden_attrs(ctx.username, ctx.contact_email, "grandfathered") + return + if lookup.ok and lookup.status == "missing": + self._task_error( + conn, + ctx.request_code, + "vaultwarden_invite", + "vaultwarden account not found for recovery email", + start, + ) + return + detail = lookup.detail or lookup.status + self._task_error(conn, ctx.request_code, "vaultwarden_invite", detail, start) + def _ensure_vaultwarden_invite(self, conn, ctx: RequestContext) -> None: start = datetime.now(timezone.utc) try: + if not self._vaultwarden_retry_due(conn, ctx.request_code): + return + if VAULTWARDEN_GRANDFATHERED_FLAG in ctx.approval_flags: + self._handle_vaultwarden_grandfathered(conn, ctx, start) + return if not mailu.wait_for_mailbox(ctx.mailu_email, settings.mailu_mailbox_wait_timeout_sec): try: mailu.sync(reason="ariadne_vaultwarden_retry", force=True) @@ -755,17 +848,15 @@ class ProvisioningManager: result = vaultwarden.invite_user(ctx.mailu_email) if result.ok: self._task_ok(conn, ctx.request_code, "vaultwarden_invite", result.status, start) + elif result.status == "rate_limited": + detail, _ = self._vaultwarden_rate_limit_detail() + self._task_pending(conn, ctx.request_code, "vaultwarden_invite", detail, start) else: detail = result.detail or result.status self._task_error(conn, ctx.request_code, "vaultwarden_invite", detail, start) - try: - now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - keycloak_admin.set_user_attribute(ctx.username, "vaultwarden_email", ctx.mailu_email) - keycloak_admin.set_user_attribute(ctx.username, "vaultwarden_status", result.status) - keycloak_admin.set_user_attribute(ctx.username, "vaultwarden_synced_at", now_iso) - except Exception: - pass + status = result.status if result.status != "rate_limited" else "rate_limited" + self._set_vaultwarden_attrs(ctx.username, ctx.mailu_email, status) except Exception as exc: detail = safe_error_detail(exc, "failed to provision vaultwarden") self._task_error(conn, ctx.request_code, "vaultwarden_invite", detail, start) diff --git a/ariadne/services/vaultwarden.py b/ariadne/services/vaultwarden.py index 8f34222..b7c5f6c 100644 --- a/ariadne/services/vaultwarden.py +++ b/ariadne/services/vaultwarden.py @@ -25,6 +25,13 @@ class VaultwardenInvite: detail: str = "" +@dataclass(frozen=True) +class VaultwardenLookup: + ok: bool + status: str + detail: str = "" + + class VaultwardenService: def __init__(self) -> None: self._admin_lock = threading.Lock() @@ -102,6 +109,33 @@ class VaultwardenService: result = VaultwardenInvite(ok=False, status="error", detail=message) return result + def _lookup_via(self, base_url: str, email: str) -> VaultwardenLookup | None: + if not base_url: + return None + try: + session = self._admin_session(base_url) + resp = session.get("/admin/users") + if resp.status_code == HTTP_TOO_MANY_REQUESTS: + self._rate_limited_until = time.time() + float(settings.vaultwarden_admin_rate_limit_backoff_sec) + return VaultwardenLookup(ok=False, status="rate_limited", detail="vaultwarden rate limited") + resp.raise_for_status() + users = resp.json() + if not isinstance(users, list): + return VaultwardenLookup(ok=False, status="error", detail="unexpected users response") + target = email.lower() + for entry in users: + if not isinstance(entry, dict): + continue + user_email = entry.get("email") + if isinstance(user_email, str) and user_email.lower() == target: + return VaultwardenLookup(ok=True, status="present", detail="user found") + return VaultwardenLookup(ok=True, status="missing", detail="user missing") + except Exception as exc: + message = str(exc) + if "rate limited" in message.lower(): + return VaultwardenLookup(ok=False, status="rate_limited", detail="vaultwarden rate limited") + return VaultwardenLookup(ok=False, status="error", detail=message) + def invite_user(self, email: str) -> VaultwardenInvite: email = self._normalize_email(email) if not email: @@ -122,6 +156,30 @@ class VaultwardenService: return VaultwardenInvite(ok=False, status="error", detail=last_error or "failed to invite") + def find_user_by_email(self, email: str) -> VaultwardenLookup: + email = self._normalize_email(email) + if not email: + return VaultwardenLookup(ok=False, status="invalid_email", detail="email invalid") + if self._rate_limited_until and time.time() < self._rate_limited_until: + return VaultwardenLookup(ok=False, status="rate_limited", detail="vaultwarden rate limited") + + last_error = "" + for candidate in self._candidate_urls(): + result = self._lookup_via(candidate, email) + if not result: + continue + if result.ok: + return result + if result.status == "rate_limited": + return result + last_error = result.detail or last_error + + return VaultwardenLookup( + ok=False, + status="error", + detail=last_error or "failed to lookup user", + ) + def _admin_session(self, base_url: str) -> httpx.Client: now = time.time() with self._admin_lock: diff --git a/tests/test_provisioning.py b/tests/test_provisioning.py index 0a30714..68bed6c 100644 --- a/tests/test_provisioning.py +++ b/tests/test_provisioning.py @@ -5,7 +5,7 @@ from datetime import datetime, timezone import types from ariadne.manager import provisioning as prov -from ariadne.services.vaultwarden import VaultwardenInvite +from ariadne.services.vaultwarden import VaultwardenInvite, VaultwardenLookup class DummyResult: @@ -1507,6 +1507,110 @@ def test_provisioning_vaultwarden_attribute_failure(monkeypatch) -> None: assert outcome.status == "accounts_building" +def test_provisioning_vaultwarden_rate_limited(monkeypatch) -> None: + dummy_settings = types.SimpleNamespace( + mailu_domain="bstein.dev", + mailu_sync_url="", + mailu_mailbox_wait_timeout_sec=1.0, + vaultwarden_admin_rate_limit_backoff_sec=600.0, + provision_retry_cooldown_sec=0.0, + ) + monkeypatch.setattr(prov, "settings", dummy_settings) + monkeypatch.setattr(prov, "keycloak_admin", DummyAdmin()) + monkeypatch.setattr(prov.mailu, "wait_for_mailbox", lambda *_args, **_kwargs: True) + monkeypatch.setattr( + prov.vaultwarden, + "invite_user", + lambda _email: VaultwardenInvite(False, "rate_limited", "vaultwarden rate limited"), + ) + + row = { + "username": "alice", + "contact_email": "alice@example.com", + "email_verified_at": datetime.now(timezone.utc), + "status": "accounts_building", + "initial_password": "temp", + "initial_password_revealed_at": None, + "provision_attempted_at": None, + "approval_flags": [], + } + manager = prov.ProvisioningManager(DummyDB(row), DummyStorage()) + ctx = prov.RequestContext( + request_code="REQ_RATE", + username="alice", + first_name="", + last_name="", + contact_email="alice@example.com", + email_verified_at=row["email_verified_at"], + status="accounts_building", + initial_password="temp", + revealed_at=None, + attempted_at=None, + approval_flags=[], + user_id="1", + mailu_email="alice@bstein.dev", + ) + conn = DummyConn(row) + manager._ensure_vaultwarden_invite(conn, ctx) + + inserts = [params for query, params in conn.executed if "access_request_tasks" in query] + assert any(params[2] == "pending" and "rate limited until" in (params[3] or "") for params in inserts) + + +def test_provisioning_vaultwarden_grandfathered(monkeypatch) -> None: + dummy_settings = types.SimpleNamespace( + mailu_domain="bstein.dev", + mailu_sync_url="", + mailu_mailbox_wait_timeout_sec=1.0, + vaultwarden_admin_rate_limit_backoff_sec=600.0, + provision_retry_cooldown_sec=0.0, + ) + monkeypatch.setattr(prov, "settings", dummy_settings) + monkeypatch.setattr(prov, "keycloak_admin", DummyAdmin()) + monkeypatch.setattr( + prov.vaultwarden, + "find_user_by_email", + lambda _email: VaultwardenLookup(True, "present", "user found"), + ) + monkeypatch.setattr( + prov.vaultwarden, + "invite_user", + lambda _email: (_ for _ in ()).throw(RuntimeError("invite should not run")), + ) + + row = { + "username": "legacy", + "contact_email": "legacy@example.com", + "email_verified_at": datetime.now(timezone.utc), + "status": "accounts_building", + "initial_password": "temp", + "initial_password_revealed_at": None, + "provision_attempted_at": None, + "approval_flags": [prov.VAULTWARDEN_GRANDFATHERED_FLAG], + } + manager = prov.ProvisioningManager(DummyDB(row), DummyStorage()) + ctx = prov.RequestContext( + request_code="REQ_GRANDFATHER", + username="legacy", + first_name="", + last_name="", + contact_email="legacy@example.com", + email_verified_at=row["email_verified_at"], + status="accounts_building", + initial_password="temp", + revealed_at=None, + attempted_at=None, + approval_flags=row["approval_flags"], + user_id="1", + mailu_email="legacy@bstein.dev", + ) + conn = DummyConn(row) + manager._ensure_vaultwarden_invite(conn, ctx) + + inserts = [params for query, params in conn.executed if "access_request_tasks" in query] + assert any(params[2] == "ok" and params[3] == "grandfathered" for params in inserts) + + def test_provisioning_complete_event_failure(monkeypatch) -> None: dummy_settings = types.SimpleNamespace( mailu_domain="bstein.dev", diff --git a/tests/test_services.py b/tests/test_services.py index e87bbe8..fc85b5a 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -31,13 +31,19 @@ class DummyExecutor: class DummyResponse: - def __init__(self, status_code=200, text=""): + def __init__(self, status_code=200, text="", json_data=None): self.status_code = status_code self.text = text + self._json_data = json_data def raise_for_status(self): return None + def json(self): + if self._json_data is None: + raise ValueError("no json data") + return self._json_data + class DummyVaultwardenClient: def __init__(self): @@ -51,6 +57,13 @@ class DummyVaultwardenClient: resp = DummyResponse(200, "") return resp + def get(self, path): + self.calls.append((path, None, None)) + resp = self.responses.get(path) + if resp is None: + resp = DummyResponse(200, "", json_data=[]) + return resp + def close(self): return None @@ -925,6 +938,60 @@ def test_vaultwarden_invite_rate_limited_short_circuit() -> None: assert result.status == "rate_limited" +def test_vaultwarden_lookup_user_present(monkeypatch) -> None: + dummy_settings = types.SimpleNamespace( + vaultwarden_namespace="vaultwarden", + vaultwarden_admin_secret_name="vaultwarden-admin", + vaultwarden_admin_secret_key="ADMIN_TOKEN", + vaultwarden_admin_rate_limit_backoff_sec=600, + vaultwarden_admin_session_ttl_sec=900, + vaultwarden_service_host="vaultwarden-service.vaultwarden.svc.cluster.local", + vaultwarden_pod_label="app=vaultwarden", + vaultwarden_pod_port=80, + ) + client = DummyVaultwardenClient() + client.responses["/admin/users"] = DummyResponse(200, "", json_data=[{"email": "alice@bstein.dev"}]) + + monkeypatch.setattr("ariadne.services.vaultwarden.settings", dummy_settings) + monkeypatch.setattr("ariadne.services.vaultwarden.get_secret_value", lambda *args, **kwargs: "token") + monkeypatch.setattr("ariadne.services.vaultwarden.httpx.Client", lambda *args, **kwargs: client) + monkeypatch.setattr( + "ariadne.services.vaultwarden.VaultwardenService._find_pod_ip", + staticmethod(lambda *args, **kwargs: "127.0.0.1"), + ) + + svc = VaultwardenService() + result = svc.find_user_by_email("alice@bstein.dev") + assert result.status == "present" + + +def test_vaultwarden_lookup_user_missing(monkeypatch) -> None: + dummy_settings = types.SimpleNamespace( + vaultwarden_namespace="vaultwarden", + vaultwarden_admin_secret_name="vaultwarden-admin", + vaultwarden_admin_secret_key="ADMIN_TOKEN", + vaultwarden_admin_rate_limit_backoff_sec=600, + vaultwarden_admin_session_ttl_sec=900, + vaultwarden_service_host="vaultwarden-service.vaultwarden.svc.cluster.local", + vaultwarden_pod_label="app=vaultwarden", + vaultwarden_pod_port=80, + ) + client = DummyVaultwardenClient() + client.responses["/admin/users"] = DummyResponse(200, "", json_data=[{"email": "bob@bstein.dev"}]) + + monkeypatch.setattr("ariadne.services.vaultwarden.settings", dummy_settings) + monkeypatch.setattr("ariadne.services.vaultwarden.get_secret_value", lambda *args, **kwargs: "token") + monkeypatch.setattr("ariadne.services.vaultwarden.httpx.Client", lambda *args, **kwargs: client) + monkeypatch.setattr( + "ariadne.services.vaultwarden.VaultwardenService._find_pod_ip", + staticmethod(lambda *args, **kwargs: "127.0.0.1"), + ) + + svc = VaultwardenService() + result = svc.find_user_by_email("alice@bstein.dev") + assert result.status == "missing" + + def test_vaultwarden_invite_handles_admin_exception(monkeypatch) -> None: dummy_settings = types.SimpleNamespace( vaultwarden_namespace="vaultwarden",