fix(keycloak): preserve profile updates

This commit is contained in:
Brad Stein 2026-01-20 03:58:56 -03:00
parent ef912df950
commit 681e9aa358
4 changed files with 146 additions and 5 deletions

View File

@ -218,7 +218,7 @@ class ProvisioningManager:
actions = full.get("requiredActions")
if isinstance(actions, list) and "CONFIGURE_TOTP" in actions:
new_actions = [a for a in actions if a != "CONFIGURE_TOTP"]
keycloak_admin.update_user(user_id, {"requiredActions": new_actions})
keycloak_admin.update_user_safe(user_id, {"requiredActions": new_actions})
if isinstance(attrs, dict):
existing = _extract_attr(attrs, MAILU_EMAIL_ATTR)
if existing:

View File

@ -14,6 +14,36 @@ class KeycloakAdminClient:
self._expires_at: float = 0.0
self._group_id_cache: dict[str, str] = {}
@staticmethod
def _safe_update_payload(full: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {}
username = full.get("username")
if isinstance(username, str):
payload["username"] = username
enabled = full.get("enabled")
if isinstance(enabled, bool):
payload["enabled"] = enabled
email = full.get("email")
if isinstance(email, str):
payload["email"] = email
email_verified = full.get("emailVerified")
if isinstance(email_verified, bool):
payload["emailVerified"] = email_verified
first_name = full.get("firstName")
if isinstance(first_name, str):
payload["firstName"] = first_name
last_name = full.get("lastName")
if isinstance(last_name, str):
payload["lastName"] = last_name
actions = full.get("requiredActions")
if isinstance(actions, list):
payload["requiredActions"] = [a for a in actions if isinstance(a, str)]
attrs = full.get("attributes")
payload["attributes"] = attrs if isinstance(attrs, dict) else {}
return payload
def ready(self) -> bool:
return bool(settings.keycloak_admin_client_id and settings.keycloak_admin_client_secret)
@ -101,6 +131,21 @@ class KeycloakAdminClient:
resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
resp.raise_for_status()
def update_user_safe(self, user_id: str, payload: dict[str, Any]) -> None:
full = self.get_user(user_id)
merged = self._safe_update_payload(full)
for key, value in payload.items():
if key == "attributes":
attrs = merged.get("attributes")
if not isinstance(attrs, dict):
attrs = {}
if isinstance(value, dict):
attrs.update(value)
merged["attributes"] = attrs
continue
merged[key] = value
self.update_user(user_id, merged)
def create_user(self, payload: dict[str, Any]) -> str:
url = f"{settings.keycloak_admin_url}/admin/realms/{settings.keycloak_realm}/users"
with httpx.Client(timeout=10.0) as client:
@ -130,11 +175,13 @@ class KeycloakAdminClient:
raise RuntimeError("user id missing")
full = self.get_user(user_id)
attrs = full.get("attributes") or {}
payload = self._safe_update_payload(full)
attrs = payload.get("attributes")
if not isinstance(attrs, dict):
attrs = {}
attrs[key] = [value]
self.update_user(user_id, {"attributes": attrs})
payload["attributes"] = attrs
self.update_user(user_id, payload)
def get_group_id(self, group_name: str) -> str | None:
cached = self._group_id_cache.get(group_name)

View File

@ -0,0 +1,79 @@
from __future__ import annotations
from typing import Any
from ariadne.services.keycloak_admin import KeycloakAdminClient
def test_set_user_attribute_preserves_profile(monkeypatch) -> None:
client = KeycloakAdminClient()
captured: dict[str, Any] = {}
def fake_find_user(username: str) -> dict[str, Any]:
return {"id": "user-123"}
def fake_get_user(user_id: str) -> dict[str, Any]:
return {
"id": user_id,
"username": "alice",
"email": "alice@bstein.dev",
"emailVerified": True,
"enabled": True,
"firstName": "Alice",
"lastName": "Smith",
"requiredActions": ["UPDATE_PASSWORD", 123],
"attributes": {"existing": ["value"]},
}
def fake_update_user(user_id: str, payload: dict[str, Any]) -> None:
captured["user_id"] = user_id
captured["payload"] = payload
monkeypatch.setattr(client, "find_user", fake_find_user)
monkeypatch.setattr(client, "get_user", fake_get_user)
monkeypatch.setattr(client, "update_user", fake_update_user)
client.set_user_attribute("alice", "mailu_app_password", "secret")
payload = captured.get("payload") or {}
assert payload.get("username") == "alice"
assert payload.get("email") == "alice@bstein.dev"
assert payload.get("emailVerified") is True
assert payload.get("enabled") is True
assert payload.get("firstName") == "Alice"
assert payload.get("lastName") == "Smith"
assert payload.get("requiredActions") == ["UPDATE_PASSWORD"]
assert payload.get("attributes") == {
"existing": ["value"],
"mailu_app_password": ["secret"],
}
def test_update_user_safe_merges_payload(monkeypatch) -> None:
client = KeycloakAdminClient()
captured: dict[str, Any] = {}
def fake_get_user(user_id: str) -> dict[str, Any]:
return {
"id": user_id,
"username": "alice",
"enabled": True,
"attributes": {"existing": ["value"]},
}
def fake_update_user(user_id: str, payload: dict[str, Any]) -> None:
captured["user_id"] = user_id
captured["payload"] = payload
monkeypatch.setattr(client, "get_user", fake_get_user)
monkeypatch.setattr(client, "update_user", fake_update_user)
client.update_user_safe(
"user-123",
{"attributes": {"new": ["item"]}, "requiredActions": ["UPDATE_PASSWORD"]},
)
payload = captured.get("payload") or {}
assert payload.get("username") == "alice"
assert payload.get("attributes") == {"existing": ["value"], "new": ["item"]}
assert payload.get("requiredActions") == ["UPDATE_PASSWORD"]

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from datetime import datetime
from datetime import datetime, timezone
from ariadne.scheduler.cron import CronScheduler, CronTask
@ -25,7 +25,22 @@ def test_execute_task_records_failure() -> None:
raise RuntimeError("boom")
task = CronTask(name="test", cron_expr="*/5 * * * *", runner=runner)
scheduler._next_run["test"] = datetime.utcnow()
scheduler._next_run["test"] = datetime.now(timezone.utc)
scheduler._execute_task(task)
assert storage.task_runs
assert storage.schedule_states
def test_execute_task_records_success() -> None:
storage = DummyStorage()
scheduler = CronScheduler(storage, tick_sec=0.1)
def runner():
return None
task = CronTask(name="ok-task", cron_expr="*/5 * * * *", runner=runner)
scheduler._next_run["ok-task"] = datetime.now(timezone.utc)
scheduler._execute_task(task)
assert storage.task_runs