atlasbot: filter metric keys by tokens
This commit is contained in:
parent
8ec369fe10
commit
0b6ffc9521
@ -320,6 +320,7 @@ class AnswerEngine:
|
|||||||
"question": normalized,
|
"question": normalized,
|
||||||
"sub_questions": sub_questions,
|
"sub_questions": sub_questions,
|
||||||
"keywords": keywords,
|
"keywords": keywords,
|
||||||
|
"keyword_tokens": keyword_tokens,
|
||||||
"summary_lines": summary_lines,
|
"summary_lines": summary_lines,
|
||||||
}
|
}
|
||||||
metric_keys, must_chunk_ids = await _select_metric_chunks(
|
metric_keys, must_chunk_ids = await _select_metric_chunks(
|
||||||
@ -405,6 +406,10 @@ class AnswerEngine:
|
|||||||
)
|
)
|
||||||
if not metric_facts and fallback_candidates:
|
if not metric_facts and fallback_candidates:
|
||||||
metric_facts = fallback_candidates[: max(2, plan.max_subquestions)]
|
metric_facts = fallback_candidates[: max(2, plan.max_subquestions)]
|
||||||
|
if metric_keys:
|
||||||
|
key_lines = _lines_for_metric_keys(summary_lines, metric_keys, max_lines=plan.max_subquestions * 3)
|
||||||
|
if key_lines:
|
||||||
|
metric_facts = _merge_fact_lines(key_lines, metric_facts)
|
||||||
if metric_facts:
|
if metric_facts:
|
||||||
metric_facts = _ensure_token_coverage(
|
metric_facts = _ensure_token_coverage(
|
||||||
metric_facts,
|
metric_facts,
|
||||||
@ -1288,10 +1293,15 @@ async def _select_metric_chunks(
|
|||||||
if not keys or not chunks:
|
if not keys or not chunks:
|
||||||
return [], []
|
return [], []
|
||||||
max_keys = max(4, plan.max_subquestions * 2)
|
max_keys = max(4, plan.max_subquestions * 2)
|
||||||
prompt = prompts.METRIC_KEYS_PROMPT.format(available="\n".join(keys), max_keys=max_keys)
|
|
||||||
question = ctx.get("question") if isinstance(ctx, dict) else ""
|
question = ctx.get("question") if isinstance(ctx, dict) else ""
|
||||||
sub_questions = ctx.get("sub_questions") if isinstance(ctx, dict) else []
|
sub_questions = ctx.get("sub_questions") if isinstance(ctx, dict) else []
|
||||||
keywords = ctx.get("keywords") if isinstance(ctx, dict) else []
|
keywords = ctx.get("keywords") if isinstance(ctx, dict) else []
|
||||||
|
keyword_tokens = ctx.get("keyword_tokens") if isinstance(ctx, dict) else []
|
||||||
|
token_set = set([str(token) for token in keyword_tokens if token])
|
||||||
|
token_set |= set(_extract_keywords(str(question), str(question), sub_questions=sub_questions, keywords=keywords))
|
||||||
|
candidate_keys = _filter_metric_keys(keys, token_set)
|
||||||
|
available_keys = candidate_keys or keys
|
||||||
|
prompt = prompts.METRIC_KEYS_PROMPT.format(available="\n".join(available_keys), max_keys=max_keys)
|
||||||
raw = await call_llm(
|
raw = await call_llm(
|
||||||
prompts.METRIC_KEYS_SYSTEM,
|
prompts.METRIC_KEYS_SYSTEM,
|
||||||
prompt + "\nQuestion: " + str(question) + "\nSubQuestions:\n" + "\n".join([str(item) for item in sub_questions]),
|
prompt + "\nQuestion: " + str(question) + "\nSubQuestions:\n" + "\n".join([str(item) for item in sub_questions]),
|
||||||
@ -1299,7 +1309,9 @@ async def _select_metric_chunks(
|
|||||||
model=plan.fast_model,
|
model=plan.fast_model,
|
||||||
tag="metric_keys",
|
tag="metric_keys",
|
||||||
)
|
)
|
||||||
selected = _parse_key_list(raw, keys, max_keys)
|
selected = _parse_key_list(raw, available_keys, max_keys)
|
||||||
|
if not selected and candidate_keys:
|
||||||
|
selected = candidate_keys[:max_keys]
|
||||||
if not selected:
|
if not selected:
|
||||||
return [], []
|
return [], []
|
||||||
ids = _chunk_ids_for_keys(chunks, selected)
|
ids = _chunk_ids_for_keys(chunks, selected)
|
||||||
@ -1358,6 +1370,37 @@ def _chunk_ids_for_keys(chunks: list[dict[str, Any]], keys: list[str]) -> list[s
|
|||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_metric_keys(keys: list[str], tokens: set[str]) -> list[str]:
|
||||||
|
if not keys or not tokens:
|
||||||
|
return []
|
||||||
|
lowered_tokens = {token.lower() for token in tokens if token and len(token) >= TOKEN_MIN_LEN}
|
||||||
|
ranked: list[tuple[int, str]] = []
|
||||||
|
for key in keys:
|
||||||
|
parts = [part for part in re.split(r"[^a-zA-Z0-9_-]+", key.lower()) if part]
|
||||||
|
if not parts:
|
||||||
|
continue
|
||||||
|
hits = len(set(parts) & lowered_tokens)
|
||||||
|
if hits:
|
||||||
|
ranked.append((hits, key))
|
||||||
|
ranked.sort(key=lambda item: (-item[0], item[1]))
|
||||||
|
return [item[1] for item in ranked]
|
||||||
|
|
||||||
|
|
||||||
|
def _lines_for_metric_keys(lines: list[str], keys: list[str], max_lines: int = 0) -> list[str]:
|
||||||
|
if not lines or not keys:
|
||||||
|
return []
|
||||||
|
prefixes = {f"{key}:" for key in keys}
|
||||||
|
selected: list[str] = []
|
||||||
|
for line in lines:
|
||||||
|
for prefix in prefixes:
|
||||||
|
if line.startswith(prefix):
|
||||||
|
selected.append(line)
|
||||||
|
break
|
||||||
|
if max_lines and len(selected) >= max_lines:
|
||||||
|
break
|
||||||
|
return selected
|
||||||
|
|
||||||
|
|
||||||
def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]:
|
def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]:
|
||||||
seen = set()
|
seen = set()
|
||||||
merged: list[str] = []
|
merged: list[str] = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user