diff --git a/ariadne/manager/provisioning.py b/ariadne/manager/provisioning.py index 47bff9e..6ef9111 100644 --- a/ariadne/manager/provisioning.py +++ b/ariadne/manager/provisioning.py @@ -218,7 +218,7 @@ class ProvisioningManager: actions = full.get("requiredActions") if isinstance(actions, list) and "CONFIGURE_TOTP" in actions: new_actions = [a for a in actions if a != "CONFIGURE_TOTP"] - keycloak_admin.update_user(user_id, {"requiredActions": new_actions}) + keycloak_admin.update_user_safe(user_id, {"requiredActions": new_actions}) if isinstance(attrs, dict): existing = _extract_attr(attrs, MAILU_EMAIL_ATTR) if existing: diff --git a/ariadne/services/keycloak_admin.py b/ariadne/services/keycloak_admin.py index f60f397..f02d8a5 100644 --- a/ariadne/services/keycloak_admin.py +++ b/ariadne/services/keycloak_admin.py @@ -14,6 +14,36 @@ class KeycloakAdminClient: self._expires_at: float = 0.0 self._group_id_cache: dict[str, str] = {} + @staticmethod + def _safe_update_payload(full: dict[str, Any]) -> dict[str, Any]: + payload: dict[str, Any] = {} + username = full.get("username") + if isinstance(username, str): + payload["username"] = username + enabled = full.get("enabled") + if isinstance(enabled, bool): + payload["enabled"] = enabled + email = full.get("email") + if isinstance(email, str): + payload["email"] = email + email_verified = full.get("emailVerified") + if isinstance(email_verified, bool): + payload["emailVerified"] = email_verified + first_name = full.get("firstName") + if isinstance(first_name, str): + payload["firstName"] = first_name + last_name = full.get("lastName") + if isinstance(last_name, str): + payload["lastName"] = last_name + + actions = full.get("requiredActions") + if isinstance(actions, list): + payload["requiredActions"] = [a for a in actions if isinstance(a, str)] + + attrs = full.get("attributes") + payload["attributes"] = attrs if isinstance(attrs, dict) else {} + return payload + def ready(self) -> bool: return bool(settings.keycloak_admin_client_id and settings.keycloak_admin_client_secret) @@ -101,6 +131,21 @@ class KeycloakAdminClient: resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload) resp.raise_for_status() + def update_user_safe(self, user_id: str, payload: dict[str, Any]) -> None: + full = self.get_user(user_id) + merged = self._safe_update_payload(full) + for key, value in payload.items(): + if key == "attributes": + attrs = merged.get("attributes") + if not isinstance(attrs, dict): + attrs = {} + if isinstance(value, dict): + attrs.update(value) + merged["attributes"] = attrs + continue + merged[key] = value + self.update_user(user_id, merged) + def create_user(self, payload: dict[str, Any]) -> str: url = f"{settings.keycloak_admin_url}/admin/realms/{settings.keycloak_realm}/users" with httpx.Client(timeout=10.0) as client: @@ -130,11 +175,13 @@ class KeycloakAdminClient: raise RuntimeError("user id missing") full = self.get_user(user_id) - attrs = full.get("attributes") or {} + payload = self._safe_update_payload(full) + attrs = payload.get("attributes") if not isinstance(attrs, dict): attrs = {} attrs[key] = [value] - self.update_user(user_id, {"attributes": attrs}) + payload["attributes"] = attrs + self.update_user(user_id, payload) def get_group_id(self, group_name: str) -> str | None: cached = self._group_id_cache.get(group_name) diff --git a/tests/test_keycloak_admin.py b/tests/test_keycloak_admin.py new file mode 100644 index 0000000..4068c4a --- /dev/null +++ b/tests/test_keycloak_admin.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Any + +from ariadne.services.keycloak_admin import KeycloakAdminClient + + +def test_set_user_attribute_preserves_profile(monkeypatch) -> None: + client = KeycloakAdminClient() + captured: dict[str, Any] = {} + + def fake_find_user(username: str) -> dict[str, Any]: + return {"id": "user-123"} + + def fake_get_user(user_id: str) -> dict[str, Any]: + return { + "id": user_id, + "username": "alice", + "email": "alice@bstein.dev", + "emailVerified": True, + "enabled": True, + "firstName": "Alice", + "lastName": "Smith", + "requiredActions": ["UPDATE_PASSWORD", 123], + "attributes": {"existing": ["value"]}, + } + + def fake_update_user(user_id: str, payload: dict[str, Any]) -> None: + captured["user_id"] = user_id + captured["payload"] = payload + + monkeypatch.setattr(client, "find_user", fake_find_user) + monkeypatch.setattr(client, "get_user", fake_get_user) + monkeypatch.setattr(client, "update_user", fake_update_user) + + client.set_user_attribute("alice", "mailu_app_password", "secret") + + payload = captured.get("payload") or {} + assert payload.get("username") == "alice" + assert payload.get("email") == "alice@bstein.dev" + assert payload.get("emailVerified") is True + assert payload.get("enabled") is True + assert payload.get("firstName") == "Alice" + assert payload.get("lastName") == "Smith" + assert payload.get("requiredActions") == ["UPDATE_PASSWORD"] + assert payload.get("attributes") == { + "existing": ["value"], + "mailu_app_password": ["secret"], + } + + +def test_update_user_safe_merges_payload(monkeypatch) -> None: + client = KeycloakAdminClient() + captured: dict[str, Any] = {} + + def fake_get_user(user_id: str) -> dict[str, Any]: + return { + "id": user_id, + "username": "alice", + "enabled": True, + "attributes": {"existing": ["value"]}, + } + + def fake_update_user(user_id: str, payload: dict[str, Any]) -> None: + captured["user_id"] = user_id + captured["payload"] = payload + + monkeypatch.setattr(client, "get_user", fake_get_user) + monkeypatch.setattr(client, "update_user", fake_update_user) + + client.update_user_safe( + "user-123", + {"attributes": {"new": ["item"]}, "requiredActions": ["UPDATE_PASSWORD"]}, + ) + + payload = captured.get("payload") or {} + assert payload.get("username") == "alice" + assert payload.get("attributes") == {"existing": ["value"], "new": ["item"]} + assert payload.get("requiredActions") == ["UPDATE_PASSWORD"] diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 9e0795c..b39e549 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timezone from ariadne.scheduler.cron import CronScheduler, CronTask @@ -25,7 +25,22 @@ def test_execute_task_records_failure() -> None: raise RuntimeError("boom") task = CronTask(name="test", cron_expr="*/5 * * * *", runner=runner) - scheduler._next_run["test"] = datetime.utcnow() + scheduler._next_run["test"] = datetime.now(timezone.utc) + scheduler._execute_task(task) + + assert storage.task_runs + assert storage.schedule_states + + +def test_execute_task_records_success() -> None: + storage = DummyStorage() + scheduler = CronScheduler(storage, tick_sec=0.1) + + def runner(): + return None + + task = CronTask(name="ok-task", cron_expr="*/5 * * * *", runner=runner) + scheduler._next_run["ok-task"] = datetime.now(timezone.utc) scheduler._execute_task(task) assert storage.task_runs