atlasbot: reinforce metric facts

This commit is contained in:
Brad Stein 2026-02-01 11:47:21 -03:00
parent 889e814b59
commit bb51321404
2 changed files with 98 additions and 0 deletions

View File

@ -117,6 +117,7 @@ class AnswerEngine:
"route", "route",
"decompose", "decompose",
"chunk_score", "chunk_score",
"fact_select",
"synth", "synth",
"subanswer", "subanswer",
"tool", "tool",
@ -261,6 +262,11 @@ class AnswerEngine:
selected = _select_chunks(chunks, scored, plan, keyword_tokens) selected = _select_chunks(chunks, scored, plan, keyword_tokens)
key_facts = _key_fact_lines(summary_lines, keyword_tokens) key_facts = _key_fact_lines(summary_lines, keyword_tokens)
metric_facts = [line for line in key_facts if re.search(r"\d", line)] metric_facts = [line for line in key_facts if re.search(r"\d", line)]
if classify.get("question_type") in {"metric", "diagnostic"} and not metric_facts:
metric_candidates = _metric_candidate_lines(summary_lines, keyword_tokens)
metric_facts = await _select_metric_facts(call_llm, normalized, metric_candidates, plan)
if metric_facts:
key_facts = _merge_fact_lines(metric_facts, key_facts)
if self._settings.debug_pipeline: if self._settings.debug_pipeline:
scored_preview = sorted( scored_preview = sorted(
[{"id": c["id"], "score": scored.get(c["id"], 0.0), "summary": c["summary"]} for c in chunks], [{"id": c["id"], "score": scored.get(c["id"], 0.0), "summary": c["summary"]} for c in chunks],
@ -394,6 +400,7 @@ class AnswerEngine:
model=plan.model, model=plan.model,
tag="runbook_enforce", tag="runbook_enforce",
) )
reply = _strip_unknown_entities(reply, unknown_nodes, unknown_namespaces)
if _needs_focus_fix(normalized, reply, classify): if _needs_focus_fix(normalized, reply, classify):
if observer: if observer:
@ -960,6 +967,64 @@ def _key_fact_lines(lines: list[str], keywords: list[str] | None, limit: int = 6
return matches return matches
def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]:
seen = set()
merged: list[str] = []
for line in primary + fallback:
if line in seen:
continue
seen.add(line)
merged.append(line)
return merged
def _metric_candidate_lines(lines: list[str], keywords: list[str] | None, limit: int = 40) -> list[str]:
if not lines:
return []
lowered = [kw.lower() for kw in (keywords or []) if kw]
candidates: list[str] = []
for line in lines:
line_lower = line.lower()
if lowered and any(kw in line_lower for kw in lowered):
candidates.append(line)
elif re.search(r"\d", line):
candidates.append(line)
if len(candidates) >= limit:
break
return candidates
async def _select_metric_facts(
call_llm: Callable[..., Any],
question: str,
candidates: list[str],
plan: ModePlan,
max_lines: int = 2,
) -> list[str]:
if not candidates:
return []
prompt = (
prompts.FACT_SELECT_PROMPT.format(max_lines=max_lines)
+ "\nQuestion: "
+ question
+ "\nCandidates:\n"
+ "\n".join([f"- {line}" for line in candidates])
)
raw = await call_llm(prompts.FACT_SELECT_SYSTEM, prompt, model=plan.fast_model, tag="fact_select")
data = _parse_json_block(raw, fallback={})
lines = data.get("lines") if isinstance(data, dict) else None
if not isinstance(lines, list):
return []
cleaned = []
allowed = set(candidates)
for line in lines:
if isinstance(line, str) and line in allowed and line not in cleaned:
cleaned.append(line)
if len(cleaned) >= max_lines:
break
return cleaned
def _metric_fact_guard(reply: str, metric_facts: list[str], keywords: list[str]) -> str: def _metric_fact_guard(reply: str, metric_facts: list[str], keywords: list[str]) -> str:
if not metric_facts: if not metric_facts:
return reply return reply
@ -978,6 +1043,28 @@ def _metric_fact_guard(reply: str, metric_facts: list[str], keywords: list[str])
return reply return reply
def _strip_unknown_entities(reply: str, unknown_nodes: list[str], unknown_namespaces: list[str]) -> str:
if not reply:
return reply
if not unknown_nodes and not unknown_namespaces:
return reply
sentences = [s.strip() for s in re.split(r"(?<=[.!?])\\s+", reply) if s.strip()]
if not sentences:
return reply
lowered_nodes = [node.lower() for node in unknown_nodes]
lowered_namespaces = [ns.lower() for ns in unknown_namespaces]
kept: list[str] = []
for sent in sentences:
lower = sent.lower()
if lowered_nodes and any(node in lower for node in lowered_nodes):
continue
if lowered_namespaces and any(f"namespace {ns}" in lower for ns in lowered_namespaces):
continue
kept.append(sent)
cleaned = " ".join(kept).strip()
return cleaned or reply
def _lexicon_context(summary: dict[str, Any]) -> str: def _lexicon_context(summary: dict[str, Any]) -> str:
if not isinstance(summary, dict): if not isinstance(summary, dict):
return "" return ""

View File

@ -207,6 +207,17 @@ DEDUP_PROMPT = (
"Return only the cleaned answer." "Return only the cleaned answer."
) )
FACT_SELECT_SYSTEM = (
CLUSTER_SYSTEM
+ " Select the most relevant fact lines for the question. "
+ "Return JSON only."
)
FACT_SELECT_PROMPT = (
"Pick up to {max_lines} lines from Candidates that best answer the question. "
"Return JSON with field: lines (list of strings). If none apply, return {\"lines\": []}."
)
SELECT_CLAIMS_PROMPT = ( SELECT_CLAIMS_PROMPT = (
"Select relevant claim ids for the follow-up. " "Select relevant claim ids for the follow-up. "
"Return JSON with field: claim_ids (list)." "Return JSON with field: claim_ids (list)."