From b3e53a7fd72f956ee4f588b75147dad2ce413316 Mon Sep 17 00:00:00 2001 From: Brad Stein Date: Sun, 1 Feb 2026 01:12:43 -0300 Subject: [PATCH] atlasbot: harden retrieval and evidence fixes --- atlasbot/engine/answerer.py | 137 ++++++++++++++++++++++++++++++++++- atlasbot/knowledge/loader.py | 13 ++++ 2 files changed, 148 insertions(+), 2 deletions(-) diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index 2fe728b..f431112 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -151,9 +151,12 @@ class AnswerEngine: if self._settings.snapshot_pin_enabled and state and state.snapshot: snapshot_used = state.snapshot summary = build_summary(snapshot_used) + allowed_nodes = _allowed_nodes(summary) + allowed_namespaces = _allowed_namespaces(summary) summary_lines = _summary_lines(snapshot_used) kb_summary = self._kb.summary() runbooks = self._kb.runbook_titles(limit=6) + runbook_paths = self._kb.runbook_paths(limit=10) history_ctx = _format_history(history) lexicon_ctx = _lexicon_context(summary) @@ -178,6 +181,7 @@ class AnswerEngine: normalized = str(normalize.get("normalized") or question).strip() or question keywords = normalize.get("keywords") or [] _debug_log("normalize_parsed", {"normalized": normalized, "keywords": keywords}) + keyword_tokens = _extract_keywords(normalized, sub_questions=[], keywords=keywords) if observer: observer("route", "routing") @@ -204,6 +208,18 @@ class AnswerEngine: "workload", "k8s", "kubernetes", + "postgres", + "database", + "db", + "connections", + "cpu", + "ram", + "memory", + "network", + "io", + "disk", + "pvc", + "storage", ) if any(term in normalized.lower() for term in cluster_terms): classify["needs_snapshot"] = True @@ -229,6 +245,7 @@ class AnswerEngine: parts = _parse_json_list(decompose_raw) sub_questions = _select_subquestions(parts, normalized, plan.max_subquestions) _debug_log("decompose_parsed", {"sub_questions": sub_questions}) + keyword_tokens = _extract_keywords(normalized, sub_questions=sub_questions, keywords=keywords) snapshot_context = "" if classify.get("needs_snapshot"): @@ -236,7 +253,7 @@ class AnswerEngine: observer("retrieve", "scoring chunks") chunks = _chunk_lines(summary_lines, plan.chunk_lines) scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan) - selected = _select_chunks(chunks, scored, plan) + selected = _select_chunks(chunks, scored, plan, keyword_tokens) if self._settings.debug_pipeline: scored_preview = sorted( [{"id": c["id"], "score": scored.get(c["id"], 0.0), "summary": c["summary"]} for c in chunks], @@ -275,15 +292,30 @@ class AnswerEngine: observer("synthesize", "synthesizing") reply = await self._synthesize_answer(normalized, subanswers, context, classify, plan, call_llm) - if snapshot_context and _needs_evidence_fix(reply, classify): + unknown_nodes = _find_unknown_nodes(reply, allowed_nodes) + unknown_namespaces = _find_unknown_namespaces(reply, allowed_namespaces) + runbook_fix = _needs_runbook_fix(reply, runbook_paths) + if snapshot_context and (_needs_evidence_fix(reply, classify) or unknown_nodes or unknown_namespaces or runbook_fix): if observer: observer("evidence_fix", "repairing missing evidence") + extra_bits = [] + if unknown_nodes: + extra_bits.append("UnknownNodes: " + ", ".join(sorted(unknown_nodes))) + if unknown_namespaces: + extra_bits.append("UnknownNamespaces: " + ", ".join(sorted(unknown_namespaces))) + if runbook_paths: + extra_bits.append("AllowedRunbooks: " + ", ".join(runbook_paths)) + if allowed_nodes: + extra_bits.append("AllowedNodes: " + ", ".join(allowed_nodes)) + if allowed_namespaces: + extra_bits.append("AllowedNamespaces: " + ", ".join(allowed_namespaces)) fix_prompt = ( prompts.EVIDENCE_FIX_PROMPT + "\nQuestion: " + normalized + "\nDraft: " + reply + + ("\n" + "\n".join(extra_bits) if extra_bits else "") ) reply = await call_llm( prompts.EVIDENCE_FIX_SYSTEM, @@ -667,6 +699,7 @@ def _select_chunks( chunks: list[dict[str, Any]], scores: dict[str, float], plan: ModePlan, + keywords: list[str] | None = None, ) -> list[dict[str, Any]]: if not chunks: return [] @@ -674,12 +707,25 @@ def _select_chunks( selected: list[dict[str, Any]] = [] head = chunks[0] selected.append(head) + keyword_hits: list[dict[str, Any]] = [] + if keywords: + lowered = [kw.lower() for kw in keywords if kw] + for item in ranked: + text = item.get("text", "").lower() + if any(kw in text for kw in lowered): + keyword_hits.append(item) for item in ranked: if len(selected) >= plan.chunk_top: break if item is head: continue selected.append(item) + for item in keyword_hits: + if len(selected) >= plan.chunk_top: + break + if item in selected: + continue + selected.append(item) return selected @@ -808,6 +854,93 @@ def _needs_evidence_fix(reply: str, classify: dict[str, Any]) -> bool: return False +def _extract_keywords(normalized: str, sub_questions: list[str], keywords: list[Any] | None) -> list[str]: + stopwords = { + "the", + "and", + "for", + "with", + "that", + "this", + "what", + "which", + "when", + "where", + "who", + "why", + "how", + "tell", + "show", + "list", + "give", + "about", + "right", + "now", + } + tokens: list[str] = [] + for source in [normalized, *sub_questions]: + for part in re.split(r"[^a-zA-Z0-9_-]+", source.lower()): + if len(part) < 3 or part in stopwords: + continue + tokens.append(part) + if keywords: + for kw in keywords: + if isinstance(kw, str): + part = kw.strip().lower() + if part and part not in stopwords and part not in tokens: + tokens.append(part) + return list(dict.fromkeys(tokens))[:12] + + +def _allowed_nodes(summary: dict[str, Any]) -> list[str]: + hardware = summary.get("hardware_by_node") if isinstance(summary.get("hardware_by_node"), dict) else {} + if hardware: + return sorted([node for node in hardware.keys() if isinstance(node, str)]) + return [] + + +def _allowed_namespaces(summary: dict[str, Any]) -> list[str]: + namespaces: list[str] = [] + for entry in summary.get("namespace_pods") or []: + if isinstance(entry, dict): + name = entry.get("namespace") + if name: + namespaces.append(str(name)) + return sorted(set(namespaces)) + + +def _find_unknown_nodes(reply: str, allowed: list[str]) -> list[str]: + if not reply or not allowed: + return [] + pattern = re.compile(r"\b(titan-[0-9a-z]+|node\d+)\b", re.IGNORECASE) + found = {m.group(1) for m in pattern.finditer(reply)} + if not found: + return [] + allowed_set = {a.lower() for a in allowed} + return sorted({item for item in found if item.lower() not in allowed_set}) + + +def _find_unknown_namespaces(reply: str, allowed: list[str]) -> list[str]: + if not reply or not allowed: + return [] + pattern = re.compile(r"\bnamespace\s+([a-z0-9-]+)\b", re.IGNORECASE) + found = {m.group(1) for m in pattern.finditer(reply)} + if not found: + return [] + allowed_set = {a.lower() for a in allowed} + return sorted({item for item in found if item.lower() not in allowed_set}) + + +def _needs_runbook_fix(reply: str, allowed: list[str]) -> bool: + if not reply or not allowed: + return False + paths = set(re.findall(r"runbooks/[A-Za-z0-9._-]+", reply)) + if not paths: + return False + allowed_set = {p.lower() for p in allowed} + return any(path.lower() not in allowed_set for path in paths) + + def _resolve_path(data: Any, path: str) -> Any | None: cursor = data for part in re.split(r"\.(?![^\[]*\])", path): diff --git a/atlasbot/knowledge/loader.py b/atlasbot/knowledge/loader.py index 8585183..a6d544b 100644 --- a/atlasbot/knowledge/loader.py +++ b/atlasbot/knowledge/loader.py @@ -58,3 +58,16 @@ class KnowledgeBase: if not titles: return "" return "Relevant runbooks:\n" + "\n".join(titles[:limit]) + + def runbook_paths(self, *, limit: int = 10) -> list[str]: + self.load() + if not self._runbooks: + return [] + paths: list[str] = [] + for entry in self._runbooks: + if not isinstance(entry, dict): + continue + path = entry.get("path") + if path: + paths.append(str(path)) + return paths[:limit]