404 lines
15 KiB
Python
Raw Normal View History

from __future__ import annotations
import time
from functools import wraps
from typing import Any
from urllib.parse import quote
from flask import g, jsonify, request
import httpx
import jwt
from jwt import PyJWKClient
from . import settings
class KeycloakOIDC:
def __init__(self) -> None:
self._jwk_client: PyJWKClient | None = None
def _client(self) -> PyJWKClient:
if self._jwk_client is None:
self._jwk_client = PyJWKClient(settings.KEYCLOAK_JWKS_URL)
return self._jwk_client
def verify(self, token: str) -> dict[str, Any]:
if not settings.KEYCLOAK_ENABLED:
raise ValueError("keycloak not enabled")
signing_key = self._client().get_signing_key_from_jwt(token).key
claims = jwt.decode(
token,
signing_key,
algorithms=["RS256"],
options={"verify_aud": False},
issuer=settings.KEYCLOAK_ISSUER,
)
azp = claims.get("azp")
aud = claims.get("aud")
aud_list: list[str] = []
if isinstance(aud, str):
aud_list = [aud]
elif isinstance(aud, list):
aud_list = [a for a in aud if isinstance(a, str)]
if azp != settings.KEYCLOAK_CLIENT_ID and settings.KEYCLOAK_CLIENT_ID not in aud_list:
raise ValueError("token not issued for this client")
return claims
class KeycloakAdminClient:
def __init__(self) -> None:
self._token: str = ""
self._expires_at: float = 0.0
self._group_id_cache: dict[str, str] = {}
2026-01-20 03:58:34 -03:00
@staticmethod
def _safe_update_payload(full: dict[str, Any]) -> dict[str, Any]:
payload: dict[str, Any] = {}
username = full.get("username")
if isinstance(username, str):
payload["username"] = username
enabled = full.get("enabled")
if isinstance(enabled, bool):
payload["enabled"] = enabled
email = full.get("email")
if isinstance(email, str):
payload["email"] = email
email_verified = full.get("emailVerified")
if isinstance(email_verified, bool):
payload["emailVerified"] = email_verified
first_name = full.get("firstName")
if isinstance(first_name, str):
payload["firstName"] = first_name
last_name = full.get("lastName")
if isinstance(last_name, str):
payload["lastName"] = last_name
actions = full.get("requiredActions")
if isinstance(actions, list):
payload["requiredActions"] = [a for a in actions if isinstance(a, str)]
attrs = full.get("attributes")
payload["attributes"] = attrs if isinstance(attrs, dict) else {}
return payload
def ready(self) -> bool:
return bool(settings.KEYCLOAK_ADMIN_CLIENT_ID and settings.KEYCLOAK_ADMIN_CLIENT_SECRET)
def _get_token(self) -> str:
if not self.ready():
raise RuntimeError("keycloak admin client not configured")
now = time.time()
if self._token and now < self._expires_at - 30:
return self._token
token_url = (
f"{settings.KEYCLOAK_ADMIN_URL}/realms/{settings.KEYCLOAK_ADMIN_REALM}/protocol/openid-connect/token"
)
data = {
"grant_type": "client_credentials",
"client_id": settings.KEYCLOAK_ADMIN_CLIENT_ID,
"client_secret": settings.KEYCLOAK_ADMIN_CLIENT_SECRET,
}
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.post(token_url, data=data)
resp.raise_for_status()
payload = resp.json()
token = payload.get("access_token") or ""
if not token:
raise RuntimeError("no access_token in response")
expires_in = int(payload.get("expires_in") or 60)
self._token = token
self._expires_at = now + expires_in
return token
def _headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._get_token()}"}
def headers(self) -> dict[str, str]:
return self._headers()
def find_user(self, username: str) -> dict[str, Any] | None:
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users"
2026-01-02 17:42:03 -03:00
# Keycloak 26.x in our environment intermittently 400s on filtered user queries unless `max` is set.
# Use `max=1` and exact username match to keep admin calls reliable for portal provisioning.
params = {"username": username, "exact": "true", "max": "1"}
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, params=params, headers=self._headers())
resp.raise_for_status()
users = resp.json()
if not isinstance(users, list) or not users:
return None
user = users[0]
return user if isinstance(user, dict) else None
def find_user_by_email(self, email: str) -> dict[str, Any] | None:
email = (email or "").strip()
if not email:
return None
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users"
# Match the portal's username query behavior: set a low `max` and post-filter for exact matches.
params = {"email": email, "exact": "true", "max": "2"}
email_norm = email.lower()
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, params=params, headers=self._headers())
resp.raise_for_status()
users = resp.json()
if not isinstance(users, list) or not users:
return None
for user in users:
if not isinstance(user, dict):
continue
candidate = user.get("email")
if isinstance(candidate, str) and candidate.strip().lower() == email_norm:
return user
return None
def get_user(self, user_id: str) -> dict[str, Any]:
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users/{quote(user_id, safe='')}"
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, headers=self._headers())
resp.raise_for_status()
data = resp.json()
if not isinstance(data, dict):
raise RuntimeError("unexpected user payload")
return data
def update_user(self, user_id: str, payload: dict[str, Any]) -> None:
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users/{quote(user_id, safe='')}"
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
resp.raise_for_status()
2026-01-20 03:58:34 -03:00
def update_user_safe(self, user_id: str, payload: dict[str, Any]) -> None:
full = self.get_user(user_id)
merged = self._safe_update_payload(full)
for key, value in payload.items():
if key == "attributes":
attrs = merged.get("attributes")
if not isinstance(attrs, dict):
attrs = {}
if isinstance(value, dict):
attrs.update(value)
merged["attributes"] = attrs
continue
merged[key] = value
self.update_user(user_id, merged)
def create_user(self, payload: dict[str, Any]) -> str:
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users"
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.post(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
resp.raise_for_status()
location = resp.headers.get("Location") or ""
if location:
return location.rstrip("/").split("/")[-1]
raise RuntimeError("failed to determine created user id")
def reset_password(self, user_id: str, password: str, temporary: bool = True) -> None:
url = (
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
f"/users/{quote(user_id, safe='')}/reset-password"
)
payload = {"type": "password", "value": password, "temporary": bool(temporary)}
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
resp.raise_for_status()
def set_user_attribute(self, username: str, key: str, value: str) -> None:
user = self.find_user(username)
if not user:
raise RuntimeError("user not found")
user_id = user.get("id") or ""
if not user_id:
raise RuntimeError("user id missing")
full = self.get_user(user_id)
2026-01-20 03:58:34 -03:00
payload = self._safe_update_payload(full)
attrs = payload.get("attributes")
if not isinstance(attrs, dict):
attrs = {}
attrs[key] = [value]
2026-01-20 03:58:34 -03:00
payload["attributes"] = attrs
# Keep profile fields intact so required actions don't re-trigger unexpectedly.
self.update_user(user_id, payload)
def get_group_id(self, group_name: str) -> str | None:
cached = self._group_id_cache.get(group_name)
if cached:
return cached
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/groups"
params = {"search": group_name}
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, params=params, headers=self._headers())
resp.raise_for_status()
items = resp.json()
if not isinstance(items, list):
return None
for item in items:
if not isinstance(item, dict):
continue
if item.get("name") == group_name and item.get("id"):
gid = str(item["id"])
self._group_id_cache[group_name] = gid
return gid
return None
def list_group_names(self) -> list[str]:
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/groups"
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, headers=self._headers())
resp.raise_for_status()
items = resp.json()
if not isinstance(items, list):
return []
names: set[str] = set()
def walk(groups: list[Any]) -> None:
for group in groups:
if not isinstance(group, dict):
continue
name = group.get("name")
if isinstance(name, str) and name:
names.add(name)
sub = group.get("subGroups")
if isinstance(sub, list) and sub:
walk(sub)
walk(items)
return sorted(names)
def add_user_to_group(self, user_id: str, group_id: str) -> None:
url = (
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
f"/users/{quote(user_id, safe='')}/groups/{quote(group_id, safe='')}"
)
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.put(url, headers=self._headers())
resp.raise_for_status()
def execute_actions_email(self, user_id: str, actions: list[str], redirect_uri: str) -> None:
url = (
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
f"/users/{quote(user_id, safe='')}/execute-actions-email"
)
params = {"client_id": settings.KEYCLOAK_CLIENT_ID, "redirect_uri": redirect_uri}
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.put(
url,
params=params,
headers={**self._headers(), "Content-Type": "application/json"},
json=actions,
)
resp.raise_for_status()
def get_user_credentials(self, user_id: str) -> list[dict[str, Any]]:
url = (
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
f"/users/{quote(user_id, safe='')}/credentials"
)
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
resp = client.get(url, headers=self._headers())
resp.raise_for_status()
data = resp.json()
if not isinstance(data, list):
return []
return [item for item in data if isinstance(item, dict)]
_OIDC: KeycloakOIDC | None = None
_ADMIN: KeycloakAdminClient | None = None
def oidc_client() -> KeycloakOIDC:
global _OIDC
if _OIDC is None:
_OIDC = KeycloakOIDC()
return _OIDC
def admin_client() -> KeycloakAdminClient:
global _ADMIN
if _ADMIN is None:
_ADMIN = KeycloakAdminClient()
return _ADMIN
def _normalize_groups(groups: Any) -> list[str]:
if not isinstance(groups, list):
return []
cleaned: list[str] = []
for gname in groups:
if not isinstance(gname, str):
continue
cleaned.append(gname.lstrip("/"))
return [gname for gname in cleaned if gname]
def _extract_bearer_token() -> str | None:
header = request.headers.get("Authorization", "")
if not header:
return None
parts = header.split(None, 1)
if len(parts) != 2:
return None
scheme, token = parts[0].lower(), parts[1].strip()
if scheme != "bearer" or not token:
return None
return token
def require_auth(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
token = _extract_bearer_token()
if not token:
return jsonify({"error": "missing bearer token"}), 401
try:
claims = oidc_client().verify(token)
except Exception:
return jsonify({"error": "invalid token"}), 401
g.keycloak_claims = claims
g.keycloak_username = claims.get("preferred_username") or ""
g.keycloak_email = claims.get("email") or ""
g.keycloak_groups = _normalize_groups(claims.get("groups"))
return fn(*args, **kwargs)
return wrapper
def require_portal_admin() -> tuple[bool, Any]:
if not settings.KEYCLOAK_ENABLED:
return False, (jsonify({"error": "keycloak not enabled"}), 503)
username = getattr(g, "keycloak_username", "") or ""
groups = set(getattr(g, "keycloak_groups", []) or [])
if username and username in settings.PORTAL_ADMIN_USERS:
return True, None
if settings.PORTAL_ADMIN_GROUPS and groups.intersection(settings.PORTAL_ADMIN_GROUPS):
return True, None
return False, (jsonify({"error": "forbidden"}), 403)
def require_account_access() -> tuple[bool, Any]:
if not settings.KEYCLOAK_ENABLED:
return False, (jsonify({"error": "keycloak not enabled"}), 503)
if not settings.ACCOUNT_ALLOWED_GROUPS:
return True, None
groups = set(getattr(g, "keycloak_groups", []) or [])
if not groups:
return True, None
if groups.intersection(settings.ACCOUNT_ALLOWED_GROUPS):
return True, None
return False, (jsonify({"error": "forbidden"}), 403)