atlasbot: harden retrieval and evidence fixes

This commit is contained in:
Brad Stein 2026-02-01 01:12:43 -03:00
parent 81e2c65a21
commit b3e53a7fd7
2 changed files with 148 additions and 2 deletions

View File

@ -151,9 +151,12 @@ class AnswerEngine:
if self._settings.snapshot_pin_enabled and state and state.snapshot:
snapshot_used = state.snapshot
summary = build_summary(snapshot_used)
allowed_nodes = _allowed_nodes(summary)
allowed_namespaces = _allowed_namespaces(summary)
summary_lines = _summary_lines(snapshot_used)
kb_summary = self._kb.summary()
runbooks = self._kb.runbook_titles(limit=6)
runbook_paths = self._kb.runbook_paths(limit=10)
history_ctx = _format_history(history)
lexicon_ctx = _lexicon_context(summary)
@ -178,6 +181,7 @@ class AnswerEngine:
normalized = str(normalize.get("normalized") or question).strip() or question
keywords = normalize.get("keywords") or []
_debug_log("normalize_parsed", {"normalized": normalized, "keywords": keywords})
keyword_tokens = _extract_keywords(normalized, sub_questions=[], keywords=keywords)
if observer:
observer("route", "routing")
@ -204,6 +208,18 @@ class AnswerEngine:
"workload",
"k8s",
"kubernetes",
"postgres",
"database",
"db",
"connections",
"cpu",
"ram",
"memory",
"network",
"io",
"disk",
"pvc",
"storage",
)
if any(term in normalized.lower() for term in cluster_terms):
classify["needs_snapshot"] = True
@ -229,6 +245,7 @@ class AnswerEngine:
parts = _parse_json_list(decompose_raw)
sub_questions = _select_subquestions(parts, normalized, plan.max_subquestions)
_debug_log("decompose_parsed", {"sub_questions": sub_questions})
keyword_tokens = _extract_keywords(normalized, sub_questions=sub_questions, keywords=keywords)
snapshot_context = ""
if classify.get("needs_snapshot"):
@ -236,7 +253,7 @@ class AnswerEngine:
observer("retrieve", "scoring chunks")
chunks = _chunk_lines(summary_lines, plan.chunk_lines)
scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan)
selected = _select_chunks(chunks, scored, plan)
selected = _select_chunks(chunks, scored, plan, keyword_tokens)
if self._settings.debug_pipeline:
scored_preview = sorted(
[{"id": c["id"], "score": scored.get(c["id"], 0.0), "summary": c["summary"]} for c in chunks],
@ -275,15 +292,30 @@ class AnswerEngine:
observer("synthesize", "synthesizing")
reply = await self._synthesize_answer(normalized, subanswers, context, classify, plan, call_llm)
if snapshot_context and _needs_evidence_fix(reply, classify):
unknown_nodes = _find_unknown_nodes(reply, allowed_nodes)
unknown_namespaces = _find_unknown_namespaces(reply, allowed_namespaces)
runbook_fix = _needs_runbook_fix(reply, runbook_paths)
if snapshot_context and (_needs_evidence_fix(reply, classify) or unknown_nodes or unknown_namespaces or runbook_fix):
if observer:
observer("evidence_fix", "repairing missing evidence")
extra_bits = []
if unknown_nodes:
extra_bits.append("UnknownNodes: " + ", ".join(sorted(unknown_nodes)))
if unknown_namespaces:
extra_bits.append("UnknownNamespaces: " + ", ".join(sorted(unknown_namespaces)))
if runbook_paths:
extra_bits.append("AllowedRunbooks: " + ", ".join(runbook_paths))
if allowed_nodes:
extra_bits.append("AllowedNodes: " + ", ".join(allowed_nodes))
if allowed_namespaces:
extra_bits.append("AllowedNamespaces: " + ", ".join(allowed_namespaces))
fix_prompt = (
prompts.EVIDENCE_FIX_PROMPT
+ "\nQuestion: "
+ normalized
+ "\nDraft: "
+ reply
+ ("\n" + "\n".join(extra_bits) if extra_bits else "")
)
reply = await call_llm(
prompts.EVIDENCE_FIX_SYSTEM,
@ -667,6 +699,7 @@ def _select_chunks(
chunks: list[dict[str, Any]],
scores: dict[str, float],
plan: ModePlan,
keywords: list[str] | None = None,
) -> list[dict[str, Any]]:
if not chunks:
return []
@ -674,12 +707,25 @@ def _select_chunks(
selected: list[dict[str, Any]] = []
head = chunks[0]
selected.append(head)
keyword_hits: list[dict[str, Any]] = []
if keywords:
lowered = [kw.lower() for kw in keywords if kw]
for item in ranked:
text = item.get("text", "").lower()
if any(kw in text for kw in lowered):
keyword_hits.append(item)
for item in ranked:
if len(selected) >= plan.chunk_top:
break
if item is head:
continue
selected.append(item)
for item in keyword_hits:
if len(selected) >= plan.chunk_top:
break
if item in selected:
continue
selected.append(item)
return selected
@ -808,6 +854,93 @@ def _needs_evidence_fix(reply: str, classify: dict[str, Any]) -> bool:
return False
def _extract_keywords(normalized: str, sub_questions: list[str], keywords: list[Any] | None) -> list[str]:
stopwords = {
"the",
"and",
"for",
"with",
"that",
"this",
"what",
"which",
"when",
"where",
"who",
"why",
"how",
"tell",
"show",
"list",
"give",
"about",
"right",
"now",
}
tokens: list[str] = []
for source in [normalized, *sub_questions]:
for part in re.split(r"[^a-zA-Z0-9_-]+", source.lower()):
if len(part) < 3 or part in stopwords:
continue
tokens.append(part)
if keywords:
for kw in keywords:
if isinstance(kw, str):
part = kw.strip().lower()
if part and part not in stopwords and part not in tokens:
tokens.append(part)
return list(dict.fromkeys(tokens))[:12]
def _allowed_nodes(summary: dict[str, Any]) -> list[str]:
hardware = summary.get("hardware_by_node") if isinstance(summary.get("hardware_by_node"), dict) else {}
if hardware:
return sorted([node for node in hardware.keys() if isinstance(node, str)])
return []
def _allowed_namespaces(summary: dict[str, Any]) -> list[str]:
namespaces: list[str] = []
for entry in summary.get("namespace_pods") or []:
if isinstance(entry, dict):
name = entry.get("namespace")
if name:
namespaces.append(str(name))
return sorted(set(namespaces))
def _find_unknown_nodes(reply: str, allowed: list[str]) -> list[str]:
if not reply or not allowed:
return []
pattern = re.compile(r"\b(titan-[0-9a-z]+|node\d+)\b", re.IGNORECASE)
found = {m.group(1) for m in pattern.finditer(reply)}
if not found:
return []
allowed_set = {a.lower() for a in allowed}
return sorted({item for item in found if item.lower() not in allowed_set})
def _find_unknown_namespaces(reply: str, allowed: list[str]) -> list[str]:
if not reply or not allowed:
return []
pattern = re.compile(r"\bnamespace\s+([a-z0-9-]+)\b", re.IGNORECASE)
found = {m.group(1) for m in pattern.finditer(reply)}
if not found:
return []
allowed_set = {a.lower() for a in allowed}
return sorted({item for item in found if item.lower() not in allowed_set})
def _needs_runbook_fix(reply: str, allowed: list[str]) -> bool:
if not reply or not allowed:
return False
paths = set(re.findall(r"runbooks/[A-Za-z0-9._-]+", reply))
if not paths:
return False
allowed_set = {p.lower() for p in allowed}
return any(path.lower() not in allowed_set for path in paths)
def _resolve_path(data: Any, path: str) -> Any | None:
cursor = data
for part in re.split(r"\.(?![^\[]*\])", path):

View File

@ -58,3 +58,16 @@ class KnowledgeBase:
if not titles:
return ""
return "Relevant runbooks:\n" + "\n".join(titles[:limit])
def runbook_paths(self, *, limit: int = 10) -> list[str]:
self.load()
if not self._runbooks:
return []
paths: list[str] = []
for entry in self._runbooks:
if not isinstance(entry, dict):
continue
path = entry.get("path")
if path:
paths.append(str(path))
return paths[:limit]