atlasbot/atlasbot/llm/client.py

79 lines
2.8 KiB
Python
Raw Normal View History

2026-01-28 11:46:52 -03:00
import json
import logging
from typing import Any
import httpx
from atlasbot.config import Settings
log = logging.getLogger(__name__)
FALLBACK_STATUS_CODE = 404
class LLMError(RuntimeError):
pass
class LLMClient:
def __init__(self, settings: Settings) -> None:
self._settings = settings
self._timeout = settings.ollama_timeout_sec
self._headers = {"Content-Type": "application/json"}
if settings.ollama_api_key:
self._headers["x-api-key"] = settings.ollama_api_key
def _endpoint(self) -> str:
base = self._settings.ollama_url.rstrip("/")
if base.endswith("/api/chat"):
return base
return base + "/api/chat"
async def chat(self, messages: list[dict[str, str]], *, model: str | None = None) -> str:
payload = {
"model": model or self._settings.ollama_model,
"messages": messages,
"stream": False,
}
for attempt in range(max(1, self._settings.ollama_retries + 1)):
try:
async with httpx.AsyncClient(timeout=self._timeout) as client:
resp = await client.post(self._endpoint(), json=payload, headers=self._headers)
if resp.status_code == FALLBACK_STATUS_CODE and self._settings.ollama_fallback_model:
payload["model"] = self._settings.ollama_fallback_model
continue
resp.raise_for_status()
data = resp.json()
message = data.get("message") if isinstance(data, dict) else None
if isinstance(message, dict):
content = message.get("content")
else:
content = data.get("response") or data.get("reply") or data
if not content:
raise LLMError("empty response")
return str(content)
except Exception as exc:
log.warning("ollama call failed", extra={"extra": {"attempt": attempt + 1, "error": str(exc)}})
if attempt + 1 >= max(1, self._settings.ollama_retries + 1):
raise LLMError(str(exc)) from exc
raise LLMError("ollama retries exhausted")
def build_messages(system: str, prompt: str, *, context: str | None = None) -> list[dict[str, str]]:
system_content = system
2026-01-28 11:46:52 -03:00
if context:
system_content = system_content + "\n\nContext (grounded facts):\n" + context
messages: list[dict[str, str]] = [{"role": "system", "content": system_content}]
2026-01-28 11:46:52 -03:00
messages.append({"role": "user", "content": prompt})
return messages
def parse_json(text: str, *, fallback: dict[str, Any] | None = None) -> dict[str, Any]:
try:
raw = text.strip()
if raw.startswith("`"):
raw = raw.strip("`")
return json.loads(raw)
except Exception:
return fallback or {}