atlasbot: parallelize independent llm calls

This commit is contained in:
Brad Stein 2026-02-03 22:25:55 -03:00
parent 5dee72cd1a
commit 68d229032d

View File

@ -2,6 +2,7 @@ import asyncio
import json import json
import logging import logging
import math import math
import asyncio
import re import re
import time import time
import difflib import difflib
@ -75,6 +76,7 @@ class ModePlan:
chunk_lines: int chunk_lines: int
chunk_top: int chunk_top: int
chunk_group: int chunk_group: int
parallelism: int
use_tool: bool use_tool: bool
use_critic: bool use_critic: bool
use_gap: bool use_gap: bool
@ -320,6 +322,7 @@ class AnswerEngine:
chunks = _chunk_lines(summary_lines, plan.chunk_lines) chunks = _chunk_lines(summary_lines, plan.chunk_lines)
metric_keys: list[str] = [] metric_keys: list[str] = []
must_chunk_ids: list[str] = [] must_chunk_ids: list[str] = []
metric_task = None
if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines: if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines:
metric_ctx = { metric_ctx = {
"question": normalized, "question": normalized,
@ -328,13 +331,11 @@ class AnswerEngine:
"keyword_tokens": keyword_tokens, "keyword_tokens": keyword_tokens,
"summary_lines": summary_lines, "summary_lines": summary_lines,
} }
metric_keys, must_chunk_ids = await _select_metric_chunks( metric_task = asyncio.create_task(_select_metric_chunks(call_llm, metric_ctx, chunks, plan))
call_llm, scored_task = asyncio.create_task(_score_chunks(call_llm, chunks, normalized, sub_questions, plan))
metric_ctx, if metric_task:
chunks, metric_keys, must_chunk_ids = await metric_task
plan, scored = await scored_task
)
scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan)
selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids) selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids)
fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12) fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12)
key_facts = await _select_fact_lines( key_facts = await _select_fact_lines(
@ -474,30 +475,40 @@ class AnswerEngine:
if observer: if observer:
observer("subanswers", "drafting subanswers") observer("subanswers", "drafting subanswers")
subanswers: list[str] = [] subanswers: list[str] = []
for subq in sub_questions: async def _subanswer_for(subq: str) -> str:
sub_prompt = prompts.SUBANSWER_PROMPT + "\nQuestion: " + subq sub_prompt = prompts.SUBANSWER_PROMPT + "\nQuestion: " + subq
if plan.subanswer_retries > 1: if plan.subanswer_retries > 1:
candidates: list[str] = [] candidates = await _gather_limited(
for _ in range(plan.subanswer_retries): [
candidate = await call_llm( call_llm(
prompts.ANSWER_SYSTEM, prompts.ANSWER_SYSTEM,
sub_prompt, sub_prompt,
context=context, context=context,
model=plan.model, model=plan.model,
tag="subanswer", tag="subanswer",
) )
candidates.append(candidate) for _ in range(plan.subanswer_retries)
best_idx = await _select_best_candidate(call_llm, subq, candidates, plan, "subanswer_select") ],
subanswers.append(candidates[best_idx]) plan.parallelism,
else:
sub_answer = await call_llm(
prompts.ANSWER_SYSTEM,
sub_prompt,
context=context,
model=plan.model,
tag="subanswer",
) )
subanswers.append(sub_answer) best_idx = await _select_best_candidate(call_llm, subq, candidates, plan, "subanswer_select")
return candidates[best_idx]
return await call_llm(
prompts.ANSWER_SYSTEM,
sub_prompt,
context=context,
model=plan.model,
tag="subanswer",
)
if plan.parallelism > 1 and len(sub_questions) > 1:
subanswers = await _gather_limited(
[_subanswer_for(subq) for subq in sub_questions],
plan.parallelism,
)
else:
for subq in sub_questions:
subanswers.append(await _subanswer_for(subq))
if observer: if observer:
observer("synthesize", "synthesizing") observer("synthesize", "synthesizing")
@ -776,8 +787,31 @@ class AnswerEngine:
+ f"\nDraftIndex: {idx + 1}" + f"\nDraftIndex: {idx + 1}"
) )
drafts: list[str] = [] drafts: list[str] = []
for prompt in draft_prompts: if plan.parallelism > 1 and len(draft_prompts) > 1:
drafts.append(await call_llm(prompts.SYNTHESIZE_SYSTEM, prompt, context=context, model=plan.model, tag="synth")) drafts = await _gather_limited(
[
call_llm(
prompts.SYNTHESIZE_SYSTEM,
prompt,
context=context,
model=plan.model,
tag="synth",
)
for prompt in draft_prompts
],
plan.parallelism,
)
else:
for prompt in draft_prompts:
drafts.append(
await call_llm(
prompts.SYNTHESIZE_SYSTEM,
prompt,
context=context,
model=plan.model,
tag="synth",
)
)
if len(drafts) == 1: if len(drafts) == 1:
return drafts[0] return drafts[0]
select_prompt = ( select_prompt = (
@ -1025,6 +1059,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
chunk_lines=6, chunk_lines=6,
chunk_top=10, chunk_top=10,
chunk_group=4, chunk_group=4,
parallelism=4,
use_tool=True, use_tool=True,
use_critic=True, use_critic=True,
use_gap=True, use_gap=True,
@ -1041,6 +1076,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
chunk_lines=8, chunk_lines=8,
chunk_top=8, chunk_top=8,
chunk_group=4, chunk_group=4,
parallelism=2,
use_tool=True, use_tool=True,
use_critic=True, use_critic=True,
use_gap=True, use_gap=True,
@ -1056,6 +1092,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
chunk_lines=12, chunk_lines=12,
chunk_top=5, chunk_top=5,
chunk_group=5, chunk_group=5,
parallelism=1,
use_tool=False, use_tool=False,
use_critic=False, use_critic=False,
use_gap=False, use_gap=False,
@ -1117,14 +1154,23 @@ async def _score_chunks(
scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks} scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks}
if not chunks: if not chunks:
return scores return scores
groups: list[list[dict[str, Any]]] = []
group: list[dict[str, Any]] = [] group: list[dict[str, Any]] = []
for chunk in chunks: for chunk in chunks:
group.append({"id": chunk["id"], "summary": chunk["summary"]}) group.append({"id": chunk["id"], "summary": chunk["summary"]})
if len(group) >= plan.chunk_group: if len(group) >= plan.chunk_group:
scores.update(await _score_chunk_group(call_llm, group, question, sub_questions)) groups.append(group)
group = [] group = []
if group: if group:
scores.update(await _score_chunk_group(call_llm, group, question, sub_questions)) groups.append(group)
if plan.parallelism <= 1 or len(groups) <= 1:
for grp in groups:
scores.update(await _score_chunk_group(call_llm, grp, question, sub_questions))
return scores
coros = [_score_chunk_group(call_llm, grp, question, sub_questions) for grp in groups]
results = await _gather_limited(coros, plan.parallelism)
for result in results:
scores.update(result)
return scores return scores
@ -1316,9 +1362,11 @@ async def _select_metric_chunks(
if available_keys: if available_keys:
missing = await _validate_metric_keys( missing = await _validate_metric_keys(
call_llm, call_llm,
question, {
sub_questions, "question": question,
selected, "sub_questions": sub_questions,
"selected": selected,
},
available_keys, available_keys,
plan, plan,
) )
@ -1332,14 +1380,15 @@ async def _select_metric_chunks(
async def _validate_metric_keys( async def _validate_metric_keys(
call_llm: Callable[..., Awaitable[str]], call_llm: Callable[..., Awaitable[str]],
question: str, ctx: dict[str, Any],
sub_questions: list[str],
selected: list[str],
available: list[str], available: list[str],
plan: ModePlan, plan: ModePlan,
) -> list[str]: ) -> list[str]:
if not available: if not available:
return [] return []
question = str(ctx.get("question") or "")
sub_questions = ctx.get("sub_questions") if isinstance(ctx.get("sub_questions"), list) else []
selected = ctx.get("selected") if isinstance(ctx.get("selected"), list) else []
cap = max(12, plan.max_subquestions * 4) cap = max(12, plan.max_subquestions * 4)
available_list = available[:cap] available_list = available[:cap]
prompt = prompts.METRIC_KEYS_VALIDATE_PROMPT.format( prompt = prompts.METRIC_KEYS_VALIDATE_PROMPT.format(
@ -1366,6 +1415,18 @@ async def _validate_metric_keys(
return out return out
async def _gather_limited(coros: list[Awaitable[Any]], limit: int) -> list[Any]:
if not coros:
return []
semaphore = asyncio.Semaphore(max(1, limit))
async def runner(coro: Awaitable[Any]) -> Any:
async with semaphore:
return await coro
return await asyncio.gather(*(runner(coro) for coro in coros))
def _metric_ctx_values(ctx: dict[str, Any]) -> tuple[list[str], str, list[str], list[str], set[str]]: def _metric_ctx_values(ctx: dict[str, Any]) -> tuple[list[str], str, list[str], list[str], set[str]]:
summary_lines = ctx.get("summary_lines") if isinstance(ctx, dict) else None summary_lines = ctx.get("summary_lines") if isinstance(ctx, dict) else None
if not isinstance(summary_lines, list): if not isinstance(summary_lines, list):