from __future__ import annotations from dataclasses import dataclass from typing import Any import time import httpx import jwt from ..settings import settings @dataclass(frozen=True) class AuthContext: """Authenticated user details returned by the OIDC verifier. Inputs: normalized claims extracted from a validated bearer token. Outputs: a compact object that downstream handlers can trust without repeating token parsing logic. """ username: str email: str groups: list[str] claims: dict[str, Any] class KeycloakOIDC: """Validate Keycloak-issued access tokens for Ariadne API requests. Inputs: the JWKS URL, expected issuer, and client identifier. Outputs: verified token claims after signature and audience checks so the API can make authorization decisions safely. """ def __init__(self, jwks_url: str, issuer: str, client_id: str) -> None: self._jwks_url = jwks_url self._issuer = issuer self._client_id = client_id self._jwks: dict[str, Any] | None = None self._jwks_fetched_at: float = 0.0 self._jwks_ttl_sec = 300.0 def _get_kid(self, token: str) -> str: header = jwt.get_unverified_header(token) kid = header.get("kid") if not isinstance(kid, str): raise ValueError("token missing kid") return kid def _find_key(self, jwks: dict[str, Any], kid: str) -> dict[str, Any] | None: for candidate in jwks.get("keys", []) if isinstance(jwks, dict) else []: if isinstance(candidate, dict) and candidate.get("kid") == kid: return candidate return None def _resolve_key(self, kid: str) -> dict[str, Any]: jwks = self._get_jwks() key = self._find_key(jwks, kid) if key: return key self._jwks = None jwks = self._get_jwks(force=True) key = self._find_key(jwks, kid) if not key: raise ValueError("token kid not found") return key def _decode_claims(self, token: str, key: dict[str, Any]) -> dict[str, Any]: return jwt.decode( token, key=jwt.algorithms.RSAAlgorithm.from_jwk(key), algorithms=["RS256"], options={"verify_aud": False}, issuer=self._issuer, ) def _validate_audience(self, claims: dict[str, Any]) -> None: azp = claims.get("azp") aud = claims.get("aud") aud_list: list[str] = [] if isinstance(aud, str): aud_list = [aud] elif isinstance(aud, list): aud_list = [item for item in aud if isinstance(item, str)] if azp != self._client_id and self._client_id not in aud_list: raise ValueError("token not issued for expected client") def verify(self, token: str) -> dict[str, Any]: if not token: raise ValueError("missing token") kid = self._get_kid(token) key = self._resolve_key(kid) claims = self._decode_claims(token, key) self._validate_audience(claims) return claims def _get_jwks(self, force: bool = False) -> dict[str, Any]: now = time.time() if not force and self._jwks and now - self._jwks_fetched_at < self._jwks_ttl_sec: return self._jwks with httpx.Client(timeout=5.0) as client: resp = client.get(self._jwks_url) resp.raise_for_status() payload = resp.json() if not isinstance(payload, dict): raise ValueError("jwks payload invalid") self._jwks = payload self._jwks_fetched_at = now return payload class Authenticator: """Convert bearer tokens into normalized Ariadne auth contexts. Inputs: raw bearer tokens from incoming API requests. Outputs: an `AuthContext` with cleaned usernames, emails, and groups so endpoint handlers can stay focused on business logic. """ def __init__(self) -> None: self._oidc = KeycloakOIDC(settings.keycloak_jwks_url, settings.keycloak_issuer, settings.keycloak_client_id) @staticmethod def _normalize_groups(groups: Any) -> list[str]: if not isinstance(groups, list): return [] cleaned: list[str] = [] for name in groups: if not isinstance(name, str): continue cleaned.append(name.lstrip("/")) return [name for name in cleaned if name] def authenticate(self, token: str) -> AuthContext: claims = self._oidc.verify(token) username = claims.get("preferred_username") or "" email = claims.get("email") or "" groups = self._normalize_groups(claims.get("groups")) return AuthContext(username=username, email=email, groups=groups, claims=claims) authenticator = Authenticator()