diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index 2205043..55d811a 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -1179,30 +1179,60 @@ def _select_chunks( if not chunks: return [] ranked = sorted(chunks, key=lambda item: scores.get(item["id"], 0.0), reverse=True) - selected: list[dict[str, Any]] = [] - head = chunks[0] - selected.append(head) - id_map = {item["id"]: item for item in chunks} - if must_ids: - for cid in must_ids: - item = id_map.get(cid) - if item and item not in selected: - selected.append(item) - if len(selected) >= plan.chunk_top: - return selected + selected: list[dict[str, Any]] = [chunks[0]] + if _append_must_chunks(chunks, selected, must_ids, plan.chunk_top): + return selected + if _append_keyword_chunks(ranked, selected, keywords, plan.chunk_top): + return selected + _append_ranked_chunks(ranked, selected, plan.chunk_top) + return selected + +def _append_must_chunks( + chunks: list[dict[str, Any]], + selected: list[dict[str, Any]], + must_ids: list[str] | None, + limit: int, +) -> bool: + if not must_ids: + return False + id_map = {item["id"]: item for item in chunks} + for cid in must_ids: + item = id_map.get(cid) + if item and item not in selected: + selected.append(item) + if len(selected) >= limit: + return True + return False + + +def _append_keyword_chunks( + ranked: list[dict[str, Any]], + selected: list[dict[str, Any]], + keywords: list[str] | None, + limit: int, +) -> bool: + if not ranked: + return False + head = ranked[0] for item in _keyword_hits(ranked, head, keywords): - if len(selected) >= plan.chunk_top: - return selected if item not in selected: selected.append(item) + if len(selected) >= limit: + return True + return False + +def _append_ranked_chunks( + ranked: list[dict[str, Any]], + selected: list[dict[str, Any]], + limit: int, +) -> None: for item in ranked: - if len(selected) >= plan.chunk_top: + if len(selected) >= limit: break if item not in selected: selected.append(item) - return selected def _format_runbooks(runbooks: list[str]) -> str: