161 lines
4.8 KiB
Python
161 lines
4.8 KiB
Python
from __future__ import annotations
|
|
|
|
import jwt
|
|
import pytest
|
|
|
|
from ariadne.auth.keycloak import Authenticator, KeycloakOIDC
|
|
|
|
|
|
def _make_token(kid: str = "test") -> str:
|
|
return jwt.encode(
|
|
{"sub": "user"},
|
|
"secret",
|
|
algorithm="HS256",
|
|
headers={"kid": kid},
|
|
)
|
|
|
|
|
|
def test_keycloak_verify_accepts_matching_audience(monkeypatch) -> None:
|
|
token = _make_token()
|
|
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
|
|
|
|
monkeypatch.setattr(kc, "_get_jwks", lambda force=False: {"keys": [{"kid": "test"}]})
|
|
monkeypatch.setattr(jwt.algorithms.RSAAlgorithm, "from_jwk", lambda key: "dummy")
|
|
monkeypatch.setattr(
|
|
jwt,
|
|
"decode",
|
|
lambda *args, **kwargs: {"azp": "portal", "preferred_username": "alice", "groups": ["/admin"]},
|
|
)
|
|
|
|
claims = kc.verify(token)
|
|
assert claims["preferred_username"] == "alice"
|
|
|
|
|
|
def test_keycloak_verify_rejects_wrong_audience(monkeypatch) -> None:
|
|
token = _make_token()
|
|
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
|
|
|
|
monkeypatch.setattr(kc, "_get_jwks", lambda force=False: {"keys": [{"kid": "test"}]})
|
|
monkeypatch.setattr(jwt.algorithms.RSAAlgorithm, "from_jwk", lambda key: "dummy")
|
|
monkeypatch.setattr(
|
|
jwt,
|
|
"decode",
|
|
lambda *args, **kwargs: {"azp": "other", "aud": ["other"]},
|
|
)
|
|
|
|
with pytest.raises(ValueError):
|
|
kc.verify(token)
|
|
|
|
|
|
def test_keycloak_verify_missing_token() -> None:
|
|
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
|
|
with pytest.raises(ValueError):
|
|
kc.verify("")
|
|
|
|
|
|
def test_keycloak_verify_missing_kid(monkeypatch) -> None:
|
|
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
|
|
monkeypatch.setattr(jwt, "get_unverified_header", lambda token: {})
|
|
|
|
with pytest.raises(ValueError):
|
|
kc.verify("header.payload.sig")
|
|
|
|
|
|
def test_keycloak_verify_refreshes_jwks(monkeypatch) -> None:
|
|
token = _make_token()
|
|
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
|
|
calls = {"force": []}
|
|
|
|
def fake_get_jwks(force=False):
|
|
calls["force"].append(force)
|
|
if not force:
|
|
return {"keys": [{"kid": "other"}]}
|
|
return {"keys": [{"kid": "test"}]}
|
|
|
|
monkeypatch.setattr(kc, "_get_jwks", fake_get_jwks)
|
|
monkeypatch.setattr(jwt.algorithms.RSAAlgorithm, "from_jwk", lambda key: "dummy")
|
|
monkeypatch.setattr(
|
|
jwt,
|
|
"decode",
|
|
lambda *args, **kwargs: {"azp": "other", "aud": "portal", "preferred_username": "alice"},
|
|
)
|
|
|
|
claims = kc.verify(token)
|
|
assert calls["force"] == [False, True]
|
|
assert claims["preferred_username"] == "alice"
|
|
|
|
|
|
def test_keycloak_verify_kid_not_found(monkeypatch) -> None:
|
|
token = _make_token()
|
|
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
|
|
monkeypatch.setattr(kc, "_get_jwks", lambda force=False: {"keys": []})
|
|
|
|
with pytest.raises(ValueError):
|
|
kc.verify(token)
|
|
|
|
|
|
def test_authenticator_normalizes_groups(monkeypatch) -> None:
|
|
token = _make_token()
|
|
auth = Authenticator()
|
|
|
|
monkeypatch.setattr(
|
|
auth._oidc,
|
|
"verify",
|
|
lambda token: {"preferred_username": "bob", "groups": ["/admin", 123, "dev"]},
|
|
)
|
|
|
|
ctx = auth.authenticate(token)
|
|
assert ctx.username == "bob"
|
|
assert ctx.groups == ["admin", "dev"]
|
|
|
|
|
|
def test_authenticator_normalizes_groups_non_list(monkeypatch) -> None:
|
|
token = _make_token()
|
|
auth = Authenticator()
|
|
|
|
monkeypatch.setattr(auth._oidc, "verify", lambda token: {"preferred_username": "bob", "groups": "admin"})
|
|
|
|
ctx = auth.authenticate(token)
|
|
assert ctx.groups == []
|
|
|
|
|
|
def test_keycloak_get_jwks_invalid_payload(monkeypatch) -> None:
|
|
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
|
|
|
|
class DummyClient:
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def get(self, url):
|
|
return type("Resp", (), {"raise_for_status": lambda self: None, "json": lambda self: []})()
|
|
|
|
monkeypatch.setattr("ariadne.auth.keycloak.httpx.Client", lambda *args, **kwargs: DummyClient())
|
|
|
|
with pytest.raises(ValueError):
|
|
kc._get_jwks(force=True)
|
|
|
|
|
|
def test_keycloak_get_jwks_cached(monkeypatch) -> None:
|
|
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
|
|
calls = {"count": 0}
|
|
|
|
class DummyClient:
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
def get(self, url):
|
|
calls["count"] += 1
|
|
return type("Resp", (), {"raise_for_status": lambda self: None, "json": lambda self: {"keys": []}})()
|
|
|
|
monkeypatch.setattr("ariadne.auth.keycloak.httpx.Client", lambda *args, **kwargs: DummyClient())
|
|
|
|
kc._get_jwks(force=True)
|
|
kc._get_jwks(force=False)
|
|
assert calls["count"] == 1
|