diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index 61b145b..ae3efd2 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -2,6 +2,7 @@ import asyncio import json import logging import math +import asyncio import re import time import difflib @@ -75,6 +76,7 @@ class ModePlan: chunk_lines: int chunk_top: int chunk_group: int + parallelism: int use_tool: bool use_critic: bool use_gap: bool @@ -320,6 +322,7 @@ class AnswerEngine: chunks = _chunk_lines(summary_lines, plan.chunk_lines) metric_keys: list[str] = [] must_chunk_ids: list[str] = [] + metric_task = None if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines: metric_ctx = { "question": normalized, @@ -328,13 +331,11 @@ class AnswerEngine: "keyword_tokens": keyword_tokens, "summary_lines": summary_lines, } - metric_keys, must_chunk_ids = await _select_metric_chunks( - call_llm, - metric_ctx, - chunks, - plan, - ) - scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan) + metric_task = asyncio.create_task(_select_metric_chunks(call_llm, metric_ctx, chunks, plan)) + scored_task = asyncio.create_task(_score_chunks(call_llm, chunks, normalized, sub_questions, plan)) + if metric_task: + metric_keys, must_chunk_ids = await metric_task + scored = await scored_task selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids) fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12) key_facts = await _select_fact_lines( @@ -474,30 +475,40 @@ class AnswerEngine: if observer: observer("subanswers", "drafting subanswers") subanswers: list[str] = [] - for subq in sub_questions: + async def _subanswer_for(subq: str) -> str: sub_prompt = prompts.SUBANSWER_PROMPT + "\nQuestion: " + subq if plan.subanswer_retries > 1: - candidates: list[str] = [] - for _ in range(plan.subanswer_retries): - candidate = await call_llm( - prompts.ANSWER_SYSTEM, - sub_prompt, - context=context, - model=plan.model, - tag="subanswer", - ) - candidates.append(candidate) - best_idx = await _select_best_candidate(call_llm, subq, candidates, plan, "subanswer_select") - subanswers.append(candidates[best_idx]) - else: - sub_answer = await call_llm( - prompts.ANSWER_SYSTEM, - sub_prompt, - context=context, - model=plan.model, - tag="subanswer", + candidates = await _gather_limited( + [ + call_llm( + prompts.ANSWER_SYSTEM, + sub_prompt, + context=context, + model=plan.model, + tag="subanswer", + ) + for _ in range(plan.subanswer_retries) + ], + plan.parallelism, ) - subanswers.append(sub_answer) + best_idx = await _select_best_candidate(call_llm, subq, candidates, plan, "subanswer_select") + return candidates[best_idx] + return await call_llm( + prompts.ANSWER_SYSTEM, + sub_prompt, + context=context, + model=plan.model, + tag="subanswer", + ) + + if plan.parallelism > 1 and len(sub_questions) > 1: + subanswers = await _gather_limited( + [_subanswer_for(subq) for subq in sub_questions], + plan.parallelism, + ) + else: + for subq in sub_questions: + subanswers.append(await _subanswer_for(subq)) if observer: observer("synthesize", "synthesizing") @@ -776,8 +787,31 @@ class AnswerEngine: + f"\nDraftIndex: {idx + 1}" ) drafts: list[str] = [] - for prompt in draft_prompts: - drafts.append(await call_llm(prompts.SYNTHESIZE_SYSTEM, prompt, context=context, model=plan.model, tag="synth")) + if plan.parallelism > 1 and len(draft_prompts) > 1: + drafts = await _gather_limited( + [ + call_llm( + prompts.SYNTHESIZE_SYSTEM, + prompt, + context=context, + model=plan.model, + tag="synth", + ) + for prompt in draft_prompts + ], + plan.parallelism, + ) + else: + for prompt in draft_prompts: + drafts.append( + await call_llm( + prompts.SYNTHESIZE_SYSTEM, + prompt, + context=context, + model=plan.model, + tag="synth", + ) + ) if len(drafts) == 1: return drafts[0] select_prompt = ( @@ -1025,6 +1059,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan: chunk_lines=6, chunk_top=10, chunk_group=4, + parallelism=4, use_tool=True, use_critic=True, use_gap=True, @@ -1041,6 +1076,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan: chunk_lines=8, chunk_top=8, chunk_group=4, + parallelism=2, use_tool=True, use_critic=True, use_gap=True, @@ -1056,6 +1092,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan: chunk_lines=12, chunk_top=5, chunk_group=5, + parallelism=1, use_tool=False, use_critic=False, use_gap=False, @@ -1117,14 +1154,23 @@ async def _score_chunks( scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks} if not chunks: return scores + groups: list[list[dict[str, Any]]] = [] group: list[dict[str, Any]] = [] for chunk in chunks: group.append({"id": chunk["id"], "summary": chunk["summary"]}) if len(group) >= plan.chunk_group: - scores.update(await _score_chunk_group(call_llm, group, question, sub_questions)) + groups.append(group) group = [] if group: - scores.update(await _score_chunk_group(call_llm, group, question, sub_questions)) + groups.append(group) + if plan.parallelism <= 1 or len(groups) <= 1: + for grp in groups: + scores.update(await _score_chunk_group(call_llm, grp, question, sub_questions)) + return scores + coros = [_score_chunk_group(call_llm, grp, question, sub_questions) for grp in groups] + results = await _gather_limited(coros, plan.parallelism) + for result in results: + scores.update(result) return scores @@ -1316,9 +1362,11 @@ async def _select_metric_chunks( if available_keys: missing = await _validate_metric_keys( call_llm, - question, - sub_questions, - selected, + { + "question": question, + "sub_questions": sub_questions, + "selected": selected, + }, available_keys, plan, ) @@ -1332,14 +1380,15 @@ async def _select_metric_chunks( async def _validate_metric_keys( call_llm: Callable[..., Awaitable[str]], - question: str, - sub_questions: list[str], - selected: list[str], + ctx: dict[str, Any], available: list[str], plan: ModePlan, ) -> list[str]: if not available: return [] + question = str(ctx.get("question") or "") + sub_questions = ctx.get("sub_questions") if isinstance(ctx.get("sub_questions"), list) else [] + selected = ctx.get("selected") if isinstance(ctx.get("selected"), list) else [] cap = max(12, plan.max_subquestions * 4) available_list = available[:cap] prompt = prompts.METRIC_KEYS_VALIDATE_PROMPT.format( @@ -1366,6 +1415,18 @@ async def _validate_metric_keys( return out +async def _gather_limited(coros: list[Awaitable[Any]], limit: int) -> list[Any]: + if not coros: + return [] + semaphore = asyncio.Semaphore(max(1, limit)) + + async def runner(coro: Awaitable[Any]) -> Any: + async with semaphore: + return await coro + + return await asyncio.gather(*(runner(coro) for coro in coros)) + + def _metric_ctx_values(ctx: dict[str, Any]) -> tuple[list[str], str, list[str], list[str], set[str]]: summary_lines = ctx.get("summary_lines") if isinstance(ctx, dict) else None if not isinstance(summary_lines, list):