From bb513214043577830470a32503e71a2976a71cc2 Mon Sep 17 00:00:00 2001 From: Brad Stein Date: Sun, 1 Feb 2026 11:47:21 -0300 Subject: [PATCH] atlasbot: reinforce metric facts --- atlasbot/engine/answerer.py | 87 +++++++++++++++++++++++++++++++++++++ atlasbot/llm/prompts.py | 11 +++++ 2 files changed, 98 insertions(+) diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index ce366d5..263805f 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -117,6 +117,7 @@ class AnswerEngine: "route", "decompose", "chunk_score", + "fact_select", "synth", "subanswer", "tool", @@ -261,6 +262,11 @@ class AnswerEngine: selected = _select_chunks(chunks, scored, plan, keyword_tokens) key_facts = _key_fact_lines(summary_lines, keyword_tokens) metric_facts = [line for line in key_facts if re.search(r"\d", line)] + if classify.get("question_type") in {"metric", "diagnostic"} and not metric_facts: + metric_candidates = _metric_candidate_lines(summary_lines, keyword_tokens) + metric_facts = await _select_metric_facts(call_llm, normalized, metric_candidates, plan) + if metric_facts: + key_facts = _merge_fact_lines(metric_facts, key_facts) if self._settings.debug_pipeline: scored_preview = sorted( [{"id": c["id"], "score": scored.get(c["id"], 0.0), "summary": c["summary"]} for c in chunks], @@ -394,6 +400,7 @@ class AnswerEngine: model=plan.model, tag="runbook_enforce", ) + reply = _strip_unknown_entities(reply, unknown_nodes, unknown_namespaces) if _needs_focus_fix(normalized, reply, classify): if observer: @@ -960,6 +967,64 @@ def _key_fact_lines(lines: list[str], keywords: list[str] | None, limit: int = 6 return matches +def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]: + seen = set() + merged: list[str] = [] + for line in primary + fallback: + if line in seen: + continue + seen.add(line) + merged.append(line) + return merged + + +def _metric_candidate_lines(lines: list[str], keywords: list[str] | None, limit: int = 40) -> list[str]: + if not lines: + return [] + lowered = [kw.lower() for kw in (keywords or []) if kw] + candidates: list[str] = [] + for line in lines: + line_lower = line.lower() + if lowered and any(kw in line_lower for kw in lowered): + candidates.append(line) + elif re.search(r"\d", line): + candidates.append(line) + if len(candidates) >= limit: + break + return candidates + + +async def _select_metric_facts( + call_llm: Callable[..., Any], + question: str, + candidates: list[str], + plan: ModePlan, + max_lines: int = 2, +) -> list[str]: + if not candidates: + return [] + prompt = ( + prompts.FACT_SELECT_PROMPT.format(max_lines=max_lines) + + "\nQuestion: " + + question + + "\nCandidates:\n" + + "\n".join([f"- {line}" for line in candidates]) + ) + raw = await call_llm(prompts.FACT_SELECT_SYSTEM, prompt, model=plan.fast_model, tag="fact_select") + data = _parse_json_block(raw, fallback={}) + lines = data.get("lines") if isinstance(data, dict) else None + if not isinstance(lines, list): + return [] + cleaned = [] + allowed = set(candidates) + for line in lines: + if isinstance(line, str) and line in allowed and line not in cleaned: + cleaned.append(line) + if len(cleaned) >= max_lines: + break + return cleaned + + def _metric_fact_guard(reply: str, metric_facts: list[str], keywords: list[str]) -> str: if not metric_facts: return reply @@ -978,6 +1043,28 @@ def _metric_fact_guard(reply: str, metric_facts: list[str], keywords: list[str]) return reply +def _strip_unknown_entities(reply: str, unknown_nodes: list[str], unknown_namespaces: list[str]) -> str: + if not reply: + return reply + if not unknown_nodes and not unknown_namespaces: + return reply + sentences = [s.strip() for s in re.split(r"(?<=[.!?])\\s+", reply) if s.strip()] + if not sentences: + return reply + lowered_nodes = [node.lower() for node in unknown_nodes] + lowered_namespaces = [ns.lower() for ns in unknown_namespaces] + kept: list[str] = [] + for sent in sentences: + lower = sent.lower() + if lowered_nodes and any(node in lower for node in lowered_nodes): + continue + if lowered_namespaces and any(f"namespace {ns}" in lower for ns in lowered_namespaces): + continue + kept.append(sent) + cleaned = " ".join(kept).strip() + return cleaned or reply + + def _lexicon_context(summary: dict[str, Any]) -> str: if not isinstance(summary, dict): return "" diff --git a/atlasbot/llm/prompts.py b/atlasbot/llm/prompts.py index 37691ea..5aebf5c 100644 --- a/atlasbot/llm/prompts.py +++ b/atlasbot/llm/prompts.py @@ -207,6 +207,17 @@ DEDUP_PROMPT = ( "Return only the cleaned answer." ) +FACT_SELECT_SYSTEM = ( + CLUSTER_SYSTEM + + " Select the most relevant fact lines for the question. " + + "Return JSON only." +) + +FACT_SELECT_PROMPT = ( + "Pick up to {max_lines} lines from Candidates that best answer the question. " + "Return JSON with field: lines (list of strings). If none apply, return {\"lines\": []}." +) + SELECT_CLAIMS_PROMPT = ( "Select relevant claim ids for the follow-up. " "Return JSON with field: claim_ids (list)."