from __future__ import annotations import time from functools import wraps from typing import Any from urllib.parse import quote from flask import g, jsonify, request import httpx import jwt from jwt import PyJWKClient from . import settings class KeycloakOIDC: """Verify user-facing Keycloak tokens for portal requests.""" def __init__(self) -> None: self._jwk_client: PyJWKClient | None = None def _client(self) -> PyJWKClient: if self._jwk_client is None: self._jwk_client = PyJWKClient(settings.KEYCLOAK_JWKS_URL) return self._jwk_client def verify(self, token: str) -> dict[str, Any]: """Validate a bearer token and return decoded claims. WHY: the portal trusts Keycloak groups and usernames only after issuer and client ownership are checked locally. """ if not settings.KEYCLOAK_ENABLED: raise ValueError("keycloak not enabled") signing_key = self._client().get_signing_key_from_jwt(token).key claims = jwt.decode( token, signing_key, algorithms=["RS256"], options={"verify_aud": False}, issuer=settings.KEYCLOAK_ISSUER, ) azp = claims.get("azp") aud = claims.get("aud") aud_list: list[str] = [] if isinstance(aud, str): aud_list = [aud] elif isinstance(aud, list): aud_list = [a for a in aud if isinstance(a, str)] if azp != settings.KEYCLOAK_CLIENT_ID and settings.KEYCLOAK_CLIENT_ID not in aud_list: raise ValueError("token not issued for this client") return claims class KeycloakAdminClient: """Perform service-account Keycloak admin operations for provisioning.""" def __init__(self) -> None: self._token: str = "" self._expires_at: float = 0.0 self._group_id_cache: dict[str, str] = {} @staticmethod def _safe_update_payload(full: dict[str, Any]) -> dict[str, Any]: """Extract mutable fields from a full Keycloak user document. WHY: partial updates can accidentally clear profile or attribute data, so callers merge desired changes into a safe copy first. """ payload: dict[str, Any] = {} username = full.get("username") if isinstance(username, str): payload["username"] = username enabled = full.get("enabled") if isinstance(enabled, bool): payload["enabled"] = enabled email = full.get("email") if isinstance(email, str): payload["email"] = email email_verified = full.get("emailVerified") if isinstance(email_verified, bool): payload["emailVerified"] = email_verified first_name = full.get("firstName") if isinstance(first_name, str): payload["firstName"] = first_name last_name = full.get("lastName") if isinstance(last_name, str): payload["lastName"] = last_name actions = full.get("requiredActions") if isinstance(actions, list): payload["requiredActions"] = [a for a in actions if isinstance(a, str)] attrs = full.get("attributes") payload["attributes"] = attrs if isinstance(attrs, dict) else {} return payload def ready(self) -> bool: """Return whether admin-client credentials are available.""" return bool(settings.KEYCLOAK_ADMIN_CLIENT_ID and settings.KEYCLOAK_ADMIN_CLIENT_SECRET) def _get_token(self) -> str: """Return a cached service-account token, refreshing before expiry.""" if not self.ready(): raise RuntimeError("keycloak admin client not configured") now = time.time() if self._token and now < self._expires_at - 30: return self._token token_url = ( f"{settings.KEYCLOAK_ADMIN_URL}/realms/{settings.KEYCLOAK_ADMIN_REALM}/protocol/openid-connect/token" ) data = { "grant_type": "client_credentials", "client_id": settings.KEYCLOAK_ADMIN_CLIENT_ID, "client_secret": settings.KEYCLOAK_ADMIN_CLIENT_SECRET, } with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.post(token_url, data=data) resp.raise_for_status() payload = resp.json() token = payload.get("access_token") or "" if not token: raise RuntimeError("no access_token in response") expires_in = int(payload.get("expires_in") or 60) self._token = token self._expires_at = now + expires_in return token def _headers(self) -> dict[str, str]: return {"Authorization": f"Bearer {self._get_token()}"} def headers(self) -> dict[str, str]: """Return authorization headers for callers that need raw admin access.""" return self._headers() def find_user(self, username: str) -> dict[str, Any] | None: """Look up one Keycloak user by exact username.""" url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users" # Keycloak 26.x in our environment intermittently 400s on filtered user queries unless `max` is set. # Use `max=1` and exact username match to keep admin calls reliable for portal provisioning. params = {"username": username, "exact": "true", "max": "1"} with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, params=params, headers=self._headers()) resp.raise_for_status() users = resp.json() if not isinstance(users, list) or not users: return None user = users[0] return user if isinstance(user, dict) else None def find_user_by_email(self, email: str) -> dict[str, Any] | None: """Look up one Keycloak user by exact email address.""" email = (email or "").strip() if not email: return None url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users" # Match the portal's username query behavior: set a low `max` and post-filter for exact matches. params = {"email": email, "exact": "true", "max": "2"} email_norm = email.lower() with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, params=params, headers=self._headers()) resp.raise_for_status() users = resp.json() if not isinstance(users, list) or not users: return None for user in users: if not isinstance(user, dict): continue candidate = user.get("email") if isinstance(candidate, str) and candidate.strip().lower() == email_norm: return user return None def get_user(self, user_id: str) -> dict[str, Any]: """Fetch a full Keycloak user representation by ID.""" url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users/{quote(user_id, safe='')}" with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, headers=self._headers()) resp.raise_for_status() data = resp.json() if not isinstance(data, dict): raise RuntimeError("unexpected user payload") return data def update_user(self, user_id: str, payload: dict[str, Any]) -> None: """Replace a Keycloak user document with the supplied payload.""" url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users/{quote(user_id, safe='')}" with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload) resp.raise_for_status() def update_user_safe(self, user_id: str, payload: dict[str, Any]) -> None: """Merge selected user changes into the current Keycloak document.""" full = self.get_user(user_id) merged = self._safe_update_payload(full) for key, value in payload.items(): if key == "attributes": attrs = merged.get("attributes") if not isinstance(attrs, dict): attrs = {} if isinstance(value, dict): attrs.update(value) merged["attributes"] = attrs continue merged[key] = value self.update_user(user_id, merged) def create_user(self, payload: dict[str, Any]) -> str: """Create a Keycloak user and return the generated user ID.""" url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users" with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.post(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload) resp.raise_for_status() location = resp.headers.get("Location") or "" if location: return location.rstrip("/").split("/")[-1] raise RuntimeError("failed to determine created user id") def reset_password(self, user_id: str, password: str, temporary: bool = True) -> None: """Set a Keycloak password credential for a user.""" url = ( f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}" f"/users/{quote(user_id, safe='')}/reset-password" ) payload = {"type": "password", "value": password, "temporary": bool(temporary)} with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload) resp.raise_for_status() def set_user_attribute(self, username: str, key: str, value: str) -> None: """Set one single-value Keycloak user attribute by username.""" user = self.find_user(username) if not user: raise RuntimeError("user not found") user_id = user.get("id") or "" if not user_id: raise RuntimeError("user id missing") full = self.get_user(user_id) payload = self._safe_update_payload(full) attrs = payload.get("attributes") if not isinstance(attrs, dict): attrs = {} attrs[key] = [value] payload["attributes"] = attrs # Keep profile fields intact so required actions don't re-trigger unexpectedly. self.update_user(user_id, payload) def get_group_id(self, group_name: str) -> str | None: """Resolve and cache the Keycloak ID for a group name.""" cached = self._group_id_cache.get(group_name) if cached: return cached url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/groups" params = {"search": group_name} with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, params=params, headers=self._headers()) resp.raise_for_status() items = resp.json() if not isinstance(items, list): return None for item in items: if not isinstance(item, dict): continue if item.get("name") == group_name and item.get("id"): gid = str(item["id"]) self._group_id_cache[group_name] = gid return gid return None def list_group_names(self) -> list[str]: """Return all Keycloak group names visible to the admin client.""" url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/groups" with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, headers=self._headers()) resp.raise_for_status() items = resp.json() if not isinstance(items, list): return [] names: set[str] = set() def walk(groups: list[Any]) -> None: """Visit nested Keycloak group records and collect names.""" for group in groups: if not isinstance(group, dict): continue name = group.get("name") if isinstance(name, str) and name: names.add(name) sub = group.get("subGroups") if isinstance(sub, list) and sub: walk(sub) walk(items) return sorted(names) def list_user_groups(self, user_id: str) -> list[str]: """Return group names assigned to one Keycloak user.""" url = ( f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}" f"/users/{quote(user_id, safe='')}/groups" ) with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, headers=self._headers()) resp.raise_for_status() items = resp.json() if not isinstance(items, list): return [] names: list[str] = [] for item in items: if not isinstance(item, dict): continue name = item.get("name") if isinstance(name, str) and name: names.append(name.lstrip("/")) return names def add_user_to_group(self, user_id: str, group_id: str) -> None: """Attach one Keycloak user to one group by ID.""" url = ( f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}" f"/users/{quote(user_id, safe='')}/groups/{quote(group_id, safe='')}" ) with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.put(url, headers=self._headers()) resp.raise_for_status() def execute_actions_email(self, user_id: str, actions: list[str], redirect_uri: str) -> None: """Ask Keycloak to email required-account-action links to a user.""" url = ( f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}" f"/users/{quote(user_id, safe='')}/execute-actions-email" ) params = {"client_id": settings.KEYCLOAK_CLIENT_ID, "redirect_uri": redirect_uri} with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.put( url, params=params, headers={**self._headers(), "Content-Type": "application/json"}, json=actions, ) resp.raise_for_status() def get_user_credentials(self, user_id: str) -> list[dict[str, Any]]: """Return credential metadata for one Keycloak user.""" url = ( f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}" f"/users/{quote(user_id, safe='')}/credentials" ) with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, headers=self._headers()) resp.raise_for_status() data = resp.json() if not isinstance(data, list): return [] return [item for item in data if isinstance(item, dict)] _OIDC: KeycloakOIDC | None = None _ADMIN: KeycloakAdminClient | None = None def oidc_client() -> KeycloakOIDC: """Return the singleton OIDC verifier.""" global _OIDC if _OIDC is None: _OIDC = KeycloakOIDC() return _OIDC def admin_client() -> KeycloakAdminClient: """Return the singleton Keycloak admin client.""" global _ADMIN if _ADMIN is None: _ADMIN = KeycloakAdminClient() return _ADMIN def _normalize_groups(groups: Any) -> list[str]: if not isinstance(groups, list): return [] cleaned: list[str] = [] for gname in groups: if not isinstance(gname, str): continue cleaned.append(gname.lstrip("/")) return [gname for gname in cleaned if gname] def _extract_bearer_token() -> str | None: """Extract a bearer token from the current Flask request.""" header = request.headers.get("Authorization", "") if not header: return None parts = header.split(None, 1) if len(parts) != 2: return None scheme, token = parts[0].lower(), parts[1].strip() if scheme != "bearer" or not token: return None return token def require_auth(fn): """Decorate a Flask route so it requires a valid Keycloak bearer token.""" @wraps(fn) def wrapper(*args, **kwargs): """Validate the request token and place normalized claims on Flask globals.""" token = _extract_bearer_token() if not token: return jsonify({"error": "missing bearer token"}), 401 try: claims = oidc_client().verify(token) except Exception: return jsonify({"error": "invalid token"}), 401 g.keycloak_claims = claims g.keycloak_username = claims.get("preferred_username") or "" g.keycloak_email = claims.get("email") or "" g.keycloak_groups = _normalize_groups(claims.get("groups")) return fn(*args, **kwargs) return wrapper def require_portal_admin() -> tuple[bool, Any]: """Return whether the authenticated user can use portal admin actions.""" if not settings.KEYCLOAK_ENABLED: return False, (jsonify({"error": "keycloak not enabled"}), 503) username = getattr(g, "keycloak_username", "") or "" groups = set(getattr(g, "keycloak_groups", []) or []) if username and username in settings.PORTAL_ADMIN_USERS: return True, None if settings.PORTAL_ADMIN_GROUPS and groups.intersection(settings.PORTAL_ADMIN_GROUPS): return True, None return False, (jsonify({"error": "forbidden"}), 403) def require_account_access() -> tuple[bool, Any]: """Return whether the authenticated user can use self-service account pages.""" if not settings.KEYCLOAK_ENABLED: return False, (jsonify({"error": "keycloak not enabled"}), 503) if not settings.ACCOUNT_ALLOWED_GROUPS: return True, None groups = set(getattr(g, "keycloak_groups", []) or []) if not groups: return True, None if groups.intersection(settings.ACCOUNT_ALLOWED_GROUPS): return True, None return False, (jsonify({"error": "forbidden"}), 403)