diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index 0cef0b9..2205043 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -312,8 +312,23 @@ class AnswerEngine: if observer: observer("retrieve", "scoring chunks") chunks = _chunk_lines(summary_lines, plan.chunk_lines) + metric_keys: list[str] = [] + must_chunk_ids: list[str] = [] + if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines: + metric_ctx = { + "question": normalized, + "sub_questions": sub_questions, + "keywords": keywords, + "summary_lines": summary_lines, + } + metric_keys, must_chunk_ids = await _select_metric_chunks( + call_llm, + metric_ctx, + chunks, + plan, + ) scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan) - selected = _select_chunks(chunks, scored, plan, keyword_tokens) + selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids) fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12) key_facts = await _select_fact_lines( call_llm, @@ -422,6 +437,8 @@ class AnswerEngine: { "selected_ids": [item["id"] for item in selected], "top_scored": scored_preview, + "metric_keys": metric_keys, + "forced_chunks": must_chunk_ids, }, ) facts_used = list(dict.fromkeys(key_facts)) if key_facts else list(dict.fromkeys(metric_facts)) @@ -1157,6 +1174,7 @@ def _select_chunks( scores: dict[str, float], plan: ModePlan, keywords: list[str] | None = None, + must_ids: list[str] | None = None, ) -> list[dict[str, Any]]: if not chunks: return [] @@ -1164,6 +1182,14 @@ def _select_chunks( selected: list[dict[str, Any]] = [] head = chunks[0] selected.append(head) + id_map = {item["id"]: item for item in chunks} + if must_ids: + for cid in must_ids: + item = id_map.get(cid) + if item and item not in selected: + selected.append(item) + if len(selected) >= plan.chunk_top: + return selected for item in _keyword_hits(ranked, head, keywords): if len(selected) >= plan.chunk_top: @@ -1218,6 +1244,89 @@ def _summary_lines(snapshot: dict[str, Any] | None) -> list[str]: return [line for line in text.splitlines() if line.strip()] +async def _select_metric_chunks( + call_llm: Callable[..., Awaitable[str]], + ctx: dict[str, Any], + chunks: list[dict[str, Any]], + plan: ModePlan, +) -> tuple[list[str], list[str]]: + summary_lines = ctx.get("summary_lines") if isinstance(ctx, dict) else None + if not isinstance(summary_lines, list): + return [], [] + keys = _extract_metric_keys(summary_lines) + if not keys or not chunks: + return [], [] + max_keys = max(4, plan.max_subquestions * 2) + prompt = prompts.METRIC_KEYS_PROMPT.format(available="\n".join(keys), max_keys=max_keys) + question = ctx.get("question") if isinstance(ctx, dict) else "" + sub_questions = ctx.get("sub_questions") if isinstance(ctx, dict) else [] + keywords = ctx.get("keywords") if isinstance(ctx, dict) else [] + raw = await call_llm( + prompts.METRIC_KEYS_SYSTEM, + prompt + "\nQuestion: " + str(question) + "\nSubQuestions:\n" + "\n".join([str(item) for item in sub_questions]), + context="Keywords:\n" + ", ".join([str(item) for item in keywords if item]), + model=plan.fast_model, + tag="metric_keys", + ) + selected = _parse_key_list(raw, keys, max_keys) + if not selected: + return [], [] + ids = _chunk_ids_for_keys(chunks, selected) + return selected, ids + + +def _extract_metric_keys(lines: list[str]) -> list[str]: + keys: list[str] = [] + for line in lines: + if ":" not in line: + continue + key = line.split(":", 1)[0].strip() + if not key or " " in key: + continue + if key not in keys: + keys.append(key) + return keys + + +def _parse_key_list(raw: str, allowed: list[str], max_keys: int) -> list[str]: + parsed = _parse_json_block(raw, fallback={}) + if isinstance(parsed, list): + items = parsed + else: + items = parsed.get("keys") if isinstance(parsed, dict) else [] + if not isinstance(items, list): + return [] + allowed_set = set(allowed) + out: list[str] = [] + for item in items: + if not isinstance(item, str): + continue + if item in allowed_set and item not in out: + out.append(item) + if len(out) >= max_keys: + break + return out + + +def _chunk_ids_for_keys(chunks: list[dict[str, Any]], keys: list[str]) -> list[str]: + if not keys: + return [] + ids: list[str] = [] + key_set = {f"{key}:" for key in keys} + for chunk in chunks: + text = str(chunk.get("text") or "") + if not text: + continue + for line in text.splitlines(): + for key in key_set: + if line.startswith(key): + cid = chunk.get("id") + if cid and cid not in ids: + ids.append(cid) + break + return ids + + def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]: seen = set() merged: list[str] = [] diff --git a/atlasbot/llm/prompts.py b/atlasbot/llm/prompts.py index 4f24e2f..d3accf1 100644 --- a/atlasbot/llm/prompts.py +++ b/atlasbot/llm/prompts.py @@ -69,6 +69,20 @@ METRIC_PREFIX_PROMPT = ( "Only use values from AvailablePrefixes." ) +METRIC_KEYS_SYSTEM = ( + CLUSTER_SYSTEM + + " Select the metric keys required to answer the question. " + + "Return JSON only." +) + +METRIC_KEYS_PROMPT = ( + "AvailableKeys:\n{available}\n\n" + "Return JSON with field: keys (list). " + "Choose only keys needed to answer the question. " + "If none apply, return an empty list. " + "Limit to at most {max_keys} keys." +) + TOOL_SYSTEM = ( CLUSTER_SYSTEM + " Suggest a safe, read-only command that could refine the answer. "