diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index 4ea70b8..088d4a3 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -46,6 +46,17 @@ class AnswerResult: meta: dict[str, Any] +@dataclass(frozen=True) +class InsightGuardInput: + question: str + reply: str + classify: dict[str, Any] + context: str + plan: "ModePlan" + call_llm: Callable[..., Awaitable[str]] + facts: list[str] + + @dataclass class EvidenceItem: path: str @@ -783,13 +794,15 @@ class AnswerEngine: if observer: observer("insight_guard", "checking for concrete signals") reply = await _apply_insight_guard( - normalized, - reply, - classify, - context, - plan, - call_llm, - metric_facts or key_facts, + InsightGuardInput( + question=normalized, + reply=reply, + classify=classify, + context=context, + plan=plan, + call_llm=call_llm, + facts=metric_facts or key_facts, + ) ) if plan.use_critic: @@ -2447,36 +2460,28 @@ def _should_use_insight_guard(classify: dict[str, Any]) -> bool: return style == "insightful" or qtype in {"open_ended", "planning"} -async def _apply_insight_guard( - question: str, - reply: str, - classify: dict[str, Any], - context: str, - plan: ModePlan, - call_llm: Callable[..., Awaitable[str]], - facts: list[str], -) -> str: - if not reply or not _should_use_insight_guard(classify): - return reply - guard_prompt = prompts.INSIGHT_GUARD_PROMPT.format(question=question, answer=reply) - guard_raw = await call_llm( +async def _apply_insight_guard(inputs: InsightGuardInput) -> str: + if not inputs.reply or not _should_use_insight_guard(inputs.classify): + return inputs.reply + guard_prompt = prompts.INSIGHT_GUARD_PROMPT.format(question=inputs.question, answer=inputs.reply) + guard_raw = await inputs.call_llm( prompts.INSIGHT_GUARD_SYSTEM, guard_prompt, - context=context, - model=plan.fast_model, + context=inputs.context, + model=inputs.plan.fast_model, tag="insight_guard", ) guard = _parse_json_block(guard_raw, fallback={}) if guard.get("ok") is True: - return reply - fix_prompt = prompts.INSIGHT_FIX_PROMPT.format(question=question, answer=reply) - if facts: - fix_prompt = fix_prompt + "\nFacts:\n" + "\n".join(facts[:6]) - return await call_llm( + return inputs.reply + fix_prompt = prompts.INSIGHT_FIX_PROMPT.format(question=inputs.question, answer=inputs.reply) + if inputs.facts: + fix_prompt = fix_prompt + "\nFacts:\n" + "\n".join(inputs.facts[:6]) + return await inputs.call_llm( prompts.INSIGHT_FIX_SYSTEM, fix_prompt, - context=context, - model=plan.model, + context=inputs.context, + model=inputs.plan.model, tag="insight_fix", )