144 lines
4.7 KiB
Python
144 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
import time
|
|
|
|
import httpx
|
|
import jwt
|
|
|
|
from ..settings import settings
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AuthContext:
|
|
"""Authenticated user details returned by the OIDC verifier.
|
|
|
|
Inputs: normalized claims extracted from a validated bearer token.
|
|
Outputs: a compact object that downstream handlers can trust without
|
|
repeating token parsing logic.
|
|
"""
|
|
|
|
username: str
|
|
email: str
|
|
groups: list[str]
|
|
claims: dict[str, Any]
|
|
|
|
|
|
class KeycloakOIDC:
|
|
"""Validate Keycloak-issued access tokens for Ariadne API requests.
|
|
|
|
Inputs: the JWKS URL, expected issuer, and client identifier.
|
|
Outputs: verified token claims after signature and audience checks so the
|
|
API can make authorization decisions safely.
|
|
"""
|
|
|
|
def __init__(self, jwks_url: str, issuer: str, client_id: str) -> None:
|
|
self._jwks_url = jwks_url
|
|
self._issuer = issuer
|
|
self._client_id = client_id
|
|
self._jwks: dict[str, Any] | None = None
|
|
self._jwks_fetched_at: float = 0.0
|
|
self._jwks_ttl_sec = 300.0
|
|
|
|
def _get_kid(self, token: str) -> str:
|
|
header = jwt.get_unverified_header(token)
|
|
kid = header.get("kid")
|
|
if not isinstance(kid, str):
|
|
raise ValueError("token missing kid")
|
|
return kid
|
|
|
|
def _find_key(self, jwks: dict[str, Any], kid: str) -> dict[str, Any] | None:
|
|
for candidate in jwks.get("keys", []) if isinstance(jwks, dict) else []:
|
|
if isinstance(candidate, dict) and candidate.get("kid") == kid:
|
|
return candidate
|
|
return None
|
|
|
|
def _resolve_key(self, kid: str) -> dict[str, Any]:
|
|
jwks = self._get_jwks()
|
|
key = self._find_key(jwks, kid)
|
|
if key:
|
|
return key
|
|
self._jwks = None
|
|
jwks = self._get_jwks(force=True)
|
|
key = self._find_key(jwks, kid)
|
|
if not key:
|
|
raise ValueError("token kid not found")
|
|
return key
|
|
|
|
def _decode_claims(self, token: str, key: dict[str, Any]) -> dict[str, Any]:
|
|
return jwt.decode(
|
|
token,
|
|
key=jwt.algorithms.RSAAlgorithm.from_jwk(key),
|
|
algorithms=["RS256"],
|
|
options={"verify_aud": False},
|
|
issuer=self._issuer,
|
|
)
|
|
|
|
def _validate_audience(self, claims: dict[str, Any]) -> None:
|
|
azp = claims.get("azp")
|
|
aud = claims.get("aud")
|
|
aud_list: list[str] = []
|
|
if isinstance(aud, str):
|
|
aud_list = [aud]
|
|
elif isinstance(aud, list):
|
|
aud_list = [item for item in aud if isinstance(item, str)]
|
|
if azp != self._client_id and self._client_id not in aud_list:
|
|
raise ValueError("token not issued for expected client")
|
|
|
|
def verify(self, token: str) -> dict[str, Any]:
|
|
if not token:
|
|
raise ValueError("missing token")
|
|
kid = self._get_kid(token)
|
|
key = self._resolve_key(kid)
|
|
claims = self._decode_claims(token, key)
|
|
self._validate_audience(claims)
|
|
return claims
|
|
|
|
def _get_jwks(self, force: bool = False) -> dict[str, Any]:
|
|
now = time.time()
|
|
if not force and self._jwks and now - self._jwks_fetched_at < self._jwks_ttl_sec:
|
|
return self._jwks
|
|
with httpx.Client(timeout=5.0) as client:
|
|
resp = client.get(self._jwks_url)
|
|
resp.raise_for_status()
|
|
payload = resp.json()
|
|
if not isinstance(payload, dict):
|
|
raise ValueError("jwks payload invalid")
|
|
self._jwks = payload
|
|
self._jwks_fetched_at = now
|
|
return payload
|
|
|
|
|
|
class Authenticator:
|
|
"""Convert bearer tokens into normalized Ariadne auth contexts.
|
|
|
|
Inputs: raw bearer tokens from incoming API requests.
|
|
Outputs: an `AuthContext` with cleaned usernames, emails, and groups so
|
|
endpoint handlers can stay focused on business logic.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._oidc = KeycloakOIDC(settings.keycloak_jwks_url, settings.keycloak_issuer, settings.keycloak_client_id)
|
|
|
|
@staticmethod
|
|
def _normalize_groups(groups: Any) -> list[str]:
|
|
if not isinstance(groups, list):
|
|
return []
|
|
cleaned: list[str] = []
|
|
for name in groups:
|
|
if not isinstance(name, str):
|
|
continue
|
|
cleaned.append(name.lstrip("/"))
|
|
return [name for name in cleaned if name]
|
|
|
|
def authenticate(self, token: str) -> AuthContext:
|
|
claims = self._oidc.verify(token)
|
|
username = claims.get("preferred_username") or ""
|
|
email = claims.get("email") or ""
|
|
groups = self._normalize_groups(claims.get("groups"))
|
|
return AuthContext(username=username, email=email, groups=groups, claims=claims)
|
|
|
|
|
|
authenticator = Authenticator()
|