test(bstein-home): cover ai route metadata
This commit is contained in:
parent
e6d807ed3f
commit
2dfbf86a93
@ -1,8 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest import TestCase, mock
|
from unittest import TestCase, mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from atlas_portal.app_factory import create_app
|
from atlas_portal.app_factory import create_app
|
||||||
from atlas_portal.routes import ai
|
from atlas_portal.routes import ai
|
||||||
|
|
||||||
@ -88,21 +91,28 @@ class AiRouteTests(TestCase):
|
|||||||
with (
|
with (
|
||||||
mock.patch.object(ai.httpx, "Client", DummyClient),
|
mock.patch.object(ai.httpx, "Client", DummyClient),
|
||||||
mock.patch.object(ai.settings, "AI_ATLASBOT_ENDPOINT", "http://atlasbot.invalid/v1/answer"),
|
mock.patch.object(ai.settings, "AI_ATLASBOT_ENDPOINT", "http://atlasbot.invalid/v1/answer"),
|
||||||
|
mock.patch.object(ai.settings, "AI_ATLASBOT_TOKEN", "internal-token"),
|
||||||
):
|
):
|
||||||
reply = ai._atlasbot_answer("How is Titan doing?", "genius", "conv-1")
|
reply = ai._atlasbot_answer("How is Titan doing?", "genius", "conv-1")
|
||||||
|
|
||||||
self.assertEqual(reply, "atlas reply")
|
self.assertEqual(reply, "atlas reply")
|
||||||
self.assertEqual(captured["timeout"], ai.settings.AI_ATLASBOT_TIMEOUT_GENIUS_SEC)
|
self.assertEqual(captured["timeout"], ai.settings.AI_ATLASBOT_TIMEOUT_GENIUS_SEC)
|
||||||
self.assertEqual(captured["json"], {"prompt": "How is Titan doing?", "mode": "genius", "conversation_id": "conv-1"})
|
self.assertEqual(captured["json"], {"prompt": "How is Titan doing?", "mode": "genius", "conversation_id": "conv-1"})
|
||||||
|
self.assertEqual(captured["headers"], {"X-Internal-Token": "internal-token"})
|
||||||
|
|
||||||
def test_chat_returns_fallback_when_atlasbot_returns_empty(self):
|
def test_chat_returns_fallback_when_atlasbot_returns_empty(self):
|
||||||
|
for profile, expected in (
|
||||||
|
("atlas-quick", "Quick mode hit"),
|
||||||
|
("atlas-smart", "Smart mode hit"),
|
||||||
|
("atlas-genius", "Atlas genius mode timed out"),
|
||||||
|
):
|
||||||
with mock.patch.object(ai, "_atlasbot_answer", return_value=""):
|
with mock.patch.object(ai, "_atlasbot_answer", return_value=""):
|
||||||
resp = self.client.post(
|
resp = self.client.post(
|
||||||
"/api/chat",
|
"/api/chat",
|
||||||
data=json.dumps(
|
data=json.dumps(
|
||||||
{
|
{
|
||||||
"message": "How is Titan doing?",
|
"message": "How is Titan doing?",
|
||||||
"profile": "atlas-quick",
|
"profile": profile,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
@ -110,5 +120,238 @@ class AiRouteTests(TestCase):
|
|||||||
|
|
||||||
data = resp.get_json()
|
data = resp.get_json()
|
||||||
self.assertEqual(resp.status_code, 200)
|
self.assertEqual(resp.status_code, 200)
|
||||||
self.assertEqual(data.get("source"), "atlas-quick")
|
self.assertIn(expected, data.get("reply", ""))
|
||||||
self.assertIn("Quick mode hit", data.get("reply", ""))
|
|
||||||
|
|
||||||
|
def test_chat_requires_message() -> None:
|
||||||
|
client = create_app().test_client()
|
||||||
|
|
||||||
|
response = client.post("/api/chat", data=json.dumps({"message": ""}), content_type="application/json")
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert response.get_json()["error"] == "message required"
|
||||||
|
|
||||||
|
|
||||||
|
def test_atlasbot_answer_soft_failure_paths(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(ai.settings, "AI_ATLASBOT_ENDPOINT", "")
|
||||||
|
assert ai._atlasbot_answer("hello", "quick", "") == ""
|
||||||
|
|
||||||
|
class NonOkClient:
|
||||||
|
def __init__(self, timeout):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def post(self, endpoint, json=None, headers=None):
|
||||||
|
return SimpleNamespace(status_code=503)
|
||||||
|
|
||||||
|
monkeypatch.setattr(ai.settings, "AI_ATLASBOT_ENDPOINT", "http://atlasbot")
|
||||||
|
monkeypatch.setattr(ai.httpx, "Client", NonOkClient)
|
||||||
|
assert ai._atlasbot_answer("hello", "smart", "") == ""
|
||||||
|
|
||||||
|
class BadJsonClient(NonOkClient):
|
||||||
|
def post(self, endpoint, json=None, headers=None):
|
||||||
|
return SimpleNamespace(status_code=200, json=lambda: (_ for _ in ()).throw(ValueError("bad")))
|
||||||
|
|
||||||
|
monkeypatch.setattr(ai.httpx, "Client", BadJsonClient)
|
||||||
|
assert ai._atlasbot_answer("hello", "quick", "") == ""
|
||||||
|
assert ai._atlasbot_timeout_sec("smart") == ai.settings.AI_ATLASBOT_TIMEOUT_SMART_SEC
|
||||||
|
assert ai._atlasbot_timeout_sec("quick") == ai.settings.AI_ATLASBOT_TIMEOUT_QUICK_SEC
|
||||||
|
|
||||||
|
|
||||||
|
def test_discover_ai_meta_reads_pod_annotations(monkeypatch) -> None:
|
||||||
|
class FakePath:
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __truediv__(self, child):
|
||||||
|
return FakePath(child)
|
||||||
|
|
||||||
|
def exists(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def read_text(self):
|
||||||
|
if self.value == "token":
|
||||||
|
return "token"
|
||||||
|
if self.value == "namespace":
|
||||||
|
return "ai"
|
||||||
|
return "ca"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
class PodClient:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get(self, url):
|
||||||
|
return SimpleNamespace(
|
||||||
|
raise_for_status=lambda: None,
|
||||||
|
json=lambda: {
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"status": {"phase": "Pending"},
|
||||||
|
"spec": {
|
||||||
|
"nodeName": "titan-24",
|
||||||
|
"containers": [{"image": "registry/atlasbot:model-from-image"}],
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"annotations": {
|
||||||
|
ai.settings.AI_GPU_ANNOTATION: "RTX 3090",
|
||||||
|
ai.settings.AI_MODEL_ANNOTATION: "annotated-model",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(ai, "Path", FakePath)
|
||||||
|
monkeypatch.setattr(ai.httpx, "Client", PodClient)
|
||||||
|
|
||||||
|
meta = ai._discover_ai_meta("atlas-quick")
|
||||||
|
|
||||||
|
assert meta["node"] == "titan-24"
|
||||||
|
assert meta["gpu"] == "RTX 3090"
|
||||||
|
assert meta["model"] == "annotated-model"
|
||||||
|
|
||||||
|
class ImageOnlyClient(PodClient):
|
||||||
|
def get(self, url):
|
||||||
|
return SimpleNamespace(
|
||||||
|
raise_for_status=lambda: None,
|
||||||
|
json=lambda: {
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"status": {"phase": "Running"},
|
||||||
|
"spec": {
|
||||||
|
"nodeName": "titan-22",
|
||||||
|
"containers": [{"image": "registry/atlasbot:model-from-image"}],
|
||||||
|
},
|
||||||
|
"metadata": {"annotations": {}},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(ai.httpx, "Client", ImageOnlyClient)
|
||||||
|
image_meta = ai._discover_ai_meta("atlas-smart")
|
||||||
|
|
||||||
|
assert image_meta["endpoint"] == "/api/ai/chat"
|
||||||
|
assert image_meta["model"] == "model-from-image"
|
||||||
|
|
||||||
|
|
||||||
|
def test_discover_ai_meta_handles_probe_errors(monkeypatch) -> None:
|
||||||
|
class MissingPath:
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __truediv__(self, child):
|
||||||
|
return MissingPath(child)
|
||||||
|
|
||||||
|
def exists(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setattr(ai, "Path", MissingPath)
|
||||||
|
assert ai._discover_ai_meta("quick")["endpoint"] == "/api/ai/chat"
|
||||||
|
|
||||||
|
class ExistingPath(MissingPath):
|
||||||
|
def __truediv__(self, child):
|
||||||
|
return ExistingPath(child)
|
||||||
|
|
||||||
|
def exists(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def read_text(self):
|
||||||
|
return "token"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
class FailingClient:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get(self, url):
|
||||||
|
raise RuntimeError("offline")
|
||||||
|
|
||||||
|
monkeypatch.setattr(ai, "Path", ExistingPath)
|
||||||
|
monkeypatch.setattr(ai.httpx, "Client", FailingClient)
|
||||||
|
assert ai._discover_ai_meta("atlas-genius")["endpoint"] == "/api/ai/chat"
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_keep_warm_disabled_and_loop(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(ai.settings, "AI_WARM_ENABLED", False)
|
||||||
|
ai._start_keep_warm()
|
||||||
|
|
||||||
|
posts: list[dict] = []
|
||||||
|
|
||||||
|
class WarmClient:
|
||||||
|
def __init__(self, timeout):
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def post(self, url, json=None):
|
||||||
|
posts.append({"url": url, "json": json})
|
||||||
|
|
||||||
|
sleeps = {"count": 0}
|
||||||
|
|
||||||
|
def fake_sleep(seconds):
|
||||||
|
sleeps["count"] += 1
|
||||||
|
if sleeps["count"] > 1:
|
||||||
|
raise KeyboardInterrupt()
|
||||||
|
|
||||||
|
class InlineThread:
|
||||||
|
def __init__(self, target, daemon, name):
|
||||||
|
self.target = target
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.target()
|
||||||
|
|
||||||
|
monkeypatch.setattr(ai.settings, "AI_WARM_ENABLED", True)
|
||||||
|
monkeypatch.setattr(ai.settings, "AI_WARM_INTERVAL_SEC", 1)
|
||||||
|
monkeypatch.setattr(ai.time, "sleep", fake_sleep)
|
||||||
|
monkeypatch.setattr(ai.httpx, "Client", WarmClient)
|
||||||
|
monkeypatch.setattr(ai.threading, "Thread", InlineThread)
|
||||||
|
|
||||||
|
with pytest.raises(KeyboardInterrupt):
|
||||||
|
ai._start_keep_warm()
|
||||||
|
|
||||||
|
assert posts
|
||||||
|
|
||||||
|
class RaisingWarmClient(WarmClient):
|
||||||
|
def post(self, url, json=None):
|
||||||
|
raise RuntimeError("keep-warm backend unavailable")
|
||||||
|
|
||||||
|
loop_sleeps = {"count": 0}
|
||||||
|
|
||||||
|
def stop_after_exception(seconds):
|
||||||
|
loop_sleeps["count"] += 1
|
||||||
|
if loop_sleeps["count"] > 1:
|
||||||
|
raise KeyboardInterrupt()
|
||||||
|
|
||||||
|
monkeypatch.setattr(ai.time, "sleep", stop_after_exception)
|
||||||
|
monkeypatch.setattr(ai.httpx, "Client", RaisingWarmClient)
|
||||||
|
|
||||||
|
with pytest.raises(KeyboardInterrupt):
|
||||||
|
ai._start_keep_warm()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user