vaultwarden: handle rate limits and grandfathered

This commit is contained in:
Brad Stein 2026-01-24 02:20:16 -03:00
parent f50ec538db
commit d21595aaac
4 changed files with 330 additions and 10 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timedelta, timezone
import hashlib import hashlib
import threading import threading
import time import time
@ -30,6 +30,7 @@ WGER_PASSWORD_ATTR = "wger_password"
WGER_PASSWORD_UPDATED_ATTR = "wger_password_updated_at" WGER_PASSWORD_UPDATED_ATTR = "wger_password_updated_at"
FIREFLY_PASSWORD_ATTR = "firefly_password" FIREFLY_PASSWORD_ATTR = "firefly_password"
FIREFLY_PASSWORD_UPDATED_ATTR = "firefly_password_updated_at" FIREFLY_PASSWORD_UPDATED_ATTR = "firefly_password_updated_at"
VAULTWARDEN_GRANDFATHERED_FLAG = "vaultwarden_grandfathered"
logger = get_logger(__name__) logger = get_logger(__name__)
@ -445,6 +446,70 @@ class ProvisioningManager:
self._upsert_task(conn, request_code, task, "error", detail) self._upsert_task(conn, request_code, task, "error", detail)
self._record_task(request_code, task, "error", detail, started) self._record_task(request_code, task, "error", detail, started)
def _task_pending(
self,
conn,
request_code: str,
task: str,
detail: str,
started: datetime,
) -> None:
self._upsert_task(conn, request_code, task, "pending", detail)
self._record_task(request_code, task, "pending", detail, started)
def _vaultwarden_rate_limit_detail(self) -> tuple[str, datetime]:
retry_at = datetime.now(timezone.utc) + timedelta(
seconds=float(settings.vaultwarden_admin_rate_limit_backoff_sec)
)
retry_iso = retry_at.strftime("%Y-%m-%dT%H:%M:%SZ")
return f"rate limited until {retry_iso}", retry_at
@staticmethod
def _parse_retry_at(detail: str) -> datetime | None:
prefix = "rate limited until "
if not isinstance(detail, str) or not detail.startswith(prefix):
return None
ts = detail[len(prefix) :].strip()
for fmt in ("%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M:%S%z"):
try:
parsed = datetime.strptime(ts, fmt)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed
except ValueError:
continue
return None
def _vaultwarden_retry_due(self, conn, request_code: str) -> bool:
row = conn.execute(
"""
SELECT status, detail
FROM access_request_tasks
WHERE request_code = %s AND task = 'vaultwarden_invite'
""",
(request_code,),
).fetchone()
if not isinstance(row, dict):
return True
if row.get("status") != "pending":
return True
retry_at = self._parse_retry_at(row.get("detail") or "")
if not retry_at:
return True
return datetime.now(timezone.utc) >= retry_at
@staticmethod
def _set_vaultwarden_attrs(username: str, email: str, status: str) -> None:
if not username or not email or not status:
return
try:
now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
keycloak_admin.set_user_attribute(username, "vaultwarden_email", email)
keycloak_admin.set_user_attribute(username, "vaultwarden_status", status)
keycloak_admin.set_user_attribute(username, "vaultwarden_synced_at", now_iso)
except Exception:
return
def _ready_for_retry(self, ctx: RequestContext) -> bool: def _ready_for_retry(self, ctx: RequestContext) -> bool:
if ctx.status != "accounts_building": if ctx.status != "accounts_building":
return True return True
@ -741,9 +806,37 @@ class ProvisioningManager:
detail = safe_error_detail(exc, "failed to provision firefly") detail = safe_error_detail(exc, "failed to provision firefly")
self._task_error(conn, ctx.request_code, "firefly_account", detail, start) self._task_error(conn, ctx.request_code, "firefly_account", detail, start)
def _handle_vaultwarden_grandfathered(self, conn, ctx: RequestContext, start: datetime) -> None:
lookup = vaultwarden.find_user_by_email(ctx.contact_email)
if lookup.status == "rate_limited":
detail, _ = self._vaultwarden_rate_limit_detail()
self._task_pending(conn, ctx.request_code, "vaultwarden_invite", detail, start)
self._set_vaultwarden_attrs(ctx.username, ctx.contact_email, "rate_limited")
return
if lookup.ok and lookup.status == "present":
self._task_ok(conn, ctx.request_code, "vaultwarden_invite", "grandfathered", start)
self._set_vaultwarden_attrs(ctx.username, ctx.contact_email, "grandfathered")
return
if lookup.ok and lookup.status == "missing":
self._task_error(
conn,
ctx.request_code,
"vaultwarden_invite",
"vaultwarden account not found for recovery email",
start,
)
return
detail = lookup.detail or lookup.status
self._task_error(conn, ctx.request_code, "vaultwarden_invite", detail, start)
def _ensure_vaultwarden_invite(self, conn, ctx: RequestContext) -> None: def _ensure_vaultwarden_invite(self, conn, ctx: RequestContext) -> None:
start = datetime.now(timezone.utc) start = datetime.now(timezone.utc)
try: try:
if not self._vaultwarden_retry_due(conn, ctx.request_code):
return
if VAULTWARDEN_GRANDFATHERED_FLAG in ctx.approval_flags:
self._handle_vaultwarden_grandfathered(conn, ctx, start)
return
if not mailu.wait_for_mailbox(ctx.mailu_email, settings.mailu_mailbox_wait_timeout_sec): if not mailu.wait_for_mailbox(ctx.mailu_email, settings.mailu_mailbox_wait_timeout_sec):
try: try:
mailu.sync(reason="ariadne_vaultwarden_retry", force=True) mailu.sync(reason="ariadne_vaultwarden_retry", force=True)
@ -755,17 +848,15 @@ class ProvisioningManager:
result = vaultwarden.invite_user(ctx.mailu_email) result = vaultwarden.invite_user(ctx.mailu_email)
if result.ok: if result.ok:
self._task_ok(conn, ctx.request_code, "vaultwarden_invite", result.status, start) self._task_ok(conn, ctx.request_code, "vaultwarden_invite", result.status, start)
elif result.status == "rate_limited":
detail, _ = self._vaultwarden_rate_limit_detail()
self._task_pending(conn, ctx.request_code, "vaultwarden_invite", detail, start)
else: else:
detail = result.detail or result.status detail = result.detail or result.status
self._task_error(conn, ctx.request_code, "vaultwarden_invite", detail, start) self._task_error(conn, ctx.request_code, "vaultwarden_invite", detail, start)
try: status = result.status if result.status != "rate_limited" else "rate_limited"
now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") self._set_vaultwarden_attrs(ctx.username, ctx.mailu_email, status)
keycloak_admin.set_user_attribute(ctx.username, "vaultwarden_email", ctx.mailu_email)
keycloak_admin.set_user_attribute(ctx.username, "vaultwarden_status", result.status)
keycloak_admin.set_user_attribute(ctx.username, "vaultwarden_synced_at", now_iso)
except Exception:
pass
except Exception as exc: except Exception as exc:
detail = safe_error_detail(exc, "failed to provision vaultwarden") detail = safe_error_detail(exc, "failed to provision vaultwarden")
self._task_error(conn, ctx.request_code, "vaultwarden_invite", detail, start) self._task_error(conn, ctx.request_code, "vaultwarden_invite", detail, start)

View File

@ -25,6 +25,13 @@ class VaultwardenInvite:
detail: str = "" detail: str = ""
@dataclass(frozen=True)
class VaultwardenLookup:
ok: bool
status: str
detail: str = ""
class VaultwardenService: class VaultwardenService:
def __init__(self) -> None: def __init__(self) -> None:
self._admin_lock = threading.Lock() self._admin_lock = threading.Lock()
@ -102,6 +109,33 @@ class VaultwardenService:
result = VaultwardenInvite(ok=False, status="error", detail=message) result = VaultwardenInvite(ok=False, status="error", detail=message)
return result return result
def _lookup_via(self, base_url: str, email: str) -> VaultwardenLookup | None:
if not base_url:
return None
try:
session = self._admin_session(base_url)
resp = session.get("/admin/users")
if resp.status_code == HTTP_TOO_MANY_REQUESTS:
self._rate_limited_until = time.time() + float(settings.vaultwarden_admin_rate_limit_backoff_sec)
return VaultwardenLookup(ok=False, status="rate_limited", detail="vaultwarden rate limited")
resp.raise_for_status()
users = resp.json()
if not isinstance(users, list):
return VaultwardenLookup(ok=False, status="error", detail="unexpected users response")
target = email.lower()
for entry in users:
if not isinstance(entry, dict):
continue
user_email = entry.get("email")
if isinstance(user_email, str) and user_email.lower() == target:
return VaultwardenLookup(ok=True, status="present", detail="user found")
return VaultwardenLookup(ok=True, status="missing", detail="user missing")
except Exception as exc:
message = str(exc)
if "rate limited" in message.lower():
return VaultwardenLookup(ok=False, status="rate_limited", detail="vaultwarden rate limited")
return VaultwardenLookup(ok=False, status="error", detail=message)
def invite_user(self, email: str) -> VaultwardenInvite: def invite_user(self, email: str) -> VaultwardenInvite:
email = self._normalize_email(email) email = self._normalize_email(email)
if not email: if not email:
@ -122,6 +156,30 @@ class VaultwardenService:
return VaultwardenInvite(ok=False, status="error", detail=last_error or "failed to invite") return VaultwardenInvite(ok=False, status="error", detail=last_error or "failed to invite")
def find_user_by_email(self, email: str) -> VaultwardenLookup:
email = self._normalize_email(email)
if not email:
return VaultwardenLookup(ok=False, status="invalid_email", detail="email invalid")
if self._rate_limited_until and time.time() < self._rate_limited_until:
return VaultwardenLookup(ok=False, status="rate_limited", detail="vaultwarden rate limited")
last_error = ""
for candidate in self._candidate_urls():
result = self._lookup_via(candidate, email)
if not result:
continue
if result.ok:
return result
if result.status == "rate_limited":
return result
last_error = result.detail or last_error
return VaultwardenLookup(
ok=False,
status="error",
detail=last_error or "failed to lookup user",
)
def _admin_session(self, base_url: str) -> httpx.Client: def _admin_session(self, base_url: str) -> httpx.Client:
now = time.time() now = time.time()
with self._admin_lock: with self._admin_lock:

View File

@ -5,7 +5,7 @@ from datetime import datetime, timezone
import types import types
from ariadne.manager import provisioning as prov from ariadne.manager import provisioning as prov
from ariadne.services.vaultwarden import VaultwardenInvite from ariadne.services.vaultwarden import VaultwardenInvite, VaultwardenLookup
class DummyResult: class DummyResult:
@ -1507,6 +1507,110 @@ def test_provisioning_vaultwarden_attribute_failure(monkeypatch) -> None:
assert outcome.status == "accounts_building" assert outcome.status == "accounts_building"
def test_provisioning_vaultwarden_rate_limited(monkeypatch) -> None:
dummy_settings = types.SimpleNamespace(
mailu_domain="bstein.dev",
mailu_sync_url="",
mailu_mailbox_wait_timeout_sec=1.0,
vaultwarden_admin_rate_limit_backoff_sec=600.0,
provision_retry_cooldown_sec=0.0,
)
monkeypatch.setattr(prov, "settings", dummy_settings)
monkeypatch.setattr(prov, "keycloak_admin", DummyAdmin())
monkeypatch.setattr(prov.mailu, "wait_for_mailbox", lambda *_args, **_kwargs: True)
monkeypatch.setattr(
prov.vaultwarden,
"invite_user",
lambda _email: VaultwardenInvite(False, "rate_limited", "vaultwarden rate limited"),
)
row = {
"username": "alice",
"contact_email": "alice@example.com",
"email_verified_at": datetime.now(timezone.utc),
"status": "accounts_building",
"initial_password": "temp",
"initial_password_revealed_at": None,
"provision_attempted_at": None,
"approval_flags": [],
}
manager = prov.ProvisioningManager(DummyDB(row), DummyStorage())
ctx = prov.RequestContext(
request_code="REQ_RATE",
username="alice",
first_name="",
last_name="",
contact_email="alice@example.com",
email_verified_at=row["email_verified_at"],
status="accounts_building",
initial_password="temp",
revealed_at=None,
attempted_at=None,
approval_flags=[],
user_id="1",
mailu_email="alice@bstein.dev",
)
conn = DummyConn(row)
manager._ensure_vaultwarden_invite(conn, ctx)
inserts = [params for query, params in conn.executed if "access_request_tasks" in query]
assert any(params[2] == "pending" and "rate limited until" in (params[3] or "") for params in inserts)
def test_provisioning_vaultwarden_grandfathered(monkeypatch) -> None:
dummy_settings = types.SimpleNamespace(
mailu_domain="bstein.dev",
mailu_sync_url="",
mailu_mailbox_wait_timeout_sec=1.0,
vaultwarden_admin_rate_limit_backoff_sec=600.0,
provision_retry_cooldown_sec=0.0,
)
monkeypatch.setattr(prov, "settings", dummy_settings)
monkeypatch.setattr(prov, "keycloak_admin", DummyAdmin())
monkeypatch.setattr(
prov.vaultwarden,
"find_user_by_email",
lambda _email: VaultwardenLookup(True, "present", "user found"),
)
monkeypatch.setattr(
prov.vaultwarden,
"invite_user",
lambda _email: (_ for _ in ()).throw(RuntimeError("invite should not run")),
)
row = {
"username": "legacy",
"contact_email": "legacy@example.com",
"email_verified_at": datetime.now(timezone.utc),
"status": "accounts_building",
"initial_password": "temp",
"initial_password_revealed_at": None,
"provision_attempted_at": None,
"approval_flags": [prov.VAULTWARDEN_GRANDFATHERED_FLAG],
}
manager = prov.ProvisioningManager(DummyDB(row), DummyStorage())
ctx = prov.RequestContext(
request_code="REQ_GRANDFATHER",
username="legacy",
first_name="",
last_name="",
contact_email="legacy@example.com",
email_verified_at=row["email_verified_at"],
status="accounts_building",
initial_password="temp",
revealed_at=None,
attempted_at=None,
approval_flags=row["approval_flags"],
user_id="1",
mailu_email="legacy@bstein.dev",
)
conn = DummyConn(row)
manager._ensure_vaultwarden_invite(conn, ctx)
inserts = [params for query, params in conn.executed if "access_request_tasks" in query]
assert any(params[2] == "ok" and params[3] == "grandfathered" for params in inserts)
def test_provisioning_complete_event_failure(monkeypatch) -> None: def test_provisioning_complete_event_failure(monkeypatch) -> None:
dummy_settings = types.SimpleNamespace( dummy_settings = types.SimpleNamespace(
mailu_domain="bstein.dev", mailu_domain="bstein.dev",

View File

@ -31,13 +31,19 @@ class DummyExecutor:
class DummyResponse: class DummyResponse:
def __init__(self, status_code=200, text=""): def __init__(self, status_code=200, text="", json_data=None):
self.status_code = status_code self.status_code = status_code
self.text = text self.text = text
self._json_data = json_data
def raise_for_status(self): def raise_for_status(self):
return None return None
def json(self):
if self._json_data is None:
raise ValueError("no json data")
return self._json_data
class DummyVaultwardenClient: class DummyVaultwardenClient:
def __init__(self): def __init__(self):
@ -51,6 +57,13 @@ class DummyVaultwardenClient:
resp = DummyResponse(200, "") resp = DummyResponse(200, "")
return resp return resp
def get(self, path):
self.calls.append((path, None, None))
resp = self.responses.get(path)
if resp is None:
resp = DummyResponse(200, "", json_data=[])
return resp
def close(self): def close(self):
return None return None
@ -925,6 +938,60 @@ def test_vaultwarden_invite_rate_limited_short_circuit() -> None:
assert result.status == "rate_limited" assert result.status == "rate_limited"
def test_vaultwarden_lookup_user_present(monkeypatch) -> None:
dummy_settings = types.SimpleNamespace(
vaultwarden_namespace="vaultwarden",
vaultwarden_admin_secret_name="vaultwarden-admin",
vaultwarden_admin_secret_key="ADMIN_TOKEN",
vaultwarden_admin_rate_limit_backoff_sec=600,
vaultwarden_admin_session_ttl_sec=900,
vaultwarden_service_host="vaultwarden-service.vaultwarden.svc.cluster.local",
vaultwarden_pod_label="app=vaultwarden",
vaultwarden_pod_port=80,
)
client = DummyVaultwardenClient()
client.responses["/admin/users"] = DummyResponse(200, "", json_data=[{"email": "alice@bstein.dev"}])
monkeypatch.setattr("ariadne.services.vaultwarden.settings", dummy_settings)
monkeypatch.setattr("ariadne.services.vaultwarden.get_secret_value", lambda *args, **kwargs: "token")
monkeypatch.setattr("ariadne.services.vaultwarden.httpx.Client", lambda *args, **kwargs: client)
monkeypatch.setattr(
"ariadne.services.vaultwarden.VaultwardenService._find_pod_ip",
staticmethod(lambda *args, **kwargs: "127.0.0.1"),
)
svc = VaultwardenService()
result = svc.find_user_by_email("alice@bstein.dev")
assert result.status == "present"
def test_vaultwarden_lookup_user_missing(monkeypatch) -> None:
dummy_settings = types.SimpleNamespace(
vaultwarden_namespace="vaultwarden",
vaultwarden_admin_secret_name="vaultwarden-admin",
vaultwarden_admin_secret_key="ADMIN_TOKEN",
vaultwarden_admin_rate_limit_backoff_sec=600,
vaultwarden_admin_session_ttl_sec=900,
vaultwarden_service_host="vaultwarden-service.vaultwarden.svc.cluster.local",
vaultwarden_pod_label="app=vaultwarden",
vaultwarden_pod_port=80,
)
client = DummyVaultwardenClient()
client.responses["/admin/users"] = DummyResponse(200, "", json_data=[{"email": "bob@bstein.dev"}])
monkeypatch.setattr("ariadne.services.vaultwarden.settings", dummy_settings)
monkeypatch.setattr("ariadne.services.vaultwarden.get_secret_value", lambda *args, **kwargs: "token")
monkeypatch.setattr("ariadne.services.vaultwarden.httpx.Client", lambda *args, **kwargs: client)
monkeypatch.setattr(
"ariadne.services.vaultwarden.VaultwardenService._find_pod_ip",
staticmethod(lambda *args, **kwargs: "127.0.0.1"),
)
svc = VaultwardenService()
result = svc.find_user_by_email("alice@bstein.dev")
assert result.status == "missing"
def test_vaultwarden_invite_handles_admin_exception(monkeypatch) -> None: def test_vaultwarden_invite_handles_admin_exception(monkeypatch) -> None:
dummy_settings = types.SimpleNamespace( dummy_settings = types.SimpleNamespace(
vaultwarden_namespace="vaultwarden", vaultwarden_namespace="vaultwarden",