163 lines
5.8 KiB
Python
163 lines
5.8 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from flask import jsonify, request
|
|
import httpx
|
|
|
|
from .. import settings
|
|
|
|
|
|
def register(app) -> None:
|
|
@app.route("/api/chat", methods=["POST"])
|
|
@app.route("/api/ai/chat", methods=["POST"])
|
|
def ai_chat() -> Any:
|
|
payload = request.get_json(silent=True) or {}
|
|
user_message = (payload.get("message") or "").strip()
|
|
history = payload.get("history") or []
|
|
|
|
if not user_message:
|
|
return jsonify({"error": "message required"}), 400
|
|
|
|
atlasbot_reply = _atlasbot_answer(user_message)
|
|
if atlasbot_reply:
|
|
elapsed_ms = int((time.time() - started) * 1000)
|
|
return jsonify({"reply": atlasbot_reply, "latency_ms": elapsed_ms, "source": "atlasbot"})
|
|
|
|
messages: list[dict[str, str]] = []
|
|
if settings.AI_CHAT_SYSTEM_PROMPT:
|
|
messages.append({"role": "system", "content": settings.AI_CHAT_SYSTEM_PROMPT})
|
|
|
|
for item in history:
|
|
role = item.get("role")
|
|
content = (item.get("content") or "").strip()
|
|
if role in ("user", "assistant") and content:
|
|
messages.append({"role": role, "content": content})
|
|
|
|
messages.append({"role": "user", "content": user_message})
|
|
|
|
body = {"model": settings.AI_CHAT_MODEL, "messages": messages, "stream": False}
|
|
started = time.time()
|
|
|
|
try:
|
|
with httpx.Client(timeout=settings.AI_CHAT_TIMEOUT_SEC) as client:
|
|
resp = client.post(f"{settings.AI_CHAT_API}/api/chat", json=body)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
reply = (data.get("message") or {}).get("content") or ""
|
|
elapsed_ms = int((time.time() - started) * 1000)
|
|
return jsonify({"reply": reply, "latency_ms": elapsed_ms})
|
|
except (httpx.RequestError, httpx.HTTPStatusError, ValueError) as exc:
|
|
return jsonify({"error": str(exc)}), 502
|
|
|
|
@app.route("/api/chat/info", methods=["GET"])
|
|
@app.route("/api/ai/info", methods=["GET"])
|
|
def ai_info() -> Any:
|
|
meta = _discover_ai_meta()
|
|
return jsonify(meta)
|
|
|
|
_start_keep_warm()
|
|
|
|
|
|
def _atlasbot_answer(message: str) -> str:
|
|
endpoint = settings.AI_ATLASBOT_ENDPOINT
|
|
if not endpoint:
|
|
return ""
|
|
headers: dict[str, str] = {}
|
|
if settings.AI_ATLASBOT_TOKEN:
|
|
headers["X-Internal-Token"] = settings.AI_ATLASBOT_TOKEN
|
|
try:
|
|
with httpx.Client(timeout=settings.AI_ATLASBOT_TIMEOUT_SEC) as client:
|
|
resp = client.post(endpoint, json={"prompt": message}, headers=headers)
|
|
if resp.status_code != 200:
|
|
return ""
|
|
data = resp.json()
|
|
answer = (data.get("answer") or "").strip()
|
|
return answer
|
|
except (httpx.RequestError, ValueError):
|
|
return ""
|
|
|
|
|
|
def _discover_ai_meta() -> dict[str, str]:
|
|
meta = {
|
|
"node": settings.AI_NODE_NAME,
|
|
"gpu": settings.AI_GPU_DESC,
|
|
"model": settings.AI_CHAT_MODEL,
|
|
"endpoint": settings.AI_PUBLIC_ENDPOINT or "/api/chat",
|
|
}
|
|
|
|
sa_path = Path("/var/run/secrets/kubernetes.io/serviceaccount")
|
|
token_path = sa_path / "token"
|
|
ca_path = sa_path / "ca.crt"
|
|
ns_path = sa_path / "namespace"
|
|
if not token_path.exists() or not ca_path.exists() or not ns_path.exists():
|
|
return meta
|
|
|
|
try:
|
|
token = token_path.read_text().strip()
|
|
namespace = settings.AI_K8S_NAMESPACE
|
|
base_url = "https://kubernetes.default.svc"
|
|
pod_url = f"{base_url}/api/v1/namespaces/{namespace}/pods?labelSelector={settings.AI_K8S_LABEL}"
|
|
|
|
with httpx.Client(
|
|
verify=str(ca_path),
|
|
timeout=settings.HTTP_CHECK_TIMEOUT_SEC,
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
) as client:
|
|
resp = client.get(pod_url)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
items = data.get("items") or []
|
|
running = [p for p in items if p.get("status", {}).get("phase") == "Running"] or items
|
|
if running:
|
|
pod = running[0]
|
|
node_name = pod.get("spec", {}).get("nodeName") or meta["node"]
|
|
meta["node"] = node_name
|
|
|
|
annotations = pod.get("metadata", {}).get("annotations") or {}
|
|
gpu_hint = (
|
|
annotations.get(settings.AI_GPU_ANNOTATION)
|
|
or annotations.get("ai.gpu/description")
|
|
or annotations.get("gpu/description")
|
|
)
|
|
if gpu_hint:
|
|
meta["gpu"] = gpu_hint
|
|
|
|
model_hint = annotations.get(settings.AI_MODEL_ANNOTATION)
|
|
if not model_hint:
|
|
containers = pod.get("spec", {}).get("containers") or []
|
|
if containers:
|
|
image = containers[0].get("image") or ""
|
|
model_hint = image.split(":")[-1] if ":" in image else image
|
|
if model_hint:
|
|
meta["model"] = model_hint
|
|
except Exception:
|
|
pass
|
|
|
|
return meta
|
|
|
|
|
|
def _start_keep_warm() -> None:
|
|
if not settings.AI_WARM_ENABLED or settings.AI_WARM_INTERVAL_SEC <= 0:
|
|
return
|
|
|
|
def loop() -> None:
|
|
while True:
|
|
time.sleep(settings.AI_WARM_INTERVAL_SEC)
|
|
try:
|
|
body = {
|
|
"model": settings.AI_CHAT_MODEL,
|
|
"messages": [{"role": "user", "content": "ping"}],
|
|
"stream": False,
|
|
}
|
|
with httpx.Client(timeout=min(settings.AI_CHAT_TIMEOUT_SEC, 15)) as client:
|
|
client.post(f"{settings.AI_CHAT_API}/api/chat", json=body)
|
|
except Exception:
|
|
continue
|
|
|
|
threading.Thread(target=loop, daemon=True, name="ai-keep-warm").start()
|