diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index ae3efd2..ed4c89b 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -77,6 +77,7 @@ class ModePlan: chunk_top: int chunk_group: int parallelism: int + score_retries: int use_tool: bool use_critic: bool use_gap: bool @@ -86,6 +87,14 @@ class ModePlan: subanswer_retries: int +@dataclass +class ScoreContext: + question: str + sub_questions: list[str] + retries: int + parallelism: int + + class AnswerEngine: def __init__( self, @@ -1060,6 +1069,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan: chunk_top=10, chunk_group=4, parallelism=4, + score_retries=3, use_tool=True, use_critic=True, use_gap=True, @@ -1077,6 +1087,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan: chunk_top=8, chunk_group=4, parallelism=2, + score_retries=2, use_tool=True, use_critic=True, use_gap=True, @@ -1093,6 +1104,7 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan: chunk_top=5, chunk_group=5, parallelism=1, + score_retries=1, use_tool=False, use_critic=False, use_gap=False, @@ -1144,6 +1156,19 @@ def _chunk_lines(lines: list[str], lines_per_chunk: int) -> list[dict[str, Any]] return chunks +def _build_chunk_groups(chunks: list[dict[str, Any]], group_size: int) -> list[list[dict[str, Any]]]: + groups: list[list[dict[str, Any]]] = [] + group: list[dict[str, Any]] = [] + for chunk in chunks: + group.append({"id": chunk["id"], "summary": chunk["summary"]}) + if len(group) >= group_size: + groups.append(group) + group = [] + if group: + groups.append(group) + return groups + + async def _score_chunks( call_llm: Callable[..., Any], chunks: list[dict[str, Any]], @@ -1154,23 +1179,46 @@ async def _score_chunks( scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks} if not chunks: return scores - groups: list[list[dict[str, Any]]] = [] - group: list[dict[str, Any]] = [] - for chunk in chunks: - group.append({"id": chunk["id"], "summary": chunk["summary"]}) - if len(group) >= plan.chunk_group: - groups.append(group) - group = [] - if group: - groups.append(group) - if plan.parallelism <= 1 or len(groups) <= 1: - for grp in groups: - scores.update(await _score_chunk_group(call_llm, grp, question, sub_questions)) - return scores - coros = [_score_chunk_group(call_llm, grp, question, sub_questions) for grp in groups] - results = await _gather_limited(coros, plan.parallelism) - for result in results: - scores.update(result) + groups = _build_chunk_groups(chunks, plan.chunk_group) + ctx = ScoreContext( + question=question, + sub_questions=sub_questions, + retries=max(1, plan.score_retries), + parallelism=plan.parallelism, + ) + if ctx.parallelism <= 1 or len(groups) * ctx.retries <= 1: + return await _score_groups_serial(call_llm, groups, ctx) + return await _score_groups_parallel(call_llm, groups, ctx) + + +async def _score_groups_serial( + call_llm: Callable[..., Any], + groups: list[list[dict[str, Any]]], + ctx: ScoreContext, +) -> dict[str, float]: + scores: dict[str, float] = {} + for grp in groups: + runs = [await _score_chunk_group(call_llm, grp, ctx.question, ctx.sub_questions) for _ in range(ctx.retries)] + scores.update(_merge_score_runs(runs)) + return scores + + +async def _score_groups_parallel( + call_llm: Callable[..., Any], + groups: list[list[dict[str, Any]]], + ctx: ScoreContext, +) -> dict[str, float]: + coros: list[Awaitable[tuple[int, dict[str, float]]]] = [] + for idx, grp in enumerate(groups): + for _ in range(ctx.retries): + coros.append(_score_chunk_group_run(call_llm, idx, grp, ctx.question, ctx.sub_questions)) + results = await _gather_limited(coros, ctx.parallelism) + grouped: dict[int, list[dict[str, float]]] = {} + for idx, result in results: + grouped.setdefault(idx, []).append(result) + scores: dict[str, float] = {} + for runs in grouped.values(): + scores.update(_merge_score_runs(runs)) return scores @@ -1206,6 +1254,28 @@ async def _score_chunk_group( return scored +async def _score_chunk_group_run( + call_llm: Callable[..., Any], + idx: int, + group: list[dict[str, Any]], + question: str, + sub_questions: list[str], +) -> tuple[int, dict[str, float]]: + return idx, await _score_chunk_group(call_llm, group, question, sub_questions) + + +def _merge_score_runs(runs: list[dict[str, float]]) -> dict[str, float]: + if not runs: + return {} + totals: dict[str, float] = {} + counts: dict[str, int] = {} + for run in runs: + for key, value in run.items(): + totals[key] = totals.get(key, 0.0) + float(value) + counts[key] = counts.get(key, 0) + 1 + return {key: totals[key] / counts[key] for key in totals} + + def _keyword_hits( ranked: list[dict[str, Any]], head: dict[str, Any],