358 lines
12 KiB
Python

from __future__ import annotations
import json
from types import SimpleNamespace
from unittest import TestCase, mock
import pytest
from atlas_portal.app_factory import create_app
from atlas_portal.routes import ai
class AiRouteTests(TestCase):
@classmethod
def setUpClass(cls):
cls.app = create_app()
cls.client = cls.app.test_client()
def test_chat_routes_profiles_to_modes(self):
seen: list[tuple[str, str]] = []
def fake_atlasbot_answer(message: str, mode: str, conversation_id: str) -> str:
seen.append((mode, conversation_id))
return f"{mode}:{conversation_id}"
with mock.patch.object(ai, "_atlasbot_answer", side_effect=fake_atlasbot_answer):
for profile, expected_mode in (
("atlas-quick", "quick"),
("atlas-smart", "smart"),
("atlas-genius", "genius"),
):
resp = self.client.post(
"/api/chat",
data=json.dumps(
{
"message": "How is Titan doing?",
"profile": profile,
"conversation_id": f"conv-{profile}",
}
),
content_type="application/json",
)
data = resp.get_json()
self.assertEqual(resp.status_code, 200)
self.assertEqual(data.get("source"), f"atlas-{expected_mode}")
self.assertEqual(data.get("reply"), f"{expected_mode}:conv-{profile}")
self.assertEqual(
seen,
[
("quick", "conv-atlas-quick"),
("smart", "conv-atlas-smart"),
("genius", "conv-atlas-genius"),
],
)
def test_info_endpoint_exposes_profile_specific_model(self):
with mock.patch.object(ai.settings, "AI_ATLASBOT_MODEL_GENIUS", "genius-model"):
resp = self.client.get("/api/ai/info?profile=atlas-genius")
data = resp.get_json()
self.assertEqual(resp.status_code, 200)
self.assertEqual(data.get("profile"), "atlas-genius")
self.assertEqual(data.get("model"), "genius-model")
def test_atlasbot_answer_uses_profile_specific_timeout(self):
captured: dict[str, object] = {}
class DummyResponse:
status_code = 200
def json(self):
return {"reply": "atlas reply"}
class DummyClient:
def __init__(self, timeout):
captured["timeout"] = timeout
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def post(self, endpoint, json=None, headers=None):
captured["endpoint"] = endpoint
captured["json"] = json
captured["headers"] = headers or {}
return DummyResponse()
with (
mock.patch.object(ai.httpx, "Client", DummyClient),
mock.patch.object(ai.settings, "AI_ATLASBOT_ENDPOINT", "http://atlasbot.invalid/v1/answer"),
mock.patch.object(ai.settings, "AI_ATLASBOT_TOKEN", "internal-token"),
):
reply = ai._atlasbot_answer("How is Titan doing?", "genius", "conv-1")
self.assertEqual(reply, "atlas reply")
self.assertEqual(captured["timeout"], ai.settings.AI_ATLASBOT_TIMEOUT_GENIUS_SEC)
self.assertEqual(captured["json"], {"prompt": "How is Titan doing?", "mode": "genius", "conversation_id": "conv-1"})
self.assertEqual(captured["headers"], {"X-Internal-Token": "internal-token"})
def test_chat_returns_fallback_when_atlasbot_returns_empty(self):
for profile, expected in (
("atlas-quick", "Quick mode hit"),
("atlas-smart", "Smart mode hit"),
("atlas-genius", "Atlas genius mode timed out"),
):
with mock.patch.object(ai, "_atlasbot_answer", return_value=""):
resp = self.client.post(
"/api/chat",
data=json.dumps(
{
"message": "How is Titan doing?",
"profile": profile,
}
),
content_type="application/json",
)
data = resp.get_json()
self.assertEqual(resp.status_code, 200)
self.assertIn(expected, data.get("reply", ""))
def test_chat_requires_message() -> None:
client = create_app().test_client()
response = client.post("/api/chat", data=json.dumps({"message": ""}), content_type="application/json")
assert response.status_code == 400
assert response.get_json()["error"] == "message required"
def test_atlasbot_answer_soft_failure_paths(monkeypatch) -> None:
monkeypatch.setattr(ai.settings, "AI_ATLASBOT_ENDPOINT", "")
assert ai._atlasbot_answer("hello", "quick", "") == ""
class NonOkClient:
def __init__(self, timeout):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def post(self, endpoint, json=None, headers=None):
return SimpleNamespace(status_code=503)
monkeypatch.setattr(ai.settings, "AI_ATLASBOT_ENDPOINT", "http://atlasbot")
monkeypatch.setattr(ai.httpx, "Client", NonOkClient)
assert ai._atlasbot_answer("hello", "smart", "") == ""
class BadJsonClient(NonOkClient):
def post(self, endpoint, json=None, headers=None):
return SimpleNamespace(status_code=200, json=lambda: (_ for _ in ()).throw(ValueError("bad")))
monkeypatch.setattr(ai.httpx, "Client", BadJsonClient)
assert ai._atlasbot_answer("hello", "quick", "") == ""
assert ai._atlasbot_timeout_sec("smart") == ai.settings.AI_ATLASBOT_TIMEOUT_SMART_SEC
assert ai._atlasbot_timeout_sec("quick") == ai.settings.AI_ATLASBOT_TIMEOUT_QUICK_SEC
def test_discover_ai_meta_reads_pod_annotations(monkeypatch) -> None:
class FakePath:
def __init__(self, value):
self.value = value
def __truediv__(self, child):
return FakePath(child)
def exists(self):
return True
def read_text(self):
if self.value == "token":
return "token"
if self.value == "namespace":
return "ai"
return "ca"
def __str__(self):
return self.value
class PodClient:
def __init__(self, **kwargs):
self.kwargs = kwargs
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def get(self, url):
return SimpleNamespace(
raise_for_status=lambda: None,
json=lambda: {
"items": [
{
"status": {"phase": "Pending"},
"spec": {
"nodeName": "titan-24",
"containers": [{"image": "registry/atlasbot:model-from-image"}],
},
"metadata": {
"annotations": {
ai.settings.AI_GPU_ANNOTATION: "RTX 3090",
ai.settings.AI_MODEL_ANNOTATION: "annotated-model",
}
},
}
]
},
)
monkeypatch.setattr(ai, "Path", FakePath)
monkeypatch.setattr(ai.httpx, "Client", PodClient)
meta = ai._discover_ai_meta("atlas-quick")
assert meta["node"] == "titan-24"
assert meta["gpu"] == "RTX 3090"
assert meta["model"] == "annotated-model"
class ImageOnlyClient(PodClient):
def get(self, url):
return SimpleNamespace(
raise_for_status=lambda: None,
json=lambda: {
"items": [
{
"status": {"phase": "Running"},
"spec": {
"nodeName": "titan-22",
"containers": [{"image": "registry/atlasbot:model-from-image"}],
},
"metadata": {"annotations": {}},
}
]
},
)
monkeypatch.setattr(ai.httpx, "Client", ImageOnlyClient)
image_meta = ai._discover_ai_meta("atlas-smart")
assert image_meta["endpoint"] == "/api/ai/chat"
assert image_meta["model"] == "model-from-image"
def test_discover_ai_meta_handles_probe_errors(monkeypatch) -> None:
class MissingPath:
def __init__(self, value):
self.value = value
def __truediv__(self, child):
return MissingPath(child)
def exists(self):
return False
monkeypatch.setattr(ai, "Path", MissingPath)
assert ai._discover_ai_meta("quick")["endpoint"] == "/api/ai/chat"
class ExistingPath(MissingPath):
def __truediv__(self, child):
return ExistingPath(child)
def exists(self):
return True
def read_text(self):
return "token"
def __str__(self):
return self.value
class FailingClient:
def __init__(self, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def get(self, url):
raise RuntimeError("offline")
monkeypatch.setattr(ai, "Path", ExistingPath)
monkeypatch.setattr(ai.httpx, "Client", FailingClient)
assert ai._discover_ai_meta("atlas-genius")["endpoint"] == "/api/ai/chat"
def test_start_keep_warm_disabled_and_loop(monkeypatch) -> None:
monkeypatch.setattr(ai.settings, "AI_WARM_ENABLED", False)
ai._start_keep_warm()
posts: list[dict] = []
class WarmClient:
def __init__(self, timeout):
self.timeout = timeout
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def post(self, url, json=None):
posts.append({"url": url, "json": json})
sleeps = {"count": 0}
def fake_sleep(seconds):
sleeps["count"] += 1
if sleeps["count"] > 1:
raise KeyboardInterrupt()
class InlineThread:
def __init__(self, target, daemon, name):
self.target = target
def start(self):
self.target()
monkeypatch.setattr(ai.settings, "AI_WARM_ENABLED", True)
monkeypatch.setattr(ai.settings, "AI_WARM_INTERVAL_SEC", 1)
monkeypatch.setattr(ai.time, "sleep", fake_sleep)
monkeypatch.setattr(ai.httpx, "Client", WarmClient)
monkeypatch.setattr(ai.threading, "Thread", InlineThread)
with pytest.raises(KeyboardInterrupt):
ai._start_keep_warm()
assert posts
class RaisingWarmClient(WarmClient):
def post(self, url, json=None):
raise RuntimeError("keep-warm backend unavailable")
loop_sleeps = {"count": 0}
def stop_after_exception(seconds):
loop_sleeps["count"] += 1
if loop_sleeps["count"] > 1:
raise KeyboardInterrupt()
monkeypatch.setattr(ai.time, "sleep", stop_after_exception)
monkeypatch.setattr(ai.httpx, "Client", RaisingWarmClient)
with pytest.raises(KeyboardInterrupt):
ai._start_keep_warm()