atlasbot: retry chunk scoring

This commit is contained in:
Brad Stein 2026-02-03 23:39:17 -03:00
parent 79f15eaa7a
commit 9716755fdb

View File

@ -77,6 +77,7 @@ class ModePlan:
chunk_top: int
chunk_group: int
parallelism: int
score_retries: int
use_tool: bool
use_critic: bool
use_gap: bool
@ -86,6 +87,14 @@ class ModePlan:
subanswer_retries: int
@dataclass
class ScoreContext:
question: str
sub_questions: list[str]
retries: int
parallelism: int
class AnswerEngine:
def __init__(
self,
@ -1060,6 +1069,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
chunk_top=10,
chunk_group=4,
parallelism=4,
score_retries=3,
use_tool=True,
use_critic=True,
use_gap=True,
@ -1077,6 +1087,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
chunk_top=8,
chunk_group=4,
parallelism=2,
score_retries=2,
use_tool=True,
use_critic=True,
use_gap=True,
@ -1093,6 +1104,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
chunk_top=5,
chunk_group=5,
parallelism=1,
score_retries=1,
use_tool=False,
use_critic=False,
use_gap=False,
@ -1144,6 +1156,19 @@ def _chunk_lines(lines: list[str], lines_per_chunk: int) -> list[dict[str, Any]]
return chunks
def _build_chunk_groups(chunks: list[dict[str, Any]], group_size: int) -> list[list[dict[str, Any]]]:
groups: list[list[dict[str, Any]]] = []
group: list[dict[str, Any]] = []
for chunk in chunks:
group.append({"id": chunk["id"], "summary": chunk["summary"]})
if len(group) >= group_size:
groups.append(group)
group = []
if group:
groups.append(group)
return groups
async def _score_chunks(
call_llm: Callable[..., Any],
chunks: list[dict[str, Any]],
@ -1154,23 +1179,46 @@ async def _score_chunks(
scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks}
if not chunks:
return scores
groups: list[list[dict[str, Any]]] = []
group: list[dict[str, Any]] = []
for chunk in chunks:
group.append({"id": chunk["id"], "summary": chunk["summary"]})
if len(group) >= plan.chunk_group:
groups.append(group)
group = []
if group:
groups.append(group)
if plan.parallelism <= 1 or len(groups) <= 1:
groups = _build_chunk_groups(chunks, plan.chunk_group)
ctx = ScoreContext(
question=question,
sub_questions=sub_questions,
retries=max(1, plan.score_retries),
parallelism=plan.parallelism,
)
if ctx.parallelism <= 1 or len(groups) * ctx.retries <= 1:
return await _score_groups_serial(call_llm, groups, ctx)
return await _score_groups_parallel(call_llm, groups, ctx)
async def _score_groups_serial(
call_llm: Callable[..., Any],
groups: list[list[dict[str, Any]]],
ctx: ScoreContext,
) -> dict[str, float]:
scores: dict[str, float] = {}
for grp in groups:
scores.update(await _score_chunk_group(call_llm, grp, question, sub_questions))
runs = [await _score_chunk_group(call_llm, grp, ctx.question, ctx.sub_questions) for _ in range(ctx.retries)]
scores.update(_merge_score_runs(runs))
return scores
coros = [_score_chunk_group(call_llm, grp, question, sub_questions) for grp in groups]
results = await _gather_limited(coros, plan.parallelism)
for result in results:
scores.update(result)
async def _score_groups_parallel(
call_llm: Callable[..., Any],
groups: list[list[dict[str, Any]]],
ctx: ScoreContext,
) -> dict[str, float]:
coros: list[Awaitable[tuple[int, dict[str, float]]]] = []
for idx, grp in enumerate(groups):
for _ in range(ctx.retries):
coros.append(_score_chunk_group_run(call_llm, idx, grp, ctx.question, ctx.sub_questions))
results = await _gather_limited(coros, ctx.parallelism)
grouped: dict[int, list[dict[str, float]]] = {}
for idx, result in results:
grouped.setdefault(idx, []).append(result)
scores: dict[str, float] = {}
for runs in grouped.values():
scores.update(_merge_score_runs(runs))
return scores
@ -1206,6 +1254,28 @@ async def _score_chunk_group(
return scored
async def _score_chunk_group_run(
call_llm: Callable[..., Any],
idx: int,
group: list[dict[str, Any]],
question: str,
sub_questions: list[str],
) -> tuple[int, dict[str, float]]:
return idx, await _score_chunk_group(call_llm, group, question, sub_questions)
def _merge_score_runs(runs: list[dict[str, float]]) -> dict[str, float]:
if not runs:
return {}
totals: dict[str, float] = {}
counts: dict[str, int] = {}
for run in runs:
for key, value in run.items():
totals[key] = totals.get(key, 0.0) + float(value)
counts[key] = counts.get(key, 0) + 1
return {key: totals[key] / counts[key] for key in totals}
def _keyword_hits(
ranked: list[dict[str, Any]],
head: dict[str, Any],