atlasbot: retry chunk scoring
This commit is contained in:
parent
79f15eaa7a
commit
9716755fdb
@ -77,6 +77,7 @@ class ModePlan:
|
|||||||
chunk_top: int
|
chunk_top: int
|
||||||
chunk_group: int
|
chunk_group: int
|
||||||
parallelism: int
|
parallelism: int
|
||||||
|
score_retries: int
|
||||||
use_tool: bool
|
use_tool: bool
|
||||||
use_critic: bool
|
use_critic: bool
|
||||||
use_gap: bool
|
use_gap: bool
|
||||||
@ -86,6 +87,14 @@ class ModePlan:
|
|||||||
subanswer_retries: int
|
subanswer_retries: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoreContext:
|
||||||
|
question: str
|
||||||
|
sub_questions: list[str]
|
||||||
|
retries: int
|
||||||
|
parallelism: int
|
||||||
|
|
||||||
|
|
||||||
class AnswerEngine:
|
class AnswerEngine:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -1060,6 +1069,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
|||||||
chunk_top=10,
|
chunk_top=10,
|
||||||
chunk_group=4,
|
chunk_group=4,
|
||||||
parallelism=4,
|
parallelism=4,
|
||||||
|
score_retries=3,
|
||||||
use_tool=True,
|
use_tool=True,
|
||||||
use_critic=True,
|
use_critic=True,
|
||||||
use_gap=True,
|
use_gap=True,
|
||||||
@ -1077,6 +1087,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
|||||||
chunk_top=8,
|
chunk_top=8,
|
||||||
chunk_group=4,
|
chunk_group=4,
|
||||||
parallelism=2,
|
parallelism=2,
|
||||||
|
score_retries=2,
|
||||||
use_tool=True,
|
use_tool=True,
|
||||||
use_critic=True,
|
use_critic=True,
|
||||||
use_gap=True,
|
use_gap=True,
|
||||||
@ -1093,6 +1104,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
|||||||
chunk_top=5,
|
chunk_top=5,
|
||||||
chunk_group=5,
|
chunk_group=5,
|
||||||
parallelism=1,
|
parallelism=1,
|
||||||
|
score_retries=1,
|
||||||
use_tool=False,
|
use_tool=False,
|
||||||
use_critic=False,
|
use_critic=False,
|
||||||
use_gap=False,
|
use_gap=False,
|
||||||
@ -1144,6 +1156,19 @@ def _chunk_lines(lines: list[str], lines_per_chunk: int) -> list[dict[str, Any]]
|
|||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _build_chunk_groups(chunks: list[dict[str, Any]], group_size: int) -> list[list[dict[str, Any]]]:
|
||||||
|
groups: list[list[dict[str, Any]]] = []
|
||||||
|
group: list[dict[str, Any]] = []
|
||||||
|
for chunk in chunks:
|
||||||
|
group.append({"id": chunk["id"], "summary": chunk["summary"]})
|
||||||
|
if len(group) >= group_size:
|
||||||
|
groups.append(group)
|
||||||
|
group = []
|
||||||
|
if group:
|
||||||
|
groups.append(group)
|
||||||
|
return groups
|
||||||
|
|
||||||
|
|
||||||
async def _score_chunks(
|
async def _score_chunks(
|
||||||
call_llm: Callable[..., Any],
|
call_llm: Callable[..., Any],
|
||||||
chunks: list[dict[str, Any]],
|
chunks: list[dict[str, Any]],
|
||||||
@ -1154,23 +1179,46 @@ async def _score_chunks(
|
|||||||
scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks}
|
scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks}
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return scores
|
return scores
|
||||||
groups: list[list[dict[str, Any]]] = []
|
groups = _build_chunk_groups(chunks, plan.chunk_group)
|
||||||
group: list[dict[str, Any]] = []
|
ctx = ScoreContext(
|
||||||
for chunk in chunks:
|
question=question,
|
||||||
group.append({"id": chunk["id"], "summary": chunk["summary"]})
|
sub_questions=sub_questions,
|
||||||
if len(group) >= plan.chunk_group:
|
retries=max(1, plan.score_retries),
|
||||||
groups.append(group)
|
parallelism=plan.parallelism,
|
||||||
group = []
|
)
|
||||||
if group:
|
if ctx.parallelism <= 1 or len(groups) * ctx.retries <= 1:
|
||||||
groups.append(group)
|
return await _score_groups_serial(call_llm, groups, ctx)
|
||||||
if plan.parallelism <= 1 or len(groups) <= 1:
|
return await _score_groups_parallel(call_llm, groups, ctx)
|
||||||
|
|
||||||
|
|
||||||
|
async def _score_groups_serial(
|
||||||
|
call_llm: Callable[..., Any],
|
||||||
|
groups: list[list[dict[str, Any]]],
|
||||||
|
ctx: ScoreContext,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
scores: dict[str, float] = {}
|
||||||
for grp in groups:
|
for grp in groups:
|
||||||
scores.update(await _score_chunk_group(call_llm, grp, question, sub_questions))
|
runs = [await _score_chunk_group(call_llm, grp, ctx.question, ctx.sub_questions) for _ in range(ctx.retries)]
|
||||||
|
scores.update(_merge_score_runs(runs))
|
||||||
return scores
|
return scores
|
||||||
coros = [_score_chunk_group(call_llm, grp, question, sub_questions) for grp in groups]
|
|
||||||
results = await _gather_limited(coros, plan.parallelism)
|
|
||||||
for result in results:
|
async def _score_groups_parallel(
|
||||||
scores.update(result)
|
call_llm: Callable[..., Any],
|
||||||
|
groups: list[list[dict[str, Any]]],
|
||||||
|
ctx: ScoreContext,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
coros: list[Awaitable[tuple[int, dict[str, float]]]] = []
|
||||||
|
for idx, grp in enumerate(groups):
|
||||||
|
for _ in range(ctx.retries):
|
||||||
|
coros.append(_score_chunk_group_run(call_llm, idx, grp, ctx.question, ctx.sub_questions))
|
||||||
|
results = await _gather_limited(coros, ctx.parallelism)
|
||||||
|
grouped: dict[int, list[dict[str, float]]] = {}
|
||||||
|
for idx, result in results:
|
||||||
|
grouped.setdefault(idx, []).append(result)
|
||||||
|
scores: dict[str, float] = {}
|
||||||
|
for runs in grouped.values():
|
||||||
|
scores.update(_merge_score_runs(runs))
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@ -1206,6 +1254,28 @@ async def _score_chunk_group(
|
|||||||
return scored
|
return scored
|
||||||
|
|
||||||
|
|
||||||
|
async def _score_chunk_group_run(
|
||||||
|
call_llm: Callable[..., Any],
|
||||||
|
idx: int,
|
||||||
|
group: list[dict[str, Any]],
|
||||||
|
question: str,
|
||||||
|
sub_questions: list[str],
|
||||||
|
) -> tuple[int, dict[str, float]]:
|
||||||
|
return idx, await _score_chunk_group(call_llm, group, question, sub_questions)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_score_runs(runs: list[dict[str, float]]) -> dict[str, float]:
|
||||||
|
if not runs:
|
||||||
|
return {}
|
||||||
|
totals: dict[str, float] = {}
|
||||||
|
counts: dict[str, int] = {}
|
||||||
|
for run in runs:
|
||||||
|
for key, value in run.items():
|
||||||
|
totals[key] = totals.get(key, 0.0) + float(value)
|
||||||
|
counts[key] = counts.get(key, 0) + 1
|
||||||
|
return {key: totals[key] / counts[key] for key in totals}
|
||||||
|
|
||||||
|
|
||||||
def _keyword_hits(
|
def _keyword_hits(
|
||||||
ranked: list[dict[str, Any]],
|
ranked: list[dict[str, Any]],
|
||||||
head: dict[str, Any],
|
head: dict[str, Any],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user