from __future__ import annotations import json from types import SimpleNamespace from unittest import TestCase, mock import pytest from atlas_portal.app_factory import create_app from atlas_portal.routes import ai class AiRouteTests(TestCase): @classmethod def setUpClass(cls): cls.app = create_app() cls.client = cls.app.test_client() def test_chat_routes_profiles_to_modes(self): seen: list[tuple[str, str]] = [] def fake_atlasbot_answer(message: str, mode: str, conversation_id: str) -> str: seen.append((mode, conversation_id)) return f"{mode}:{conversation_id}" with mock.patch.object(ai, "_atlasbot_answer", side_effect=fake_atlasbot_answer): for profile, expected_mode in ( ("atlas-quick", "quick"), ("atlas-smart", "smart"), ("atlas-genius", "genius"), ): resp = self.client.post( "/api/chat", data=json.dumps( { "message": "How is Titan doing?", "profile": profile, "conversation_id": f"conv-{profile}", } ), content_type="application/json", ) data = resp.get_json() self.assertEqual(resp.status_code, 200) self.assertEqual(data.get("source"), f"atlas-{expected_mode}") self.assertEqual(data.get("reply"), f"{expected_mode}:conv-{profile}") self.assertEqual( seen, [ ("quick", "conv-atlas-quick"), ("smart", "conv-atlas-smart"), ("genius", "conv-atlas-genius"), ], ) def test_info_endpoint_exposes_profile_specific_model(self): with mock.patch.object(ai.settings, "AI_ATLASBOT_MODEL_GENIUS", "genius-model"): resp = self.client.get("/api/ai/info?profile=atlas-genius") data = resp.get_json() self.assertEqual(resp.status_code, 200) self.assertEqual(data.get("profile"), "atlas-genius") self.assertEqual(data.get("model"), "genius-model") def test_atlasbot_answer_uses_profile_specific_timeout(self): captured: dict[str, object] = {} class DummyResponse: status_code = 200 def json(self): return {"reply": "atlas reply"} class DummyClient: def __init__(self, timeout): captured["timeout"] = timeout def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def post(self, endpoint, json=None, headers=None): captured["endpoint"] = endpoint captured["json"] = json captured["headers"] = headers or {} return DummyResponse() with ( mock.patch.object(ai.httpx, "Client", DummyClient), mock.patch.object(ai.settings, "AI_ATLASBOT_ENDPOINT", "http://atlasbot.invalid/v1/answer"), mock.patch.object(ai.settings, "AI_ATLASBOT_TOKEN", "internal-token"), ): reply = ai._atlasbot_answer("How is Titan doing?", "genius", "conv-1") self.assertEqual(reply, "atlas reply") self.assertEqual(captured["timeout"], ai.settings.AI_ATLASBOT_TIMEOUT_GENIUS_SEC) self.assertEqual(captured["json"], {"prompt": "How is Titan doing?", "mode": "genius", "conversation_id": "conv-1"}) self.assertEqual(captured["headers"], {"X-Internal-Token": "internal-token"}) def test_chat_returns_fallback_when_atlasbot_returns_empty(self): for profile, expected in ( ("atlas-quick", "Quick mode hit"), ("atlas-smart", "Smart mode hit"), ("atlas-genius", "Atlas genius mode timed out"), ): with mock.patch.object(ai, "_atlasbot_answer", return_value=""): resp = self.client.post( "/api/chat", data=json.dumps( { "message": "How is Titan doing?", "profile": profile, } ), content_type="application/json", ) data = resp.get_json() self.assertEqual(resp.status_code, 200) self.assertIn(expected, data.get("reply", "")) def test_chat_requires_message() -> None: client = create_app().test_client() response = client.post("/api/chat", data=json.dumps({"message": ""}), content_type="application/json") assert response.status_code == 400 assert response.get_json()["error"] == "message required" def test_atlasbot_answer_soft_failure_paths(monkeypatch) -> None: monkeypatch.setattr(ai.settings, "AI_ATLASBOT_ENDPOINT", "") assert ai._atlasbot_answer("hello", "quick", "") == "" class NonOkClient: def __init__(self, timeout): pass def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def post(self, endpoint, json=None, headers=None): return SimpleNamespace(status_code=503) monkeypatch.setattr(ai.settings, "AI_ATLASBOT_ENDPOINT", "http://atlasbot") monkeypatch.setattr(ai.httpx, "Client", NonOkClient) assert ai._atlasbot_answer("hello", "smart", "") == "" class BadJsonClient(NonOkClient): def post(self, endpoint, json=None, headers=None): return SimpleNamespace(status_code=200, json=lambda: (_ for _ in ()).throw(ValueError("bad"))) monkeypatch.setattr(ai.httpx, "Client", BadJsonClient) assert ai._atlasbot_answer("hello", "quick", "") == "" assert ai._atlasbot_timeout_sec("smart") == ai.settings.AI_ATLASBOT_TIMEOUT_SMART_SEC assert ai._atlasbot_timeout_sec("quick") == ai.settings.AI_ATLASBOT_TIMEOUT_QUICK_SEC def test_discover_ai_meta_reads_pod_annotations(monkeypatch) -> None: class FakePath: def __init__(self, value): self.value = value def __truediv__(self, child): return FakePath(child) def exists(self): return True def read_text(self): if self.value == "token": return "token" if self.value == "namespace": return "ai" return "ca" def __str__(self): return self.value class PodClient: def __init__(self, **kwargs): self.kwargs = kwargs def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def get(self, url): return SimpleNamespace( raise_for_status=lambda: None, json=lambda: { "items": [ { "status": {"phase": "Pending"}, "spec": { "nodeName": "titan-24", "containers": [{"image": "registry/atlasbot:model-from-image"}], }, "metadata": { "annotations": { ai.settings.AI_GPU_ANNOTATION: "RTX 3090", ai.settings.AI_MODEL_ANNOTATION: "annotated-model", } }, } ] }, ) monkeypatch.setattr(ai, "Path", FakePath) monkeypatch.setattr(ai.httpx, "Client", PodClient) meta = ai._discover_ai_meta("atlas-quick") assert meta["node"] == "titan-24" assert meta["gpu"] == "RTX 3090" assert meta["model"] == "annotated-model" class ImageOnlyClient(PodClient): def get(self, url): return SimpleNamespace( raise_for_status=lambda: None, json=lambda: { "items": [ { "status": {"phase": "Running"}, "spec": { "nodeName": "titan-22", "containers": [{"image": "registry/atlasbot:model-from-image"}], }, "metadata": {"annotations": {}}, } ] }, ) monkeypatch.setattr(ai.httpx, "Client", ImageOnlyClient) image_meta = ai._discover_ai_meta("atlas-smart") assert image_meta["endpoint"] == "/api/ai/chat" assert image_meta["model"] == "model-from-image" def test_discover_ai_meta_handles_probe_errors(monkeypatch) -> None: class MissingPath: def __init__(self, value): self.value = value def __truediv__(self, child): return MissingPath(child) def exists(self): return False monkeypatch.setattr(ai, "Path", MissingPath) assert ai._discover_ai_meta("quick")["endpoint"] == "/api/ai/chat" class ExistingPath(MissingPath): def __truediv__(self, child): return ExistingPath(child) def exists(self): return True def read_text(self): return "token" def __str__(self): return self.value class FailingClient: def __init__(self, **kwargs): pass def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def get(self, url): raise RuntimeError("offline") monkeypatch.setattr(ai, "Path", ExistingPath) monkeypatch.setattr(ai.httpx, "Client", FailingClient) assert ai._discover_ai_meta("atlas-genius")["endpoint"] == "/api/ai/chat" def test_start_keep_warm_disabled_and_loop(monkeypatch) -> None: monkeypatch.setattr(ai.settings, "AI_WARM_ENABLED", False) ai._start_keep_warm() posts: list[dict] = [] class WarmClient: def __init__(self, timeout): self.timeout = timeout def __enter__(self): return self def __exit__(self, exc_type, exc, tb): return False def post(self, url, json=None): posts.append({"url": url, "json": json}) sleeps = {"count": 0} def fake_sleep(seconds): sleeps["count"] += 1 if sleeps["count"] > 1: raise KeyboardInterrupt() class InlineThread: def __init__(self, target, daemon, name): self.target = target def start(self): self.target() monkeypatch.setattr(ai.settings, "AI_WARM_ENABLED", True) monkeypatch.setattr(ai.settings, "AI_WARM_INTERVAL_SEC", 1) monkeypatch.setattr(ai.time, "sleep", fake_sleep) monkeypatch.setattr(ai.httpx, "Client", WarmClient) monkeypatch.setattr(ai.threading, "Thread", InlineThread) with pytest.raises(KeyboardInterrupt): ai._start_keep_warm() assert posts class RaisingWarmClient(WarmClient): def post(self, url, json=None): raise RuntimeError("keep-warm backend unavailable") loop_sleeps = {"count": 0} def stop_after_exception(seconds): loop_sleeps["count"] += 1 if loop_sleeps["count"] > 1: raise KeyboardInterrupt() monkeypatch.setattr(ai.time, "sleep", stop_after_exception) monkeypatch.setattr(ai.httpx, "Client", RaisingWarmClient) with pytest.raises(KeyboardInterrupt): ai._start_keep_warm()