diff --git a/backend/tests/test_ai.py b/backend/tests/test_ai.py index eb2a3a2..d3cb6cd 100644 --- a/backend/tests/test_ai.py +++ b/backend/tests/test_ai.py @@ -1,8 +1,11 @@ from __future__ import annotations import json +from types import SimpleNamespace from unittest import TestCase, mock +import pytest + from atlas_portal.app_factory import create_app from atlas_portal.routes import ai @@ -88,27 +91,267 @@ class AiRouteTests(TestCase): with ( mock.patch.object(ai.httpx, "Client", DummyClient), mock.patch.object(ai.settings, "AI_ATLASBOT_ENDPOINT", "http://atlasbot.invalid/v1/answer"), + mock.patch.object(ai.settings, "AI_ATLASBOT_TOKEN", "internal-token"), ): reply = ai._atlasbot_answer("How is Titan doing?", "genius", "conv-1") self.assertEqual(reply, "atlas reply") self.assertEqual(captured["timeout"], ai.settings.AI_ATLASBOT_TIMEOUT_GENIUS_SEC) self.assertEqual(captured["json"], {"prompt": "How is Titan doing?", "mode": "genius", "conversation_id": "conv-1"}) + self.assertEqual(captured["headers"], {"X-Internal-Token": "internal-token"}) def test_chat_returns_fallback_when_atlasbot_returns_empty(self): - with mock.patch.object(ai, "_atlasbot_answer", return_value=""): - resp = self.client.post( - "/api/chat", - data=json.dumps( - { - "message": "How is Titan doing?", - "profile": "atlas-quick", - } - ), - content_type="application/json", + for profile, expected in ( + ("atlas-quick", "Quick mode hit"), + ("atlas-smart", "Smart mode hit"), + ("atlas-genius", "Atlas genius mode timed out"), + ): + with mock.patch.object(ai, "_atlasbot_answer", return_value=""): + resp = self.client.post( + "/api/chat", + data=json.dumps( + { + "message": "How is Titan doing?", + "profile": profile, + } + ), + content_type="application/json", + ) + + data = resp.get_json() + self.assertEqual(resp.status_code, 200) + self.assertIn(expected, data.get("reply", "")) + + +def test_chat_requires_message() -> None: + client = create_app().test_client() + + response = client.post("/api/chat", data=json.dumps({"message": ""}), content_type="application/json") + + assert response.status_code == 400 + assert response.get_json()["error"] == "message required" + + +def test_atlasbot_answer_soft_failure_paths(monkeypatch) -> None: + monkeypatch.setattr(ai.settings, "AI_ATLASBOT_ENDPOINT", "") + assert ai._atlasbot_answer("hello", "quick", "") == "" + + class NonOkClient: + def __init__(self, timeout): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def post(self, endpoint, json=None, headers=None): + return SimpleNamespace(status_code=503) + + monkeypatch.setattr(ai.settings, "AI_ATLASBOT_ENDPOINT", "http://atlasbot") + monkeypatch.setattr(ai.httpx, "Client", NonOkClient) + assert ai._atlasbot_answer("hello", "smart", "") == "" + + class BadJsonClient(NonOkClient): + def post(self, endpoint, json=None, headers=None): + return SimpleNamespace(status_code=200, json=lambda: (_ for _ in ()).throw(ValueError("bad"))) + + monkeypatch.setattr(ai.httpx, "Client", BadJsonClient) + assert ai._atlasbot_answer("hello", "quick", "") == "" + assert ai._atlasbot_timeout_sec("smart") == ai.settings.AI_ATLASBOT_TIMEOUT_SMART_SEC + assert ai._atlasbot_timeout_sec("quick") == ai.settings.AI_ATLASBOT_TIMEOUT_QUICK_SEC + + +def test_discover_ai_meta_reads_pod_annotations(monkeypatch) -> None: + class FakePath: + def __init__(self, value): + self.value = value + + def __truediv__(self, child): + return FakePath(child) + + def exists(self): + return True + + def read_text(self): + if self.value == "token": + return "token" + if self.value == "namespace": + return "ai" + return "ca" + + def __str__(self): + return self.value + + class PodClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def get(self, url): + return SimpleNamespace( + raise_for_status=lambda: None, + json=lambda: { + "items": [ + { + "status": {"phase": "Pending"}, + "spec": { + "nodeName": "titan-24", + "containers": [{"image": "registry/atlasbot:model-from-image"}], + }, + "metadata": { + "annotations": { + ai.settings.AI_GPU_ANNOTATION: "RTX 3090", + ai.settings.AI_MODEL_ANNOTATION: "annotated-model", + } + }, + } + ] + }, ) - data = resp.get_json() - self.assertEqual(resp.status_code, 200) - self.assertEqual(data.get("source"), "atlas-quick") - self.assertIn("Quick mode hit", data.get("reply", "")) + monkeypatch.setattr(ai, "Path", FakePath) + monkeypatch.setattr(ai.httpx, "Client", PodClient) + + meta = ai._discover_ai_meta("atlas-quick") + + assert meta["node"] == "titan-24" + assert meta["gpu"] == "RTX 3090" + assert meta["model"] == "annotated-model" + + class ImageOnlyClient(PodClient): + def get(self, url): + return SimpleNamespace( + raise_for_status=lambda: None, + json=lambda: { + "items": [ + { + "status": {"phase": "Running"}, + "spec": { + "nodeName": "titan-22", + "containers": [{"image": "registry/atlasbot:model-from-image"}], + }, + "metadata": {"annotations": {}}, + } + ] + }, + ) + + monkeypatch.setattr(ai.httpx, "Client", ImageOnlyClient) + image_meta = ai._discover_ai_meta("atlas-smart") + + assert image_meta["endpoint"] == "/api/ai/chat" + assert image_meta["model"] == "model-from-image" + + +def test_discover_ai_meta_handles_probe_errors(monkeypatch) -> None: + class MissingPath: + def __init__(self, value): + self.value = value + + def __truediv__(self, child): + return MissingPath(child) + + def exists(self): + return False + + monkeypatch.setattr(ai, "Path", MissingPath) + assert ai._discover_ai_meta("quick")["endpoint"] == "/api/ai/chat" + + class ExistingPath(MissingPath): + def __truediv__(self, child): + return ExistingPath(child) + + def exists(self): + return True + + def read_text(self): + return "token" + + def __str__(self): + return self.value + + class FailingClient: + def __init__(self, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def get(self, url): + raise RuntimeError("offline") + + monkeypatch.setattr(ai, "Path", ExistingPath) + monkeypatch.setattr(ai.httpx, "Client", FailingClient) + assert ai._discover_ai_meta("atlas-genius")["endpoint"] == "/api/ai/chat" + + +def test_start_keep_warm_disabled_and_loop(monkeypatch) -> None: + monkeypatch.setattr(ai.settings, "AI_WARM_ENABLED", False) + ai._start_keep_warm() + + posts: list[dict] = [] + + class WarmClient: + def __init__(self, timeout): + self.timeout = timeout + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def post(self, url, json=None): + posts.append({"url": url, "json": json}) + + sleeps = {"count": 0} + + def fake_sleep(seconds): + sleeps["count"] += 1 + if sleeps["count"] > 1: + raise KeyboardInterrupt() + + class InlineThread: + def __init__(self, target, daemon, name): + self.target = target + + def start(self): + self.target() + + monkeypatch.setattr(ai.settings, "AI_WARM_ENABLED", True) + monkeypatch.setattr(ai.settings, "AI_WARM_INTERVAL_SEC", 1) + monkeypatch.setattr(ai.time, "sleep", fake_sleep) + monkeypatch.setattr(ai.httpx, "Client", WarmClient) + monkeypatch.setattr(ai.threading, "Thread", InlineThread) + + with pytest.raises(KeyboardInterrupt): + ai._start_keep_warm() + + assert posts + + class RaisingWarmClient(WarmClient): + def post(self, url, json=None): + raise RuntimeError("keep-warm backend unavailable") + + loop_sleeps = {"count": 0} + + def stop_after_exception(seconds): + loop_sleeps["count"] += 1 + if loop_sleeps["count"] > 1: + raise KeyboardInterrupt() + + monkeypatch.setattr(ai.time, "sleep", stop_after_exception) + monkeypatch.setattr(ai.httpx, "Client", RaisingWarmClient) + + with pytest.raises(KeyboardInterrupt): + ai._start_keep_warm()