atlasbot: parallelize independent llm calls
This commit is contained in:
parent
5dee72cd1a
commit
68d229032d
@ -2,6 +2,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import difflib
|
||||
@ -75,6 +76,7 @@ class ModePlan:
|
||||
chunk_lines: int
|
||||
chunk_top: int
|
||||
chunk_group: int
|
||||
parallelism: int
|
||||
use_tool: bool
|
||||
use_critic: bool
|
||||
use_gap: bool
|
||||
@ -320,6 +322,7 @@ class AnswerEngine:
|
||||
chunks = _chunk_lines(summary_lines, plan.chunk_lines)
|
||||
metric_keys: list[str] = []
|
||||
must_chunk_ids: list[str] = []
|
||||
metric_task = None
|
||||
if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines:
|
||||
metric_ctx = {
|
||||
"question": normalized,
|
||||
@ -328,13 +331,11 @@ class AnswerEngine:
|
||||
"keyword_tokens": keyword_tokens,
|
||||
"summary_lines": summary_lines,
|
||||
}
|
||||
metric_keys, must_chunk_ids = await _select_metric_chunks(
|
||||
call_llm,
|
||||
metric_ctx,
|
||||
chunks,
|
||||
plan,
|
||||
)
|
||||
scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan)
|
||||
metric_task = asyncio.create_task(_select_metric_chunks(call_llm, metric_ctx, chunks, plan))
|
||||
scored_task = asyncio.create_task(_score_chunks(call_llm, chunks, normalized, sub_questions, plan))
|
||||
if metric_task:
|
||||
metric_keys, must_chunk_ids = await metric_task
|
||||
scored = await scored_task
|
||||
selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids)
|
||||
fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12)
|
||||
key_facts = await _select_fact_lines(
|
||||
@ -474,30 +475,40 @@ class AnswerEngine:
|
||||
if observer:
|
||||
observer("subanswers", "drafting subanswers")
|
||||
subanswers: list[str] = []
|
||||
for subq in sub_questions:
|
||||
async def _subanswer_for(subq: str) -> str:
|
||||
sub_prompt = prompts.SUBANSWER_PROMPT + "\nQuestion: " + subq
|
||||
if plan.subanswer_retries > 1:
|
||||
candidates: list[str] = []
|
||||
for _ in range(plan.subanswer_retries):
|
||||
candidate = await call_llm(
|
||||
prompts.ANSWER_SYSTEM,
|
||||
sub_prompt,
|
||||
context=context,
|
||||
model=plan.model,
|
||||
tag="subanswer",
|
||||
)
|
||||
candidates.append(candidate)
|
||||
best_idx = await _select_best_candidate(call_llm, subq, candidates, plan, "subanswer_select")
|
||||
subanswers.append(candidates[best_idx])
|
||||
else:
|
||||
sub_answer = await call_llm(
|
||||
prompts.ANSWER_SYSTEM,
|
||||
sub_prompt,
|
||||
context=context,
|
||||
model=plan.model,
|
||||
tag="subanswer",
|
||||
candidates = await _gather_limited(
|
||||
[
|
||||
call_llm(
|
||||
prompts.ANSWER_SYSTEM,
|
||||
sub_prompt,
|
||||
context=context,
|
||||
model=plan.model,
|
||||
tag="subanswer",
|
||||
)
|
||||
for _ in range(plan.subanswer_retries)
|
||||
],
|
||||
plan.parallelism,
|
||||
)
|
||||
subanswers.append(sub_answer)
|
||||
best_idx = await _select_best_candidate(call_llm, subq, candidates, plan, "subanswer_select")
|
||||
return candidates[best_idx]
|
||||
return await call_llm(
|
||||
prompts.ANSWER_SYSTEM,
|
||||
sub_prompt,
|
||||
context=context,
|
||||
model=plan.model,
|
||||
tag="subanswer",
|
||||
)
|
||||
|
||||
if plan.parallelism > 1 and len(sub_questions) > 1:
|
||||
subanswers = await _gather_limited(
|
||||
[_subanswer_for(subq) for subq in sub_questions],
|
||||
plan.parallelism,
|
||||
)
|
||||
else:
|
||||
for subq in sub_questions:
|
||||
subanswers.append(await _subanswer_for(subq))
|
||||
|
||||
if observer:
|
||||
observer("synthesize", "synthesizing")
|
||||
@ -776,8 +787,31 @@ class AnswerEngine:
|
||||
+ f"\nDraftIndex: {idx + 1}"
|
||||
)
|
||||
drafts: list[str] = []
|
||||
for prompt in draft_prompts:
|
||||
drafts.append(await call_llm(prompts.SYNTHESIZE_SYSTEM, prompt, context=context, model=plan.model, tag="synth"))
|
||||
if plan.parallelism > 1 and len(draft_prompts) > 1:
|
||||
drafts = await _gather_limited(
|
||||
[
|
||||
call_llm(
|
||||
prompts.SYNTHESIZE_SYSTEM,
|
||||
prompt,
|
||||
context=context,
|
||||
model=plan.model,
|
||||
tag="synth",
|
||||
)
|
||||
for prompt in draft_prompts
|
||||
],
|
||||
plan.parallelism,
|
||||
)
|
||||
else:
|
||||
for prompt in draft_prompts:
|
||||
drafts.append(
|
||||
await call_llm(
|
||||
prompts.SYNTHESIZE_SYSTEM,
|
||||
prompt,
|
||||
context=context,
|
||||
model=plan.model,
|
||||
tag="synth",
|
||||
)
|
||||
)
|
||||
if len(drafts) == 1:
|
||||
return drafts[0]
|
||||
select_prompt = (
|
||||
@ -1025,6 +1059,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
||||
chunk_lines=6,
|
||||
chunk_top=10,
|
||||
chunk_group=4,
|
||||
parallelism=4,
|
||||
use_tool=True,
|
||||
use_critic=True,
|
||||
use_gap=True,
|
||||
@ -1041,6 +1076,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
||||
chunk_lines=8,
|
||||
chunk_top=8,
|
||||
chunk_group=4,
|
||||
parallelism=2,
|
||||
use_tool=True,
|
||||
use_critic=True,
|
||||
use_gap=True,
|
||||
@ -1056,6 +1092,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
||||
chunk_lines=12,
|
||||
chunk_top=5,
|
||||
chunk_group=5,
|
||||
parallelism=1,
|
||||
use_tool=False,
|
||||
use_critic=False,
|
||||
use_gap=False,
|
||||
@ -1117,14 +1154,23 @@ async def _score_chunks(
|
||||
scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks}
|
||||
if not chunks:
|
||||
return scores
|
||||
groups: list[list[dict[str, Any]]] = []
|
||||
group: list[dict[str, Any]] = []
|
||||
for chunk in chunks:
|
||||
group.append({"id": chunk["id"], "summary": chunk["summary"]})
|
||||
if len(group) >= plan.chunk_group:
|
||||
scores.update(await _score_chunk_group(call_llm, group, question, sub_questions))
|
||||
groups.append(group)
|
||||
group = []
|
||||
if group:
|
||||
scores.update(await _score_chunk_group(call_llm, group, question, sub_questions))
|
||||
groups.append(group)
|
||||
if plan.parallelism <= 1 or len(groups) <= 1:
|
||||
for grp in groups:
|
||||
scores.update(await _score_chunk_group(call_llm, grp, question, sub_questions))
|
||||
return scores
|
||||
coros = [_score_chunk_group(call_llm, grp, question, sub_questions) for grp in groups]
|
||||
results = await _gather_limited(coros, plan.parallelism)
|
||||
for result in results:
|
||||
scores.update(result)
|
||||
return scores
|
||||
|
||||
|
||||
@ -1316,9 +1362,11 @@ async def _select_metric_chunks(
|
||||
if available_keys:
|
||||
missing = await _validate_metric_keys(
|
||||
call_llm,
|
||||
question,
|
||||
sub_questions,
|
||||
selected,
|
||||
{
|
||||
"question": question,
|
||||
"sub_questions": sub_questions,
|
||||
"selected": selected,
|
||||
},
|
||||
available_keys,
|
||||
plan,
|
||||
)
|
||||
@ -1332,14 +1380,15 @@ async def _select_metric_chunks(
|
||||
|
||||
async def _validate_metric_keys(
|
||||
call_llm: Callable[..., Awaitable[str]],
|
||||
question: str,
|
||||
sub_questions: list[str],
|
||||
selected: list[str],
|
||||
ctx: dict[str, Any],
|
||||
available: list[str],
|
||||
plan: ModePlan,
|
||||
) -> list[str]:
|
||||
if not available:
|
||||
return []
|
||||
question = str(ctx.get("question") or "")
|
||||
sub_questions = ctx.get("sub_questions") if isinstance(ctx.get("sub_questions"), list) else []
|
||||
selected = ctx.get("selected") if isinstance(ctx.get("selected"), list) else []
|
||||
cap = max(12, plan.max_subquestions * 4)
|
||||
available_list = available[:cap]
|
||||
prompt = prompts.METRIC_KEYS_VALIDATE_PROMPT.format(
|
||||
@ -1366,6 +1415,18 @@ async def _validate_metric_keys(
|
||||
return out
|
||||
|
||||
|
||||
async def _gather_limited(coros: list[Awaitable[Any]], limit: int) -> list[Any]:
|
||||
if not coros:
|
||||
return []
|
||||
semaphore = asyncio.Semaphore(max(1, limit))
|
||||
|
||||
async def runner(coro: Awaitable[Any]) -> Any:
|
||||
async with semaphore:
|
||||
return await coro
|
||||
|
||||
return await asyncio.gather(*(runner(coro) for coro in coros))
|
||||
|
||||
|
||||
def _metric_ctx_values(ctx: dict[str, Any]) -> tuple[list[str], str, list[str], list[str], set[str]]:
|
||||
summary_lines = ctx.get("summary_lines") if isinstance(ctx, dict) else None
|
||||
if not isinstance(summary_lines, list):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user