238 lines
8.1 KiB
Python
238 lines
8.1 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
from atlasbot.llm import prompts
|
|
from atlasbot.llm.client import parse_json
|
|
from ._base import *
|
|
|
|
|
|
def _parse_json_block(text: str, *, fallback: dict[str, Any]) -> dict[str, Any]:
|
|
raw = text.strip()
|
|
match = re.search(r"\{.*\}", raw, flags=re.S)
|
|
if match:
|
|
return parse_json(match.group(0), fallback=fallback)
|
|
return parse_json(raw, fallback=fallback)
|
|
|
|
|
|
def _metric_key_tokens(summary_lines: list[str]) -> set[str]:
|
|
tokens: set[str] = set()
|
|
for line in summary_lines:
|
|
if not isinstance(line, str) or ":" not in line:
|
|
continue
|
|
key = line.split(":", 1)[0].strip().lower()
|
|
if not key:
|
|
continue
|
|
tokens.add(key)
|
|
for part in re.split(r"[_\s]+", key):
|
|
if part:
|
|
tokens.add(part)
|
|
return tokens
|
|
|
|
|
|
async def _select_best_candidate(
|
|
call_llm: Callable[..., Any],
|
|
question: str,
|
|
candidates: list[str],
|
|
plan: ModePlan,
|
|
tag: str,
|
|
) -> int:
|
|
if len(candidates) <= 1:
|
|
return 0
|
|
prompt = (
|
|
prompts.CANDIDATE_SELECT_PROMPT
|
|
+ "\nQuestion: "
|
|
+ question
|
|
+ "\nCandidates:\n"
|
|
+ "\n".join([f"{idx+1}) {cand}" for idx, cand in enumerate(candidates)])
|
|
)
|
|
raw = await call_llm(prompts.CANDIDATE_SELECT_SYSTEM, prompt, model=plan.model, tag=tag)
|
|
data = _parse_json_block(raw, fallback={})
|
|
best = data.get("best") if isinstance(data, dict) else None
|
|
if isinstance(best, int) and 1 <= best <= len(candidates):
|
|
return best - 1
|
|
return 0
|
|
|
|
|
|
def _dedupe_lines(lines: list[str], limit: int | None = None) -> list[str]:
|
|
seen: set[str] = set()
|
|
cleaned: list[str] = []
|
|
for line in lines:
|
|
value = (line or "").strip()
|
|
if not value or value in seen:
|
|
continue
|
|
if value.lower().startswith("lexicon_") or value.lower().startswith("units:"):
|
|
continue
|
|
cleaned.append(value)
|
|
seen.add(value)
|
|
if limit and len(cleaned) >= limit:
|
|
break
|
|
return cleaned
|
|
|
|
|
|
def _collect_fact_candidates(selected: list[dict[str, Any]], limit: int) -> list[str]:
|
|
lines: list[str] = []
|
|
for chunk in selected:
|
|
text = chunk.get("text") if isinstance(chunk, dict) else None
|
|
if not isinstance(text, str):
|
|
continue
|
|
lines.extend([line for line in text.splitlines() if line.strip()])
|
|
return _dedupe_lines(lines, limit=limit)
|
|
|
|
|
|
async def _select_best_list(
|
|
call_llm: Callable[..., Any],
|
|
question: str,
|
|
candidates: list[list[str]],
|
|
plan: ModePlan,
|
|
tag: str,
|
|
) -> list[str]:
|
|
if not candidates:
|
|
return []
|
|
if len(candidates) == 1:
|
|
return candidates[0]
|
|
render = ["; ".join(items) for items in candidates]
|
|
best_idx = await _select_best_candidate(call_llm, question, render, plan, tag)
|
|
chosen = candidates[best_idx] if 0 <= best_idx < len(candidates) else candidates[0]
|
|
if not chosen:
|
|
merged: list[str] = []
|
|
for entry in candidates:
|
|
for item in entry:
|
|
if item not in merged:
|
|
merged.append(item)
|
|
chosen = merged
|
|
return chosen
|
|
|
|
|
|
async def _extract_fact_types(
|
|
call_llm: Callable[..., Any],
|
|
question: str,
|
|
keywords: list[str],
|
|
plan: ModePlan,
|
|
) -> list[str]:
|
|
prompt = prompts.FACT_TYPES_PROMPT + "\nQuestion: " + question
|
|
if keywords:
|
|
prompt += "\nKeywords: " + ", ".join(keywords)
|
|
candidates: list[list[str]] = []
|
|
attempts = max(plan.metric_retries, 1)
|
|
for _ in range(attempts):
|
|
raw = await call_llm(prompts.FACT_TYPES_SYSTEM, prompt, model=plan.fast_model, tag="fact_types")
|
|
data = _parse_json_block(raw, fallback={})
|
|
items = data.get("fact_types") if isinstance(data, dict) else None
|
|
if not isinstance(items, list):
|
|
continue
|
|
cleaned = _dedupe_lines([str(item) for item in items if isinstance(item, (str, int, float))], limit=10)
|
|
if cleaned:
|
|
candidates.append(cleaned)
|
|
chosen = await _select_best_list(call_llm, question, candidates, plan, "fact_types_select")
|
|
return chosen[:10]
|
|
|
|
|
|
async def _derive_signals(
|
|
call_llm: Callable[..., Any],
|
|
question: str,
|
|
fact_types: list[str],
|
|
plan: ModePlan,
|
|
) -> list[str]:
|
|
if not fact_types:
|
|
return []
|
|
prompt = prompts.SIGNAL_PROMPT.format(question=question, fact_types="; ".join(fact_types))
|
|
candidates: list[list[str]] = []
|
|
attempts = max(plan.metric_retries, 1)
|
|
for _ in range(attempts):
|
|
raw = await call_llm(prompts.SIGNAL_SYSTEM, prompt, model=plan.fast_model, tag="signals")
|
|
data = _parse_json_block(raw, fallback={})
|
|
items = data.get("signals") if isinstance(data, dict) else None
|
|
if not isinstance(items, list):
|
|
continue
|
|
cleaned = _dedupe_lines([str(item) for item in items if isinstance(item, (str, int, float))], limit=12)
|
|
if cleaned:
|
|
candidates.append(cleaned)
|
|
chosen = await _select_best_list(call_llm, question, candidates, plan, "signals_select")
|
|
return chosen[:12]
|
|
|
|
|
|
async def _scan_chunk_for_signals(
|
|
call_llm: Callable[..., Any],
|
|
question: str,
|
|
signals: list[str],
|
|
chunk_lines: list[str],
|
|
plan: ModePlan,
|
|
) -> list[str]:
|
|
if not signals or not chunk_lines:
|
|
return []
|
|
prompt = prompts.CHUNK_SCAN_PROMPT.format(
|
|
signals="; ".join(signals),
|
|
lines="\n".join(chunk_lines),
|
|
)
|
|
attempts = max(1, min(plan.metric_retries, 2))
|
|
candidates: list[list[str]] = []
|
|
for _ in range(attempts):
|
|
raw = await call_llm(prompts.CHUNK_SCAN_SYSTEM, prompt, model=plan.fast_model, tag="chunk_scan")
|
|
data = _parse_json_block(raw, fallback={})
|
|
items = data.get("lines") if isinstance(data, dict) else None
|
|
if not isinstance(items, list):
|
|
continue
|
|
cleaned = [line for line in chunk_lines if line in items]
|
|
cleaned = _dedupe_lines(cleaned, limit=15)
|
|
if cleaned:
|
|
candidates.append(cleaned)
|
|
chosen = await _select_best_list(call_llm, question, candidates, plan, "chunk_scan_select")
|
|
return chosen[:15]
|
|
|
|
|
|
async def _prune_metric_candidates(
|
|
call_llm: Callable[..., Any],
|
|
question: str,
|
|
candidates: list[str],
|
|
plan: ModePlan,
|
|
attempts: int,
|
|
) -> list[str]:
|
|
if not candidates:
|
|
return []
|
|
prompt = prompts.FACT_PRUNE_PROMPT.format(question=question, candidates="\n".join(candidates), max_lines=6)
|
|
picks: list[list[str]] = []
|
|
for _ in range(max(attempts, 1)):
|
|
raw = await call_llm(prompts.FACT_PRUNE_SYSTEM, prompt, model=plan.fast_model, tag="fact_prune")
|
|
data = _parse_json_block(raw, fallback={})
|
|
items = data.get("lines") if isinstance(data, dict) else None
|
|
if not isinstance(items, list):
|
|
continue
|
|
cleaned = [line for line in candidates if line in items]
|
|
cleaned = _dedupe_lines(cleaned, limit=6)
|
|
if cleaned:
|
|
picks.append(cleaned)
|
|
chosen = await _select_best_list(call_llm, question, picks, plan, "fact_prune_select")
|
|
return chosen[:6]
|
|
|
|
|
|
async def _select_fact_lines(
|
|
call_llm: Callable[..., Any],
|
|
question: str,
|
|
candidates: list[str],
|
|
plan: ModePlan,
|
|
max_lines: int,
|
|
) -> list[str]:
|
|
if not candidates:
|
|
return []
|
|
prompt = prompts.FACT_PRUNE_PROMPT.format(question=question, candidates="\n".join(candidates), max_lines=max_lines)
|
|
picks: list[list[str]] = []
|
|
attempts = max(plan.metric_retries, 1)
|
|
for _ in range(attempts):
|
|
raw = await call_llm(prompts.FACT_PRUNE_SYSTEM, prompt, model=plan.fast_model, tag="fact_select")
|
|
data = _parse_json_block(raw, fallback={})
|
|
items = data.get("lines") if isinstance(data, dict) else None
|
|
if not isinstance(items, list):
|
|
continue
|
|
cleaned = [line for line in candidates if line in items]
|
|
cleaned = _dedupe_lines(cleaned, limit=max_lines)
|
|
if cleaned:
|
|
picks.append(cleaned)
|
|
chosen = await _select_best_list(call_llm, question, picks, plan, "fact_select_best")
|
|
return chosen[:max_lines]
|
|
|
|
|
|
__all__ = [name for name in globals() if name.startswith("_") and not name.startswith("__")]
|