diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index e974975..a5c62f0 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -320,6 +320,7 @@ class AnswerEngine: "question": normalized, "sub_questions": sub_questions, "keywords": keywords, + "keyword_tokens": keyword_tokens, "summary_lines": summary_lines, } metric_keys, must_chunk_ids = await _select_metric_chunks( @@ -405,6 +406,10 @@ class AnswerEngine: ) if not metric_facts and fallback_candidates: metric_facts = fallback_candidates[: max(2, plan.max_subquestions)] + if metric_keys: + key_lines = _lines_for_metric_keys(summary_lines, metric_keys, max_lines=plan.max_subquestions * 3) + if key_lines: + metric_facts = _merge_fact_lines(key_lines, metric_facts) if metric_facts: metric_facts = _ensure_token_coverage( metric_facts, @@ -1288,10 +1293,15 @@ async def _select_metric_chunks( if not keys or not chunks: return [], [] max_keys = max(4, plan.max_subquestions * 2) - prompt = prompts.METRIC_KEYS_PROMPT.format(available="\n".join(keys), max_keys=max_keys) question = ctx.get("question") if isinstance(ctx, dict) else "" sub_questions = ctx.get("sub_questions") if isinstance(ctx, dict) else [] keywords = ctx.get("keywords") if isinstance(ctx, dict) else [] + keyword_tokens = ctx.get("keyword_tokens") if isinstance(ctx, dict) else [] + token_set = set([str(token) for token in keyword_tokens if token]) + token_set |= set(_extract_keywords(str(question), str(question), sub_questions=sub_questions, keywords=keywords)) + candidate_keys = _filter_metric_keys(keys, token_set) + available_keys = candidate_keys or keys + prompt = prompts.METRIC_KEYS_PROMPT.format(available="\n".join(available_keys), max_keys=max_keys) raw = await call_llm( prompts.METRIC_KEYS_SYSTEM, prompt + "\nQuestion: " + str(question) + "\nSubQuestions:\n" + "\n".join([str(item) for item in sub_questions]), @@ -1299,7 +1309,9 @@ async def _select_metric_chunks( model=plan.fast_model, tag="metric_keys", ) - selected = _parse_key_list(raw, keys, max_keys) + selected = _parse_key_list(raw, available_keys, max_keys) + if not selected and candidate_keys: + selected = candidate_keys[:max_keys] if not selected: return [], [] ids = _chunk_ids_for_keys(chunks, selected) @@ -1358,6 +1370,37 @@ def _chunk_ids_for_keys(chunks: list[dict[str, Any]], keys: list[str]) -> list[s return ids +def _filter_metric_keys(keys: list[str], tokens: set[str]) -> list[str]: + if not keys or not tokens: + return [] + lowered_tokens = {token.lower() for token in tokens if token and len(token) >= TOKEN_MIN_LEN} + ranked: list[tuple[int, str]] = [] + for key in keys: + parts = [part for part in re.split(r"[^a-zA-Z0-9_-]+", key.lower()) if part] + if not parts: + continue + hits = len(set(parts) & lowered_tokens) + if hits: + ranked.append((hits, key)) + ranked.sort(key=lambda item: (-item[0], item[1])) + return [item[1] for item in ranked] + + +def _lines_for_metric_keys(lines: list[str], keys: list[str], max_lines: int = 0) -> list[str]: + if not lines or not keys: + return [] + prefixes = {f"{key}:" for key in keys} + selected: list[str] = [] + for line in lines: + for prefix in prefixes: + if line.startswith(prefix): + selected.append(line) + break + if max_lines and len(selected) >= max_lines: + break + return selected + + def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]: seen = set() merged: list[str] = []