atlasbot: parallelize independent llm calls
This commit is contained in:
parent
5dee72cd1a
commit
68d229032d
@ -2,6 +2,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import difflib
|
import difflib
|
||||||
@ -75,6 +76,7 @@ class ModePlan:
|
|||||||
chunk_lines: int
|
chunk_lines: int
|
||||||
chunk_top: int
|
chunk_top: int
|
||||||
chunk_group: int
|
chunk_group: int
|
||||||
|
parallelism: int
|
||||||
use_tool: bool
|
use_tool: bool
|
||||||
use_critic: bool
|
use_critic: bool
|
||||||
use_gap: bool
|
use_gap: bool
|
||||||
@ -320,6 +322,7 @@ class AnswerEngine:
|
|||||||
chunks = _chunk_lines(summary_lines, plan.chunk_lines)
|
chunks = _chunk_lines(summary_lines, plan.chunk_lines)
|
||||||
metric_keys: list[str] = []
|
metric_keys: list[str] = []
|
||||||
must_chunk_ids: list[str] = []
|
must_chunk_ids: list[str] = []
|
||||||
|
metric_task = None
|
||||||
if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines:
|
if (classify.get("question_type") in {"metric", "diagnostic"} or force_metric) and summary_lines:
|
||||||
metric_ctx = {
|
metric_ctx = {
|
||||||
"question": normalized,
|
"question": normalized,
|
||||||
@ -328,13 +331,11 @@ class AnswerEngine:
|
|||||||
"keyword_tokens": keyword_tokens,
|
"keyword_tokens": keyword_tokens,
|
||||||
"summary_lines": summary_lines,
|
"summary_lines": summary_lines,
|
||||||
}
|
}
|
||||||
metric_keys, must_chunk_ids = await _select_metric_chunks(
|
metric_task = asyncio.create_task(_select_metric_chunks(call_llm, metric_ctx, chunks, plan))
|
||||||
call_llm,
|
scored_task = asyncio.create_task(_score_chunks(call_llm, chunks, normalized, sub_questions, plan))
|
||||||
metric_ctx,
|
if metric_task:
|
||||||
chunks,
|
metric_keys, must_chunk_ids = await metric_task
|
||||||
plan,
|
scored = await scored_task
|
||||||
)
|
|
||||||
scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan)
|
|
||||||
selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids)
|
selected = _select_chunks(chunks, scored, plan, keyword_tokens, must_chunk_ids)
|
||||||
fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12)
|
fact_candidates = _collect_fact_candidates(selected, limit=plan.max_subquestions * 12)
|
||||||
key_facts = await _select_fact_lines(
|
key_facts = await _select_fact_lines(
|
||||||
@ -474,30 +475,40 @@ class AnswerEngine:
|
|||||||
if observer:
|
if observer:
|
||||||
observer("subanswers", "drafting subanswers")
|
observer("subanswers", "drafting subanswers")
|
||||||
subanswers: list[str] = []
|
subanswers: list[str] = []
|
||||||
for subq in sub_questions:
|
async def _subanswer_for(subq: str) -> str:
|
||||||
sub_prompt = prompts.SUBANSWER_PROMPT + "\nQuestion: " + subq
|
sub_prompt = prompts.SUBANSWER_PROMPT + "\nQuestion: " + subq
|
||||||
if plan.subanswer_retries > 1:
|
if plan.subanswer_retries > 1:
|
||||||
candidates: list[str] = []
|
candidates = await _gather_limited(
|
||||||
for _ in range(plan.subanswer_retries):
|
[
|
||||||
candidate = await call_llm(
|
call_llm(
|
||||||
prompts.ANSWER_SYSTEM,
|
prompts.ANSWER_SYSTEM,
|
||||||
sub_prompt,
|
sub_prompt,
|
||||||
context=context,
|
context=context,
|
||||||
model=plan.model,
|
model=plan.model,
|
||||||
tag="subanswer",
|
tag="subanswer",
|
||||||
)
|
)
|
||||||
candidates.append(candidate)
|
for _ in range(plan.subanswer_retries)
|
||||||
|
],
|
||||||
|
plan.parallelism,
|
||||||
|
)
|
||||||
best_idx = await _select_best_candidate(call_llm, subq, candidates, plan, "subanswer_select")
|
best_idx = await _select_best_candidate(call_llm, subq, candidates, plan, "subanswer_select")
|
||||||
subanswers.append(candidates[best_idx])
|
return candidates[best_idx]
|
||||||
else:
|
return await call_llm(
|
||||||
sub_answer = await call_llm(
|
|
||||||
prompts.ANSWER_SYSTEM,
|
prompts.ANSWER_SYSTEM,
|
||||||
sub_prompt,
|
sub_prompt,
|
||||||
context=context,
|
context=context,
|
||||||
model=plan.model,
|
model=plan.model,
|
||||||
tag="subanswer",
|
tag="subanswer",
|
||||||
)
|
)
|
||||||
subanswers.append(sub_answer)
|
|
||||||
|
if plan.parallelism > 1 and len(sub_questions) > 1:
|
||||||
|
subanswers = await _gather_limited(
|
||||||
|
[_subanswer_for(subq) for subq in sub_questions],
|
||||||
|
plan.parallelism,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for subq in sub_questions:
|
||||||
|
subanswers.append(await _subanswer_for(subq))
|
||||||
|
|
||||||
if observer:
|
if observer:
|
||||||
observer("synthesize", "synthesizing")
|
observer("synthesize", "synthesizing")
|
||||||
@ -776,8 +787,31 @@ class AnswerEngine:
|
|||||||
+ f"\nDraftIndex: {idx + 1}"
|
+ f"\nDraftIndex: {idx + 1}"
|
||||||
)
|
)
|
||||||
drafts: list[str] = []
|
drafts: list[str] = []
|
||||||
|
if plan.parallelism > 1 and len(draft_prompts) > 1:
|
||||||
|
drafts = await _gather_limited(
|
||||||
|
[
|
||||||
|
call_llm(
|
||||||
|
prompts.SYNTHESIZE_SYSTEM,
|
||||||
|
prompt,
|
||||||
|
context=context,
|
||||||
|
model=plan.model,
|
||||||
|
tag="synth",
|
||||||
|
)
|
||||||
|
for prompt in draft_prompts
|
||||||
|
],
|
||||||
|
plan.parallelism,
|
||||||
|
)
|
||||||
|
else:
|
||||||
for prompt in draft_prompts:
|
for prompt in draft_prompts:
|
||||||
drafts.append(await call_llm(prompts.SYNTHESIZE_SYSTEM, prompt, context=context, model=plan.model, tag="synth"))
|
drafts.append(
|
||||||
|
await call_llm(
|
||||||
|
prompts.SYNTHESIZE_SYSTEM,
|
||||||
|
prompt,
|
||||||
|
context=context,
|
||||||
|
model=plan.model,
|
||||||
|
tag="synth",
|
||||||
|
)
|
||||||
|
)
|
||||||
if len(drafts) == 1:
|
if len(drafts) == 1:
|
||||||
return drafts[0]
|
return drafts[0]
|
||||||
select_prompt = (
|
select_prompt = (
|
||||||
@ -1025,6 +1059,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
|||||||
chunk_lines=6,
|
chunk_lines=6,
|
||||||
chunk_top=10,
|
chunk_top=10,
|
||||||
chunk_group=4,
|
chunk_group=4,
|
||||||
|
parallelism=4,
|
||||||
use_tool=True,
|
use_tool=True,
|
||||||
use_critic=True,
|
use_critic=True,
|
||||||
use_gap=True,
|
use_gap=True,
|
||||||
@ -1041,6 +1076,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
|||||||
chunk_lines=8,
|
chunk_lines=8,
|
||||||
chunk_top=8,
|
chunk_top=8,
|
||||||
chunk_group=4,
|
chunk_group=4,
|
||||||
|
parallelism=2,
|
||||||
use_tool=True,
|
use_tool=True,
|
||||||
use_critic=True,
|
use_critic=True,
|
||||||
use_gap=True,
|
use_gap=True,
|
||||||
@ -1056,6 +1092,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
|||||||
chunk_lines=12,
|
chunk_lines=12,
|
||||||
chunk_top=5,
|
chunk_top=5,
|
||||||
chunk_group=5,
|
chunk_group=5,
|
||||||
|
parallelism=1,
|
||||||
use_tool=False,
|
use_tool=False,
|
||||||
use_critic=False,
|
use_critic=False,
|
||||||
use_gap=False,
|
use_gap=False,
|
||||||
@ -1117,14 +1154,23 @@ async def _score_chunks(
|
|||||||
scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks}
|
scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks}
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return scores
|
return scores
|
||||||
|
groups: list[list[dict[str, Any]]] = []
|
||||||
group: list[dict[str, Any]] = []
|
group: list[dict[str, Any]] = []
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
group.append({"id": chunk["id"], "summary": chunk["summary"]})
|
group.append({"id": chunk["id"], "summary": chunk["summary"]})
|
||||||
if len(group) >= plan.chunk_group:
|
if len(group) >= plan.chunk_group:
|
||||||
scores.update(await _score_chunk_group(call_llm, group, question, sub_questions))
|
groups.append(group)
|
||||||
group = []
|
group = []
|
||||||
if group:
|
if group:
|
||||||
scores.update(await _score_chunk_group(call_llm, group, question, sub_questions))
|
groups.append(group)
|
||||||
|
if plan.parallelism <= 1 or len(groups) <= 1:
|
||||||
|
for grp in groups:
|
||||||
|
scores.update(await _score_chunk_group(call_llm, grp, question, sub_questions))
|
||||||
|
return scores
|
||||||
|
coros = [_score_chunk_group(call_llm, grp, question, sub_questions) for grp in groups]
|
||||||
|
results = await _gather_limited(coros, plan.parallelism)
|
||||||
|
for result in results:
|
||||||
|
scores.update(result)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@ -1316,9 +1362,11 @@ async def _select_metric_chunks(
|
|||||||
if available_keys:
|
if available_keys:
|
||||||
missing = await _validate_metric_keys(
|
missing = await _validate_metric_keys(
|
||||||
call_llm,
|
call_llm,
|
||||||
question,
|
{
|
||||||
sub_questions,
|
"question": question,
|
||||||
selected,
|
"sub_questions": sub_questions,
|
||||||
|
"selected": selected,
|
||||||
|
},
|
||||||
available_keys,
|
available_keys,
|
||||||
plan,
|
plan,
|
||||||
)
|
)
|
||||||
@ -1332,14 +1380,15 @@ async def _select_metric_chunks(
|
|||||||
|
|
||||||
async def _validate_metric_keys(
|
async def _validate_metric_keys(
|
||||||
call_llm: Callable[..., Awaitable[str]],
|
call_llm: Callable[..., Awaitable[str]],
|
||||||
question: str,
|
ctx: dict[str, Any],
|
||||||
sub_questions: list[str],
|
|
||||||
selected: list[str],
|
|
||||||
available: list[str],
|
available: list[str],
|
||||||
plan: ModePlan,
|
plan: ModePlan,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
if not available:
|
if not available:
|
||||||
return []
|
return []
|
||||||
|
question = str(ctx.get("question") or "")
|
||||||
|
sub_questions = ctx.get("sub_questions") if isinstance(ctx.get("sub_questions"), list) else []
|
||||||
|
selected = ctx.get("selected") if isinstance(ctx.get("selected"), list) else []
|
||||||
cap = max(12, plan.max_subquestions * 4)
|
cap = max(12, plan.max_subquestions * 4)
|
||||||
available_list = available[:cap]
|
available_list = available[:cap]
|
||||||
prompt = prompts.METRIC_KEYS_VALIDATE_PROMPT.format(
|
prompt = prompts.METRIC_KEYS_VALIDATE_PROMPT.format(
|
||||||
@ -1366,6 +1415,18 @@ async def _validate_metric_keys(
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
async def _gather_limited(coros: list[Awaitable[Any]], limit: int) -> list[Any]:
|
||||||
|
if not coros:
|
||||||
|
return []
|
||||||
|
semaphore = asyncio.Semaphore(max(1, limit))
|
||||||
|
|
||||||
|
async def runner(coro: Awaitable[Any]) -> Any:
|
||||||
|
async with semaphore:
|
||||||
|
return await coro
|
||||||
|
|
||||||
|
return await asyncio.gather(*(runner(coro) for coro in coros))
|
||||||
|
|
||||||
|
|
||||||
def _metric_ctx_values(ctx: dict[str, Any]) -> tuple[list[str], str, list[str], list[str], set[str]]:
|
def _metric_ctx_values(ctx: dict[str, Any]) -> tuple[list[str], str, list[str], list[str], set[str]]:
|
||||||
summary_lines = ctx.get("summary_lines") if isinstance(ctx, dict) else None
|
summary_lines = ctx.get("summary_lines") if isinstance(ctx, dict) else None
|
||||||
if not isinstance(summary_lines, list):
|
if not isinstance(summary_lines, list):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user