from __future__ import annotations import time from functools import wraps from typing import Any from urllib.parse import quote from flask import g, jsonify, request import httpx import jwt from jwt import PyJWKClient from . import settings class KeycloakOIDC: def __init__(self) -> None: self._jwk_client: PyJWKClient | None = None def _client(self) -> PyJWKClient: if self._jwk_client is None: self._jwk_client = PyJWKClient(settings.KEYCLOAK_JWKS_URL) return self._jwk_client def verify(self, token: str) -> dict[str, Any]: if not settings.KEYCLOAK_ENABLED: raise ValueError("keycloak not enabled") signing_key = self._client().get_signing_key_from_jwt(token).key claims = jwt.decode( token, signing_key, algorithms=["RS256"], options={"verify_aud": False}, issuer=settings.KEYCLOAK_ISSUER, ) azp = claims.get("azp") aud = claims.get("aud") aud_list: list[str] = [] if isinstance(aud, str): aud_list = [aud] elif isinstance(aud, list): aud_list = [a for a in aud if isinstance(a, str)] if azp != settings.KEYCLOAK_CLIENT_ID and settings.KEYCLOAK_CLIENT_ID not in aud_list: raise ValueError("token not issued for this client") return claims class KeycloakAdminClient: def __init__(self) -> None: self._token: str = "" self._expires_at: float = 0.0 self._group_id_cache: dict[str, str] = {} @staticmethod def _safe_update_payload(full: dict[str, Any]) -> dict[str, Any]: payload: dict[str, Any] = {} username = full.get("username") if isinstance(username, str): payload["username"] = username enabled = full.get("enabled") if isinstance(enabled, bool): payload["enabled"] = enabled email = full.get("email") if isinstance(email, str): payload["email"] = email email_verified = full.get("emailVerified") if isinstance(email_verified, bool): payload["emailVerified"] = email_verified first_name = full.get("firstName") if isinstance(first_name, str): payload["firstName"] = first_name last_name = full.get("lastName") if isinstance(last_name, str): payload["lastName"] = last_name actions = full.get("requiredActions") if isinstance(actions, list): payload["requiredActions"] = [a for a in actions if isinstance(a, str)] attrs = full.get("attributes") payload["attributes"] = attrs if isinstance(attrs, dict) else {} return payload def ready(self) -> bool: return bool(settings.KEYCLOAK_ADMIN_CLIENT_ID and settings.KEYCLOAK_ADMIN_CLIENT_SECRET) def _get_token(self) -> str: if not self.ready(): raise RuntimeError("keycloak admin client not configured") now = time.time() if self._token and now < self._expires_at - 30: return self._token token_url = ( f"{settings.KEYCLOAK_ADMIN_URL}/realms/{settings.KEYCLOAK_ADMIN_REALM}/protocol/openid-connect/token" ) data = { "grant_type": "client_credentials", "client_id": settings.KEYCLOAK_ADMIN_CLIENT_ID, "client_secret": settings.KEYCLOAK_ADMIN_CLIENT_SECRET, } with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.post(token_url, data=data) resp.raise_for_status() payload = resp.json() token = payload.get("access_token") or "" if not token: raise RuntimeError("no access_token in response") expires_in = int(payload.get("expires_in") or 60) self._token = token self._expires_at = now + expires_in return token def _headers(self) -> dict[str, str]: return {"Authorization": f"Bearer {self._get_token()}"} def headers(self) -> dict[str, str]: return self._headers() def find_user(self, username: str) -> dict[str, Any] | None: url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users" # Keycloak 26.x in our environment intermittently 400s on filtered user queries unless `max` is set. # Use `max=1` and exact username match to keep admin calls reliable for portal provisioning. params = {"username": username, "exact": "true", "max": "1"} with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, params=params, headers=self._headers()) resp.raise_for_status() users = resp.json() if not isinstance(users, list) or not users: return None user = users[0] return user if isinstance(user, dict) else None def find_user_by_email(self, email: str) -> dict[str, Any] | None: email = (email or "").strip() if not email: return None url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users" # Match the portal's username query behavior: set a low `max` and post-filter for exact matches. params = {"email": email, "exact": "true", "max": "2"} email_norm = email.lower() with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, params=params, headers=self._headers()) resp.raise_for_status() users = resp.json() if not isinstance(users, list) or not users: return None for user in users: if not isinstance(user, dict): continue candidate = user.get("email") if isinstance(candidate, str) and candidate.strip().lower() == email_norm: return user return None def get_user(self, user_id: str) -> dict[str, Any]: url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users/{quote(user_id, safe='')}" with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, headers=self._headers()) resp.raise_for_status() data = resp.json() if not isinstance(data, dict): raise RuntimeError("unexpected user payload") return data def update_user(self, user_id: str, payload: dict[str, Any]) -> None: url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users/{quote(user_id, safe='')}" with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload) resp.raise_for_status() def update_user_safe(self, user_id: str, payload: dict[str, Any]) -> None: full = self.get_user(user_id) merged = self._safe_update_payload(full) for key, value in payload.items(): if key == "attributes": attrs = merged.get("attributes") if not isinstance(attrs, dict): attrs = {} if isinstance(value, dict): attrs.update(value) merged["attributes"] = attrs continue merged[key] = value self.update_user(user_id, merged) def create_user(self, payload: dict[str, Any]) -> str: url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users" with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.post(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload) resp.raise_for_status() location = resp.headers.get("Location") or "" if location: return location.rstrip("/").split("/")[-1] raise RuntimeError("failed to determine created user id") def reset_password(self, user_id: str, password: str, temporary: bool = True) -> None: url = ( f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}" f"/users/{quote(user_id, safe='')}/reset-password" ) payload = {"type": "password", "value": password, "temporary": bool(temporary)} with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload) resp.raise_for_status() def set_user_attribute(self, username: str, key: str, value: str) -> None: user = self.find_user(username) if not user: raise RuntimeError("user not found") user_id = user.get("id") or "" if not user_id: raise RuntimeError("user id missing") full = self.get_user(user_id) payload = self._safe_update_payload(full) attrs = payload.get("attributes") if not isinstance(attrs, dict): attrs = {} attrs[key] = [value] payload["attributes"] = attrs # Keep profile fields intact so required actions don't re-trigger unexpectedly. self.update_user(user_id, payload) def get_group_id(self, group_name: str) -> str | None: cached = self._group_id_cache.get(group_name) if cached: return cached url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/groups" params = {"search": group_name} with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, params=params, headers=self._headers()) resp.raise_for_status() items = resp.json() if not isinstance(items, list): return None for item in items: if not isinstance(item, dict): continue if item.get("name") == group_name and item.get("id"): gid = str(item["id"]) self._group_id_cache[group_name] = gid return gid return None def list_group_names(self) -> list[str]: url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/groups" with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, headers=self._headers()) resp.raise_for_status() items = resp.json() if not isinstance(items, list): return [] names: set[str] = set() def walk(groups: list[Any]) -> None: for group in groups: if not isinstance(group, dict): continue name = group.get("name") if isinstance(name, str) and name: names.add(name) sub = group.get("subGroups") if isinstance(sub, list) and sub: walk(sub) walk(items) return sorted(names) def add_user_to_group(self, user_id: str, group_id: str) -> None: url = ( f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}" f"/users/{quote(user_id, safe='')}/groups/{quote(group_id, safe='')}" ) with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.put(url, headers=self._headers()) resp.raise_for_status() def execute_actions_email(self, user_id: str, actions: list[str], redirect_uri: str) -> None: url = ( f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}" f"/users/{quote(user_id, safe='')}/execute-actions-email" ) params = {"client_id": settings.KEYCLOAK_CLIENT_ID, "redirect_uri": redirect_uri} with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.put( url, params=params, headers={**self._headers(), "Content-Type": "application/json"}, json=actions, ) resp.raise_for_status() def get_user_credentials(self, user_id: str) -> list[dict[str, Any]]: url = ( f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}" f"/users/{quote(user_id, safe='')}/credentials" ) with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client: resp = client.get(url, headers=self._headers()) resp.raise_for_status() data = resp.json() if not isinstance(data, list): return [] return [item for item in data if isinstance(item, dict)] _OIDC: KeycloakOIDC | None = None _ADMIN: KeycloakAdminClient | None = None def oidc_client() -> KeycloakOIDC: global _OIDC if _OIDC is None: _OIDC = KeycloakOIDC() return _OIDC def admin_client() -> KeycloakAdminClient: global _ADMIN if _ADMIN is None: _ADMIN = KeycloakAdminClient() return _ADMIN def _normalize_groups(groups: Any) -> list[str]: if not isinstance(groups, list): return [] cleaned: list[str] = [] for gname in groups: if not isinstance(gname, str): continue cleaned.append(gname.lstrip("/")) return [gname for gname in cleaned if gname] def _extract_bearer_token() -> str | None: header = request.headers.get("Authorization", "") if not header: return None parts = header.split(None, 1) if len(parts) != 2: return None scheme, token = parts[0].lower(), parts[1].strip() if scheme != "bearer" or not token: return None return token def require_auth(fn): @wraps(fn) def wrapper(*args, **kwargs): token = _extract_bearer_token() if not token: return jsonify({"error": "missing bearer token"}), 401 try: claims = oidc_client().verify(token) except Exception: return jsonify({"error": "invalid token"}), 401 g.keycloak_claims = claims g.keycloak_username = claims.get("preferred_username") or "" g.keycloak_email = claims.get("email") or "" g.keycloak_groups = _normalize_groups(claims.get("groups")) return fn(*args, **kwargs) return wrapper def require_portal_admin() -> tuple[bool, Any]: if not settings.KEYCLOAK_ENABLED: return False, (jsonify({"error": "keycloak not enabled"}), 503) username = getattr(g, "keycloak_username", "") or "" groups = set(getattr(g, "keycloak_groups", []) or []) if username and username in settings.PORTAL_ADMIN_USERS: return True, None if settings.PORTAL_ADMIN_GROUPS and groups.intersection(settings.PORTAL_ADMIN_GROUPS): return True, None return False, (jsonify({"error": "forbidden"}), 403) def require_account_access() -> tuple[bool, Any]: if not settings.KEYCLOAK_ENABLED: return False, (jsonify({"error": "keycloak not enabled"}), 503) if not settings.ACCOUNT_ALLOWED_GROUPS: return True, None groups = set(getattr(g, "keycloak_groups", []) or []) if groups.intersection(settings.ACCOUNT_ALLOWED_GROUPS): return True, None return False, (jsonify({"error": "forbidden"}), 403)