ariadne/ariadne/auth/keycloak.py

144 lines
4.7 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import time
import httpx
import jwt
from ..settings import settings
@dataclass(frozen=True)
class AuthContext:
"""Authenticated user details returned by the OIDC verifier.
Inputs: normalized claims extracted from a validated bearer token.
Outputs: a compact object that downstream handlers can trust without
repeating token parsing logic.
"""
username: str
email: str
groups: list[str]
claims: dict[str, Any]
class KeycloakOIDC:
"""Validate Keycloak-issued access tokens for Ariadne API requests.
Inputs: the JWKS URL, expected issuer, and client identifier.
Outputs: verified token claims after signature and audience checks so the
API can make authorization decisions safely.
"""
def __init__(self, jwks_url: str, issuer: str, client_id: str) -> None:
self._jwks_url = jwks_url
self._issuer = issuer
self._client_id = client_id
self._jwks: dict[str, Any] | None = None
self._jwks_fetched_at: float = 0.0
self._jwks_ttl_sec = 300.0
def _get_kid(self, token: str) -> str:
header = jwt.get_unverified_header(token)
kid = header.get("kid")
if not isinstance(kid, str):
raise ValueError("token missing kid")
return kid
def _find_key(self, jwks: dict[str, Any], kid: str) -> dict[str, Any] | None:
for candidate in jwks.get("keys", []) if isinstance(jwks, dict) else []:
if isinstance(candidate, dict) and candidate.get("kid") == kid:
return candidate
return None
def _resolve_key(self, kid: str) -> dict[str, Any]:
jwks = self._get_jwks()
key = self._find_key(jwks, kid)
if key:
return key
self._jwks = None
jwks = self._get_jwks(force=True)
key = self._find_key(jwks, kid)
if not key:
raise ValueError("token kid not found")
return key
def _decode_claims(self, token: str, key: dict[str, Any]) -> dict[str, Any]:
return jwt.decode(
token,
key=jwt.algorithms.RSAAlgorithm.from_jwk(key),
algorithms=["RS256"],
options={"verify_aud": False},
issuer=self._issuer,
)
def _validate_audience(self, claims: dict[str, Any]) -> None:
azp = claims.get("azp")
aud = claims.get("aud")
aud_list: list[str] = []
if isinstance(aud, str):
aud_list = [aud]
elif isinstance(aud, list):
aud_list = [item for item in aud if isinstance(item, str)]
if azp != self._client_id and self._client_id not in aud_list:
raise ValueError("token not issued for expected client")
def verify(self, token: str) -> dict[str, Any]:
if not token:
raise ValueError("missing token")
kid = self._get_kid(token)
key = self._resolve_key(kid)
claims = self._decode_claims(token, key)
self._validate_audience(claims)
return claims
def _get_jwks(self, force: bool = False) -> dict[str, Any]:
now = time.time()
if not force and self._jwks and now - self._jwks_fetched_at < self._jwks_ttl_sec:
return self._jwks
with httpx.Client(timeout=5.0) as client:
resp = client.get(self._jwks_url)
resp.raise_for_status()
payload = resp.json()
if not isinstance(payload, dict):
raise ValueError("jwks payload invalid")
self._jwks = payload
self._jwks_fetched_at = now
return payload
class Authenticator:
"""Convert bearer tokens into normalized Ariadne auth contexts.
Inputs: raw bearer tokens from incoming API requests.
Outputs: an `AuthContext` with cleaned usernames, emails, and groups so
endpoint handlers can stay focused on business logic.
"""
def __init__(self) -> None:
self._oidc = KeycloakOIDC(settings.keycloak_jwks_url, settings.keycloak_issuer, settings.keycloak_client_id)
@staticmethod
def _normalize_groups(groups: Any) -> list[str]:
if not isinstance(groups, list):
return []
cleaned: list[str] = []
for name in groups:
if not isinstance(name, str):
continue
cleaned.append(name.lstrip("/"))
return [name for name in cleaned if name]
def authenticate(self, token: str) -> AuthContext:
claims = self._oidc.verify(token)
username = claims.get("preferred_username") or ""
email = claims.get("email") or ""
groups = self._normalize_groups(claims.get("groups"))
return AuthContext(username=username, email=email, groups=groups, claims=claims)
authenticator = Authenticator()