atlasbot/atlasbot/engine/answerer.py

519 lines
19 KiB
Python
Raw Normal View History

2026-01-28 11:46:52 -03:00
import asyncio
2026-01-30 16:41:04 -03:00
import json
2026-01-28 11:46:52 -03:00
import logging
import re
import time
from dataclasses import dataclass
from typing import Any, Callable
from atlasbot.config import Settings
from atlasbot.knowledge.loader import KnowledgeBase
from atlasbot.llm.client import LLMClient, build_messages, parse_json
from atlasbot.llm import prompts
2026-01-30 16:41:04 -03:00
from atlasbot.snapshot.builder import SnapshotProvider, build_summary, summary_text
2026-01-28 11:46:52 -03:00
log = logging.getLogger(__name__)
@dataclass
class AnswerScores:
confidence: int
relevance: int
satisfaction: int
hallucination_risk: str
@dataclass
class AnswerResult:
reply: str
scores: AnswerScores
meta: dict[str, Any]
2026-01-30 16:41:04 -03:00
@dataclass
class EvidenceItem:
path: str
reason: str
value: Any | None = None
value_at_claim: Any | None = None
@dataclass
class ClaimItem:
id: str
claim: str
evidence: list[EvidenceItem]
@dataclass
class ConversationState:
updated_at: float
claims: list[ClaimItem]
snapshot_id: str | None = None
snapshot: dict[str, Any] | None = None
2026-01-28 11:46:52 -03:00
class AnswerEngine:
def __init__(
self,
settings: Settings,
llm: LLMClient,
kb: KnowledgeBase,
snapshot: SnapshotProvider,
) -> None:
self._settings = settings
self._llm = llm
self._kb = kb
self._snapshot = snapshot
2026-01-30 16:41:04 -03:00
self._state: dict[str, ConversationState] = {}
2026-01-28 11:46:52 -03:00
async def answer(
self,
question: str,
*,
mode: str,
history: list[dict[str, str]] | None = None,
2026-01-28 11:46:52 -03:00
observer: Callable[[str, str], None] | None = None,
2026-01-30 16:41:04 -03:00
conversation_id: str | None = None,
2026-01-28 11:46:52 -03:00
) -> AnswerResult:
question = (question or "").strip()
if not question:
return AnswerResult("I need a question to answer.", _default_scores(), {"mode": mode})
if mode == "stock":
return await self._answer_stock(question)
2026-01-30 16:41:04 -03:00
state = self._get_state(conversation_id)
2026-01-28 11:46:52 -03:00
snapshot = self._snapshot.get()
2026-01-30 16:41:04 -03:00
snapshot_used = snapshot
if self._settings.snapshot_pin_enabled and state and state.snapshot:
snapshot_used = state.snapshot
summary = build_summary(snapshot_used)
2026-01-28 11:46:52 -03:00
kb_summary = self._kb.summary()
runbooks = self._kb.runbook_titles(limit=4)
2026-01-30 16:41:04 -03:00
snapshot_ctx = summary_text(snapshot_used)
history_ctx = _format_history(history)
2026-01-28 11:46:52 -03:00
base_context = _join_context([
kb_summary,
runbooks,
f"ClusterSnapshot:{snapshot_ctx}" if snapshot_ctx else "",
history_ctx,
2026-01-28 11:46:52 -03:00
])
started = time.monotonic()
if observer:
observer("classify", "classifying intent")
classify = await self._classify(question, base_context)
log.info(
"atlasbot_classify",
extra={"extra": {"mode": mode, "elapsed_sec": round(time.monotonic() - started, 2), "classify": classify}},
)
if observer:
observer("angles", "drafting angles")
angles = await self._angles(question, classify, mode)
log.info(
"atlasbot_angles",
extra={"extra": {"mode": mode, "count": len(angles)}},
)
if observer:
observer("candidates", "drafting answers")
2026-01-29 20:53:28 -03:00
candidates = await self._candidates(question, angles, base_context, classify, mode)
2026-01-28 11:46:52 -03:00
log.info(
"atlasbot_candidates",
extra={"extra": {"mode": mode, "count": len(candidates)}},
)
if observer:
observer("select", "scoring candidates")
best, scores = await self._select_best(question, candidates)
log.info(
"atlasbot_selection",
extra={"extra": {"mode": mode, "selected": len(best), "scores": scores.__dict__}},
)
2026-01-30 16:41:04 -03:00
if classify.get("follow_up") and state and state.claims:
if observer:
observer("followup", "answering follow-up")
reply = await self._answer_followup(question, state, summary, classify, mode)
meta = {
"mode": mode,
"follow_up": True,
"classify": classify,
}
return AnswerResult(reply, scores, meta)
2026-01-28 11:46:52 -03:00
if observer:
observer("synthesize", "synthesizing reply")
reply = await self._synthesize(question, best, base_context, classify, mode)
2026-01-30 16:41:04 -03:00
claims = await self._extract_claims(question, reply, summary)
if conversation_id and claims:
self._store_state(conversation_id, claims, summary, snapshot_used)
2026-01-28 11:46:52 -03:00
meta = {
"mode": mode,
"angles": angles,
"scores": scores.__dict__,
"classify": classify,
"candidates": len(candidates),
2026-01-30 16:41:04 -03:00
"claims": len(claims),
2026-01-28 11:46:52 -03:00
}
return AnswerResult(reply, scores, meta)
async def _answer_stock(self, question: str) -> AnswerResult:
messages = build_messages(prompts.STOCK_SYSTEM, question)
reply = await self._llm.chat(messages, model=self._settings.ollama_model)
return AnswerResult(reply, _default_scores(), {"mode": "stock"})
async def _classify(self, question: str, context: str) -> dict[str, Any]:
prompt = prompts.CLASSIFY_PROMPT + "\nQuestion: " + question
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
data = _parse_json_block(raw, fallback={"needs_snapshot": True})
if "answer_style" not in data:
data["answer_style"] = "direct"
2026-01-30 16:41:04 -03:00
if "follow_up_kind" not in data:
data["follow_up_kind"] = "other"
return data
2026-01-28 11:46:52 -03:00
async def _angles(self, question: str, classify: dict[str, Any], mode: str) -> list[dict[str, Any]]:
max_angles = _angles_limit(self._settings, mode)
2026-01-28 11:46:52 -03:00
prompt = prompts.ANGLE_PROMPT.format(max_angles=max_angles) + "\nQuestion: " + question
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
angles = _parse_json_list(raw)
if not angles:
return [{"name": "primary", "question": question, "relevance": 100}]
2026-01-29 20:53:28 -03:00
if classify.get("answer_style") == "insightful":
if not any("implication" in (a.get("name") or "").lower() for a in angles):
angles.append({"name": "implications", "question": f"What are the implications of the data for: {question}", "relevance": 85})
2026-01-28 11:46:52 -03:00
return angles[:max_angles]
async def _candidates(
self,
question: str,
angles: list[dict[str, Any]],
context: str,
2026-01-29 20:53:28 -03:00
classify: dict[str, Any],
2026-01-28 11:46:52 -03:00
mode: str,
) -> list[dict[str, Any]]:
limit = _candidates_limit(self._settings, mode)
2026-01-28 11:46:52 -03:00
selected = angles[:limit]
tasks = []
model = _candidate_model(self._settings, mode)
2026-01-28 11:46:52 -03:00
for angle in selected:
angle_q = angle.get("question") or question
prompt = prompts.CANDIDATE_PROMPT + "\nQuestion: " + angle_q
2026-01-29 20:53:28 -03:00
if classify.get("answer_style"):
prompt += f"\nAnswerStyle: {classify.get('answer_style')}"
2026-01-28 11:46:52 -03:00
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
tasks.append(self._llm.chat(messages, model=model))
2026-01-28 11:46:52 -03:00
replies = await asyncio.gather(*tasks)
candidates = []
for angle, reply in zip(selected, replies, strict=False):
candidates.append({"angle": angle, "reply": reply})
return candidates
async def _select_best(self, question: str, candidates: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], AnswerScores]:
if not candidates:
return ([], _default_scores())
scored: list[tuple[dict[str, Any], AnswerScores]] = []
for entry in candidates:
prompt = prompts.SCORE_PROMPT + "\nQuestion: " + question + "\nAnswer: " + entry["reply"]
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
data = _parse_json_block(raw, fallback={})
scores = _scores_from_json(data)
scored.append((entry, scores))
scored.sort(key=lambda item: (item[1].relevance, item[1].confidence), reverse=True)
best = [entry for entry, _scores in scored[:3]]
return best, scored[0][1]
async def _synthesize(
self,
question: str,
best: list[dict[str, Any]],
context: str,
classify: dict[str, Any],
mode: str,
) -> str:
2026-01-28 11:46:52 -03:00
if not best:
return "I do not have enough information to answer that yet."
parts = []
for item in best:
parts.append(f"- {item['reply']}")
style = classify.get("answer_style") if isinstance(classify, dict) else None
intent = classify.get("intent") if isinstance(classify, dict) else None
ambiguity = classify.get("ambiguity") if isinstance(classify, dict) else None
style_line = f"AnswerStyle: {style}" if style else "AnswerStyle: default"
if intent:
style_line = f"{style_line}; Intent: {intent}"
if ambiguity is not None:
style_line = f"{style_line}; Ambiguity: {ambiguity}"
2026-01-28 11:46:52 -03:00
prompt = (
prompts.SYNTHESIZE_PROMPT
+ "\n"
+ style_line
2026-01-28 11:46:52 -03:00
+ "\nQuestion: "
+ question
+ "\nCandidate answers:\n"
+ "\n".join(parts)
)
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
model = _synthesis_model(self._settings, mode)
reply = await self._llm.chat(messages, model=model)
2026-01-29 20:53:28 -03:00
needs_refine = _needs_refine(reply, classify)
if not needs_refine:
return reply
refine_prompt = prompts.REFINE_PROMPT + "\nQuestion: " + question + "\nDraft: " + reply
refine_messages = build_messages(prompts.CLUSTER_SYSTEM, refine_prompt, context=context)
return await self._llm.chat(refine_messages, model=model)
2026-01-28 11:46:52 -03:00
2026-01-30 16:41:04 -03:00
async def _extract_claims(
self,
question: str,
reply: str,
summary: dict[str, Any],
) -> list[ClaimItem]:
if not reply or not summary:
return []
summary_json = _json_excerpt(summary)
prompt = prompts.CLAIM_MAP_PROMPT + "\nQuestion: " + question + "\nAnswer: " + reply
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=f"SnapshotSummaryJson:{summary_json}")
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
data = _parse_json_block(raw, fallback={})
claims_raw = data.get("claims") if isinstance(data, dict) else None
claims: list[ClaimItem] = []
if isinstance(claims_raw, list):
for entry in claims_raw:
if not isinstance(entry, dict):
continue
claim_text = str(entry.get("claim") or "").strip()
claim_id = str(entry.get("id") or "").strip() or f"c{len(claims)+1}"
evidence_items: list[EvidenceItem] = []
for ev in entry.get("evidence") or []:
if not isinstance(ev, dict):
continue
path = str(ev.get("path") or "").strip()
if not path:
continue
reason = str(ev.get("reason") or "").strip()
value = _resolve_path(summary, path)
evidence_items.append(EvidenceItem(path=path, reason=reason, value=value, value_at_claim=value))
if claim_text and evidence_items:
claims.append(ClaimItem(id=claim_id, claim=claim_text, evidence=evidence_items))
return claims
async def _answer_followup(
self,
question: str,
state: ConversationState,
summary: dict[str, Any],
classify: dict[str, Any],
mode: str,
) -> str:
follow_kind = classify.get("follow_up_kind") if isinstance(classify, dict) else "other"
claim_ids = await self._select_claims(question, state.claims)
selected = [claim for claim in state.claims if claim.id in claim_ids] if claim_ids else state.claims[:2]
evidence_lines = []
for claim in selected:
evidence_lines.append(f"Claim: {claim.claim}")
for ev in claim.evidence:
current = _resolve_path(summary, ev.path)
ev.value = current
delta_note = ""
if ev.value_at_claim is not None and current is not None and current != ev.value_at_claim:
delta_note = f" (now {current})"
evidence_lines.append(f"- {ev.path}: {ev.value_at_claim}{delta_note} {('- ' + ev.reason) if ev.reason else ''}")
evidence_ctx = "\n".join(evidence_lines)
prompt = prompts.FOLLOWUP_EVIDENCE_PROMPT
if follow_kind in {"next_steps", "change"}:
prompt = prompts.FOLLOWUP_ACTION_PROMPT
prompt = prompt + "\nFollow-up: " + question + "\nEvidence:\n" + evidence_ctx
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
model = _synthesis_model(self._settings, mode)
return await self._llm.chat(messages, model=model)
async def _select_claims(self, question: str, claims: list[ClaimItem]) -> list[str]:
if not claims:
return []
claims_brief = [{"id": claim.id, "claim": claim.claim} for claim in claims]
prompt = prompts.SELECT_CLAIMS_PROMPT + "\nFollow-up: " + question + "\nClaims: " + json.dumps(claims_brief)
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
data = _parse_json_block(raw, fallback={})
ids = data.get("claim_ids") if isinstance(data, dict) else []
if isinstance(ids, list):
return [str(item) for item in ids if item]
return []
def _get_state(self, conversation_id: str | None) -> ConversationState | None:
if not conversation_id:
return None
self._cleanup_state()
return self._state.get(conversation_id)
def _store_state(
self,
conversation_id: str,
claims: list[ClaimItem],
summary: dict[str, Any],
snapshot: dict[str, Any] | None,
) -> None:
snapshot_id = _snapshot_id(summary)
pinned_snapshot = snapshot if self._settings.snapshot_pin_enabled else None
self._state[conversation_id] = ConversationState(
updated_at=time.monotonic(),
claims=claims,
snapshot_id=snapshot_id,
snapshot=pinned_snapshot,
)
self._cleanup_state()
def _cleanup_state(self) -> None:
ttl = max(60, self._settings.conversation_ttl_sec)
now = time.monotonic()
expired = [key for key, state in self._state.items() if now - state.updated_at > ttl]
for key in expired:
self._state.pop(key, None)
2026-01-28 11:46:52 -03:00
def _join_context(parts: list[str]) -> str:
text = "\n".join([p for p in parts if p])
return text.strip()
def _format_history(history: list[dict[str, str]] | None) -> str:
if not history:
return ""
lines = ["Recent conversation:"]
for entry in history[-4:]:
2026-01-30 16:41:04 -03:00
if not isinstance(entry, dict):
continue
question = entry.get("q")
answer = entry.get("a")
role = entry.get("role")
content = entry.get("content")
if question:
lines.append(f"Q: {question}")
if answer:
lines.append(f"A: {answer}")
2026-01-30 16:41:04 -03:00
if role and content:
prefix = "Q" if role == "user" else "A"
lines.append(f"{prefix}: {content}")
return "\n".join(lines)
def _angles_limit(settings: Settings, mode: str) -> int:
if mode == "genius":
return settings.genius_max_angles
if mode == "quick":
return settings.fast_max_angles
return settings.smart_max_angles
def _candidates_limit(settings: Settings, mode: str) -> int:
if mode == "genius":
return settings.genius_max_candidates
if mode == "quick":
return settings.fast_max_candidates
return settings.smart_max_candidates
def _candidate_model(settings: Settings, mode: str) -> str:
if mode == "genius":
return settings.ollama_model_genius
return settings.ollama_model_smart
def _synthesis_model(settings: Settings, mode: str) -> str:
if mode == "genius":
return settings.ollama_model_genius
return settings.ollama_model_smart
2026-01-28 11:46:52 -03:00
def _parse_json_block(text: str, *, fallback: dict[str, Any]) -> dict[str, Any]:
raw = text.strip()
match = re.search(r"\{.*\}", raw, flags=re.S)
if match:
return parse_json(match.group(0), fallback=fallback)
return parse_json(raw, fallback=fallback)
2026-01-29 20:53:28 -03:00
def _needs_refine(reply: str, classify: dict[str, Any]) -> bool:
if not reply:
return False
style = classify.get("answer_style") if isinstance(classify, dict) else None
if style != "insightful":
return False
metric_markers = ["cpu", "ram", "pods", "connections", "%"]
lower = reply.lower()
metric_hits = sum(1 for m in metric_markers if m in lower)
sentence_count = reply.count(".") + reply.count("!") + reply.count("?")
return metric_hits >= 2 and sentence_count <= 2
2026-01-28 11:46:52 -03:00
def _parse_json_list(text: str) -> list[dict[str, Any]]:
raw = text.strip()
match = re.search(r"\[.*\]", raw, flags=re.S)
data = parse_json(match.group(0), fallback={}) if match else parse_json(raw, fallback={})
if isinstance(data, list):
return [entry for entry in data if isinstance(entry, dict)]
return []
def _scores_from_json(data: dict[str, Any]) -> AnswerScores:
return AnswerScores(
confidence=_coerce_int(data.get("confidence"), 60),
relevance=_coerce_int(data.get("relevance"), 60),
satisfaction=_coerce_int(data.get("satisfaction"), 60),
hallucination_risk=str(data.get("hallucination_risk") or "medium"),
)
def _coerce_int(value: Any, default: int) -> int:
try:
return int(float(value))
except (TypeError, ValueError):
return default
def _default_scores() -> AnswerScores:
return AnswerScores(confidence=60, relevance=60, satisfaction=60, hallucination_risk="medium")
2026-01-30 16:41:04 -03:00
def _resolve_path(data: Any, path: str) -> Any | None:
cursor = data
for part in re.split(r"\\.(?![^\\[]*\\])", path):
if not part:
continue
match = re.match(r"^(\\w+)(?:\\[(\\d+)\\])?$", part)
if not match:
return None
key = match.group(1)
index = match.group(2)
if isinstance(cursor, dict):
cursor = cursor.get(key)
else:
return None
if index is not None:
try:
idx = int(index)
if isinstance(cursor, list) and 0 <= idx < len(cursor):
cursor = cursor[idx]
else:
return None
except ValueError:
return None
return cursor
def _snapshot_id(summary: dict[str, Any]) -> str | None:
if not summary:
return None
for key in ("generated_at", "snapshot_ts", "snapshot_id"):
value = summary.get(key)
if isinstance(value, str) and value:
return value
return None
def _json_excerpt(summary: dict[str, Any], max_chars: int = 12000) -> str:
raw = json.dumps(summary, ensure_ascii=False)
return raw[:max_chars]