atlasbot: retry chunk scoring
This commit is contained in:
parent
79f15eaa7a
commit
9716755fdb
@ -77,6 +77,7 @@ class ModePlan:
|
||||
chunk_top: int
|
||||
chunk_group: int
|
||||
parallelism: int
|
||||
score_retries: int
|
||||
use_tool: bool
|
||||
use_critic: bool
|
||||
use_gap: bool
|
||||
@ -86,6 +87,14 @@ class ModePlan:
|
||||
subanswer_retries: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoreContext:
|
||||
question: str
|
||||
sub_questions: list[str]
|
||||
retries: int
|
||||
parallelism: int
|
||||
|
||||
|
||||
class AnswerEngine:
|
||||
def __init__(
|
||||
self,
|
||||
@ -1060,6 +1069,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
||||
chunk_top=10,
|
||||
chunk_group=4,
|
||||
parallelism=4,
|
||||
score_retries=3,
|
||||
use_tool=True,
|
||||
use_critic=True,
|
||||
use_gap=True,
|
||||
@ -1077,6 +1087,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
||||
chunk_top=8,
|
||||
chunk_group=4,
|
||||
parallelism=2,
|
||||
score_retries=2,
|
||||
use_tool=True,
|
||||
use_critic=True,
|
||||
use_gap=True,
|
||||
@ -1093,6 +1104,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
||||
chunk_top=5,
|
||||
chunk_group=5,
|
||||
parallelism=1,
|
||||
score_retries=1,
|
||||
use_tool=False,
|
||||
use_critic=False,
|
||||
use_gap=False,
|
||||
@ -1144,6 +1156,19 @@ def _chunk_lines(lines: list[str], lines_per_chunk: int) -> list[dict[str, Any]]
|
||||
return chunks
|
||||
|
||||
|
||||
def _build_chunk_groups(chunks: list[dict[str, Any]], group_size: int) -> list[list[dict[str, Any]]]:
|
||||
groups: list[list[dict[str, Any]]] = []
|
||||
group: list[dict[str, Any]] = []
|
||||
for chunk in chunks:
|
||||
group.append({"id": chunk["id"], "summary": chunk["summary"]})
|
||||
if len(group) >= group_size:
|
||||
groups.append(group)
|
||||
group = []
|
||||
if group:
|
||||
groups.append(group)
|
||||
return groups
|
||||
|
||||
|
||||
async def _score_chunks(
|
||||
call_llm: Callable[..., Any],
|
||||
chunks: list[dict[str, Any]],
|
||||
@ -1154,23 +1179,46 @@ async def _score_chunks(
|
||||
scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks}
|
||||
if not chunks:
|
||||
return scores
|
||||
groups: list[list[dict[str, Any]]] = []
|
||||
group: list[dict[str, Any]] = []
|
||||
for chunk in chunks:
|
||||
group.append({"id": chunk["id"], "summary": chunk["summary"]})
|
||||
if len(group) >= plan.chunk_group:
|
||||
groups.append(group)
|
||||
group = []
|
||||
if group:
|
||||
groups.append(group)
|
||||
if plan.parallelism <= 1 or len(groups) <= 1:
|
||||
groups = _build_chunk_groups(chunks, plan.chunk_group)
|
||||
ctx = ScoreContext(
|
||||
question=question,
|
||||
sub_questions=sub_questions,
|
||||
retries=max(1, plan.score_retries),
|
||||
parallelism=plan.parallelism,
|
||||
)
|
||||
if ctx.parallelism <= 1 or len(groups) * ctx.retries <= 1:
|
||||
return await _score_groups_serial(call_llm, groups, ctx)
|
||||
return await _score_groups_parallel(call_llm, groups, ctx)
|
||||
|
||||
|
||||
async def _score_groups_serial(
|
||||
call_llm: Callable[..., Any],
|
||||
groups: list[list[dict[str, Any]]],
|
||||
ctx: ScoreContext,
|
||||
) -> dict[str, float]:
|
||||
scores: dict[str, float] = {}
|
||||
for grp in groups:
|
||||
scores.update(await _score_chunk_group(call_llm, grp, question, sub_questions))
|
||||
runs = [await _score_chunk_group(call_llm, grp, ctx.question, ctx.sub_questions) for _ in range(ctx.retries)]
|
||||
scores.update(_merge_score_runs(runs))
|
||||
return scores
|
||||
coros = [_score_chunk_group(call_llm, grp, question, sub_questions) for grp in groups]
|
||||
results = await _gather_limited(coros, plan.parallelism)
|
||||
for result in results:
|
||||
scores.update(result)
|
||||
|
||||
|
||||
async def _score_groups_parallel(
|
||||
call_llm: Callable[..., Any],
|
||||
groups: list[list[dict[str, Any]]],
|
||||
ctx: ScoreContext,
|
||||
) -> dict[str, float]:
|
||||
coros: list[Awaitable[tuple[int, dict[str, float]]]] = []
|
||||
for idx, grp in enumerate(groups):
|
||||
for _ in range(ctx.retries):
|
||||
coros.append(_score_chunk_group_run(call_llm, idx, grp, ctx.question, ctx.sub_questions))
|
||||
results = await _gather_limited(coros, ctx.parallelism)
|
||||
grouped: dict[int, list[dict[str, float]]] = {}
|
||||
for idx, result in results:
|
||||
grouped.setdefault(idx, []).append(result)
|
||||
scores: dict[str, float] = {}
|
||||
for runs in grouped.values():
|
||||
scores.update(_merge_score_runs(runs))
|
||||
return scores
|
||||
|
||||
|
||||
@ -1206,6 +1254,28 @@ async def _score_chunk_group(
|
||||
return scored
|
||||
|
||||
|
||||
async def _score_chunk_group_run(
|
||||
call_llm: Callable[..., Any],
|
||||
idx: int,
|
||||
group: list[dict[str, Any]],
|
||||
question: str,
|
||||
sub_questions: list[str],
|
||||
) -> tuple[int, dict[str, float]]:
|
||||
return idx, await _score_chunk_group(call_llm, group, question, sub_questions)
|
||||
|
||||
|
||||
def _merge_score_runs(runs: list[dict[str, float]]) -> dict[str, float]:
|
||||
if not runs:
|
||||
return {}
|
||||
totals: dict[str, float] = {}
|
||||
counts: dict[str, int] = {}
|
||||
for run in runs:
|
||||
for key, value in run.items():
|
||||
totals[key] = totals.get(key, 0.0) + float(value)
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
return {key: totals[key] / counts[key] for key in totals}
|
||||
|
||||
|
||||
def _keyword_hits(
|
||||
ranked: list[dict[str, Any]],
|
||||
head: dict[str, Any],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user