from __future__ import annotations import types import pytest from ariadne.services import vault as vault_module from ariadne.services.vault import VaultClient, VaultService, _read_file class DummyResponse: def __init__(self, payload=None, status_code=200): self._payload = payload or {} self.status_code = status_code def json(self): return self._payload def raise_for_status(self): if self.status_code >= 400: raise RuntimeError("status error") def vault_settings(**overrides): base = { "vault_addr": "http://vault", "vault_token": "", "vault_k8s_role": "vault", "vault_k8s_role_ttl": "1h", "vault_k8s_token_reviewer_jwt": "", "vault_k8s_token_reviewer_jwt_file": "", "vault_oidc_discovery_url": "http://oidc", "vault_oidc_client_id": "client", "vault_oidc_client_secret": "secret", "vault_oidc_default_role": "admin", "vault_oidc_scopes": "", "vault_oidc_user_claim": "", "vault_oidc_groups_claim": "", "vault_oidc_token_policies": "", "vault_oidc_admin_group": "admin", "vault_oidc_admin_policies": "default,vault-admin", "vault_oidc_dev_group": "dev", "vault_oidc_dev_policies": "default,dev-kv", "vault_oidc_user_group": "", "vault_oidc_user_policies": "", "vault_oidc_redirect_uris": "https://secret.bstein.dev/ui/vault/auth/oidc/oidc/callback", "vault_oidc_bound_audiences": "", "vault_oidc_bound_claims_type": "", "k8s_api_timeout_sec": 5.0, } base.update(overrides) return types.SimpleNamespace(**base) def test_read_file_returns_stripped_contents(tmp_path) -> None: token_file = tmp_path / "token" token_file.write_text(" jwt\n", encoding="utf-8") assert _read_file(str(token_file)) == "jwt" def test_vault_client_attaches_token_header(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings()) captured = {} def fake_request(method, url, headers=None, json=None, timeout=None): captured.update(method=method, url=url, headers=headers, json=json, timeout=timeout) return DummyResponse() monkeypatch.setattr(vault_module.httpx, "request", fake_request) VaultClient("http://vault/", "tok").request("POST", "/v1/sys", json={"ok": True}) assert captured == { "method": "POST", "url": "http://vault/v1/sys", "headers": {"X-Vault-Token": "tok"}, "json": {"ok": True}, "timeout": 5.0, } def test_ensure_token_prefers_cached_token() -> None: svc = VaultService() svc._token = "cached" assert svc._ensure_token() == "cached" def test_ensure_token_reads_configured_jwt_file(monkeypatch, tmp_path) -> None: jwt_file = tmp_path / "reviewer.jwt" jwt_file.write_text("reviewer-jwt\n", encoding="utf-8") monkeypatch.setattr( vault_module, "settings", vault_settings(vault_k8s_token_reviewer_jwt_file=str(jwt_file)), ) posted = {} def fake_post(url, json=None, timeout=None): posted.update(url=url, json=json, timeout=timeout) return DummyResponse({"auth": {"client_token": "vault-token"}}) monkeypatch.setattr(vault_module.httpx, "post", fake_post) assert VaultService()._ensure_token() == "vault-token" assert posted["json"]["jwt"] == "reviewer-jwt" def test_ensure_token_falls_back_to_service_account_jwt(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings()) monkeypatch.setattr(vault_module, "_read_file", lambda path: "service-account-jwt") posted = {} def fake_post(_url, json=None, timeout=None): posted["json"] = json return DummyResponse({"auth": {"client_token": "vault-token"}}) monkeypatch.setattr(vault_module.httpx, "post", fake_post) assert VaultService()._ensure_token() == "vault-token" assert posted["json"]["jwt"] == "service-account-jwt" def test_ensure_token_requires_a_jwt(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings()) monkeypatch.setattr(vault_module, "_read_file", lambda path: "") with pytest.raises(RuntimeError, match="vault auth jwt missing"): VaultService()._ensure_token() def test_ensure_token_requires_login_token(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings(vault_k8s_token_reviewer_jwt="jwt")) monkeypatch.setattr(vault_module.httpx, "post", lambda *args, **kwargs: DummyResponse({"auth": {}})) with pytest.raises(RuntimeError, match="vault login token missing"): VaultService()._ensure_token() def test_vault_ready_reports_health_error(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings()) svc = VaultService() monkeypatch.setattr(svc, "_health", lambda _client: (_ for _ in ()).throw(RuntimeError("boom"))) result = svc._vault_ready() assert result is not None assert result.status == "error" assert result.detail == "boom" def test_vault_ready_skips_uninitialized_and_sealed(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings()) svc = VaultService() monkeypatch.setattr(svc, "_health", lambda _client: {"initialized": False, "sealed": False}) assert svc._vault_ready().detail == "vault not initialized" monkeypatch.setattr(svc, "_health", lambda _client: {"initialized": True, "sealed": True}) assert svc._vault_ready().detail == "vault sealed" def test_validate_oidc_settings_requires_credentials(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings(vault_oidc_client_secret="")) assert VaultService()._validate_oidc_settings() == "oidc client credentials missing" def test_tune_oidc_listing_ignores_best_effort_failure() -> None: class FailingClient: def request(self, *args, **kwargs): raise RuntimeError("tune unsupported") VaultService()._tune_oidc_listing(FailingClient()) def test_oidc_payload_skips_empty_groups_or_policies(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings()) svc = VaultService() context = svc._oidc_context() assert svc._oidc_role_payload(context, "", "default") is None assert svc._oidc_role_payload(context, "dev", "") is None def test_sync_k8s_auth_reports_health_failures(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings()) svc = VaultService() monkeypatch.setattr(svc, "_health", lambda _client: (_ for _ in ()).throw(RuntimeError("offline"))) assert svc.sync_k8s_auth()["status"] == "error" def test_sync_k8s_auth_skips_when_vault_unavailable(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings()) svc = VaultService() monkeypatch.setattr(svc, "_health", lambda _client: {"initialized": False, "sealed": False}) assert svc.sync_k8s_auth()["detail"] == "vault not initialized" monkeypatch.setattr(svc, "_health", lambda _client: {"initialized": True, "sealed": True}) assert svc.sync_k8s_auth()["detail"] == "vault sealed" def test_sync_k8s_auth_reads_reviewer_jwt_file(monkeypatch, tmp_path) -> None: jwt_file = tmp_path / "reviewer.jwt" jwt_file.write_text("reviewer-jwt", encoding="utf-8") monkeypatch.setattr( vault_module, "settings", vault_settings(vault_token="token", vault_k8s_token_reviewer_jwt_file=str(jwt_file)), ) configs = [] def fake_request(self, method, path, json=None): if path == "/v1/sys/health": return DummyResponse({"initialized": True, "sealed": False}) if path == "/v1/sys/auth": return DummyResponse({"kubernetes/": {}}) if path == "/v1/auth/kubernetes/config": configs.append(json) return DummyResponse() monkeypatch.setattr(vault_module.VaultClient, "request", fake_request) assert VaultService().sync_k8s_auth()["status"] == "ok" assert configs[0]["token_reviewer_jwt"] == "reviewer-jwt" def test_sync_oidc_returns_readiness_status(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings()) svc = VaultService() monkeypatch.setattr(svc, "_vault_ready", lambda: vault_module.VaultResult("skip", "vault sealed")) assert svc.sync_oidc() == {"status": "skip", "detail": "vault sealed"} def test_sync_oidc_skips_roles_without_payload(monkeypatch) -> None: monkeypatch.setattr(vault_module, "settings", vault_settings(vault_token="token")) calls = [] def fake_request(self, method, path, json=None): calls.append((method, path, json)) if path == "/v1/sys/health": return DummyResponse({"initialized": True, "sealed": False}) if path == "/v1/sys/auth": return DummyResponse({"oidc/": {}}) return DummyResponse() monkeypatch.setattr(vault_module.VaultClient, "request", fake_request) svc = VaultService() monkeypatch.setattr(svc, "_oidc_roles", lambda: [("empty", "", "default"), ("dev", "dev", "default")]) assert svc.sync_oidc()["status"] == "ok" assert [path for _, path, _ in calls if "/v1/auth/oidc/role/" in path] == ["/v1/auth/oidc/role/dev"]