atlasbot/atlasbot/engine/answerer.py

224 lines
8.3 KiB
Python

import asyncio
import logging
import re
import time
from dataclasses import dataclass
from typing import Any, Callable
from atlasbot.config import Settings
from atlasbot.knowledge.loader import KnowledgeBase
from atlasbot.llm.client import LLMClient, build_messages, parse_json
from atlasbot.llm import prompts
from atlasbot.snapshot.builder import SnapshotProvider, summary_text
log = logging.getLogger(__name__)
@dataclass
class AnswerScores:
confidence: int
relevance: int
satisfaction: int
hallucination_risk: str
@dataclass
class AnswerResult:
reply: str
scores: AnswerScores
meta: dict[str, Any]
class AnswerEngine:
def __init__(
self,
settings: Settings,
llm: LLMClient,
kb: KnowledgeBase,
snapshot: SnapshotProvider,
) -> None:
self._settings = settings
self._llm = llm
self._kb = kb
self._snapshot = snapshot
async def answer(
self,
question: str,
*,
mode: str,
observer: Callable[[str, str], None] | None = None,
) -> AnswerResult:
question = (question or "").strip()
if not question:
return AnswerResult("I need a question to answer.", _default_scores(), {"mode": mode})
if mode == "stock":
return await self._answer_stock(question)
snapshot = self._snapshot.get()
kb_summary = self._kb.summary()
runbooks = self._kb.runbook_titles(limit=4)
snapshot_ctx = summary_text(snapshot)
base_context = _join_context([
kb_summary,
runbooks,
f"ClusterSnapshot:{snapshot_ctx}" if snapshot_ctx else "",
])
started = time.monotonic()
if observer:
observer("classify", "classifying intent")
classify = await self._classify(question, base_context)
log.info(
"atlasbot_classify",
extra={"extra": {"mode": mode, "elapsed_sec": round(time.monotonic() - started, 2), "classify": classify}},
)
if observer:
observer("angles", "drafting angles")
angles = await self._angles(question, classify, mode)
log.info(
"atlasbot_angles",
extra={"extra": {"mode": mode, "count": len(angles)}},
)
if observer:
observer("candidates", "drafting answers")
candidates = await self._candidates(question, angles, base_context, mode)
log.info(
"atlasbot_candidates",
extra={"extra": {"mode": mode, "count": len(candidates)}},
)
if observer:
observer("select", "scoring candidates")
best, scores = await self._select_best(question, candidates)
log.info(
"atlasbot_selection",
extra={"extra": {"mode": mode, "selected": len(best), "scores": scores.__dict__}},
)
if observer:
observer("synthesize", "synthesizing reply")
reply = await self._synthesize(question, best, base_context)
meta = {
"mode": mode,
"angles": angles,
"scores": scores.__dict__,
"classify": classify,
"candidates": len(candidates),
}
return AnswerResult(reply, scores, meta)
async def _answer_stock(self, question: str) -> AnswerResult:
messages = build_messages(prompts.STOCK_SYSTEM, question)
reply = await self._llm.chat(messages, model=self._settings.ollama_model)
return AnswerResult(reply, _default_scores(), {"mode": "stock"})
async def _classify(self, question: str, context: str) -> dict[str, Any]:
prompt = prompts.CLASSIFY_PROMPT + "\nQuestion: " + question
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
return _parse_json_block(raw, fallback={"needs_snapshot": True})
async def _angles(self, question: str, classify: dict[str, Any], mode: str) -> list[dict[str, Any]]:
max_angles = self._settings.fast_max_angles if mode == "quick" else self._settings.smart_max_angles
prompt = prompts.ANGLE_PROMPT.format(max_angles=max_angles) + "\nQuestion: " + question
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
angles = _parse_json_list(raw)
if not angles:
return [{"name": "primary", "question": question, "relevance": 100}]
return angles[:max_angles]
async def _candidates(
self,
question: str,
angles: list[dict[str, Any]],
context: str,
mode: str,
) -> list[dict[str, Any]]:
limit = self._settings.fast_max_candidates if mode == "quick" else self._settings.smart_max_candidates
selected = angles[:limit]
tasks = []
for angle in selected:
angle_q = angle.get("question") or question
prompt = prompts.CANDIDATE_PROMPT + "\nQuestion: " + angle_q
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
tasks.append(self._llm.chat(messages, model=self._settings.ollama_model_smart))
replies = await asyncio.gather(*tasks)
candidates = []
for angle, reply in zip(selected, replies, strict=False):
candidates.append({"angle": angle, "reply": reply})
return candidates
async def _select_best(self, question: str, candidates: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], AnswerScores]:
if not candidates:
return ([], _default_scores())
scored: list[tuple[dict[str, Any], AnswerScores]] = []
for entry in candidates:
prompt = prompts.SCORE_PROMPT + "\nQuestion: " + question + "\nAnswer: " + entry["reply"]
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
data = _parse_json_block(raw, fallback={})
scores = _scores_from_json(data)
scored.append((entry, scores))
scored.sort(key=lambda item: (item[1].relevance, item[1].confidence), reverse=True)
best = [entry for entry, _scores in scored[:3]]
return best, scored[0][1]
async def _synthesize(self, question: str, best: list[dict[str, Any]], context: str) -> str:
if not best:
return "I do not have enough information to answer that yet."
parts = []
for item in best:
parts.append(f"- {item['reply']}")
prompt = (
prompts.SYNTHESIZE_PROMPT
+ "\nQuestion: "
+ question
+ "\nCandidate answers:\n"
+ "\n".join(parts)
)
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
reply = await self._llm.chat(messages, model=self._settings.ollama_model_smart)
return reply
def _join_context(parts: list[str]) -> str:
text = "\n".join([p for p in parts if p])
return text.strip()
def _parse_json_block(text: str, *, fallback: dict[str, Any]) -> dict[str, Any]:
raw = text.strip()
match = re.search(r"\{.*\}", raw, flags=re.S)
if match:
return parse_json(match.group(0), fallback=fallback)
return parse_json(raw, fallback=fallback)
def _parse_json_list(text: str) -> list[dict[str, Any]]:
raw = text.strip()
match = re.search(r"\[.*\]", raw, flags=re.S)
data = parse_json(match.group(0), fallback={}) if match else parse_json(raw, fallback={})
if isinstance(data, list):
return [entry for entry in data if isinstance(entry, dict)]
return []
def _scores_from_json(data: dict[str, Any]) -> AnswerScores:
return AnswerScores(
confidence=_coerce_int(data.get("confidence"), 60),
relevance=_coerce_int(data.get("relevance"), 60),
satisfaction=_coerce_int(data.get("satisfaction"), 60),
hallucination_risk=str(data.get("hallucination_risk") or "medium"),
)
def _coerce_int(value: Any, default: int) -> int:
try:
return int(float(value))
except (TypeError, ValueError):
return default
def _default_scores() -> AnswerScores:
return AnswerScores(confidence=60, relevance=60, satisfaction=60, hallucination_risk="medium")