490 lines
18 KiB
Python

from __future__ import annotations
import time
from functools import wraps
from typing import Any
from urllib.parse import quote
from flask import g, jsonify, request
import httpx
import jwt
from jwt import PyJWKClient
from . import settings
class KeycloakOIDC:
"""Verify user-facing Keycloak tokens for portal requests."""
def __init__(self) -> None:
self._jwk_client: PyJWKClient | None = None
def _client(self) -> PyJWKClient:
if self._jwk_client is None:
self._jwk_client = PyJWKClient(settings.KEYCLOAK_JWKS_URL)
return self._jwk_client
def verify(self, token: str) -> dict[str, Any]:
"""Validate a bearer token and return decoded claims.
WHY: the portal trusts Keycloak groups and usernames only after issuer
and client ownership are checked locally.
"""
if not settings.KEYCLOAK_ENABLED:
raise ValueError("keycloak not enabled")
signing_key = self._client().get_signing_key_from_jwt(token).key
claims = jwt.decode(
token,
signing_key,
algorithms=["RS256"],
options={"verify_aud": False},
issuer=settings.KEYCLOAK_ISSUER,
)
azp = claims.get("azp")
aud = claims.get("aud")
aud_list: list[str] = []
if isinstance(aud, str):
aud_list = [aud]
elif isinstance(aud, list):
aud_list = [a for a in aud if isinstance(a, str)]
if azp != settings.KEYCLOAK_CLIENT_ID and settings.KEYCLOAK_CLIENT_ID not in aud_list:
raise ValueError("token not issued for this client")
return claims
class KeycloakAdminClient:
"""Perform service-account Keycloak admin operations for provisioning."""
def __init__(self) -> None:
self._token: str = ""
self._expires_at: float = 0.0
self._group_id_cache: dict[str, str] = {}
@staticmethod
def _safe_update_payload(full: dict[str, Any]) -> dict[str, Any]:
"""Extract mutable fields from a full Keycloak user document.
WHY: partial updates can accidentally clear profile or attribute data,
so callers merge desired changes into a safe copy first.
"""
payload: dict[str, Any] = {}
username = full.get("username")
if isinstance(username, str):
payload["username"] = username
enabled = full.get("enabled")
if isinstance(enabled, bool):
payload["enabled"] = enabled
email = full.get("email")
if isinstance(email, str):
payload["email"] = email
email_verified = full.get("emailVerified")
if isinstance(email_verified, bool):
payload["emailVerified"] = email_verified
first_name = full.get("firstName")
if isinstance(first_name, str):
payload["firstName"] = first_name
last_name = full.get("lastName")
if isinstance(last_name, str):
payload["lastName"] = last_name
actions = full.get("requiredActions")
if isinstance(actions, list):
payload["requiredActions"] = [a for a in actions if isinstance(a, str)]
attrs = full.get("attributes")
payload["attributes"] = attrs if isinstance(attrs, dict) else {}
return payload
def ready(self) -> bool:
"""Return whether admin-client credentials are available."""
return bool(settings.KEYCLOAK_ADMIN_CLIENT_ID and settings.KEYCLOAK_ADMIN_CLIENT_SECRET)
def _get_token(self) -> str:
"""Return a cached service-account token, refreshing before expiry."""
if not self.ready():
raise RuntimeError("keycloak admin client not configured")
now = time.time()
if self._token and now < self._expires_at - 30:
return self._token
token_url = (
f"{settings.KEYCLOAK_ADMIN_URL}/realms/{settings.KEYCLOAK_ADMIN_REALM}/protocol/openid-connect/token"
)
data = {
"grant_type": "client_credentials",
"client_id": settings.KEYCLOAK_ADMIN_CLIENT_ID,
"client_secret": settings.KEYCLOAK_ADMIN_CLIENT_SECRET,
}
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.post(token_url, data=data)
resp.raise_for_status()
payload = resp.json()
token = payload.get("access_token") or ""
if not token:
raise RuntimeError("no access_token in response")
expires_in = int(payload.get("expires_in") or 60)
self._token = token
self._expires_at = now + expires_in
return token
def _headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._get_token()}"}
def headers(self) -> dict[str, str]:
"""Return authorization headers for callers that need raw admin access."""
return self._headers()
def find_user(self, username: str) -> dict[str, Any] | None:
"""Look up one Keycloak user by exact username."""
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users"
# Keycloak 26.x in our environment intermittently 400s on filtered user queries unless `max` is set.
# Use `max=1` and exact username match to keep admin calls reliable for portal provisioning.
params = {"username": username, "exact": "true", "max": "1"}
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, params=params, headers=self._headers())
resp.raise_for_status()
users = resp.json()
if not isinstance(users, list) or not users:
return None
user = users[0]
return user if isinstance(user, dict) else None
def find_user_by_email(self, email: str) -> dict[str, Any] | None:
"""Look up one Keycloak user by exact email address."""
email = (email or "").strip()
if not email:
return None
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users"
# Match the portal's username query behavior: set a low `max` and post-filter for exact matches.
params = {"email": email, "exact": "true", "max": "2"}
email_norm = email.lower()
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, params=params, headers=self._headers())
resp.raise_for_status()
users = resp.json()
if not isinstance(users, list) or not users:
return None
for user in users:
if not isinstance(user, dict):
continue
candidate = user.get("email")
if isinstance(candidate, str) and candidate.strip().lower() == email_norm:
return user
return None
def get_user(self, user_id: str) -> dict[str, Any]:
"""Fetch a full Keycloak user representation by ID."""
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users/{quote(user_id, safe='')}"
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, headers=self._headers())
resp.raise_for_status()
data = resp.json()
if not isinstance(data, dict):
raise RuntimeError("unexpected user payload")
return data
def update_user(self, user_id: str, payload: dict[str, Any]) -> None:
"""Replace a Keycloak user document with the supplied payload."""
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users/{quote(user_id, safe='')}"
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
resp.raise_for_status()
def update_user_safe(self, user_id: str, payload: dict[str, Any]) -> None:
"""Merge selected user changes into the current Keycloak document."""
full = self.get_user(user_id)
merged = self._safe_update_payload(full)
for key, value in payload.items():
if key == "attributes":
attrs = merged.get("attributes")
if not isinstance(attrs, dict):
attrs = {}
if isinstance(value, dict):
attrs.update(value)
merged["attributes"] = attrs
continue
merged[key] = value
self.update_user(user_id, merged)
def create_user(self, payload: dict[str, Any]) -> str:
"""Create a Keycloak user and return the generated user ID."""
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users"
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.post(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
resp.raise_for_status()
location = resp.headers.get("Location") or ""
if location:
return location.rstrip("/").split("/")[-1]
raise RuntimeError("failed to determine created user id")
def reset_password(self, user_id: str, password: str, temporary: bool = True) -> None:
"""Set a Keycloak password credential for a user."""
url = (
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
f"/users/{quote(user_id, safe='')}/reset-password"
)
payload = {"type": "password", "value": password, "temporary": bool(temporary)}
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
resp.raise_for_status()
def set_user_attribute(self, username: str, key: str, value: str) -> None:
"""Set one single-value Keycloak user attribute by username."""
user = self.find_user(username)
if not user:
raise RuntimeError("user not found")
user_id = user.get("id") or ""
if not user_id:
raise RuntimeError("user id missing")
full = self.get_user(user_id)
payload = self._safe_update_payload(full)
attrs = payload.get("attributes")
if not isinstance(attrs, dict):
attrs = {}
attrs[key] = [value]
payload["attributes"] = attrs
# Keep profile fields intact so required actions don't re-trigger unexpectedly.
self.update_user(user_id, payload)
def get_group_id(self, group_name: str) -> str | None:
"""Resolve and cache the Keycloak ID for a group name."""
cached = self._group_id_cache.get(group_name)
if cached:
return cached
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/groups"
params = {"search": group_name}
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, params=params, headers=self._headers())
resp.raise_for_status()
items = resp.json()
if not isinstance(items, list):
return None
for item in items:
if not isinstance(item, dict):
continue
if item.get("name") == group_name and item.get("id"):
gid = str(item["id"])
self._group_id_cache[group_name] = gid
return gid
return None
def list_group_names(self) -> list[str]:
"""Return all Keycloak group names visible to the admin client."""
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/groups"
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, headers=self._headers())
resp.raise_for_status()
items = resp.json()
if not isinstance(items, list):
return []
names: set[str] = set()
def walk(groups: list[Any]) -> None:
"""Visit nested Keycloak group records and collect names."""
for group in groups:
if not isinstance(group, dict):
continue
name = group.get("name")
if isinstance(name, str) and name:
names.add(name)
sub = group.get("subGroups")
if isinstance(sub, list) and sub:
walk(sub)
walk(items)
return sorted(names)
def list_user_groups(self, user_id: str) -> list[str]:
"""Return group names assigned to one Keycloak user."""
url = (
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
f"/users/{quote(user_id, safe='')}/groups"
)
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, headers=self._headers())
resp.raise_for_status()
items = resp.json()
if not isinstance(items, list):
return []
names: list[str] = []
for item in items:
if not isinstance(item, dict):
continue
name = item.get("name")
if isinstance(name, str) and name:
names.append(name.lstrip("/"))
return names
def add_user_to_group(self, user_id: str, group_id: str) -> None:
"""Attach one Keycloak user to one group by ID."""
url = (
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
f"/users/{quote(user_id, safe='')}/groups/{quote(group_id, safe='')}"
)
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.put(url, headers=self._headers())
resp.raise_for_status()
def execute_actions_email(self, user_id: str, actions: list[str], redirect_uri: str) -> None:
"""Ask Keycloak to email required-account-action links to a user."""
url = (
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
f"/users/{quote(user_id, safe='')}/execute-actions-email"
)
params = {"client_id": settings.KEYCLOAK_CLIENT_ID, "redirect_uri": redirect_uri}
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.put(
url,
params=params,
headers={**self._headers(), "Content-Type": "application/json"},
json=actions,
)
resp.raise_for_status()
def get_user_credentials(self, user_id: str) -> list[dict[str, Any]]:
"""Return credential metadata for one Keycloak user."""
url = (
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
f"/users/{quote(user_id, safe='')}/credentials"
)
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, headers=self._headers())
resp.raise_for_status()
data = resp.json()
if not isinstance(data, list):
return []
return [item for item in data if isinstance(item, dict)]
_OIDC: KeycloakOIDC | None = None
_ADMIN: KeycloakAdminClient | None = None
def oidc_client() -> KeycloakOIDC:
"""Return the singleton OIDC verifier."""
global _OIDC
if _OIDC is None:
_OIDC = KeycloakOIDC()
return _OIDC
def admin_client() -> KeycloakAdminClient:
"""Return the singleton Keycloak admin client."""
global _ADMIN
if _ADMIN is None:
_ADMIN = KeycloakAdminClient()
return _ADMIN
def _normalize_groups(groups: Any) -> list[str]:
if not isinstance(groups, list):
return []
cleaned: list[str] = []
for gname in groups:
if not isinstance(gname, str):
continue
cleaned.append(gname.lstrip("/"))
return [gname for gname in cleaned if gname]
def _extract_bearer_token() -> str | None:
"""Extract a bearer token from the current Flask request."""
header = request.headers.get("Authorization", "")
if not header:
return None
parts = header.split(None, 1)
if len(parts) != 2:
return None
scheme, token = parts[0].lower(), parts[1].strip()
if scheme != "bearer" or not token:
return None
return token
def require_auth(fn):
"""Decorate a Flask route so it requires a valid Keycloak bearer token."""
@wraps(fn)
def wrapper(*args, **kwargs):
"""Validate the request token and place normalized claims on Flask globals."""
token = _extract_bearer_token()
if not token:
return jsonify({"error": "missing bearer token"}), 401
try:
claims = oidc_client().verify(token)
except Exception:
return jsonify({"error": "invalid token"}), 401
g.keycloak_claims = claims
g.keycloak_username = claims.get("preferred_username") or ""
g.keycloak_email = claims.get("email") or ""
g.keycloak_groups = _normalize_groups(claims.get("groups"))
return fn(*args, **kwargs)
return wrapper
def require_portal_admin() -> tuple[bool, Any]:
"""Return whether the authenticated user can use portal admin actions."""
if not settings.KEYCLOAK_ENABLED:
return False, (jsonify({"error": "keycloak not enabled"}), 503)
username = getattr(g, "keycloak_username", "") or ""
groups = set(getattr(g, "keycloak_groups", []) or [])
if username and username in settings.PORTAL_ADMIN_USERS:
return True, None
if settings.PORTAL_ADMIN_GROUPS and groups.intersection(settings.PORTAL_ADMIN_GROUPS):
return True, None
return False, (jsonify({"error": "forbidden"}), 403)
def require_account_access() -> tuple[bool, Any]:
"""Return whether the authenticated user can use self-service account pages."""
if not settings.KEYCLOAK_ENABLED:
return False, (jsonify({"error": "keycloak not enabled"}), 503)
if not settings.ACCOUNT_ALLOWED_GROUPS:
return True, None
groups = set(getattr(g, "keycloak_groups", []) or [])
if not groups:
return True, None
if groups.intersection(settings.ACCOUNT_ALLOWED_GROUPS):
return True, None
return False, (jsonify({"error": "forbidden"}), 403)