ariadne/tests/test_auth.py

161 lines
4.8 KiB
Python
Raw Normal View History

2026-01-19 19:01:32 -03:00
from __future__ import annotations
import jwt
import pytest
from ariadne.auth.keycloak import Authenticator, KeycloakOIDC
def _make_token(kid: str = "test") -> str:
return jwt.encode(
{"sub": "user"},
"secret",
algorithm="HS256",
headers={"kid": kid},
)
def test_keycloak_verify_accepts_matching_audience(monkeypatch) -> None:
token = _make_token()
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
monkeypatch.setattr(kc, "_get_jwks", lambda force=False: {"keys": [{"kid": "test"}]})
monkeypatch.setattr(jwt.algorithms.RSAAlgorithm, "from_jwk", lambda key: "dummy")
monkeypatch.setattr(
jwt,
"decode",
lambda *args, **kwargs: {"azp": "portal", "preferred_username": "alice", "groups": ["/admin"]},
)
claims = kc.verify(token)
assert claims["preferred_username"] == "alice"
def test_keycloak_verify_rejects_wrong_audience(monkeypatch) -> None:
token = _make_token()
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
monkeypatch.setattr(kc, "_get_jwks", lambda force=False: {"keys": [{"kid": "test"}]})
monkeypatch.setattr(jwt.algorithms.RSAAlgorithm, "from_jwk", lambda key: "dummy")
monkeypatch.setattr(
jwt,
"decode",
lambda *args, **kwargs: {"azp": "other", "aud": ["other"]},
)
with pytest.raises(ValueError):
kc.verify(token)
def test_keycloak_verify_missing_token() -> None:
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
with pytest.raises(ValueError):
kc.verify("")
2026-01-19 19:01:32 -03:00
def test_keycloak_verify_missing_kid(monkeypatch) -> None:
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
monkeypatch.setattr(jwt, "get_unverified_header", lambda token: {})
with pytest.raises(ValueError):
kc.verify("header.payload.sig")
def test_keycloak_verify_refreshes_jwks(monkeypatch) -> None:
token = _make_token()
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
calls = {"force": []}
def fake_get_jwks(force=False):
calls["force"].append(force)
if not force:
return {"keys": [{"kid": "other"}]}
return {"keys": [{"kid": "test"}]}
monkeypatch.setattr(kc, "_get_jwks", fake_get_jwks)
monkeypatch.setattr(jwt.algorithms.RSAAlgorithm, "from_jwk", lambda key: "dummy")
monkeypatch.setattr(
jwt,
"decode",
lambda *args, **kwargs: {"azp": "other", "aud": "portal", "preferred_username": "alice"},
)
claims = kc.verify(token)
assert calls["force"] == [False, True]
assert claims["preferred_username"] == "alice"
def test_keycloak_verify_kid_not_found(monkeypatch) -> None:
token = _make_token()
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
monkeypatch.setattr(kc, "_get_jwks", lambda force=False: {"keys": []})
with pytest.raises(ValueError):
kc.verify(token)
2026-01-19 19:01:32 -03:00
def test_authenticator_normalizes_groups(monkeypatch) -> None:
token = _make_token()
auth = Authenticator()
monkeypatch.setattr(
auth._oidc,
"verify",
lambda token: {"preferred_username": "bob", "groups": ["/admin", 123, "dev"]},
)
2026-01-19 19:01:32 -03:00
ctx = auth.authenticate(token)
assert ctx.username == "bob"
assert ctx.groups == ["admin", "dev"]
def test_authenticator_normalizes_groups_non_list(monkeypatch) -> None:
token = _make_token()
auth = Authenticator()
monkeypatch.setattr(auth._oidc, "verify", lambda token: {"preferred_username": "bob", "groups": "admin"})
ctx = auth.authenticate(token)
assert ctx.groups == []
def test_keycloak_get_jwks_invalid_payload(monkeypatch) -> None:
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
class DummyClient:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def get(self, url):
return type("Resp", (), {"raise_for_status": lambda self: None, "json": lambda self: []})()
monkeypatch.setattr("ariadne.auth.keycloak.httpx.Client", lambda *args, **kwargs: DummyClient())
with pytest.raises(ValueError):
kc._get_jwks(force=True)
def test_keycloak_get_jwks_cached(monkeypatch) -> None:
kc = KeycloakOIDC("https://jwks", "https://issuer", "portal")
calls = {"count": 0}
class DummyClient:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def get(self, url):
calls["count"] += 1
return type("Resp", (), {"raise_for_status": lambda self: None, "json": lambda self: {"keys": []}})()
monkeypatch.setattr("ariadne.auth.keycloak.httpx.Client", lambda *args, **kwargs: DummyClient())
kc._get_jwks(force=True)
kc._get_jwks(force=False)
assert calls["count"] == 1