atlasbot: filter metric keys by tokens

This commit is contained in:
Brad Stein 2026-02-03 21:30:37 -03:00
parent 8ec369fe10
commit 0b6ffc9521

View File

@ -320,6 +320,7 @@ class AnswerEngine:
"question": normalized, "question": normalized,
"sub_questions": sub_questions, "sub_questions": sub_questions,
"keywords": keywords, "keywords": keywords,
"keyword_tokens": keyword_tokens,
"summary_lines": summary_lines, "summary_lines": summary_lines,
} }
metric_keys, must_chunk_ids = await _select_metric_chunks( metric_keys, must_chunk_ids = await _select_metric_chunks(
@ -405,6 +406,10 @@ class AnswerEngine:
) )
if not metric_facts and fallback_candidates: if not metric_facts and fallback_candidates:
metric_facts = fallback_candidates[: max(2, plan.max_subquestions)] metric_facts = fallback_candidates[: max(2, plan.max_subquestions)]
if metric_keys:
key_lines = _lines_for_metric_keys(summary_lines, metric_keys, max_lines=plan.max_subquestions * 3)
if key_lines:
metric_facts = _merge_fact_lines(key_lines, metric_facts)
if metric_facts: if metric_facts:
metric_facts = _ensure_token_coverage( metric_facts = _ensure_token_coverage(
metric_facts, metric_facts,
@ -1288,10 +1293,15 @@ async def _select_metric_chunks(
if not keys or not chunks: if not keys or not chunks:
return [], [] return [], []
max_keys = max(4, plan.max_subquestions * 2) max_keys = max(4, plan.max_subquestions * 2)
prompt = prompts.METRIC_KEYS_PROMPT.format(available="\n".join(keys), max_keys=max_keys)
question = ctx.get("question") if isinstance(ctx, dict) else "" question = ctx.get("question") if isinstance(ctx, dict) else ""
sub_questions = ctx.get("sub_questions") if isinstance(ctx, dict) else [] sub_questions = ctx.get("sub_questions") if isinstance(ctx, dict) else []
keywords = ctx.get("keywords") if isinstance(ctx, dict) else [] keywords = ctx.get("keywords") if isinstance(ctx, dict) else []
keyword_tokens = ctx.get("keyword_tokens") if isinstance(ctx, dict) else []
token_set = set([str(token) for token in keyword_tokens if token])
token_set |= set(_extract_keywords(str(question), str(question), sub_questions=sub_questions, keywords=keywords))
candidate_keys = _filter_metric_keys(keys, token_set)
available_keys = candidate_keys or keys
prompt = prompts.METRIC_KEYS_PROMPT.format(available="\n".join(available_keys), max_keys=max_keys)
raw = await call_llm( raw = await call_llm(
prompts.METRIC_KEYS_SYSTEM, prompts.METRIC_KEYS_SYSTEM,
prompt + "\nQuestion: " + str(question) + "\nSubQuestions:\n" + "\n".join([str(item) for item in sub_questions]), prompt + "\nQuestion: " + str(question) + "\nSubQuestions:\n" + "\n".join([str(item) for item in sub_questions]),
@ -1299,7 +1309,9 @@ async def _select_metric_chunks(
model=plan.fast_model, model=plan.fast_model,
tag="metric_keys", tag="metric_keys",
) )
selected = _parse_key_list(raw, keys, max_keys) selected = _parse_key_list(raw, available_keys, max_keys)
if not selected and candidate_keys:
selected = candidate_keys[:max_keys]
if not selected: if not selected:
return [], [] return [], []
ids = _chunk_ids_for_keys(chunks, selected) ids = _chunk_ids_for_keys(chunks, selected)
@ -1358,6 +1370,37 @@ def _chunk_ids_for_keys(chunks: list[dict[str, Any]], keys: list[str]) -> list[s
return ids return ids
def _filter_metric_keys(keys: list[str], tokens: set[str]) -> list[str]:
if not keys or not tokens:
return []
lowered_tokens = {token.lower() for token in tokens if token and len(token) >= TOKEN_MIN_LEN}
ranked: list[tuple[int, str]] = []
for key in keys:
parts = [part for part in re.split(r"[^a-zA-Z0-9_-]+", key.lower()) if part]
if not parts:
continue
hits = len(set(parts) & lowered_tokens)
if hits:
ranked.append((hits, key))
ranked.sort(key=lambda item: (-item[0], item[1]))
return [item[1] for item in ranked]
def _lines_for_metric_keys(lines: list[str], keys: list[str], max_lines: int = 0) -> list[str]:
if not lines or not keys:
return []
prefixes = {f"{key}:" for key in keys}
selected: list[str] = []
for line in lines:
for prefix in prefixes:
if line.startswith(prefix):
selected.append(line)
break
if max_lines and len(selected) >= max_lines:
break
return selected
def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]: def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]:
seen = set() seen = set()
merged: list[str] = [] merged: list[str] = []