#!/usr/bin/env python3 import base64 import json import os import sys import time import urllib.parse import urllib.error import urllib.request def _require_env(name: str) -> str: value = os.environ.get(name) if not value: raise SystemExit(f"missing required env var: {name}") return value def _post_form(url: str, data: dict[str, str], token: str | None = None, timeout_s: int = 30) -> dict: body = urllib.parse.urlencode(data).encode() headers = {"Content-Type": "application/x-www-form-urlencoded"} if token: headers["Authorization"] = f"Bearer {token}" req = urllib.request.Request(url, data=body, headers=headers, method="POST") try: with urllib.request.urlopen(req, timeout=timeout_s) as resp: payload = resp.read().decode() return json.loads(payload) if payload else {} except urllib.error.HTTPError as exc: raw = exc.read().decode(errors="replace") raise SystemExit(f"HTTP {exc.code} from {url}: {raw}") def _get_json(url: str, token: str, timeout_s: int = 30) -> object: req = urllib.request.Request(url, headers={"Authorization": f"Bearer {token}"}, method="GET") try: with urllib.request.urlopen(req, timeout=timeout_s) as resp: payload = resp.read().decode() return json.loads(payload) if payload else None except urllib.error.HTTPError as exc: raw = exc.read().decode(errors="replace") raise SystemExit(f"HTTP {exc.code} from {url}: {raw}") def _decode_jwt_without_verification(jwt: str) -> dict: parts = jwt.split(".") if len(parts) < 2: return {} padded = parts[1] + "=" * (-len(parts[1]) % 4) try: return json.loads(base64.urlsafe_b64decode(padded.encode()).decode()) except Exception: return {} def _is_retryable_failure(message: str) -> bool: retryable_markers = ( "HTTP 401 ", "HTTP 403 ", "HTTP 404 ", "HTTP 409 ", "HTTP 429 ", "HTTP 500 ", "HTTP 502 ", "HTTP 503 ", "HTTP 504 ", "timed out", "Temporary failure", "Connection refused", ) return any(marker in message for marker in retryable_markers) def main() -> int: keycloak_base = _require_env("KEYCLOAK_SERVER").rstrip("/") realm = os.environ.get("KEYCLOAK_REALM", "atlas") client_id = _require_env("PORTAL_E2E_CLIENT_ID") client_secret = _require_env("PORTAL_E2E_CLIENT_SECRET") target_client_id = os.environ.get("TARGET_CLIENT_ID", "bstein-dev-home") impersonate_username = os.environ.get("IMPERSONATE_USERNAME", "robotuser") token_url = f"{keycloak_base}/realms/{realm}/protocol/openid-connect/token" admin_users_url = f"{keycloak_base}/admin/realms/{realm}/users" def run_once() -> None: token_payload = _post_form( token_url, {"grant_type": "client_credentials", "client_id": client_id, "client_secret": client_secret}, ) access_token = token_payload.get("access_token") if not isinstance(access_token, str) or not access_token: raise SystemExit("client credentials token missing access_token") users = _get_json( f"{admin_users_url}?{urllib.parse.urlencode({'username': impersonate_username, 'exact': 'true'})}", access_token, ) if not isinstance(users, list) or not users: raise SystemExit(f"unable to locate user {impersonate_username!r} via admin API") user_id = users[0].get("id") if not isinstance(user_id, str) or not user_id: raise SystemExit(f"user {impersonate_username!r} missing id") exchange_payload = _post_form( token_url, { "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", "client_id": client_id, "client_secret": client_secret, "subject_token": access_token, "requested_subject": user_id, "audience": target_client_id, }, ) exchanged = exchange_payload.get("access_token") if not isinstance(exchanged, str) or not exchanged: raise SystemExit("token exchange response missing access_token") claims = _decode_jwt_without_verification(exchanged) aud = claims.get("aud") if aud is None: raise SystemExit("token exchange access_token missing aud claim") if isinstance(aud, str): aud_ok = aud == target_client_id elif isinstance(aud, list): aud_ok = target_client_id in aud else: aud_ok = False if not aud_ok: raise SystemExit(f"token exchange aud mismatch (expected {target_client_id!r})") deadline_seconds = int(os.environ.get("RETRY_DEADLINE_SECONDS", "300")) retry_interval_seconds = int(os.environ.get("RETRY_INTERVAL_SECONDS", "5")) deadline_at = time.monotonic() + deadline_seconds last_error: str | None = None while True: try: run_once() print("PASS: token exchange works") return 0 except SystemExit as exc: message = str(exc) last_error = message or last_error if time.monotonic() >= deadline_at: raise if not _is_retryable_failure(message): raise time.sleep(retry_interval_seconds) except Exception as exc: last_error = str(exc) or last_error if time.monotonic() >= deadline_at: raise SystemExit(str(exc)) time.sleep(retry_interval_seconds) if __name__ == "__main__": sys.exit(main())