fix(keycloak): preserve profile updates
This commit is contained in:
parent
ef912df950
commit
681e9aa358
@ -218,7 +218,7 @@ class ProvisioningManager:
|
||||
actions = full.get("requiredActions")
|
||||
if isinstance(actions, list) and "CONFIGURE_TOTP" in actions:
|
||||
new_actions = [a for a in actions if a != "CONFIGURE_TOTP"]
|
||||
keycloak_admin.update_user(user_id, {"requiredActions": new_actions})
|
||||
keycloak_admin.update_user_safe(user_id, {"requiredActions": new_actions})
|
||||
if isinstance(attrs, dict):
|
||||
existing = _extract_attr(attrs, MAILU_EMAIL_ATTR)
|
||||
if existing:
|
||||
|
||||
@ -14,6 +14,36 @@ class KeycloakAdminClient:
|
||||
self._expires_at: float = 0.0
|
||||
self._group_id_cache: dict[str, str] = {}
|
||||
|
||||
@staticmethod
|
||||
def _safe_update_payload(full: dict[str, Any]) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {}
|
||||
username = full.get("username")
|
||||
if isinstance(username, str):
|
||||
payload["username"] = username
|
||||
enabled = full.get("enabled")
|
||||
if isinstance(enabled, bool):
|
||||
payload["enabled"] = enabled
|
||||
email = full.get("email")
|
||||
if isinstance(email, str):
|
||||
payload["email"] = email
|
||||
email_verified = full.get("emailVerified")
|
||||
if isinstance(email_verified, bool):
|
||||
payload["emailVerified"] = email_verified
|
||||
first_name = full.get("firstName")
|
||||
if isinstance(first_name, str):
|
||||
payload["firstName"] = first_name
|
||||
last_name = full.get("lastName")
|
||||
if isinstance(last_name, str):
|
||||
payload["lastName"] = last_name
|
||||
|
||||
actions = full.get("requiredActions")
|
||||
if isinstance(actions, list):
|
||||
payload["requiredActions"] = [a for a in actions if isinstance(a, str)]
|
||||
|
||||
attrs = full.get("attributes")
|
||||
payload["attributes"] = attrs if isinstance(attrs, dict) else {}
|
||||
return payload
|
||||
|
||||
def ready(self) -> bool:
|
||||
return bool(settings.keycloak_admin_client_id and settings.keycloak_admin_client_secret)
|
||||
|
||||
@ -101,6 +131,21 @@ class KeycloakAdminClient:
|
||||
resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
|
||||
resp.raise_for_status()
|
||||
|
||||
def update_user_safe(self, user_id: str, payload: dict[str, Any]) -> None:
|
||||
full = self.get_user(user_id)
|
||||
merged = self._safe_update_payload(full)
|
||||
for key, value in payload.items():
|
||||
if key == "attributes":
|
||||
attrs = merged.get("attributes")
|
||||
if not isinstance(attrs, dict):
|
||||
attrs = {}
|
||||
if isinstance(value, dict):
|
||||
attrs.update(value)
|
||||
merged["attributes"] = attrs
|
||||
continue
|
||||
merged[key] = value
|
||||
self.update_user(user_id, merged)
|
||||
|
||||
def create_user(self, payload: dict[str, Any]) -> str:
|
||||
url = f"{settings.keycloak_admin_url}/admin/realms/{settings.keycloak_realm}/users"
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
@ -130,11 +175,13 @@ class KeycloakAdminClient:
|
||||
raise RuntimeError("user id missing")
|
||||
|
||||
full = self.get_user(user_id)
|
||||
attrs = full.get("attributes") or {}
|
||||
payload = self._safe_update_payload(full)
|
||||
attrs = payload.get("attributes")
|
||||
if not isinstance(attrs, dict):
|
||||
attrs = {}
|
||||
attrs[key] = [value]
|
||||
self.update_user(user_id, {"attributes": attrs})
|
||||
payload["attributes"] = attrs
|
||||
self.update_user(user_id, payload)
|
||||
|
||||
def get_group_id(self, group_name: str) -> str | None:
|
||||
cached = self._group_id_cache.get(group_name)
|
||||
|
||||
79
tests/test_keycloak_admin.py
Normal file
79
tests/test_keycloak_admin.py
Normal file
@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ariadne.services.keycloak_admin import KeycloakAdminClient
|
||||
|
||||
|
||||
def test_set_user_attribute_preserves_profile(monkeypatch) -> None:
|
||||
client = KeycloakAdminClient()
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def fake_find_user(username: str) -> dict[str, Any]:
|
||||
return {"id": "user-123"}
|
||||
|
||||
def fake_get_user(user_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": user_id,
|
||||
"username": "alice",
|
||||
"email": "alice@bstein.dev",
|
||||
"emailVerified": True,
|
||||
"enabled": True,
|
||||
"firstName": "Alice",
|
||||
"lastName": "Smith",
|
||||
"requiredActions": ["UPDATE_PASSWORD", 123],
|
||||
"attributes": {"existing": ["value"]},
|
||||
}
|
||||
|
||||
def fake_update_user(user_id: str, payload: dict[str, Any]) -> None:
|
||||
captured["user_id"] = user_id
|
||||
captured["payload"] = payload
|
||||
|
||||
monkeypatch.setattr(client, "find_user", fake_find_user)
|
||||
monkeypatch.setattr(client, "get_user", fake_get_user)
|
||||
monkeypatch.setattr(client, "update_user", fake_update_user)
|
||||
|
||||
client.set_user_attribute("alice", "mailu_app_password", "secret")
|
||||
|
||||
payload = captured.get("payload") or {}
|
||||
assert payload.get("username") == "alice"
|
||||
assert payload.get("email") == "alice@bstein.dev"
|
||||
assert payload.get("emailVerified") is True
|
||||
assert payload.get("enabled") is True
|
||||
assert payload.get("firstName") == "Alice"
|
||||
assert payload.get("lastName") == "Smith"
|
||||
assert payload.get("requiredActions") == ["UPDATE_PASSWORD"]
|
||||
assert payload.get("attributes") == {
|
||||
"existing": ["value"],
|
||||
"mailu_app_password": ["secret"],
|
||||
}
|
||||
|
||||
|
||||
def test_update_user_safe_merges_payload(monkeypatch) -> None:
|
||||
client = KeycloakAdminClient()
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
def fake_get_user(user_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": user_id,
|
||||
"username": "alice",
|
||||
"enabled": True,
|
||||
"attributes": {"existing": ["value"]},
|
||||
}
|
||||
|
||||
def fake_update_user(user_id: str, payload: dict[str, Any]) -> None:
|
||||
captured["user_id"] = user_id
|
||||
captured["payload"] = payload
|
||||
|
||||
monkeypatch.setattr(client, "get_user", fake_get_user)
|
||||
monkeypatch.setattr(client, "update_user", fake_update_user)
|
||||
|
||||
client.update_user_safe(
|
||||
"user-123",
|
||||
{"attributes": {"new": ["item"]}, "requiredActions": ["UPDATE_PASSWORD"]},
|
||||
)
|
||||
|
||||
payload = captured.get("payload") or {}
|
||||
assert payload.get("username") == "alice"
|
||||
assert payload.get("attributes") == {"existing": ["value"], "new": ["item"]}
|
||||
assert payload.get("requiredActions") == ["UPDATE_PASSWORD"]
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from ariadne.scheduler.cron import CronScheduler, CronTask
|
||||
|
||||
@ -25,7 +25,22 @@ def test_execute_task_records_failure() -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
task = CronTask(name="test", cron_expr="*/5 * * * *", runner=runner)
|
||||
scheduler._next_run["test"] = datetime.utcnow()
|
||||
scheduler._next_run["test"] = datetime.now(timezone.utc)
|
||||
scheduler._execute_task(task)
|
||||
|
||||
assert storage.task_runs
|
||||
assert storage.schedule_states
|
||||
|
||||
|
||||
def test_execute_task_records_success() -> None:
|
||||
storage = DummyStorage()
|
||||
scheduler = CronScheduler(storage, tick_sec=0.1)
|
||||
|
||||
def runner():
|
||||
return None
|
||||
|
||||
task = CronTask(name="ok-task", cron_expr="*/5 * * * *", runner=runner)
|
||||
scheduler._next_run["ok-task"] = datetime.now(timezone.utc)
|
||||
scheduler._execute_task(task)
|
||||
|
||||
assert storage.task_runs
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user