ariadne/ariadne/auth/keycloak.py

133 lines
4.4 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import time
import httpx
import jwt
from ..settings import settings
@dataclass(frozen=True)
class AuthContext:
username: str
email: str
groups: list[str]
claims: dict[str, Any]
class KeycloakOIDC:
"""Validate Keycloak-issued OIDC tokens and return trusted claims."""
def __init__(self, jwks_url: str, issuer: str, client_id: str) -> None:
self._jwks_url = jwks_url
self._issuer = issuer
self._client_id = client_id
self._jwks: dict[str, Any] | None = None
self._jwks_fetched_at: float = 0.0
self._jwks_ttl_sec = 300.0
def _get_kid(self, token: str) -> str:
header = jwt.get_unverified_header(token)
kid = header.get("kid")
if not isinstance(kid, str):
raise ValueError("token missing kid")
return kid
def _find_key(self, jwks: dict[str, Any], kid: str) -> dict[str, Any] | None:
for candidate in jwks.get("keys", []) if isinstance(jwks, dict) else []:
if isinstance(candidate, dict) and candidate.get("kid") == kid:
return candidate
return None
def _resolve_key(self, kid: str) -> dict[str, Any]:
jwks = self._get_jwks()
key = self._find_key(jwks, kid)
if key:
return key
self._jwks = None
jwks = self._get_jwks(force=True)
key = self._find_key(jwks, kid)
if not key:
raise ValueError("token kid not found")
return key
def _decode_claims(self, token: str, key: dict[str, Any]) -> dict[str, Any]:
return jwt.decode(
token,
key=self._key_from_jwk(key),
algorithms=["RS256"],
options={"verify_aud": False},
issuer=self._issuer,
)
def _key_from_jwk(self, key: dict[str, Any]) -> Any:
algorithm = getattr(jwt.algorithms, "RSAAlgorithm", None)
if algorithm and hasattr(algorithm, "from_jwk"):
return algorithm.from_jwk(key)
return jwt.PyJWK.from_dict(key).key
def _validate_audience(self, claims: dict[str, Any]) -> None:
azp = claims.get("azp")
aud = claims.get("aud")
aud_list: list[str] = []
if isinstance(aud, str):
aud_list = [aud]
elif isinstance(aud, list):
aud_list = [item for item in aud if isinstance(item, str)]
if azp != self._client_id and self._client_id not in aud_list:
raise ValueError("token not issued for expected client")
def verify(self, token: str) -> dict[str, Any]:
if not token:
raise ValueError("missing token")
kid = self._get_kid(token)
key = self._resolve_key(kid)
claims = self._decode_claims(token, key)
self._validate_audience(claims)
return claims
def _get_jwks(self, force: bool = False) -> dict[str, Any]:
now = time.time()
if not force and self._jwks and now - self._jwks_fetched_at < self._jwks_ttl_sec:
return self._jwks
with httpx.Client(timeout=5.0) as client:
resp = client.get(self._jwks_url)
resp.raise_for_status()
payload = resp.json()
if not isinstance(payload, dict):
raise ValueError("jwks payload invalid")
self._jwks = payload
self._jwks_fetched_at = now
return payload
class Authenticator:
"""Translate bearer tokens into Ariadne authorization context."""
def __init__(self) -> None:
self._oidc = KeycloakOIDC(settings.keycloak_jwks_url, settings.keycloak_issuer, settings.keycloak_client_id)
@staticmethod
def _normalize_groups(groups: Any) -> list[str]:
if not isinstance(groups, list):
return []
cleaned: list[str] = []
for name in groups:
if not isinstance(name, str):
continue
cleaned.append(name.lstrip("/"))
return [name for name in cleaned if name]
def authenticate(self, token: str) -> AuthContext:
claims = self._oidc.verify(token)
username = claims.get("preferred_username") or ""
email = claims.get("email") or ""
groups = self._normalize_groups(claims.get("groups"))
return AuthContext(username=username, email=email, groups=groups, claims=claims)
authenticator = Authenticator()