atlasbot: force metric key chunks via llm

This commit is contained in:
Brad Stein 2026-02-03 21:09:15 -03:00
parent c6a3eda478
commit ecb0776b3e
2 changed files with 124 additions and 1 deletions

View File

@ -312,8 +312,23 @@ class AnswerEngine:
if observer: if observer:
observer("retrieve", "scoring chunks") observer("retrieve", "scoring chunks")
chunks = _chunk_lines(summary_lines, plan.chunk_lines) chunks = _chunk_lines(summary_lines, plan.chunk_lines)
metric_keys: list[str] = []
must_chunk_ids: list[str] = []
if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines:
metric_ctx = {
"question": normalized,
"sub_questions": sub_questions,
"keywords": keywords,
"summary_lines": summary_lines,
}
metric_keys, must_chunk_ids = await _select_metric_chunks(
call_llm,
metric_ctx,
chunks,
plan,
)
scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan) scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan)
selected = _select_chunks(chunks, scored, plan, keyword_tokens) selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids)
fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12) fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12)
key_facts = await _select_fact_lines( key_facts = await _select_fact_lines(
call_llm, call_llm,
@ -422,6 +437,8 @@ class AnswerEngine:
{ {
"selected_ids": [item["id"] for item in selected], "selected_ids": [item["id"] for item in selected],
"top_scored": scored_preview, "top_scored": scored_preview,
"metric_keys": metric_keys,
"forced_chunks": must_chunk_ids,
}, },
) )
facts_used = list(dict.fromkeys(key_facts)) if key_facts else list(dict.fromkeys(metric_facts)) facts_used = list(dict.fromkeys(key_facts)) if key_facts else list(dict.fromkeys(metric_facts))
@ -1157,6 +1174,7 @@ def _select_chunks(
scores: dict[str, float], scores: dict[str, float],
plan: ModePlan, plan: ModePlan,
keywords: list[str] | None = None, keywords: list[str] | None = None,
must_ids: list[str] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
if not chunks: if not chunks:
return [] return []
@ -1164,6 +1182,14 @@ def _select_chunks(
selected: list[dict[str, Any]] = [] selected: list[dict[str, Any]] = []
head = chunks[0] head = chunks[0]
selected.append(head) selected.append(head)
id_map = {item["id"]: item for item in chunks}
if must_ids:
for cid in must_ids:
item = id_map.get(cid)
if item and item not in selected:
selected.append(item)
if len(selected) >= plan.chunk_top:
return selected
for item in _keyword_hits(ranked, head, keywords): for item in _keyword_hits(ranked, head, keywords):
if len(selected) >= plan.chunk_top: if len(selected) >= plan.chunk_top:
@ -1218,6 +1244,89 @@ def _summary_lines(snapshot: dict[str, Any] | None) -> list[str]:
return [line for line in text.splitlines() if line.strip()] return [line for line in text.splitlines() if line.strip()]
async def _select_metric_chunks(
call_llm: Callable[..., Awaitable[str]],
ctx: dict[str, Any],
chunks: list[dict[str, Any]],
plan: ModePlan,
) -> tuple[list[str], list[str]]:
summary_lines = ctx.get("summary_lines") if isinstance(ctx, dict) else None
if not isinstance(summary_lines, list):
return [], []
keys = _extract_metric_keys(summary_lines)
if not keys or not chunks:
return [], []
max_keys = max(4, plan.max_subquestions * 2)
prompt = prompts.METRIC_KEYS_PROMPT.format(available="\n".join(keys), max_keys=max_keys)
question = ctx.get("question") if isinstance(ctx, dict) else ""
sub_questions = ctx.get("sub_questions") if isinstance(ctx, dict) else []
keywords = ctx.get("keywords") if isinstance(ctx, dict) else []
raw = await call_llm(
prompts.METRIC_KEYS_SYSTEM,
prompt + "\nQuestion: " + str(question) + "\nSubQuestions:\n" + "\n".join([str(item) for item in sub_questions]),
context="Keywords:\n" + ", ".join([str(item) for item in keywords if item]),
model=plan.fast_model,
tag="metric_keys",
)
selected = _parse_key_list(raw, keys, max_keys)
if not selected:
return [], []
ids = _chunk_ids_for_keys(chunks, selected)
return selected, ids
def _extract_metric_keys(lines: list[str]) -> list[str]:
keys: list[str] = []
for line in lines:
if ":" not in line:
continue
key = line.split(":", 1)[0].strip()
if not key or " " in key:
continue
if key not in keys:
keys.append(key)
return keys
def _parse_key_list(raw: str, allowed: list[str], max_keys: int) -> list[str]:
parsed = _parse_json_block(raw, fallback={})
if isinstance(parsed, list):
items = parsed
else:
items = parsed.get("keys") if isinstance(parsed, dict) else []
if not isinstance(items, list):
return []
allowed_set = set(allowed)
out: list[str] = []
for item in items:
if not isinstance(item, str):
continue
if item in allowed_set and item not in out:
out.append(item)
if len(out) >= max_keys:
break
return out
def _chunk_ids_for_keys(chunks: list[dict[str, Any]], keys: list[str]) -> list[str]:
if not keys:
return []
ids: list[str] = []
key_set = {f"{key}:" for key in keys}
for chunk in chunks:
text = str(chunk.get("text") or "")
if not text:
continue
for line in text.splitlines():
for key in key_set:
if line.startswith(key):
cid = chunk.get("id")
if cid and cid not in ids:
ids.append(cid)
break
return ids
def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]: def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]:
seen = set() seen = set()
merged: list[str] = [] merged: list[str] = []

View File

@ -69,6 +69,20 @@ METRIC_PREFIX_PROMPT = (
"Only use values from AvailablePrefixes." "Only use values from AvailablePrefixes."
) )
METRIC_KEYS_SYSTEM = (
CLUSTER_SYSTEM
+ " Select the metric keys required to answer the question. "
+ "Return JSON only."
)
METRIC_KEYS_PROMPT = (
"AvailableKeys:\n{available}\n\n"
"Return JSON with field: keys (list). "
"Choose only keys needed to answer the question. "
"If none apply, return an empty list. "
"Limit to at most {max_keys} keys."
)
TOOL_SYSTEM = ( TOOL_SYSTEM = (
CLUSTER_SYSTEM CLUSTER_SYSTEM
+ " Suggest a safe, read-only command that could refine the answer. " + " Suggest a safe, read-only command that could refine the answer. "