224 lines
8.3 KiB
Python
224 lines
8.3 KiB
Python
import asyncio
|
|
import logging
|
|
import re
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable
|
|
|
|
from atlasbot.config import Settings
|
|
from atlasbot.knowledge.loader import KnowledgeBase
|
|
from atlasbot.llm.client import LLMClient, build_messages, parse_json
|
|
from atlasbot.llm import prompts
|
|
from atlasbot.snapshot.builder import SnapshotProvider, summary_text
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class AnswerScores:
|
|
confidence: int
|
|
relevance: int
|
|
satisfaction: int
|
|
hallucination_risk: str
|
|
|
|
|
|
@dataclass
|
|
class AnswerResult:
|
|
reply: str
|
|
scores: AnswerScores
|
|
meta: dict[str, Any]
|
|
|
|
|
|
class AnswerEngine:
|
|
def __init__(
|
|
self,
|
|
settings: Settings,
|
|
llm: LLMClient,
|
|
kb: KnowledgeBase,
|
|
snapshot: SnapshotProvider,
|
|
) -> None:
|
|
self._settings = settings
|
|
self._llm = llm
|
|
self._kb = kb
|
|
self._snapshot = snapshot
|
|
|
|
async def answer(
|
|
self,
|
|
question: str,
|
|
*,
|
|
mode: str,
|
|
observer: Callable[[str, str], None] | None = None,
|
|
) -> AnswerResult:
|
|
question = (question or "").strip()
|
|
if not question:
|
|
return AnswerResult("I need a question to answer.", _default_scores(), {"mode": mode})
|
|
if mode == "stock":
|
|
return await self._answer_stock(question)
|
|
|
|
snapshot = self._snapshot.get()
|
|
kb_summary = self._kb.summary()
|
|
runbooks = self._kb.runbook_titles(limit=4)
|
|
snapshot_ctx = summary_text(snapshot)
|
|
base_context = _join_context([
|
|
kb_summary,
|
|
runbooks,
|
|
f"ClusterSnapshot:{snapshot_ctx}" if snapshot_ctx else "",
|
|
])
|
|
|
|
started = time.monotonic()
|
|
if observer:
|
|
observer("classify", "classifying intent")
|
|
classify = await self._classify(question, base_context)
|
|
log.info(
|
|
"atlasbot_classify",
|
|
extra={"extra": {"mode": mode, "elapsed_sec": round(time.monotonic() - started, 2), "classify": classify}},
|
|
)
|
|
if observer:
|
|
observer("angles", "drafting angles")
|
|
angles = await self._angles(question, classify, mode)
|
|
log.info(
|
|
"atlasbot_angles",
|
|
extra={"extra": {"mode": mode, "count": len(angles)}},
|
|
)
|
|
if observer:
|
|
observer("candidates", "drafting answers")
|
|
candidates = await self._candidates(question, angles, base_context, mode)
|
|
log.info(
|
|
"atlasbot_candidates",
|
|
extra={"extra": {"mode": mode, "count": len(candidates)}},
|
|
)
|
|
if observer:
|
|
observer("select", "scoring candidates")
|
|
best, scores = await self._select_best(question, candidates)
|
|
log.info(
|
|
"atlasbot_selection",
|
|
extra={"extra": {"mode": mode, "selected": len(best), "scores": scores.__dict__}},
|
|
)
|
|
if observer:
|
|
observer("synthesize", "synthesizing reply")
|
|
reply = await self._synthesize(question, best, base_context)
|
|
meta = {
|
|
"mode": mode,
|
|
"angles": angles,
|
|
"scores": scores.__dict__,
|
|
"classify": classify,
|
|
"candidates": len(candidates),
|
|
}
|
|
return AnswerResult(reply, scores, meta)
|
|
|
|
async def _answer_stock(self, question: str) -> AnswerResult:
|
|
messages = build_messages(prompts.STOCK_SYSTEM, question)
|
|
reply = await self._llm.chat(messages, model=self._settings.ollama_model)
|
|
return AnswerResult(reply, _default_scores(), {"mode": "stock"})
|
|
|
|
async def _classify(self, question: str, context: str) -> dict[str, Any]:
|
|
prompt = prompts.CLASSIFY_PROMPT + "\nQuestion: " + question
|
|
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
|
|
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
|
|
return _parse_json_block(raw, fallback={"needs_snapshot": True})
|
|
|
|
async def _angles(self, question: str, classify: dict[str, Any], mode: str) -> list[dict[str, Any]]:
|
|
max_angles = self._settings.fast_max_angles if mode == "quick" else self._settings.smart_max_angles
|
|
prompt = prompts.ANGLE_PROMPT.format(max_angles=max_angles) + "\nQuestion: " + question
|
|
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
|
|
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
|
|
angles = _parse_json_list(raw)
|
|
if not angles:
|
|
return [{"name": "primary", "question": question, "relevance": 100}]
|
|
return angles[:max_angles]
|
|
|
|
async def _candidates(
|
|
self,
|
|
question: str,
|
|
angles: list[dict[str, Any]],
|
|
context: str,
|
|
mode: str,
|
|
) -> list[dict[str, Any]]:
|
|
limit = self._settings.fast_max_candidates if mode == "quick" else self._settings.smart_max_candidates
|
|
selected = angles[:limit]
|
|
tasks = []
|
|
for angle in selected:
|
|
angle_q = angle.get("question") or question
|
|
prompt = prompts.CANDIDATE_PROMPT + "\nQuestion: " + angle_q
|
|
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
|
|
tasks.append(self._llm.chat(messages, model=self._settings.ollama_model_smart))
|
|
replies = await asyncio.gather(*tasks)
|
|
candidates = []
|
|
for angle, reply in zip(selected, replies, strict=False):
|
|
candidates.append({"angle": angle, "reply": reply})
|
|
return candidates
|
|
|
|
async def _select_best(self, question: str, candidates: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], AnswerScores]:
|
|
if not candidates:
|
|
return ([], _default_scores())
|
|
scored: list[tuple[dict[str, Any], AnswerScores]] = []
|
|
for entry in candidates:
|
|
prompt = prompts.SCORE_PROMPT + "\nQuestion: " + question + "\nAnswer: " + entry["reply"]
|
|
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
|
|
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
|
|
data = _parse_json_block(raw, fallback={})
|
|
scores = _scores_from_json(data)
|
|
scored.append((entry, scores))
|
|
scored.sort(key=lambda item: (item[1].relevance, item[1].confidence), reverse=True)
|
|
best = [entry for entry, _scores in scored[:3]]
|
|
return best, scored[0][1]
|
|
|
|
async def _synthesize(self, question: str, best: list[dict[str, Any]], context: str) -> str:
|
|
if not best:
|
|
return "I do not have enough information to answer that yet."
|
|
parts = []
|
|
for item in best:
|
|
parts.append(f"- {item['reply']}")
|
|
prompt = (
|
|
prompts.SYNTHESIZE_PROMPT
|
|
+ "\nQuestion: "
|
|
+ question
|
|
+ "\nCandidate answers:\n"
|
|
+ "\n".join(parts)
|
|
)
|
|
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
|
|
reply = await self._llm.chat(messages, model=self._settings.ollama_model_smart)
|
|
return reply
|
|
|
|
|
|
def _join_context(parts: list[str]) -> str:
|
|
text = "\n".join([p for p in parts if p])
|
|
return text.strip()
|
|
|
|
|
|
def _parse_json_block(text: str, *, fallback: dict[str, Any]) -> dict[str, Any]:
|
|
raw = text.strip()
|
|
match = re.search(r"\{.*\}", raw, flags=re.S)
|
|
if match:
|
|
return parse_json(match.group(0), fallback=fallback)
|
|
return parse_json(raw, fallback=fallback)
|
|
|
|
|
|
def _parse_json_list(text: str) -> list[dict[str, Any]]:
|
|
raw = text.strip()
|
|
match = re.search(r"\[.*\]", raw, flags=re.S)
|
|
data = parse_json(match.group(0), fallback={}) if match else parse_json(raw, fallback={})
|
|
if isinstance(data, list):
|
|
return [entry for entry in data if isinstance(entry, dict)]
|
|
return []
|
|
|
|
|
|
def _scores_from_json(data: dict[str, Any]) -> AnswerScores:
|
|
return AnswerScores(
|
|
confidence=_coerce_int(data.get("confidence"), 60),
|
|
relevance=_coerce_int(data.get("relevance"), 60),
|
|
satisfaction=_coerce_int(data.get("satisfaction"), 60),
|
|
hallucination_risk=str(data.get("hallucination_risk") or "medium"),
|
|
)
|
|
|
|
|
|
def _coerce_int(value: Any, default: int) -> int:
|
|
try:
|
|
return int(float(value))
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
|
|
def _default_scores() -> AnswerScores:
|
|
return AnswerScores(confidence=60, relevance=60, satisfaction=60, hallucination_risk="medium")
|