atlasbot/atlasbot/engine/answerer/retrieval_ext.py

238 lines
8.1 KiB
Python

from __future__ import annotations
import re
from collections.abc import Callable
from typing import Any
from atlasbot.llm import prompts
from atlasbot.llm.client import parse_json
from ._base import *
def _parse_json_block(text: str, *, fallback: dict[str, Any]) -> dict[str, Any]:
raw = text.strip()
match = re.search(r"\{.*\}", raw, flags=re.S)
if match:
return parse_json(match.group(0), fallback=fallback)
return parse_json(raw, fallback=fallback)
def _metric_key_tokens(summary_lines: list[str]) -> set[str]:
tokens: set[str] = set()
for line in summary_lines:
if not isinstance(line, str) or ":" not in line:
continue
key = line.split(":", 1)[0].strip().lower()
if not key:
continue
tokens.add(key)
for part in re.split(r"[_\s]+", key):
if part:
tokens.add(part)
return tokens
async def _select_best_candidate(
call_llm: Callable[..., Any],
question: str,
candidates: list[str],
plan: ModePlan,
tag: str,
) -> int:
if len(candidates) <= 1:
return 0
prompt = (
prompts.CANDIDATE_SELECT_PROMPT
+ "\nQuestion: "
+ question
+ "\nCandidates:\n"
+ "\n".join([f"{idx+1}) {cand}" for idx, cand in enumerate(candidates)])
)
raw = await call_llm(prompts.CANDIDATE_SELECT_SYSTEM, prompt, model=plan.model, tag=tag)
data = _parse_json_block(raw, fallback={})
best = data.get("best") if isinstance(data, dict) else None
if isinstance(best, int) and 1 <= best <= len(candidates):
return best - 1
return 0
def _dedupe_lines(lines: list[str], limit: int | None = None) -> list[str]:
seen: set[str] = set()
cleaned: list[str] = []
for line in lines:
value = (line or "").strip()
if not value or value in seen:
continue
if value.lower().startswith("lexicon_") or value.lower().startswith("units:"):
continue
cleaned.append(value)
seen.add(value)
if limit and len(cleaned) >= limit:
break
return cleaned
def _collect_fact_candidates(selected: list[dict[str, Any]], limit: int) -> list[str]:
lines: list[str] = []
for chunk in selected:
text = chunk.get("text") if isinstance(chunk, dict) else None
if not isinstance(text, str):
continue
lines.extend([line for line in text.splitlines() if line.strip()])
return _dedupe_lines(lines, limit=limit)
async def _select_best_list(
call_llm: Callable[..., Any],
question: str,
candidates: list[list[str]],
plan: ModePlan,
tag: str,
) -> list[str]:
if not candidates:
return []
if len(candidates) == 1:
return candidates[0]
render = ["; ".join(items) for items in candidates]
best_idx = await _select_best_candidate(call_llm, question, render, plan, tag)
chosen = candidates[best_idx] if 0 <= best_idx < len(candidates) else candidates[0]
if not chosen:
merged: list[str] = []
for entry in candidates:
for item in entry:
if item not in merged:
merged.append(item)
chosen = merged
return chosen
async def _extract_fact_types(
call_llm: Callable[..., Any],
question: str,
keywords: list[str],
plan: ModePlan,
) -> list[str]:
prompt = prompts.FACT_TYPES_PROMPT + "\nQuestion: " + question
if keywords:
prompt += "\nKeywords: " + ", ".join(keywords)
candidates: list[list[str]] = []
attempts = max(plan.metric_retries, 1)
for _ in range(attempts):
raw = await call_llm(prompts.FACT_TYPES_SYSTEM, prompt, model=plan.fast_model, tag="fact_types")
data = _parse_json_block(raw, fallback={})
items = data.get("fact_types") if isinstance(data, dict) else None
if not isinstance(items, list):
continue
cleaned = _dedupe_lines([str(item) for item in items if isinstance(item, (str, int, float))], limit=10)
if cleaned:
candidates.append(cleaned)
chosen = await _select_best_list(call_llm, question, candidates, plan, "fact_types_select")
return chosen[:10]
async def _derive_signals(
call_llm: Callable[..., Any],
question: str,
fact_types: list[str],
plan: ModePlan,
) -> list[str]:
if not fact_types:
return []
prompt = prompts.SIGNAL_PROMPT.format(question=question, fact_types="; ".join(fact_types))
candidates: list[list[str]] = []
attempts = max(plan.metric_retries, 1)
for _ in range(attempts):
raw = await call_llm(prompts.SIGNAL_SYSTEM, prompt, model=plan.fast_model, tag="signals")
data = _parse_json_block(raw, fallback={})
items = data.get("signals") if isinstance(data, dict) else None
if not isinstance(items, list):
continue
cleaned = _dedupe_lines([str(item) for item in items if isinstance(item, (str, int, float))], limit=12)
if cleaned:
candidates.append(cleaned)
chosen = await _select_best_list(call_llm, question, candidates, plan, "signals_select")
return chosen[:12]
async def _scan_chunk_for_signals(
call_llm: Callable[..., Any],
question: str,
signals: list[str],
chunk_lines: list[str],
plan: ModePlan,
) -> list[str]:
if not signals or not chunk_lines:
return []
prompt = prompts.CHUNK_SCAN_PROMPT.format(
signals="; ".join(signals),
lines="\n".join(chunk_lines),
)
attempts = max(1, min(plan.metric_retries, 2))
candidates: list[list[str]] = []
for _ in range(attempts):
raw = await call_llm(prompts.CHUNK_SCAN_SYSTEM, prompt, model=plan.fast_model, tag="chunk_scan")
data = _parse_json_block(raw, fallback={})
items = data.get("lines") if isinstance(data, dict) else None
if not isinstance(items, list):
continue
cleaned = [line for line in chunk_lines if line in items]
cleaned = _dedupe_lines(cleaned, limit=15)
if cleaned:
candidates.append(cleaned)
chosen = await _select_best_list(call_llm, question, candidates, plan, "chunk_scan_select")
return chosen[:15]
async def _prune_metric_candidates(
call_llm: Callable[..., Any],
question: str,
candidates: list[str],
plan: ModePlan,
attempts: int,
) -> list[str]:
if not candidates:
return []
prompt = prompts.FACT_PRUNE_PROMPT.format(question=question, candidates="\n".join(candidates), max_lines=6)
picks: list[list[str]] = []
for _ in range(max(attempts, 1)):
raw = await call_llm(prompts.FACT_PRUNE_SYSTEM, prompt, model=plan.fast_model, tag="fact_prune")
data = _parse_json_block(raw, fallback={})
items = data.get("lines") if isinstance(data, dict) else None
if not isinstance(items, list):
continue
cleaned = [line for line in candidates if line in items]
cleaned = _dedupe_lines(cleaned, limit=6)
if cleaned:
picks.append(cleaned)
chosen = await _select_best_list(call_llm, question, picks, plan, "fact_prune_select")
return chosen[:6]
async def _select_fact_lines(
call_llm: Callable[..., Any],
question: str,
candidates: list[str],
plan: ModePlan,
max_lines: int,
) -> list[str]:
if not candidates:
return []
prompt = prompts.FACT_PRUNE_PROMPT.format(question=question, candidates="\n".join(candidates), max_lines=max_lines)
picks: list[list[str]] = []
attempts = max(plan.metric_retries, 1)
for _ in range(attempts):
raw = await call_llm(prompts.FACT_PRUNE_SYSTEM, prompt, model=plan.fast_model, tag="fact_select")
data = _parse_json_block(raw, fallback={})
items = data.get("lines") if isinstance(data, dict) else None
if not isinstance(items, list):
continue
cleaned = [line for line in candidates if line in items]
cleaned = _dedupe_lines(cleaned, limit=max_lines)
if cleaned:
picks.append(cleaned)
chosen = await _select_best_list(call_llm, question, picks, plan, "fact_select_best")
return chosen[:max_lines]
__all__ = [name for name in globals() if name.startswith("_") and not name.startswith("__")]