from __future__ import annotations import jwt import pytest from ariadne.auth.keycloak import Authenticator, KeycloakOIDC def _make_token(kid: str = "test") -> str: return jwt.encode( {"sub": "user"}, "secret", algorithm="HS256", headers={"kid": kid}, ) def test_keycloak_verify_accepts_matching_audience(monkeypatch) -> None: token = _make_token() kc = KeycloakOIDC("https://jwks", "https://issuer", "portal") monkeypatch.setattr(kc, "_get_jwks", lambda force=False: {"keys": [{"kid": "test"}]}) monkeypatch.setattr(jwt.algorithms.RSAAlgorithm, "from_jwk", lambda key: "dummy") monkeypatch.setattr( jwt, "decode", lambda *args, **kwargs: {"azp": "portal", "preferred_username": "alice", "groups": ["/admin"]}, ) claims = kc.verify(token) assert claims["preferred_username"] == "alice" def test_keycloak_verify_rejects_wrong_audience(monkeypatch) -> None: token = _make_token() kc = KeycloakOIDC("https://jwks", "https://issuer", "portal") monkeypatch.setattr(kc, "_get_jwks", lambda force=False: {"keys": [{"kid": "test"}]}) monkeypatch.setattr(jwt.algorithms.RSAAlgorithm, "from_jwk", lambda key: "dummy") monkeypatch.setattr( jwt, "decode", lambda *args, **kwargs: {"azp": "other", "aud": ["other"]}, ) with pytest.raises(ValueError): kc.verify(token) def test_keycloak_verify_missing_token() -> None: kc = KeycloakOIDC("https://jwks", "https://issuer", "portal") with pytest.raises(ValueError): kc.verify("") def test_keycloak_verify_missing_kid(monkeypatch) -> None: kc = KeycloakOIDC("https://jwks", "https://issuer", "portal") monkeypatch.setattr(jwt, "get_unverified_header", lambda token: {}) with pytest.raises(ValueError): kc.verify("header.payload.sig") def test_keycloak_verify_refreshes_jwks(monkeypatch) -> None: token = _make_token() kc = KeycloakOIDC("https://jwks", "https://issuer", "portal") calls = {"force": []} def fake_get_jwks(force=False): calls["force"].append(force) if not force: return {"keys": [{"kid": "other"}]} return {"keys": [{"kid": "test"}]} monkeypatch.setattr(kc, "_get_jwks", fake_get_jwks) monkeypatch.setattr(jwt.algorithms.RSAAlgorithm, "from_jwk", lambda key: "dummy") monkeypatch.setattr( jwt, "decode", lambda *args, **kwargs: {"azp": "other", "aud": "portal", "preferred_username": "alice"}, ) claims = kc.verify(token) assert calls["force"] == [False, True] assert claims["preferred_username"] == "alice" def test_keycloak_verify_kid_not_found(monkeypatch) -> None: token = _make_token() kc = KeycloakOIDC("https://jwks", "https://issuer", "portal") monkeypatch.setattr(kc, "_get_jwks", lambda force=False: {"keys": []}) with pytest.raises(ValueError): kc.verify(token) def test_authenticator_normalizes_groups(monkeypatch) -> None: token = _make_token() auth = Authenticator() monkeypatch.setattr( auth._oidc, "verify", lambda token: {"preferred_username": "bob", "groups": ["/admin", 123, "dev"]}, ) ctx = auth.authenticate(token) assert ctx.username == "bob" assert ctx.groups == ["admin", "dev"] def test_authenticator_normalizes_groups_non_list(monkeypatch) -> None: token = _make_token() auth = Authenticator() monkeypatch.setattr(auth._oidc, "verify", lambda token: {"preferred_username": "bob", "groups": "admin"}) ctx = auth.authenticate(token) assert ctx.groups == [] def test_keycloak_get_jwks_invalid_payload(monkeypatch) -> None: kc = KeycloakOIDC("https://jwks", "https://issuer", "portal") class DummyClient: def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def get(self, url): return type("Resp", (), {"raise_for_status": lambda self: None, "json": lambda self: []})() monkeypatch.setattr("ariadne.auth.keycloak.httpx.Client", lambda *args, **kwargs: DummyClient()) with pytest.raises(ValueError): kc._get_jwks(force=True) def test_keycloak_get_jwks_cached(monkeypatch) -> None: kc = KeycloakOIDC("https://jwks", "https://issuer", "portal") calls = {"count": 0} class DummyClient: def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def get(self, url): calls["count"] += 1 return type("Resp", (), {"raise_for_status": lambda self: None, "json": lambda self: {"keys": []}})() monkeypatch.setattr("ariadne.auth.keycloak.httpx.Client", lambda *args, **kwargs: DummyClient()) kc._get_jwks(force=True) kc._get_jwks(force=False) assert calls["count"] == 1