485 lines
26 KiB
Python

from __future__ import annotations
import asyncio
import json
import math
import re
import time
from collections.abc import Callable
from typing import Any
from atlasbot.engine.intent_router import route_intent
from atlasbot.llm import prompts
from atlasbot.llm.client import build_messages
from atlasbot.snapshot.builder import build_summary
from ._base import *
from .common import *
from .factsheet import *
from .post import *
from .post_ext import *
from .retrieval import *
from .retrieval_ext import *
from .spine import *
from .workflow_post import finalize_answer
async def run_answer(engine: Any, question: str, *, mode: str, history: list[dict[str, str]] | None = None, observer: Callable[[str, str], None] | None = None, conversation_id: str | None = None, snapshot_pin: bool | None = None) -> AnswerResult: # noqa: C901
"""Answer a question using the staged reasoning pipeline."""
settings = engine._settings
question = (question or "").strip()
if not question:
return AnswerResult("I need a question to answer.", _default_scores(), {"mode": mode})
if mode == "stock":
return await engine._answer_stock(question)
limitless = "run limitless" in question.lower()
if limitless:
question = re.sub(r"(?i)run limitless", "", question).strip()
plan = _mode_plan(settings, mode)
call_limit = _llm_call_limit(settings, mode)
call_cap = math.ceil(call_limit * settings.llm_limit_multiplier)
call_count = 0
limit_hit = False
time_budget_hit = False
started = time.monotonic()
time_budget_sec = _mode_time_budget(settings, mode) if not limitless else 0.0
debug_tags = {
"route",
"decompose",
"chunk_score",
"chunk_select",
"fact_select",
"synth",
"subanswer",
"tool",
"followup",
"select_claims",
"evidence_fix",
}
async def call_llm(system: str, prompt: str, *, context: str | None = None, model: str | None = None, tag: str = "") -> str:
nonlocal call_count, limit_hit, time_budget_hit
if not limitless and call_count >= call_cap:
limit_hit = True
raise LLMLimitReached("llm_limit")
timeout_sec = None
if not limitless and time_budget_sec > 0:
time_left = time_budget_sec - (time.monotonic() - started)
if time_left <= 0:
time_budget_hit = True
raise LLMTimeBudgetExceeded("time_budget")
timeout_sec = min(settings.ollama_timeout_sec, time_left)
call_count += 1
messages = build_messages(system, prompt, context=context)
try:
llm_call = engine._llm.chat(messages, model=model or plan.model, timeout_sec=timeout_sec)
if timeout_sec is not None:
response = await asyncio.wait_for(llm_call, timeout=max(0.001, timeout_sec))
else:
response = await llm_call
except TimeoutError as exc:
time_budget_hit = True
raise LLMTimeBudgetExceeded("time_budget") from exc
log.info(
"atlasbot_llm_call",
extra={"extra": {"mode": mode, "tag": tag, "call": call_count, "limit": call_cap}},
)
if settings.debug_pipeline and tag in debug_tags:
_debug_pipeline_log(settings, f"llm_raw_{tag}", str(response)[:1200])
return response
state = engine._get_state(conversation_id)
pin_snapshot = bool(snapshot_pin) or settings.snapshot_pin_enabled
snapshot = engine._snapshot.get()
snapshot_used = state.snapshot if pin_snapshot and state and state.snapshot else snapshot
summary = build_summary(snapshot_used)
summary_lines = _summary_lines(snapshot_used)
allowed_nodes = _allowed_nodes(summary)
allowed_namespaces = _allowed_namespaces(summary)
spine = _spine_from_summary(summary) or _spine_lines(summary_lines)
metric_tokens = _metric_key_tokens(summary_lines)
global_facts = _global_facts(summary_lines)
kb_summary = engine._kb.summary()
runbooks = engine._kb.runbook_titles(limit=6)
runbook_paths = engine._kb.runbook_paths(limit=10)
history_ctx = _format_history(history)
lexicon_ctx = _lexicon_context(summary)
key_facts: list[str] = []
metric_facts: list[str] = []
facts_used: list[str] = []
reply = ""
scores = _default_scores()
claims: list[ClaimItem] = []
classify: dict[str, Any] = {}
tool_hint: dict[str, Any] | None = None
try:
if mode in {"quick", "fast", "smart", "genius"} and not limitless:
if observer:
observer("factsheet", "building fact sheet")
if _is_plain_math_question(question):
reply = (
"I focus on Titan cluster operations. Ask me about cluster health, nodes, workloads, "
"namespaces, storage, or alerts."
)
return AnswerResult(reply, _default_scores(), _build_meta(mode, call_count, call_cap, limit_hit, time_budget_hit, time_budget_sec, classify, tool_hint, started))
kb_lines = (
engine._kb.chunk_lines(max_files=plan.kb_max_files, max_chars=_factsheet_kb_chars(mode, plan.kb_max_chars))
if engine._kb
else []
)
fact_lines = _quick_fact_sheet_lines(question, summary_lines, kb_lines, limit=_factsheet_line_limit(mode))
classify = {
"needs_snapshot": True,
"needs_kb": bool(kb_lines),
"question_type": f"{mode}_factsheet",
"answer_style": "direct" if mode in {"quick", "fast"} else "concise",
"follow_up": False,
}
heuristic_reply = _quick_fact_sheet_heuristic_answer(question, fact_lines)
if heuristic_reply:
return AnswerResult(heuristic_reply, _default_scores(), _build_meta(mode, call_count, call_cap, limit_hit, time_budget_hit, time_budget_sec, classify, tool_hint, started))
if observer:
observer("quick", "answering from fact sheet")
quick_context = _quick_fact_sheet_text(fact_lines)
quick_prompt = "Question: " + question + "\nAnswer using only the Fact Sheet. " + _factsheet_instruction(mode)
reply = await call_llm(prompts.ANSWER_SYSTEM, quick_prompt, context=quick_context, model=_factsheet_model(mode, plan), tag=f"{mode}_factsheet")
reply = _strip_followup_meta(reply)
return AnswerResult(reply, _default_scores(), _build_meta(mode, call_count, call_cap, limit_hit, time_budget_hit, time_budget_sec, classify, tool_hint, started))
if observer:
observer("normalize", "normalizing")
normalize_prompt = prompts.NORMALIZE_PROMPT + "\nQuestion: " + question
normalize_raw = await call_llm(prompts.NORMALIZE_SYSTEM, normalize_prompt, context=lexicon_ctx, model=plan.fast_model, tag="normalize")
normalize = _parse_json_block(normalize_raw, fallback={"normalized": question, "keywords": []})
normalized = str(normalize.get("normalized") or question).strip() or question
keywords = normalize.get("keywords") or []
_debug_pipeline_log(settings, "normalize_parsed", {"normalized": normalized, "keywords": keywords})
keyword_tokens = _extract_keywords(question, normalized, sub_questions=[], keywords=keywords)
question_tokens = _extract_question_tokens(normalized)
if observer:
observer("route", "routing")
route_prompt = prompts.ROUTE_PROMPT + "\nQuestion: " + normalized + "\nKeywords: " + json.dumps(keywords)
route_raw = await call_llm(prompts.ROUTE_SYSTEM, route_prompt, context=_join_context([kb_summary, lexicon_ctx]), model=plan.fast_model, tag="route")
classify = _parse_json_block(route_raw, fallback={})
classify.setdefault("needs_snapshot", True)
classify.setdefault("answer_style", "direct")
classify.setdefault("follow_up", False)
classify.setdefault("focus_entity", "unknown")
classify.setdefault("focus_metric", "unknown")
if metric_tokens and keyword_tokens and any(token in metric_tokens for token in keyword_tokens):
classify["needs_snapshot"] = True
intent = route_intent(normalized)
if intent:
classify["needs_snapshot"] = True
classify["question_type"] = "metric"
_debug_pipeline_log(settings, "route_parsed", {"classify": classify, "normalized": normalized})
lowered_question = f"{question} {normalized}".lower()
force_metric = bool(re.search(r"\bhow many\b|\bcount\b|\btotal\b", lowered_question))
if any(term in lowered_question for term in ("postgres", "connections", "pvc", "ready")):
force_metric = True
if intent:
spine_line = spine.get(intent.kind) if isinstance(spine, dict) else None
if not spine_line:
spine_line = _spine_fallback(intent, summary_lines)
spine_answer = _spine_answer(intent, spine_line)
if spine_line:
key_facts = _merge_fact_lines([spine_line], key_facts)
metric_facts = _merge_fact_lines([spine_line], metric_facts)
if spine_answer and mode in {"fast", "quick"}:
return AnswerResult(spine_answer, _default_scores(), _build_meta(mode, call_count, call_cap, limit_hit, time_budget_hit, time_budget_sec, classify, tool_hint, started))
cluster_terms = (
"atlas",
"cluster",
"node",
"nodes",
"namespace",
"pod",
"workload",
"k8s",
"kubernetes",
"postgres",
"database",
"db",
"connections",
"cpu",
"ram",
"memory",
"network",
"io",
"disk",
"pvc",
"storage",
)
has_cluster_terms = any(term in lowered_question for term in cluster_terms)
if has_cluster_terms:
classify["needs_snapshot"] = True
lowered_norm = normalized.lower()
if ("namespace" in lowered_norm and ("pod" in lowered_norm or "pods" in lowered_norm)) or re.search(r"\bmost\s+pods\b", lowered_norm) or re.search(r"\bpods\s+running\b", lowered_norm):
classify["question_type"] = "metric"
classify["needs_snapshot"] = True
if re.search(r"\b(how many|count|number of|list)\b", lowered_question):
classify["question_type"] = "metric"
if any(term in lowered_question for term in ("postgres", "connections", "db")):
classify["question_type"] = "metric"
classify["needs_snapshot"] = True
if any(term in lowered_question for term in ("pvc", "persistentvolume", "persistent volume", "storage")):
if classify.get("question_type") not in {"metric", "diagnostic"}:
classify["question_type"] = "metric"
classify["needs_snapshot"] = True
if "ready" in lowered_question and classify.get("question_type") not in {"metric", "diagnostic"}:
classify["question_type"] = "diagnostic"
hottest_terms = ("hottest", "highest", "lowest", "most")
metric_terms = ("cpu", "ram", "memory", "net", "network", "io", "disk", "load", "usage", "pod", "pods", "namespace")
if any(term in lowered_question for term in hottest_terms) and any(term in lowered_question for term in metric_terms):
classify["question_type"] = "metric"
baseline_terms = ("baseline", "delta", "trend", "increase", "decrease", "drop", "spike", "regression", "change")
if any(term in lowered_question for term in baseline_terms) and any(term in lowered_question for term in metric_terms):
classify["question_type"] = "metric"
classify["needs_snapshot"] = True
if not classify.get("follow_up") and state and state.claims:
follow_terms = ("there", "that", "those", "these", "it", "them", "that one", "this", "former", "latter")
is_metric_query = force_metric or classify.get("question_type") in {"metric", "diagnostic"}
if not is_metric_query and (
any(term in lowered_question for term in follow_terms)
or (len(normalized.split()) <= FOLLOWUP_SHORT_WORDS and not has_cluster_terms)
):
classify["follow_up"] = True
if classify.get("follow_up") and state and state.claims:
if observer:
observer("followup", "answering follow-up")
reply = await engine._answer_followup(question, state, summary, classify, plan, call_llm)
scores = await engine._score_answer(question, reply, plan, call_llm)
return AnswerResult(reply, scores, _build_meta(mode, call_count, call_cap, limit_hit, time_budget_hit, time_budget_sec, classify, tool_hint, started))
if observer:
observer("decompose", "decomposing")
decompose_prompt = prompts.DECOMPOSE_PROMPT.format(max_parts=plan.max_subquestions * 2)
decompose_raw = await call_llm(prompts.DECOMPOSE_SYSTEM, decompose_prompt + "\nQuestion: " + normalized, context=lexicon_ctx, model=plan.fast_model if mode == "quick" else plan.model, tag="decompose")
parts = _parse_json_list(decompose_raw)
sub_questions = _select_subquestions(parts, normalized, plan.max_subquestions)
_debug_pipeline_log(settings, "decompose_parsed", {"sub_questions": sub_questions})
keyword_tokens = _extract_keywords(question, normalized, sub_questions=sub_questions, keywords=keywords)
snapshot_context = ""
signal_tokens: list[str] = []
if classify.get("needs_snapshot"):
if observer:
observer("retrieve", "scoring chunks")
chunks = _chunk_lines(summary_lines, plan.chunk_lines)
if plan.use_raw_snapshot:
raw_chunks = _raw_snapshot_chunks(snapshot_used)
if raw_chunks:
chunks.extend(raw_chunks)
kb_lines = engine._kb.chunk_lines(max_files=plan.kb_max_files, max_chars=plan.kb_max_chars) if engine._kb else []
if kb_lines:
kb_chunks = _chunk_lines(kb_lines, plan.chunk_lines)
for idx, chunk in enumerate(kb_chunks):
chunk["id"] = f"k{idx}"
chunks.extend(kb_chunks)
metric_keys: list[str] = []
must_chunk_ids: list[str] = []
metric_task = None
if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines:
metric_ctx = {"question": normalized, "sub_questions": sub_questions, "keywords": keywords, "keyword_tokens": keyword_tokens, "summary_lines": summary_lines}
metric_task = asyncio.create_task(_select_metric_chunks(call_llm, metric_ctx, chunks, plan))
scored_task = asyncio.create_task(_score_chunks(call_llm, chunks, normalized, sub_questions, plan))
if metric_task:
metric_keys, must_chunk_ids = await metric_task
scored = await scored_task
selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids)
fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12)
key_facts = await _select_fact_lines(call_llm, normalized, fact_candidates, plan, max_lines=max(4, plan.max_subquestions * 2))
metric_facts = []
if classify.get("question_type") in {"metric", "diagnostic"} or force_metric:
global_metric_facts: list[str] = []
if global_facts:
global_metric_facts = await _select_fact_lines(call_llm, normalized, global_facts, plan, max_lines=min(2, max(1, plan.max_subquestions)))
if not global_metric_facts and (keyword_tokens or question_tokens):
tokens = {tok for tok in (keyword_tokens or question_tokens) if tok and tok not in GENERIC_METRIC_TOKENS}
global_metric_facts = _rank_metric_lines(global_facts, tokens, max_lines=2)
if global_metric_facts:
key_facts = _merge_fact_lines(global_metric_facts, key_facts)
all_tokens = _merge_tokens(signal_tokens, keyword_tokens, question_tokens)
if plan.use_deep_retrieval:
if observer:
observer("retrieve", "extracting fact types")
fact_types = await _extract_fact_types(call_llm, normalized, keyword_tokens, plan)
if observer:
observer("retrieve", "deriving signals")
signals = await _derive_signals(call_llm, normalized, fact_types, plan)
if isinstance(signals, list):
signal_tokens = [str(item) for item in signals if item]
all_tokens = _merge_tokens(signal_tokens, keyword_tokens, question_tokens)
if observer:
observer("retrieve", "scanning chunks")
candidate_lines: list[str] = []
if signals:
for chunk in selected:
chunk_lines = chunk["text"].splitlines()
if not chunk_lines:
continue
hits = await _scan_chunk_for_signals(call_llm, normalized, signals, chunk_lines, plan)
if hits:
candidate_lines.extend(hits)
candidate_lines = list(dict.fromkeys(candidate_lines))
if candidate_lines:
if observer:
observer("retrieve", "pruning candidates")
metric_facts = await _prune_metric_candidates(call_llm, normalized, candidate_lines, plan, plan.metric_retries)
if metric_facts:
key_facts = _merge_fact_lines(metric_facts, key_facts)
if settings.debug_pipeline:
_debug_pipeline_log(settings, "metric_facts_selected", {"facts": metric_facts})
if not metric_facts:
if observer:
observer("retrieve", "fallback metric selection")
token_set = {tok for tok in all_tokens if tok and tok not in GENERIC_METRIC_TOKENS}
fallback_candidates = _rank_metric_lines(summary_lines, token_set, max_lines=200)
if fallback_candidates:
metric_facts = await _select_fact_lines(call_llm, normalized, fallback_candidates, plan, max_lines=max(2, plan.max_subquestions))
if not metric_facts and fallback_candidates:
metric_facts = fallback_candidates[: max(2, plan.max_subquestions)]
if metric_keys:
key_lines = _lines_for_metric_keys(summary_lines, metric_keys, max_lines=plan.max_subquestions * 3)
if key_lines:
metric_facts = _merge_fact_lines(key_lines, metric_facts)
if metric_facts:
metric_cover_tokens = [tok for tok in keyword_tokens if tok and tok not in GENERIC_METRIC_TOKENS]
if not metric_cover_tokens:
metric_cover_tokens = [tok for tok in question_tokens if tok and tok not in GENERIC_METRIC_TOKENS]
metric_facts = _ensure_token_coverage(metric_facts, metric_cover_tokens or all_tokens, summary_lines, max_add=plan.max_subquestions)
if metric_cover_tokens:
ranked_metric_lines = _rank_metric_lines(summary_lines, set(metric_cover_tokens), max_lines=max(1, plan.max_subquestions))
if ranked_metric_lines:
metric_facts = _merge_fact_lines(ranked_metric_lines, metric_facts)
if metric_facts and not _has_keyword_overlap(metric_facts, keyword_tokens):
best_line = _best_keyword_line(summary_lines, keyword_tokens)
if best_line:
metric_facts = _merge_fact_lines([best_line], metric_facts)
if metric_facts:
key_facts = _merge_fact_lines(metric_facts, key_facts)
if global_metric_facts:
metric_facts = _merge_fact_lines(global_metric_facts, metric_facts)
if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and not metric_facts and key_facts:
metric_facts = key_facts
if key_facts:
key_facts = _ensure_token_coverage(key_facts, _merge_tokens(keyword_tokens, question_tokens), summary_lines, max_add=plan.max_subquestions)
facts_used = list(dict.fromkeys(key_facts)) if key_facts else list(dict.fromkeys(metric_facts))
snapshot_context = "ClusterSnapshot:\n" + "\n".join([chunk["text"] for chunk in selected])
combined_facts = _merge_fact_lines(global_facts, key_facts) if global_facts else key_facts
if combined_facts:
snapshot_context = "KeyFacts:\n" + "\n".join(combined_facts) + "\n\n" + snapshot_context
context = _join_context([kb_summary, _format_runbooks(runbooks), snapshot_context, history_ctx if classify.get("follow_up") else ""])
if plan.use_tool and classify.get("needs_tool"):
if observer:
observer("tool", "suggesting tools")
tool_prompt = prompts.TOOL_PROMPT + "\nQuestion: " + normalized
tool_raw = await call_llm(prompts.TOOL_SYSTEM, tool_prompt, context=context, model=plan.fast_model, tag="tool")
tool_hint = _parse_json_block(tool_raw, fallback={})
if observer:
observer("subanswers", "drafting subanswers")
async def _subanswer_for(subq: str) -> str:
sub_prompt = prompts.SUBANSWER_PROMPT + "\nQuestion: " + subq
if plan.subanswer_retries > 1:
candidates = await _gather_limited(
[call_llm(prompts.ANSWER_SYSTEM, sub_prompt, context=context, model=plan.model, tag="subanswer") for _ in range(plan.subanswer_retries)],
plan.parallelism,
)
best_idx = await _select_best_candidate(call_llm, subq, candidates, plan, "subanswer_select")
return candidates[best_idx]
return await call_llm(prompts.ANSWER_SYSTEM, sub_prompt, context=context, model=plan.model, tag="subanswer")
subanswers: list[str] = []
if plan.parallelism > 1 and len(sub_questions) > 1:
subanswers = await _gather_limited([_subanswer_for(subq) for subq in sub_questions], plan.parallelism)
else:
for subq in sub_questions:
subanswers.append(await _subanswer_for(subq))
if observer:
observer("synthesize", "synthesizing")
reply, scores, claims = await finalize_answer(
engine=engine,
call_llm=call_llm,
normalized=normalized,
subanswers=subanswers,
context=context,
classify=classify,
plan=plan,
summary=summary,
summary_lines=summary_lines,
metric_facts=metric_facts,
key_facts=key_facts,
facts_used=facts_used,
allowed_nodes=allowed_nodes,
allowed_namespaces=allowed_namespaces,
runbook_paths=runbook_paths,
lowered_question=lowered_question,
force_metric=force_metric,
keyword_tokens=keyword_tokens,
question_tokens=question_tokens,
snapshot_context=snapshot_context,
observer=observer,
mode=mode,
metric_keys=metric_keys if 'metric_keys' in locals() else None,
)
except LLMTimeBudgetExceeded:
time_budget_hit = True
if not reply:
budget = max(1, round(time_budget_sec)) if time_budget_sec > 0 else 0
budget_text = f"{budget}s" if budget else "its configured"
if mode in {"quick", "fast"}:
reply = f"Quick mode hit {budget_text} time budget before finishing. Try atlas-smart for a deeper answer."
elif mode == "smart":
reply = f"Smart mode hit {budget_text} time budget before finishing. Try atlas-genius or ask a narrower follow-up."
else:
reply = "I ran out of time before I could finish this answer."
scores = _default_scores()
except LLMLimitReached:
if not reply:
reply = "I started working on this but hit my reasoning limit. Ask again with 'Run limitless' for a deeper pass."
scores = _default_scores()
finally:
elapsed = round(time.monotonic() - started, 2)
log.info(
"atlasbot_answer",
extra={
"extra": {
"mode": mode,
"seconds": elapsed,
"llm_calls": call_count,
"limit": call_cap,
"limit_hit": limit_hit,
"time_budget_sec": time_budget_sec,
"time_budget_hit": time_budget_hit,
}
},
)
if limit_hit and "run limitless" not in reply.lower():
reply = reply.rstrip() + "\n\nNote: I hit my reasoning limit. Ask again with 'Run limitless' for a deeper pass."
if conversation_id and claims:
engine._store_state(conversation_id, claims, summary, snapshot_used, pin_snapshot)
return AnswerResult(
reply,
scores,
_build_meta(mode, call_count, call_cap, limit_hit, time_budget_hit, time_budget_sec, classify, tool_hint, started),
)