atlasbot: parallelize independent llm calls

This commit is contained in:
Brad Stein 2026-02-03 22:25:55 -03:00
parent 5dee72cd1a
commit 68d229032d

View File

@ -2,6 +2,7 @@ import asyncio
import json
import logging
import math
import asyncio
import re
import time
import difflib
@ -75,6 +76,7 @@ class ModePlan:
chunk_lines: int
chunk_top: int
chunk_group: int
parallelism: int
use_tool: bool
use_critic: bool
use_gap: bool
@ -320,6 +322,7 @@ class AnswerEngine:
chunks = _chunk_lines(summary_lines, plan.chunk_lines)
metric_keys: list[str] = []
must_chunk_ids: list[str] = []
metric_task = None
if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines:
metric_ctx = {
"question": normalized,
@ -328,13 +331,11 @@ class AnswerEngine:
"keyword_tokens": keyword_tokens,
"summary_lines": summary_lines,
}
metric_keys, must_chunk_ids = await _select_metric_chunks(
call_llm,
metric_ctx,
chunks,
plan,
)
scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan)
metric_task = asyncio.create_task(_select_metric_chunks(call_llm, metric_ctx, chunks, plan))
scored_task = asyncio.create_task(_score_chunks(call_llm, chunks, normalized, sub_questions, plan))
if metric_task:
metric_keys, must_chunk_ids = await metric_task
scored = await scored_task
selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids)
fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12)
key_facts = await _select_fact_lines(
@ -474,30 +475,40 @@ class AnswerEngine:
if observer:
observer("subanswers", "drafting subanswers")
subanswers: list[str] = []
for subq in sub_questions:
async def _subanswer_for(subq: str) -> str:
sub_prompt = prompts.SUBANSWER_PROMPT + "\nQuestion: " + subq
if plan.subanswer_retries > 1:
candidates: list[str] = []
for _ in range(plan.subanswer_retries):
candidate = await call_llm(
candidates = await _gather_limited(
[
call_llm(
prompts.ANSWER_SYSTEM,
sub_prompt,
context=context,
model=plan.model,
tag="subanswer",
)
candidates.append(candidate)
for _ in range(plan.subanswer_retries)
],
plan.parallelism,
)
best_idx = await _select_best_candidate(call_llm, subq, candidates, plan, "subanswer_select")
subanswers.append(candidates[best_idx])
else:
sub_answer = await call_llm(
return candidates[best_idx]
return await call_llm(
prompts.ANSWER_SYSTEM,
sub_prompt,
context=context,
model=plan.model,
tag="subanswer",
)
subanswers.append(sub_answer)
if plan.parallelism > 1 and len(sub_questions) > 1:
subanswers = await _gather_limited(
[_subanswer_for(subq) for subq in sub_questions],
plan.parallelism,
)
else:
for subq in sub_questions:
subanswers.append(await _subanswer_for(subq))
if observer:
observer("synthesize", "synthesizing")
@ -776,8 +787,31 @@ class AnswerEngine:
+ f"\nDraftIndex: {idx + 1}"
)
drafts: list[str] = []
if plan.parallelism > 1 and len(draft_prompts) > 1:
drafts = await _gather_limited(
[
call_llm(
prompts.SYNTHESIZE_SYSTEM,
prompt,
context=context,
model=plan.model,
tag="synth",
)
for prompt in draft_prompts
],
plan.parallelism,
)
else:
for prompt in draft_prompts:
drafts.append(await call_llm(prompts.SYNTHESIZE_SYSTEM, prompt, context=context, model=plan.model, tag="synth"))
drafts.append(
await call_llm(
prompts.SYNTHESIZE_SYSTEM,
prompt,
context=context,
model=plan.model,
tag="synth",
)
)
if len(drafts) == 1:
return drafts[0]
select_prompt = (
@ -1025,6 +1059,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
chunk_lines=6,
chunk_top=10,
chunk_group=4,
parallelism=4,
use_tool=True,
use_critic=True,
use_gap=True,
@ -1041,6 +1076,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
chunk_lines=8,
chunk_top=8,
chunk_group=4,
parallelism=2,
use_tool=True,
use_critic=True,
use_gap=True,
@ -1056,6 +1092,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
chunk_lines=12,
chunk_top=5,
chunk_group=5,
parallelism=1,
use_tool=False,
use_critic=False,
use_gap=False,
@ -1117,14 +1154,23 @@ async def _score_chunks(
scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks}
if not chunks:
return scores
groups: list[list[dict[str, Any]]] = []
group: list[dict[str, Any]] = []
for chunk in chunks:
group.append({"id": chunk["id"], "summary": chunk["summary"]})
if len(group) >= plan.chunk_group:
scores.update(await _score_chunk_group(call_llm, group, question, sub_questions))
groups.append(group)
group = []
if group:
scores.update(await _score_chunk_group(call_llm, group, question, sub_questions))
groups.append(group)
if plan.parallelism <= 1 or len(groups) <= 1:
for grp in groups:
scores.update(await _score_chunk_group(call_llm, grp, question, sub_questions))
return scores
coros = [_score_chunk_group(call_llm, grp, question, sub_questions) for grp in groups]
results = await _gather_limited(coros, plan.parallelism)
for result in results:
scores.update(result)
return scores
@ -1316,9 +1362,11 @@ async def _select_metric_chunks(
if available_keys:
missing = await _validate_metric_keys(
call_llm,
question,
sub_questions,
selected,
{
"question": question,
"sub_questions": sub_questions,
"selected": selected,
},
available_keys,
plan,
)
@ -1332,14 +1380,15 @@ async def _select_metric_chunks(
async def _validate_metric_keys(
call_llm: Callable[..., Awaitable[str]],
question: str,
sub_questions: list[str],
selected: list[str],
ctx: dict[str, Any],
available: list[str],
plan: ModePlan,
) -> list[str]:
if not available:
return []
question = str(ctx.get("question") or "")
sub_questions = ctx.get("sub_questions") if isinstance(ctx.get("sub_questions"), list) else []
selected = ctx.get("selected") if isinstance(ctx.get("selected"), list) else []
cap = max(12, plan.max_subquestions * 4)
available_list = available[:cap]
prompt = prompts.METRIC_KEYS_VALIDATE_PROMPT.format(
@ -1366,6 +1415,18 @@ async def _validate_metric_keys(
return out
async def _gather_limited(coros: list[Awaitable[Any]], limit: int) -> list[Any]:
if not coros:
return []
semaphore = asyncio.Semaphore(max(1, limit))
async def runner(coro: Awaitable[Any]) -> Any:
async with semaphore:
return await coro
return await asyncio.gather(*(runner(coro) for coro in coros))
def _metric_ctx_values(ctx: dict[str, Any]) -> tuple[list[str], str, list[str], list[str], set[str]]:
summary_lines = ctx.get("summary_lines") if isinstance(ctx, dict) else None
if not isinstance(summary_lines, list):