158 lines
5.6 KiB
Python
158 lines
5.6 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
import base64
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
import time
|
||
|
|
import urllib.parse
|
||
|
|
import urllib.error
|
||
|
|
import urllib.request
|
||
|
|
|
||
|
|
|
||
|
|
def _require_env(name: str) -> str:
|
||
|
|
value = os.environ.get(name)
|
||
|
|
if not value:
|
||
|
|
raise SystemExit(f"missing required env var: {name}")
|
||
|
|
return value
|
||
|
|
|
||
|
|
|
||
|
|
def _post_form(url: str, data: dict[str, str], token: str | None = None, timeout_s: int = 30) -> dict:
|
||
|
|
body = urllib.parse.urlencode(data).encode()
|
||
|
|
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||
|
|
if token:
|
||
|
|
headers["Authorization"] = f"Bearer {token}"
|
||
|
|
req = urllib.request.Request(url, data=body, headers=headers, method="POST")
|
||
|
|
try:
|
||
|
|
with urllib.request.urlopen(req, timeout=timeout_s) as resp:
|
||
|
|
payload = resp.read().decode()
|
||
|
|
return json.loads(payload) if payload else {}
|
||
|
|
except urllib.error.HTTPError as exc:
|
||
|
|
raw = exc.read().decode(errors="replace")
|
||
|
|
raise SystemExit(f"HTTP {exc.code} from {url}: {raw}")
|
||
|
|
|
||
|
|
|
||
|
|
def _get_json(url: str, token: str, timeout_s: int = 30) -> object:
|
||
|
|
req = urllib.request.Request(url, headers={"Authorization": f"Bearer {token}"}, method="GET")
|
||
|
|
try:
|
||
|
|
with urllib.request.urlopen(req, timeout=timeout_s) as resp:
|
||
|
|
payload = resp.read().decode()
|
||
|
|
return json.loads(payload) if payload else None
|
||
|
|
except urllib.error.HTTPError as exc:
|
||
|
|
raw = exc.read().decode(errors="replace")
|
||
|
|
raise SystemExit(f"HTTP {exc.code} from {url}: {raw}")
|
||
|
|
|
||
|
|
|
||
|
|
def _decode_jwt_without_verification(jwt: str) -> dict:
|
||
|
|
parts = jwt.split(".")
|
||
|
|
if len(parts) < 2:
|
||
|
|
return {}
|
||
|
|
padded = parts[1] + "=" * (-len(parts[1]) % 4)
|
||
|
|
try:
|
||
|
|
return json.loads(base64.urlsafe_b64decode(padded.encode()).decode())
|
||
|
|
except Exception:
|
||
|
|
return {}
|
||
|
|
|
||
|
|
def _is_retryable_failure(message: str) -> bool:
|
||
|
|
retryable_markers = (
|
||
|
|
"HTTP 401 ",
|
||
|
|
"HTTP 403 ",
|
||
|
|
"HTTP 404 ",
|
||
|
|
"HTTP 409 ",
|
||
|
|
"HTTP 429 ",
|
||
|
|
"HTTP 500 ",
|
||
|
|
"HTTP 502 ",
|
||
|
|
"HTTP 503 ",
|
||
|
|
"HTTP 504 ",
|
||
|
|
"timed out",
|
||
|
|
"Temporary failure",
|
||
|
|
"Connection refused",
|
||
|
|
)
|
||
|
|
return any(marker in message for marker in retryable_markers)
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> int:
|
||
|
|
keycloak_base = _require_env("KEYCLOAK_SERVER").rstrip("/")
|
||
|
|
realm = os.environ.get("KEYCLOAK_REALM", "atlas")
|
||
|
|
client_id = _require_env("PORTAL_E2E_CLIENT_ID")
|
||
|
|
client_secret = _require_env("PORTAL_E2E_CLIENT_SECRET")
|
||
|
|
target_client_id = os.environ.get("TARGET_CLIENT_ID", "bstein-dev-home")
|
||
|
|
impersonate_username = os.environ.get("IMPERSONATE_USERNAME", "robotuser")
|
||
|
|
|
||
|
|
token_url = f"{keycloak_base}/realms/{realm}/protocol/openid-connect/token"
|
||
|
|
admin_users_url = f"{keycloak_base}/admin/realms/{realm}/users"
|
||
|
|
|
||
|
|
def run_once() -> None:
|
||
|
|
token_payload = _post_form(
|
||
|
|
token_url,
|
||
|
|
{"grant_type": "client_credentials", "client_id": client_id, "client_secret": client_secret},
|
||
|
|
)
|
||
|
|
access_token = token_payload.get("access_token")
|
||
|
|
if not isinstance(access_token, str) or not access_token:
|
||
|
|
raise SystemExit("client credentials token missing access_token")
|
||
|
|
|
||
|
|
users = _get_json(
|
||
|
|
f"{admin_users_url}?{urllib.parse.urlencode({'username': impersonate_username, 'exact': 'true'})}",
|
||
|
|
access_token,
|
||
|
|
)
|
||
|
|
if not isinstance(users, list) or not users:
|
||
|
|
raise SystemExit(f"unable to locate user {impersonate_username!r} via admin API")
|
||
|
|
user_id = users[0].get("id")
|
||
|
|
if not isinstance(user_id, str) or not user_id:
|
||
|
|
raise SystemExit(f"user {impersonate_username!r} missing id")
|
||
|
|
|
||
|
|
exchange_payload = _post_form(
|
||
|
|
token_url,
|
||
|
|
{
|
||
|
|
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
|
||
|
|
"client_id": client_id,
|
||
|
|
"client_secret": client_secret,
|
||
|
|
"subject_token": access_token,
|
||
|
|
"requested_subject": user_id,
|
||
|
|
"audience": target_client_id,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
exchanged = exchange_payload.get("access_token")
|
||
|
|
if not isinstance(exchanged, str) or not exchanged:
|
||
|
|
raise SystemExit("token exchange response missing access_token")
|
||
|
|
|
||
|
|
claims = _decode_jwt_without_verification(exchanged)
|
||
|
|
aud = claims.get("aud")
|
||
|
|
if aud is None:
|
||
|
|
raise SystemExit("token exchange access_token missing aud claim")
|
||
|
|
if isinstance(aud, str):
|
||
|
|
aud_ok = aud == target_client_id
|
||
|
|
elif isinstance(aud, list):
|
||
|
|
aud_ok = target_client_id in aud
|
||
|
|
else:
|
||
|
|
aud_ok = False
|
||
|
|
if not aud_ok:
|
||
|
|
raise SystemExit(f"token exchange aud mismatch (expected {target_client_id!r})")
|
||
|
|
|
||
|
|
deadline_seconds = int(os.environ.get("RETRY_DEADLINE_SECONDS", "300"))
|
||
|
|
retry_interval_seconds = int(os.environ.get("RETRY_INTERVAL_SECONDS", "5"))
|
||
|
|
deadline_at = time.monotonic() + deadline_seconds
|
||
|
|
last_error: str | None = None
|
||
|
|
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
run_once()
|
||
|
|
print("PASS: token exchange works")
|
||
|
|
return 0
|
||
|
|
except SystemExit as exc:
|
||
|
|
message = str(exc)
|
||
|
|
last_error = message or last_error
|
||
|
|
if time.monotonic() >= deadline_at:
|
||
|
|
raise
|
||
|
|
if not _is_retryable_failure(message):
|
||
|
|
raise
|
||
|
|
time.sleep(retry_interval_seconds)
|
||
|
|
except Exception as exc:
|
||
|
|
last_error = str(exc) or last_error
|
||
|
|
if time.monotonic() >= deadline_at:
|
||
|
|
raise SystemExit(str(exc))
|
||
|
|
time.sleep(retry_interval_seconds)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
sys.exit(main())
|