377 lines
14 KiB
Python
377 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
from functools import wraps
|
|
from typing import Any
|
|
from urllib.parse import quote
|
|
|
|
from flask import g, jsonify, request
|
|
import httpx
|
|
import jwt
|
|
from jwt import PyJWKClient
|
|
|
|
from . import settings
|
|
|
|
|
|
class KeycloakOIDC:
|
|
def __init__(self) -> None:
|
|
self._jwk_client: PyJWKClient | None = None
|
|
|
|
def _client(self) -> PyJWKClient:
|
|
if self._jwk_client is None:
|
|
self._jwk_client = PyJWKClient(settings.KEYCLOAK_JWKS_URL)
|
|
return self._jwk_client
|
|
|
|
def verify(self, token: str) -> dict[str, Any]:
|
|
if not settings.KEYCLOAK_ENABLED:
|
|
raise ValueError("keycloak not enabled")
|
|
|
|
signing_key = self._client().get_signing_key_from_jwt(token).key
|
|
claims = jwt.decode(
|
|
token,
|
|
signing_key,
|
|
algorithms=["RS256"],
|
|
options={"verify_aud": False},
|
|
issuer=settings.KEYCLOAK_ISSUER,
|
|
)
|
|
|
|
azp = claims.get("azp")
|
|
aud = claims.get("aud")
|
|
aud_list: list[str] = []
|
|
if isinstance(aud, str):
|
|
aud_list = [aud]
|
|
elif isinstance(aud, list):
|
|
aud_list = [a for a in aud if isinstance(a, str)]
|
|
|
|
if azp != settings.KEYCLOAK_CLIENT_ID and settings.KEYCLOAK_CLIENT_ID not in aud_list:
|
|
raise ValueError("token not issued for this client")
|
|
|
|
return claims
|
|
|
|
|
|
class KeycloakAdminClient:
|
|
def __init__(self) -> None:
|
|
self._token: str = ""
|
|
self._expires_at: float = 0.0
|
|
self._group_id_cache: dict[str, str] = {}
|
|
|
|
@staticmethod
|
|
def _safe_update_payload(full: dict[str, Any]) -> dict[str, Any]:
|
|
payload: dict[str, Any] = {}
|
|
username = full.get("username")
|
|
if isinstance(username, str):
|
|
payload["username"] = username
|
|
enabled = full.get("enabled")
|
|
if isinstance(enabled, bool):
|
|
payload["enabled"] = enabled
|
|
email = full.get("email")
|
|
if isinstance(email, str):
|
|
payload["email"] = email
|
|
email_verified = full.get("emailVerified")
|
|
if isinstance(email_verified, bool):
|
|
payload["emailVerified"] = email_verified
|
|
first_name = full.get("firstName")
|
|
if isinstance(first_name, str):
|
|
payload["firstName"] = first_name
|
|
last_name = full.get("lastName")
|
|
if isinstance(last_name, str):
|
|
payload["lastName"] = last_name
|
|
|
|
actions = full.get("requiredActions")
|
|
if isinstance(actions, list):
|
|
payload["requiredActions"] = [a for a in actions if isinstance(a, str)]
|
|
|
|
attrs = full.get("attributes")
|
|
payload["attributes"] = attrs if isinstance(attrs, dict) else {}
|
|
return payload
|
|
|
|
def ready(self) -> bool:
|
|
return bool(settings.KEYCLOAK_ADMIN_CLIENT_ID and settings.KEYCLOAK_ADMIN_CLIENT_SECRET)
|
|
|
|
def _get_token(self) -> str:
|
|
if not self.ready():
|
|
raise RuntimeError("keycloak admin client not configured")
|
|
|
|
now = time.time()
|
|
if self._token and now < self._expires_at - 30:
|
|
return self._token
|
|
|
|
token_url = (
|
|
f"{settings.KEYCLOAK_ADMIN_URL}/realms/{settings.KEYCLOAK_ADMIN_REALM}/protocol/openid-connect/token"
|
|
)
|
|
data = {
|
|
"grant_type": "client_credentials",
|
|
"client_id": settings.KEYCLOAK_ADMIN_CLIENT_ID,
|
|
"client_secret": settings.KEYCLOAK_ADMIN_CLIENT_SECRET,
|
|
}
|
|
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
|
|
resp = client.post(token_url, data=data)
|
|
resp.raise_for_status()
|
|
payload = resp.json()
|
|
token = payload.get("access_token") or ""
|
|
if not token:
|
|
raise RuntimeError("no access_token in response")
|
|
expires_in = int(payload.get("expires_in") or 60)
|
|
self._token = token
|
|
self._expires_at = now + expires_in
|
|
return token
|
|
|
|
def _headers(self) -> dict[str, str]:
|
|
return {"Authorization": f"Bearer {self._get_token()}"}
|
|
|
|
def headers(self) -> dict[str, str]:
|
|
return self._headers()
|
|
|
|
def find_user(self, username: str) -> dict[str, Any] | None:
|
|
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users"
|
|
# Keycloak 26.x in our environment intermittently 400s on filtered user queries unless `max` is set.
|
|
# Use `max=1` and exact username match to keep admin calls reliable for portal provisioning.
|
|
params = {"username": username, "exact": "true", "max": "1"}
|
|
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
|
|
resp = client.get(url, params=params, headers=self._headers())
|
|
resp.raise_for_status()
|
|
users = resp.json()
|
|
if not isinstance(users, list) or not users:
|
|
return None
|
|
user = users[0]
|
|
return user if isinstance(user, dict) else None
|
|
|
|
def find_user_by_email(self, email: str) -> dict[str, Any] | None:
|
|
email = (email or "").strip()
|
|
if not email:
|
|
return None
|
|
|
|
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users"
|
|
# Match the portal's username query behavior: set a low `max` and post-filter for exact matches.
|
|
params = {"email": email, "exact": "true", "max": "2"}
|
|
email_norm = email.lower()
|
|
|
|
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
|
|
resp = client.get(url, params=params, headers=self._headers())
|
|
resp.raise_for_status()
|
|
users = resp.json()
|
|
if not isinstance(users, list) or not users:
|
|
return None
|
|
for user in users:
|
|
if not isinstance(user, dict):
|
|
continue
|
|
candidate = user.get("email")
|
|
if isinstance(candidate, str) and candidate.strip().lower() == email_norm:
|
|
return user
|
|
return None
|
|
|
|
def get_user(self, user_id: str) -> dict[str, Any]:
|
|
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users/{quote(user_id, safe='')}"
|
|
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
|
|
resp = client.get(url, headers=self._headers())
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
if not isinstance(data, dict):
|
|
raise RuntimeError("unexpected user payload")
|
|
return data
|
|
|
|
def update_user(self, user_id: str, payload: dict[str, Any]) -> None:
|
|
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users/{quote(user_id, safe='')}"
|
|
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
|
|
resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
|
|
resp.raise_for_status()
|
|
|
|
def update_user_safe(self, user_id: str, payload: dict[str, Any]) -> None:
|
|
full = self.get_user(user_id)
|
|
merged = self._safe_update_payload(full)
|
|
for key, value in payload.items():
|
|
if key == "attributes":
|
|
attrs = merged.get("attributes")
|
|
if not isinstance(attrs, dict):
|
|
attrs = {}
|
|
if isinstance(value, dict):
|
|
attrs.update(value)
|
|
merged["attributes"] = attrs
|
|
continue
|
|
merged[key] = value
|
|
self.update_user(user_id, merged)
|
|
|
|
def create_user(self, payload: dict[str, Any]) -> str:
|
|
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/users"
|
|
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
|
|
resp = client.post(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
|
|
resp.raise_for_status()
|
|
location = resp.headers.get("Location") or ""
|
|
if location:
|
|
return location.rstrip("/").split("/")[-1]
|
|
raise RuntimeError("failed to determine created user id")
|
|
|
|
def reset_password(self, user_id: str, password: str, temporary: bool = True) -> None:
|
|
url = (
|
|
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
|
|
f"/users/{quote(user_id, safe='')}/reset-password"
|
|
)
|
|
payload = {"type": "password", "value": password, "temporary": bool(temporary)}
|
|
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
|
|
resp = client.put(url, headers={**self._headers(), "Content-Type": "application/json"}, json=payload)
|
|
resp.raise_for_status()
|
|
|
|
def set_user_attribute(self, username: str, key: str, value: str) -> None:
|
|
user = self.find_user(username)
|
|
if not user:
|
|
raise RuntimeError("user not found")
|
|
user_id = user.get("id") or ""
|
|
if not user_id:
|
|
raise RuntimeError("user id missing")
|
|
|
|
full = self.get_user(user_id)
|
|
payload = self._safe_update_payload(full)
|
|
attrs = payload.get("attributes")
|
|
if not isinstance(attrs, dict):
|
|
attrs = {}
|
|
attrs[key] = [value]
|
|
payload["attributes"] = attrs
|
|
# Keep profile fields intact so required actions don't re-trigger unexpectedly.
|
|
self.update_user(user_id, payload)
|
|
|
|
def get_group_id(self, group_name: str) -> str | None:
|
|
cached = self._group_id_cache.get(group_name)
|
|
if cached:
|
|
return cached
|
|
|
|
url = f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}/groups"
|
|
params = {"search": group_name}
|
|
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
|
|
resp = client.get(url, params=params, headers=self._headers())
|
|
resp.raise_for_status()
|
|
items = resp.json()
|
|
if not isinstance(items, list):
|
|
return None
|
|
for item in items:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
if item.get("name") == group_name and item.get("id"):
|
|
gid = str(item["id"])
|
|
self._group_id_cache[group_name] = gid
|
|
return gid
|
|
return None
|
|
|
|
def add_user_to_group(self, user_id: str, group_id: str) -> None:
|
|
url = (
|
|
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
|
|
f"/users/{quote(user_id, safe='')}/groups/{quote(group_id, safe='')}"
|
|
)
|
|
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
|
|
resp = client.put(url, headers=self._headers())
|
|
resp.raise_for_status()
|
|
|
|
def execute_actions_email(self, user_id: str, actions: list[str], redirect_uri: str) -> None:
|
|
url = (
|
|
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
|
|
f"/users/{quote(user_id, safe='')}/execute-actions-email"
|
|
)
|
|
params = {"client_id": settings.KEYCLOAK_CLIENT_ID, "redirect_uri": redirect_uri}
|
|
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
|
|
resp = client.put(
|
|
url,
|
|
params=params,
|
|
headers={**self._headers(), "Content-Type": "application/json"},
|
|
json=actions,
|
|
)
|
|
resp.raise_for_status()
|
|
|
|
def get_user_credentials(self, user_id: str) -> list[dict[str, Any]]:
|
|
url = (
|
|
f"{settings.KEYCLOAK_ADMIN_URL}/admin/realms/{settings.KEYCLOAK_REALM}"
|
|
f"/users/{quote(user_id, safe='')}/credentials"
|
|
)
|
|
with httpx.Client(timeout=settings.HTTP_CHECK_TIMEOUT_SEC) as client:
|
|
resp = client.get(url, headers=self._headers())
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
if not isinstance(data, list):
|
|
return []
|
|
return [item for item in data if isinstance(item, dict)]
|
|
|
|
|
|
_OIDC: KeycloakOIDC | None = None
|
|
_ADMIN: KeycloakAdminClient | None = None
|
|
|
|
|
|
def oidc_client() -> KeycloakOIDC:
|
|
global _OIDC
|
|
if _OIDC is None:
|
|
_OIDC = KeycloakOIDC()
|
|
return _OIDC
|
|
|
|
|
|
def admin_client() -> KeycloakAdminClient:
|
|
global _ADMIN
|
|
if _ADMIN is None:
|
|
_ADMIN = KeycloakAdminClient()
|
|
return _ADMIN
|
|
|
|
|
|
def _normalize_groups(groups: Any) -> list[str]:
|
|
if not isinstance(groups, list):
|
|
return []
|
|
cleaned: list[str] = []
|
|
for gname in groups:
|
|
if not isinstance(gname, str):
|
|
continue
|
|
cleaned.append(gname.lstrip("/"))
|
|
return [gname for gname in cleaned if gname]
|
|
|
|
|
|
def _extract_bearer_token() -> str | None:
|
|
header = request.headers.get("Authorization", "")
|
|
if not header:
|
|
return None
|
|
parts = header.split(None, 1)
|
|
if len(parts) != 2:
|
|
return None
|
|
scheme, token = parts[0].lower(), parts[1].strip()
|
|
if scheme != "bearer" or not token:
|
|
return None
|
|
return token
|
|
|
|
|
|
def require_auth(fn):
|
|
@wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
token = _extract_bearer_token()
|
|
if not token:
|
|
return jsonify({"error": "missing bearer token"}), 401
|
|
try:
|
|
claims = oidc_client().verify(token)
|
|
except Exception:
|
|
return jsonify({"error": "invalid token"}), 401
|
|
|
|
g.keycloak_claims = claims
|
|
g.keycloak_username = claims.get("preferred_username") or ""
|
|
g.keycloak_email = claims.get("email") or ""
|
|
g.keycloak_groups = _normalize_groups(claims.get("groups"))
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def require_portal_admin() -> tuple[bool, Any]:
|
|
if not settings.KEYCLOAK_ENABLED:
|
|
return False, (jsonify({"error": "keycloak not enabled"}), 503)
|
|
|
|
username = getattr(g, "keycloak_username", "") or ""
|
|
groups = set(getattr(g, "keycloak_groups", []) or [])
|
|
|
|
if username and username in settings.PORTAL_ADMIN_USERS:
|
|
return True, None
|
|
if settings.PORTAL_ADMIN_GROUPS and groups.intersection(settings.PORTAL_ADMIN_GROUPS):
|
|
return True, None
|
|
return False, (jsonify({"error": "forbidden"}), 403)
|
|
|
|
|
|
def require_account_access() -> tuple[bool, Any]:
|
|
if not settings.KEYCLOAK_ENABLED:
|
|
return False, (jsonify({"error": "keycloak not enabled"}), 503)
|
|
if not settings.ACCOUNT_ALLOWED_GROUPS:
|
|
return True, None
|
|
groups = set(getattr(g, "keycloak_groups", []) or [])
|
|
if groups.intersection(settings.ACCOUNT_ALLOWED_GROUPS):
|
|
return True, None
|
|
return False, (jsonify({"error": "forbidden"}), 403)
|