fix(keycloak): preserve profile updates
This commit is contained in:
parent
ef912df950
commit
681e9aa358
@ -218,7 +218,7 @@ class ProvisioningManager:
|
|||||||
actions = full.get("requiredActions")
|
actions = full.get("requiredActions")
|
||||||
if isinstance(actions, list) and "CONFIGURE_TOTP" in actions:
|
if isinstance(actions, list) and "CONFIGURE_TOTP" in actions:
|
||||||
new_actions = [a for a in actions if a != "CONFIGURE_TOTP"]
|
new_actions = [a for a in actions if a != "CONFIGURE_TOTP"]
|
||||||
keycloak_admin.update_user(user_id, {"requiredActions": new_actions})
|
keycloak_admin.update_user_safe(user_id, {"requiredActions": new_actions})
|
||||||
if isinstance(attrs, dict):
|
if isinstance(attrs, dict):
|
||||||
existing = _extract_attr(attrs, MAILU_EMAIL_ATTR)
|
existing = _extract_attr(attrs, MAILU_EMAIL_ATTR)
|
||||||
if existing:
|
if existing:
|
||||||
|
|||||||
@ -14,6 +14,36 @@ class KeycloakAdminClient:
|
|||||||
self._expires_at: float = 0.0
|
self._expires_at: float = 0.0
|
||||||
self._group_id_cache: dict[str, str] = {}
|
self._group_id_cache: dict[str, str] = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _safe_update_payload(full: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
payload: dict[str, Any] = {}
|
||||||
|
username = full.get("username")
|
||||||
|
if isinstance(username, str):
|
||||||
|
payload["username"] = username
|
||||||
|
enabled = full.get("enabled")
|
||||||
|
if isinstance(enabled, bool):
|
||||||
|
payload["enabled"] = enabled
|
||||||
|
email = full.get("email")
|
||||||
|
if isinstance(email, str):
|
||||||
|
payload["email"] = email
|
||||||
|
email_verified = full.get("emailVerified")
|
||||||
|
if isinstance(email_verified, bool):
|
||||||
|
payload["emailVerified"] = email_verified
|
||||||
|
first_name = full.get("firstName")
|
||||||
|
if isinstance(first_name, str):
|
||||||
|
payload["firstName"] = first_name
|
||||||
|
last_name = full.get("lastName")
|
||||||
|
if isinstance(last_name, str):
|
||||||
|
payload["lastName"] = last_name
|
||||||
|
|
||||||
|
actions = full.get("requiredActions")
|
||||||
|
if isinstance(actions, list):
|
||||||
|
payload["requiredActions"] = [a for a in actions if isinstance(a, str)]
|
||||||
|
|
||||||
|
attrs = full.get("attributes")
|
||||||
|
payload["attributes"] = attrs if isinstance(attrs, dict) else {}
|
||||||
|
return payload
|
||||||
|
|
||||||
def ready(self) -> bool:
|
def ready(self) -> bool:
|
||||||
return bool(settings.keycloak_admin_client_id and settings.keycloak_admin_client_secret)
|
return bool(settings.keycloak_admin_client_id and settings.keycloak_admin_client_secret)
|
||||||
|
|
||||||
@ -101,6 +131,21 @@ class KeycloakAdminClient:
|
|||||||
resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
|
resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
def update_user_safe(self, user_id: str, payload: dict[str, Any]) -> None:
|
||||||
|
full = self.get_user(user_id)
|
||||||
|
merged = self._safe_update_payload(full)
|
||||||
|
for key, value in payload.items():
|
||||||
|
if key == "attributes":
|
||||||
|
attrs = merged.get("attributes")
|
||||||
|
if not isinstance(attrs, dict):
|
||||||
|
attrs = {}
|
||||||
|
if isinstance(value, dict):
|
||||||
|
attrs.update(value)
|
||||||
|
merged["attributes"] = attrs
|
||||||
|
continue
|
||||||
|
merged[key] = value
|
||||||
|
self.update_user(user_id, merged)
|
||||||
|
|
||||||
def create_user(self, payload: dict[str, Any]) -> str:
|
def create_user(self, payload: dict[str, Any]) -> str:
|
||||||
url = f"{settings.keycloak_admin_url}/admin/realms/{settings.keycloak_realm}/users"
|
url = f"{settings.keycloak_admin_url}/admin/realms/{settings.keycloak_realm}/users"
|
||||||
with httpx.Client(timeout=10.0) as client:
|
with httpx.Client(timeout=10.0) as client:
|
||||||
@ -130,11 +175,13 @@ class KeycloakAdminClient:
|
|||||||
raise RuntimeError("user id missing")
|
raise RuntimeError("user id missing")
|
||||||
|
|
||||||
full = self.get_user(user_id)
|
full = self.get_user(user_id)
|
||||||
attrs = full.get("attributes") or {}
|
payload = self._safe_update_payload(full)
|
||||||
|
attrs = payload.get("attributes")
|
||||||
if not isinstance(attrs, dict):
|
if not isinstance(attrs, dict):
|
||||||
attrs = {}
|
attrs = {}
|
||||||
attrs[key] = [value]
|
attrs[key] = [value]
|
||||||
self.update_user(user_id, {"attributes": attrs})
|
payload["attributes"] = attrs
|
||||||
|
self.update_user(user_id, payload)
|
||||||
|
|
||||||
def get_group_id(self, group_name: str) -> str | None:
|
def get_group_id(self, group_name: str) -> str | None:
|
||||||
cached = self._group_id_cache.get(group_name)
|
cached = self._group_id_cache.get(group_name)
|
||||||
|
|||||||
79
tests/test_keycloak_admin.py
Normal file
79
tests/test_keycloak_admin.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ariadne.services.keycloak_admin import KeycloakAdminClient
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_user_attribute_preserves_profile(monkeypatch) -> None:
|
||||||
|
client = KeycloakAdminClient()
|
||||||
|
captured: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def fake_find_user(username: str) -> dict[str, Any]:
|
||||||
|
return {"id": "user-123"}
|
||||||
|
|
||||||
|
def fake_get_user(user_id: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": user_id,
|
||||||
|
"username": "alice",
|
||||||
|
"email": "alice@bstein.dev",
|
||||||
|
"emailVerified": True,
|
||||||
|
"enabled": True,
|
||||||
|
"firstName": "Alice",
|
||||||
|
"lastName": "Smith",
|
||||||
|
"requiredActions": ["UPDATE_PASSWORD", 123],
|
||||||
|
"attributes": {"existing": ["value"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
def fake_update_user(user_id: str, payload: dict[str, Any]) -> None:
|
||||||
|
captured["user_id"] = user_id
|
||||||
|
captured["payload"] = payload
|
||||||
|
|
||||||
|
monkeypatch.setattr(client, "find_user", fake_find_user)
|
||||||
|
monkeypatch.setattr(client, "get_user", fake_get_user)
|
||||||
|
monkeypatch.setattr(client, "update_user", fake_update_user)
|
||||||
|
|
||||||
|
client.set_user_attribute("alice", "mailu_app_password", "secret")
|
||||||
|
|
||||||
|
payload = captured.get("payload") or {}
|
||||||
|
assert payload.get("username") == "alice"
|
||||||
|
assert payload.get("email") == "alice@bstein.dev"
|
||||||
|
assert payload.get("emailVerified") is True
|
||||||
|
assert payload.get("enabled") is True
|
||||||
|
assert payload.get("firstName") == "Alice"
|
||||||
|
assert payload.get("lastName") == "Smith"
|
||||||
|
assert payload.get("requiredActions") == ["UPDATE_PASSWORD"]
|
||||||
|
assert payload.get("attributes") == {
|
||||||
|
"existing": ["value"],
|
||||||
|
"mailu_app_password": ["secret"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_user_safe_merges_payload(monkeypatch) -> None:
|
||||||
|
client = KeycloakAdminClient()
|
||||||
|
captured: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def fake_get_user(user_id: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": user_id,
|
||||||
|
"username": "alice",
|
||||||
|
"enabled": True,
|
||||||
|
"attributes": {"existing": ["value"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
def fake_update_user(user_id: str, payload: dict[str, Any]) -> None:
|
||||||
|
captured["user_id"] = user_id
|
||||||
|
captured["payload"] = payload
|
||||||
|
|
||||||
|
monkeypatch.setattr(client, "get_user", fake_get_user)
|
||||||
|
monkeypatch.setattr(client, "update_user", fake_update_user)
|
||||||
|
|
||||||
|
client.update_user_safe(
|
||||||
|
"user-123",
|
||||||
|
{"attributes": {"new": ["item"]}, "requiredActions": ["UPDATE_PASSWORD"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = captured.get("payload") or {}
|
||||||
|
assert payload.get("username") == "alice"
|
||||||
|
assert payload.get("attributes") == {"existing": ["value"], "new": ["item"]}
|
||||||
|
assert payload.get("requiredActions") == ["UPDATE_PASSWORD"]
|
||||||
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from ariadne.scheduler.cron import CronScheduler, CronTask
|
from ariadne.scheduler.cron import CronScheduler, CronTask
|
||||||
|
|
||||||
@ -25,7 +25,22 @@ def test_execute_task_records_failure() -> None:
|
|||||||
raise RuntimeError("boom")
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
task = CronTask(name="test", cron_expr="*/5 * * * *", runner=runner)
|
task = CronTask(name="test", cron_expr="*/5 * * * *", runner=runner)
|
||||||
scheduler._next_run["test"] = datetime.utcnow()
|
scheduler._next_run["test"] = datetime.now(timezone.utc)
|
||||||
|
scheduler._execute_task(task)
|
||||||
|
|
||||||
|
assert storage.task_runs
|
||||||
|
assert storage.schedule_states
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_task_records_success() -> None:
|
||||||
|
storage = DummyStorage()
|
||||||
|
scheduler = CronScheduler(storage, tick_sec=0.1)
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
return None
|
||||||
|
|
||||||
|
task = CronTask(name="ok-task", cron_expr="*/5 * * * *", runner=runner)
|
||||||
|
scheduler._next_run["ok-task"] = datetime.now(timezone.utc)
|
||||||
scheduler._execute_task(task)
|
scheduler._execute_task(task)
|
||||||
|
|
||||||
assert storage.task_runs
|
assert storage.task_runs
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user