atlasbot: force metric key chunks via llm
This commit is contained in:
parent
c6a3eda478
commit
ecb0776b3e
@ -312,8 +312,23 @@ class AnswerEngine:
|
||||
if observer:
|
||||
observer("retrieve", "scoring chunks")
|
||||
chunks = _chunk_lines(summary_lines, plan.chunk_lines)
|
||||
metric_keys: list[str] = []
|
||||
must_chunk_ids: list[str] = []
|
||||
if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines:
|
||||
metric_ctx = {
|
||||
"question": normalized,
|
||||
"sub_questions": sub_questions,
|
||||
"keywords": keywords,
|
||||
"summary_lines": summary_lines,
|
||||
}
|
||||
metric_keys, must_chunk_ids = await _select_metric_chunks(
|
||||
call_llm,
|
||||
metric_ctx,
|
||||
chunks,
|
||||
plan,
|
||||
)
|
||||
scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan)
|
||||
selected = _select_chunks(chunks, scored, plan, keyword_tokens)
|
||||
selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids)
|
||||
fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12)
|
||||
key_facts = await _select_fact_lines(
|
||||
call_llm,
|
||||
@ -422,6 +437,8 @@ class AnswerEngine:
|
||||
{
|
||||
"selected_ids": [item["id"] for item in selected],
|
||||
"top_scored": scored_preview,
|
||||
"metric_keys": metric_keys,
|
||||
"forced_chunks": must_chunk_ids,
|
||||
},
|
||||
)
|
||||
facts_used = list(dict.fromkeys(key_facts)) if key_facts else list(dict.fromkeys(metric_facts))
|
||||
@ -1157,6 +1174,7 @@ def _select_chunks(
|
||||
scores: dict[str, float],
|
||||
plan: ModePlan,
|
||||
keywords: list[str] | None = None,
|
||||
must_ids: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not chunks:
|
||||
return []
|
||||
@ -1164,6 +1182,14 @@ def _select_chunks(
|
||||
selected: list[dict[str, Any]] = []
|
||||
head = chunks[0]
|
||||
selected.append(head)
|
||||
id_map = {item["id"]: item for item in chunks}
|
||||
if must_ids:
|
||||
for cid in must_ids:
|
||||
item = id_map.get(cid)
|
||||
if item and item not in selected:
|
||||
selected.append(item)
|
||||
if len(selected) >= plan.chunk_top:
|
||||
return selected
|
||||
|
||||
for item in _keyword_hits(ranked, head, keywords):
|
||||
if len(selected) >= plan.chunk_top:
|
||||
@ -1218,6 +1244,89 @@ def _summary_lines(snapshot: dict[str, Any] | None) -> list[str]:
|
||||
return [line for line in text.splitlines() if line.strip()]
|
||||
|
||||
|
||||
async def _select_metric_chunks(
|
||||
call_llm: Callable[..., Awaitable[str]],
|
||||
ctx: dict[str, Any],
|
||||
chunks: list[dict[str, Any]],
|
||||
plan: ModePlan,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
summary_lines = ctx.get("summary_lines") if isinstance(ctx, dict) else None
|
||||
if not isinstance(summary_lines, list):
|
||||
return [], []
|
||||
keys = _extract_metric_keys(summary_lines)
|
||||
if not keys or not chunks:
|
||||
return [], []
|
||||
max_keys = max(4, plan.max_subquestions * 2)
|
||||
prompt = prompts.METRIC_KEYS_PROMPT.format(available="\n".join(keys), max_keys=max_keys)
|
||||
question = ctx.get("question") if isinstance(ctx, dict) else ""
|
||||
sub_questions = ctx.get("sub_questions") if isinstance(ctx, dict) else []
|
||||
keywords = ctx.get("keywords") if isinstance(ctx, dict) else []
|
||||
raw = await call_llm(
|
||||
prompts.METRIC_KEYS_SYSTEM,
|
||||
prompt + "\nQuestion: " + str(question) + "\nSubQuestions:\n" + "\n".join([str(item) for item in sub_questions]),
|
||||
context="Keywords:\n" + ", ".join([str(item) for item in keywords if item]),
|
||||
model=plan.fast_model,
|
||||
tag="metric_keys",
|
||||
)
|
||||
selected = _parse_key_list(raw, keys, max_keys)
|
||||
if not selected:
|
||||
return [], []
|
||||
ids = _chunk_ids_for_keys(chunks, selected)
|
||||
return selected, ids
|
||||
|
||||
|
||||
def _extract_metric_keys(lines: list[str]) -> list[str]:
|
||||
keys: list[str] = []
|
||||
for line in lines:
|
||||
if ":" not in line:
|
||||
continue
|
||||
key = line.split(":", 1)[0].strip()
|
||||
if not key or " " in key:
|
||||
continue
|
||||
if key not in keys:
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
|
||||
def _parse_key_list(raw: str, allowed: list[str], max_keys: int) -> list[str]:
|
||||
parsed = _parse_json_block(raw, fallback={})
|
||||
if isinstance(parsed, list):
|
||||
items = parsed
|
||||
else:
|
||||
items = parsed.get("keys") if isinstance(parsed, dict) else []
|
||||
if not isinstance(items, list):
|
||||
return []
|
||||
allowed_set = set(allowed)
|
||||
out: list[str] = []
|
||||
for item in items:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
if item in allowed_set and item not in out:
|
||||
out.append(item)
|
||||
if len(out) >= max_keys:
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def _chunk_ids_for_keys(chunks: list[dict[str, Any]], keys: list[str]) -> list[str]:
|
||||
if not keys:
|
||||
return []
|
||||
ids: list[str] = []
|
||||
key_set = {f"{key}:" for key in keys}
|
||||
for chunk in chunks:
|
||||
text = str(chunk.get("text") or "")
|
||||
if not text:
|
||||
continue
|
||||
for line in text.splitlines():
|
||||
for key in key_set:
|
||||
if line.startswith(key):
|
||||
cid = chunk.get("id")
|
||||
if cid and cid not in ids:
|
||||
ids.append(cid)
|
||||
break
|
||||
return ids
|
||||
|
||||
|
||||
def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]:
|
||||
seen = set()
|
||||
merged: list[str] = []
|
||||
|
||||
@ -69,6 +69,20 @@ METRIC_PREFIX_PROMPT = (
|
||||
"Only use values from AvailablePrefixes."
|
||||
)
|
||||
|
||||
METRIC_KEYS_SYSTEM = (
|
||||
CLUSTER_SYSTEM
|
||||
+ " Select the metric keys required to answer the question. "
|
||||
+ "Return JSON only."
|
||||
)
|
||||
|
||||
METRIC_KEYS_PROMPT = (
|
||||
"AvailableKeys:\n{available}\n\n"
|
||||
"Return JSON with field: keys (list). "
|
||||
"Choose only keys needed to answer the question. "
|
||||
"If none apply, return an empty list. "
|
||||
"Limit to at most {max_keys} keys."
|
||||
)
|
||||
|
||||
TOOL_SYSTEM = (
|
||||
CLUSTER_SYSTEM
|
||||
+ " Suggest a safe, read-only command that could refine the answer. "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user