atlasbot: reinforce metric facts
This commit is contained in:
parent
889e814b59
commit
bb51321404
@ -117,6 +117,7 @@ class AnswerEngine:
|
||||
"route",
|
||||
"decompose",
|
||||
"chunk_score",
|
||||
"fact_select",
|
||||
"synth",
|
||||
"subanswer",
|
||||
"tool",
|
||||
@ -261,6 +262,11 @@ class AnswerEngine:
|
||||
selected = _select_chunks(chunks, scored, plan, keyword_tokens)
|
||||
key_facts = _key_fact_lines(summary_lines, keyword_tokens)
|
||||
metric_facts = [line for line in key_facts if re.search(r"\d", line)]
|
||||
if classify.get("question_type") in {"metric", "diagnostic"} and not metric_facts:
|
||||
metric_candidates = _metric_candidate_lines(summary_lines, keyword_tokens)
|
||||
metric_facts = await _select_metric_facts(call_llm, normalized, metric_candidates, plan)
|
||||
if metric_facts:
|
||||
key_facts = _merge_fact_lines(metric_facts, key_facts)
|
||||
if self._settings.debug_pipeline:
|
||||
scored_preview = sorted(
|
||||
[{"id": c["id"], "score": scored.get(c["id"], 0.0), "summary": c["summary"]} for c in chunks],
|
||||
@ -394,6 +400,7 @@ class AnswerEngine:
|
||||
model=plan.model,
|
||||
tag="runbook_enforce",
|
||||
)
|
||||
reply = _strip_unknown_entities(reply, unknown_nodes, unknown_namespaces)
|
||||
|
||||
if _needs_focus_fix(normalized, reply, classify):
|
||||
if observer:
|
||||
@ -960,6 +967,64 @@ def _key_fact_lines(lines: list[str], keywords: list[str] | None, limit: int = 6
|
||||
return matches
|
||||
|
||||
|
||||
def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]:
|
||||
seen = set()
|
||||
merged: list[str] = []
|
||||
for line in primary + fallback:
|
||||
if line in seen:
|
||||
continue
|
||||
seen.add(line)
|
||||
merged.append(line)
|
||||
return merged
|
||||
|
||||
|
||||
def _metric_candidate_lines(lines: list[str], keywords: list[str] | None, limit: int = 40) -> list[str]:
|
||||
if not lines:
|
||||
return []
|
||||
lowered = [kw.lower() for kw in (keywords or []) if kw]
|
||||
candidates: list[str] = []
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
if lowered and any(kw in line_lower for kw in lowered):
|
||||
candidates.append(line)
|
||||
elif re.search(r"\d", line):
|
||||
candidates.append(line)
|
||||
if len(candidates) >= limit:
|
||||
break
|
||||
return candidates
|
||||
|
||||
|
||||
async def _select_metric_facts(
|
||||
call_llm: Callable[..., Any],
|
||||
question: str,
|
||||
candidates: list[str],
|
||||
plan: ModePlan,
|
||||
max_lines: int = 2,
|
||||
) -> list[str]:
|
||||
if not candidates:
|
||||
return []
|
||||
prompt = (
|
||||
prompts.FACT_SELECT_PROMPT.format(max_lines=max_lines)
|
||||
+ "\nQuestion: "
|
||||
+ question
|
||||
+ "\nCandidates:\n"
|
||||
+ "\n".join([f"- {line}" for line in candidates])
|
||||
)
|
||||
raw = await call_llm(prompts.FACT_SELECT_SYSTEM, prompt, model=plan.fast_model, tag="fact_select")
|
||||
data = _parse_json_block(raw, fallback={})
|
||||
lines = data.get("lines") if isinstance(data, dict) else None
|
||||
if not isinstance(lines, list):
|
||||
return []
|
||||
cleaned = []
|
||||
allowed = set(candidates)
|
||||
for line in lines:
|
||||
if isinstance(line, str) and line in allowed and line not in cleaned:
|
||||
cleaned.append(line)
|
||||
if len(cleaned) >= max_lines:
|
||||
break
|
||||
return cleaned
|
||||
|
||||
|
||||
def _metric_fact_guard(reply: str, metric_facts: list[str], keywords: list[str]) -> str:
|
||||
if not metric_facts:
|
||||
return reply
|
||||
@ -978,6 +1043,28 @@ def _metric_fact_guard(reply: str, metric_facts: list[str], keywords: list[str])
|
||||
return reply
|
||||
|
||||
|
||||
def _strip_unknown_entities(reply: str, unknown_nodes: list[str], unknown_namespaces: list[str]) -> str:
|
||||
if not reply:
|
||||
return reply
|
||||
if not unknown_nodes and not unknown_namespaces:
|
||||
return reply
|
||||
sentences = [s.strip() for s in re.split(r"(?<=[.!?])\\s+", reply) if s.strip()]
|
||||
if not sentences:
|
||||
return reply
|
||||
lowered_nodes = [node.lower() for node in unknown_nodes]
|
||||
lowered_namespaces = [ns.lower() for ns in unknown_namespaces]
|
||||
kept: list[str] = []
|
||||
for sent in sentences:
|
||||
lower = sent.lower()
|
||||
if lowered_nodes and any(node in lower for node in lowered_nodes):
|
||||
continue
|
||||
if lowered_namespaces and any(f"namespace {ns}" in lower for ns in lowered_namespaces):
|
||||
continue
|
||||
kept.append(sent)
|
||||
cleaned = " ".join(kept).strip()
|
||||
return cleaned or reply
|
||||
|
||||
|
||||
def _lexicon_context(summary: dict[str, Any]) -> str:
|
||||
if not isinstance(summary, dict):
|
||||
return ""
|
||||
|
||||
@ -207,6 +207,17 @@ DEDUP_PROMPT = (
|
||||
"Return only the cleaned answer."
|
||||
)
|
||||
|
||||
FACT_SELECT_SYSTEM = (
|
||||
CLUSTER_SYSTEM
|
||||
+ " Select the most relevant fact lines for the question. "
|
||||
+ "Return JSON only."
|
||||
)
|
||||
|
||||
FACT_SELECT_PROMPT = (
|
||||
"Pick up to {max_lines} lines from Candidates that best answer the question. "
|
||||
"Return JSON with field: lines (list of strings). If none apply, return {\"lines\": []}."
|
||||
)
|
||||
|
||||
SELECT_CLAIMS_PROMPT = (
|
||||
"Select relevant claim ids for the follow-up. "
|
||||
"Return JSON with field: claim_ids (list)."
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user