diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index 187caa6..835981d 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -78,6 +78,7 @@ class ModePlan: chunk_group: int parallelism: int score_retries: int + use_deep_retrieval: bool use_tool: bool use_critic: bool use_gap: bool @@ -360,57 +361,59 @@ class AnswerEngine: ) metric_facts: list[str] = [] if classify.get("question_type") in {"metric", "diagnostic"} or force_metric: - if observer: - observer("retrieve", "extracting fact types") - fact_types = await _extract_fact_types( - call_llm, - normalized, - keyword_tokens, - plan, - ) - if observer: - observer("retrieve", "deriving signals") - signals = await _derive_signals( - call_llm, - normalized, - fact_types, - plan, - ) - if isinstance(signals, list): - signal_tokens = [str(item) for item in signals if item] all_tokens = _merge_tokens(signal_tokens, keyword_tokens, question_tokens) - if observer: - observer("retrieve", "scanning chunks") - candidate_lines: list[str] = [] - if signals: - for chunk in selected: - chunk_lines = chunk["text"].splitlines() - if not chunk_lines: - continue - hits = await _scan_chunk_for_signals( - call_llm, - normalized, - signals, - chunk_lines, - plan, - ) - if hits: - candidate_lines.extend(hits) - candidate_lines = list(dict.fromkeys(candidate_lines)) - if candidate_lines: + if plan.use_deep_retrieval: if observer: - observer("retrieve", "pruning candidates") - metric_facts = await _prune_metric_candidates( + observer("retrieve", "extracting fact types") + fact_types = await _extract_fact_types( call_llm, normalized, - candidate_lines, + keyword_tokens, plan, - plan.metric_retries, ) - if metric_facts: - key_facts = _merge_fact_lines(metric_facts, key_facts) - if self._settings.debug_pipeline: - _debug_log("metric_facts_selected", {"facts": metric_facts}) + if observer: + observer("retrieve", "deriving signals") + signals = await _derive_signals( + call_llm, + normalized, + fact_types, + plan, + ) + if isinstance(signals, list): + signal_tokens = [str(item) for item in signals if item] + all_tokens = _merge_tokens(signal_tokens, keyword_tokens, question_tokens) + if observer: + observer("retrieve", "scanning chunks") + candidate_lines: list[str] = [] + if signals: + for chunk in selected: + chunk_lines = chunk["text"].splitlines() + if not chunk_lines: + continue + hits = await _scan_chunk_for_signals( + call_llm, + normalized, + signals, + chunk_lines, + plan, + ) + if hits: + candidate_lines.extend(hits) + candidate_lines = list(dict.fromkeys(candidate_lines)) + if candidate_lines: + if observer: + observer("retrieve", "pruning candidates") + metric_facts = await _prune_metric_candidates( + call_llm, + normalized, + candidate_lines, + plan, + plan.metric_retries, + ) + if metric_facts: + key_facts = _merge_fact_lines(metric_facts, key_facts) + if self._settings.debug_pipeline: + _debug_log("metric_facts_selected", {"facts": metric_facts}) if not metric_facts: if observer: observer("retrieve", "fallback metric selection") @@ -1074,6 +1077,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan: chunk_group=4, parallelism=4, score_retries=3, + use_deep_retrieval=True, use_tool=True, use_critic=True, use_gap=True, @@ -1092,6 +1096,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan: chunk_group=4, parallelism=2, score_retries=2, + use_deep_retrieval=True, use_tool=True, use_critic=True, use_gap=True, @@ -1109,6 +1114,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan: chunk_group=5, parallelism=1, score_retries=1, + use_deep_retrieval=False, use_tool=False, use_critic=False, use_gap=False,