atlasbot: force metric key chunks via llm
This commit is contained in:
parent
c6a3eda478
commit
ecb0776b3e
@ -312,8 +312,23 @@ class AnswerEngine:
|
|||||||
if observer:
|
if observer:
|
||||||
observer("retrieve", "scoring chunks")
|
observer("retrieve", "scoring chunks")
|
||||||
chunks = _chunk_lines(summary_lines, plan.chunk_lines)
|
chunks = _chunk_lines(summary_lines, plan.chunk_lines)
|
||||||
|
metric_keys: list[str] = []
|
||||||
|
must_chunk_ids: list[str] = []
|
||||||
|
if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines:
|
||||||
|
metric_ctx = {
|
||||||
|
"question": normalized,
|
||||||
|
"sub_questions": sub_questions,
|
||||||
|
"keywords": keywords,
|
||||||
|
"summary_lines": summary_lines,
|
||||||
|
}
|
||||||
|
metric_keys, must_chunk_ids = await _select_metric_chunks(
|
||||||
|
call_llm,
|
||||||
|
metric_ctx,
|
||||||
|
chunks,
|
||||||
|
plan,
|
||||||
|
)
|
||||||
scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan)
|
scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan)
|
||||||
selected = _select_chunks(chunks, scored, plan, keyword_tokens)
|
selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids)
|
||||||
fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12)
|
fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12)
|
||||||
key_facts = await _select_fact_lines(
|
key_facts = await _select_fact_lines(
|
||||||
call_llm,
|
call_llm,
|
||||||
@ -422,6 +437,8 @@ class AnswerEngine:
|
|||||||
{
|
{
|
||||||
"selected_ids": [item["id"] for item in selected],
|
"selected_ids": [item["id"] for item in selected],
|
||||||
"top_scored": scored_preview,
|
"top_scored": scored_preview,
|
||||||
|
"metric_keys": metric_keys,
|
||||||
|
"forced_chunks": must_chunk_ids,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
facts_used = list(dict.fromkeys(key_facts)) if key_facts else list(dict.fromkeys(metric_facts))
|
facts_used = list(dict.fromkeys(key_facts)) if key_facts else list(dict.fromkeys(metric_facts))
|
||||||
@ -1157,6 +1174,7 @@ def _select_chunks(
|
|||||||
scores: dict[str, float],
|
scores: dict[str, float],
|
||||||
plan: ModePlan,
|
plan: ModePlan,
|
||||||
keywords: list[str] | None = None,
|
keywords: list[str] | None = None,
|
||||||
|
must_ids: list[str] | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return []
|
return []
|
||||||
@ -1164,6 +1182,14 @@ def _select_chunks(
|
|||||||
selected: list[dict[str, Any]] = []
|
selected: list[dict[str, Any]] = []
|
||||||
head = chunks[0]
|
head = chunks[0]
|
||||||
selected.append(head)
|
selected.append(head)
|
||||||
|
id_map = {item["id"]: item for item in chunks}
|
||||||
|
if must_ids:
|
||||||
|
for cid in must_ids:
|
||||||
|
item = id_map.get(cid)
|
||||||
|
if item and item not in selected:
|
||||||
|
selected.append(item)
|
||||||
|
if len(selected) >= plan.chunk_top:
|
||||||
|
return selected
|
||||||
|
|
||||||
for item in _keyword_hits(ranked, head, keywords):
|
for item in _keyword_hits(ranked, head, keywords):
|
||||||
if len(selected) >= plan.chunk_top:
|
if len(selected) >= plan.chunk_top:
|
||||||
@ -1218,6 +1244,89 @@ def _summary_lines(snapshot: dict[str, Any] | None) -> list[str]:
|
|||||||
return [line for line in text.splitlines() if line.strip()]
|
return [line for line in text.splitlines() if line.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
async def _select_metric_chunks(
|
||||||
|
call_llm: Callable[..., Awaitable[str]],
|
||||||
|
ctx: dict[str, Any],
|
||||||
|
chunks: list[dict[str, Any]],
|
||||||
|
plan: ModePlan,
|
||||||
|
) -> tuple[list[str], list[str]]:
|
||||||
|
summary_lines = ctx.get("summary_lines") if isinstance(ctx, dict) else None
|
||||||
|
if not isinstance(summary_lines, list):
|
||||||
|
return [], []
|
||||||
|
keys = _extract_metric_keys(summary_lines)
|
||||||
|
if not keys or not chunks:
|
||||||
|
return [], []
|
||||||
|
max_keys = max(4, plan.max_subquestions * 2)
|
||||||
|
prompt = prompts.METRIC_KEYS_PROMPT.format(available="\n".join(keys), max_keys=max_keys)
|
||||||
|
question = ctx.get("question") if isinstance(ctx, dict) else ""
|
||||||
|
sub_questions = ctx.get("sub_questions") if isinstance(ctx, dict) else []
|
||||||
|
keywords = ctx.get("keywords") if isinstance(ctx, dict) else []
|
||||||
|
raw = await call_llm(
|
||||||
|
prompts.METRIC_KEYS_SYSTEM,
|
||||||
|
prompt + "\nQuestion: " + str(question) + "\nSubQuestions:\n" + "\n".join([str(item) for item in sub_questions]),
|
||||||
|
context="Keywords:\n" + ", ".join([str(item) for item in keywords if item]),
|
||||||
|
model=plan.fast_model,
|
||||||
|
tag="metric_keys",
|
||||||
|
)
|
||||||
|
selected = _parse_key_list(raw, keys, max_keys)
|
||||||
|
if not selected:
|
||||||
|
return [], []
|
||||||
|
ids = _chunk_ids_for_keys(chunks, selected)
|
||||||
|
return selected, ids
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_metric_keys(lines: list[str]) -> list[str]:
|
||||||
|
keys: list[str] = []
|
||||||
|
for line in lines:
|
||||||
|
if ":" not in line:
|
||||||
|
continue
|
||||||
|
key = line.split(":", 1)[0].strip()
|
||||||
|
if not key or " " in key:
|
||||||
|
continue
|
||||||
|
if key not in keys:
|
||||||
|
keys.append(key)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_key_list(raw: str, allowed: list[str], max_keys: int) -> list[str]:
|
||||||
|
parsed = _parse_json_block(raw, fallback={})
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
items = parsed
|
||||||
|
else:
|
||||||
|
items = parsed.get("keys") if isinstance(parsed, dict) else []
|
||||||
|
if not isinstance(items, list):
|
||||||
|
return []
|
||||||
|
allowed_set = set(allowed)
|
||||||
|
out: list[str] = []
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
continue
|
||||||
|
if item in allowed_set and item not in out:
|
||||||
|
out.append(item)
|
||||||
|
if len(out) >= max_keys:
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_ids_for_keys(chunks: list[dict[str, Any]], keys: list[str]) -> list[str]:
|
||||||
|
if not keys:
|
||||||
|
return []
|
||||||
|
ids: list[str] = []
|
||||||
|
key_set = {f"{key}:" for key in keys}
|
||||||
|
for chunk in chunks:
|
||||||
|
text = str(chunk.get("text") or "")
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
for line in text.splitlines():
|
||||||
|
for key in key_set:
|
||||||
|
if line.startswith(key):
|
||||||
|
cid = chunk.get("id")
|
||||||
|
if cid and cid not in ids:
|
||||||
|
ids.append(cid)
|
||||||
|
break
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]:
|
def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]:
|
||||||
seen = set()
|
seen = set()
|
||||||
merged: list[str] = []
|
merged: list[str] = []
|
||||||
|
|||||||
@ -69,6 +69,20 @@ METRIC_PREFIX_PROMPT = (
|
|||||||
"Only use values from AvailablePrefixes."
|
"Only use values from AvailablePrefixes."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
METRIC_KEYS_SYSTEM = (
|
||||||
|
CLUSTER_SYSTEM
|
||||||
|
+ " Select the metric keys required to answer the question. "
|
||||||
|
+ "Return JSON only."
|
||||||
|
)
|
||||||
|
|
||||||
|
METRIC_KEYS_PROMPT = (
|
||||||
|
"AvailableKeys:\n{available}\n\n"
|
||||||
|
"Return JSON with field: keys (list). "
|
||||||
|
"Choose only keys needed to answer the question. "
|
||||||
|
"If none apply, return an empty list. "
|
||||||
|
"Limit to at most {max_keys} keys."
|
||||||
|
)
|
||||||
|
|
||||||
TOOL_SYSTEM = (
|
TOOL_SYSTEM = (
|
||||||
CLUSTER_SYSTEM
|
CLUSTER_SYSTEM
|
||||||
+ " Suggest a safe, read-only command that could refine the answer. "
|
+ " Suggest a safe, read-only command that could refine the answer. "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user