79 lines
2.8 KiB
Python
79 lines
2.8 KiB
Python
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from atlasbot.config import Settings
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
FALLBACK_STATUS_CODE = 404
|
|
|
|
|
|
class LLMError(RuntimeError):
|
|
pass
|
|
|
|
|
|
class LLMClient:
|
|
def __init__(self, settings: Settings) -> None:
|
|
self._settings = settings
|
|
self._timeout = settings.ollama_timeout_sec
|
|
self._headers = {"Content-Type": "application/json"}
|
|
if settings.ollama_api_key:
|
|
self._headers["x-api-key"] = settings.ollama_api_key
|
|
|
|
def _endpoint(self) -> str:
|
|
base = self._settings.ollama_url.rstrip("/")
|
|
if base.endswith("/api/chat"):
|
|
return base
|
|
return base + "/api/chat"
|
|
|
|
async def chat(self, messages: list[dict[str, str]], *, model: str | None = None) -> str:
|
|
payload = {
|
|
"model": model or self._settings.ollama_model,
|
|
"messages": messages,
|
|
"stream": False,
|
|
}
|
|
for attempt in range(max(1, self._settings.ollama_retries + 1)):
|
|
try:
|
|
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
|
resp = await client.post(self._endpoint(), json=payload, headers=self._headers)
|
|
if resp.status_code == FALLBACK_STATUS_CODE and self._settings.ollama_fallback_model:
|
|
payload["model"] = self._settings.ollama_fallback_model
|
|
continue
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
message = data.get("message") if isinstance(data, dict) else None
|
|
if isinstance(message, dict):
|
|
content = message.get("content")
|
|
else:
|
|
content = data.get("response") or data.get("reply") or data
|
|
if not content:
|
|
raise LLMError("empty response")
|
|
return str(content)
|
|
except Exception as exc:
|
|
log.warning("ollama call failed", extra={"extra": {"attempt": attempt + 1, "error": str(exc)}})
|
|
if attempt + 1 >= max(1, self._settings.ollama_retries + 1):
|
|
raise LLMError(str(exc)) from exc
|
|
raise LLMError("ollama retries exhausted")
|
|
|
|
|
|
def build_messages(system: str, prompt: str, *, context: str | None = None) -> list[dict[str, str]]:
|
|
system_content = system
|
|
if context:
|
|
system_content = system_content + "\n\nContext (grounded facts):\n" + context
|
|
messages: list[dict[str, str]] = [{"role": "system", "content": system_content}]
|
|
messages.append({"role": "user", "content": prompt})
|
|
return messages
|
|
|
|
|
|
def parse_json(text: str, *, fallback: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
try:
|
|
raw = text.strip()
|
|
if raw.startswith("`"):
|
|
raw = raw.strip("`")
|
|
return json.loads(raw)
|
|
except Exception:
|
|
return fallback or {}
|