diff --git a/atlasbot/config.py b/atlasbot/config.py index c53bf23..1b0316a 100644 --- a/atlasbot/config.py +++ b/atlasbot/config.py @@ -68,6 +68,10 @@ class Settings: fast_max_candidates: int smart_max_candidates: int genius_max_candidates: int + fast_llm_calls_max: int + smart_llm_calls_max: int + genius_llm_calls_max: int + llm_limit_multiplier: float @dataclass(frozen=True) @@ -156,4 +160,8 @@ def load_settings() -> Settings: fast_max_candidates=_env_int("ATLASBOT_FAST_MAX_CANDIDATES", "2"), smart_max_candidates=_env_int("ATLASBOT_SMART_MAX_CANDIDATES", "6"), genius_max_candidates=_env_int("ATLASBOT_GENIUS_MAX_CANDIDATES", "10"), + fast_llm_calls_max=_env_int("ATLASBOT_FAST_LLM_CALLS_MAX", "9"), + smart_llm_calls_max=_env_int("ATLASBOT_SMART_LLM_CALLS_MAX", "17"), + genius_llm_calls_max=_env_int("ATLASBOT_GENIUS_LLM_CALLS_MAX", "32"), + llm_limit_multiplier=_env_float("ATLASBOT_LLM_LIMIT_MULTIPLIER", "1.5"), ) diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index ad42da3..5906706 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import math import re import time from dataclasses import dataclass @@ -15,6 +16,10 @@ from atlasbot.snapshot.builder import SnapshotProvider, build_summary, summary_t log = logging.getLogger(__name__) +class LLMLimitReached(RuntimeError): + pass + + @dataclass class AnswerScores: confidence: int @@ -53,6 +58,21 @@ class ConversationState: snapshot: dict[str, Any] | None = None +@dataclass +class ModePlan: + model: str + fast_model: str + max_subquestions: int + chunk_lines: int + chunk_top: int + chunk_group: int + use_tool: bool + use_critic: bool + use_gap: bool + use_scores: bool + drafts: int + + class AnswerEngine: def __init__( self, @@ -82,96 +102,161 @@ class AnswerEngine: if mode == "stock": return await self._answer_stock(question) + limitless = "run limitless" in question.lower() + if limitless: + question = re.sub(r"(?i)run limitless", "", question).strip() + plan = _mode_plan(self._settings, mode) + call_limit = _llm_call_limit(self._settings, mode) + call_cap = math.ceil(call_limit * self._settings.llm_limit_multiplier) + call_count = 0 + limit_hit = False + + async def call_llm(system: str, prompt: str, *, context: str | None = None, model: str | None = None, tag: str = "") -> str: + nonlocal call_count, limit_hit + if not limitless and call_count >= call_cap: + limit_hit = True + raise LLMLimitReached("llm_limit") + call_count += 1 + messages = build_messages(system, prompt, context=context) + response = await self._llm.chat(messages, model=model or plan.model) + log.info( + "atlasbot_llm_call", + extra={"extra": {"mode": mode, "tag": tag, "call": call_count, "limit": call_cap}}, + ) + return response + state = self._get_state(conversation_id) snapshot = self._snapshot.get() snapshot_used = snapshot if self._settings.snapshot_pin_enabled and state and state.snapshot: snapshot_used = state.snapshot summary = build_summary(snapshot_used) - kb_summary = self._kb.summary() - runbooks = self._kb.runbook_titles(limit=4) summary_lines = _summary_lines(snapshot_used) - core_context = _build_context(summary_lines, question, {"answer_style": "direct"}, max_lines=40) - base_context = _join_context([ - kb_summary, - runbooks, - f"ClusterSnapshot:{core_context}" if core_context else "", - ]) + kb_summary = self._kb.summary() + runbooks = self._kb.runbook_titles(limit=6) + history_ctx = _format_history(history) started = time.monotonic() - if observer: - observer("classify", "classifying intent") - classify = await self._classify(question, base_context) - history_ctx = _format_history(history) - context_lines = _build_context(summary_lines, question, classify, max_lines=120) - base_context = _join_context([ - kb_summary, - runbooks, - f"ClusterSnapshot:{context_lines}" if context_lines else "", - ]) - if history_ctx and classify.get("follow_up"): - history_ctx = "ConversationHistory (non-authoritative, use only for phrasing):\n" + history_ctx - base_context = _join_context([base_context, history_ctx]) - log.info( - "atlasbot_context", - extra={ - "extra": { - "mode": mode, - "lines": len(context_lines.splitlines()) if context_lines else 0, - "chars": len(context_lines) if context_lines else 0, - } - }, - ) - log.info( - "atlasbot_classify", - extra={"extra": {"mode": mode, "elapsed_sec": round(time.monotonic() - started, 2), "classify": classify}}, - ) - if observer: - observer("angles", "drafting angles") - angles = await self._angles(question, classify, mode) - log.info( - "atlasbot_angles", - extra={"extra": {"mode": mode, "count": len(angles)}}, - ) - if observer: - observer("candidates", "drafting answers") - candidates = await self._candidates(question, angles, base_context, classify, mode) - log.info( - "atlasbot_candidates", - extra={"extra": {"mode": mode, "count": len(candidates)}}, - ) - if observer: - observer("select", "scoring candidates") - best, scores = await self._select_best(question, candidates) - log.info( - "atlasbot_selection", - extra={"extra": {"mode": mode, "selected": len(best), "scores": scores.__dict__}}, - ) - if classify.get("follow_up") and state and state.claims: + reply = "" + scores = _default_scores() + claims: list[ClaimItem] = [] + classify: dict[str, Any] = {} + tool_hint: dict[str, Any] | None = None + try: if observer: - observer("followup", "answering follow-up") - reply = await self._answer_followup(question, state, summary, classify, mode) - meta = { - "mode": mode, - "follow_up": True, - "classify": classify, - } - return AnswerResult(reply, scores, meta) + observer("normalize", "normalizing") + normalize_prompt = prompts.NORMALIZE_PROMPT + "\nQuestion: " + question + normalize_raw = await call_llm(prompts.NORMALIZE_SYSTEM, normalize_prompt, model=plan.fast_model, tag="normalize") + normalize = _parse_json_block(normalize_raw, fallback={"normalized": question, "keywords": []}) + normalized = str(normalize.get("normalized") or question).strip() or question + keywords = normalize.get("keywords") or [] + + if observer: + observer("route", "routing") + route_prompt = prompts.ROUTE_PROMPT + "\nQuestion: " + normalized + "\nKeywords: " + json.dumps(keywords) + route_raw = await call_llm(prompts.ROUTE_SYSTEM, route_prompt, context=kb_summary, model=plan.fast_model, tag="route") + classify = _parse_json_block(route_raw, fallback={}) + classify.setdefault("needs_snapshot", True) + classify.setdefault("answer_style", "direct") + classify.setdefault("follow_up", False) + + if classify.get("follow_up") and state and state.claims: + if observer: + observer("followup", "answering follow-up") + reply = await self._answer_followup(question, state, summary, classify, plan, call_llm) + scores = await self._score_answer(question, reply, plan, call_llm) + meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started) + return AnswerResult(reply, scores, meta) + + if observer: + observer("decompose", "decomposing") + decompose_prompt = prompts.DECOMPOSE_PROMPT.format(max_parts=plan.max_subquestions * 2) + decompose_raw = await call_llm( + prompts.DECOMPOSE_SYSTEM, + decompose_prompt + "\nQuestion: " + normalized, + model=plan.fast_model if mode == "quick" else plan.model, + tag="decompose", + ) + parts = _parse_json_list(decompose_raw) + sub_questions = _select_subquestions(parts, normalized, plan.max_subquestions) + + snapshot_context = "" + if classify.get("needs_snapshot"): + if observer: + observer("retrieve", "scoring chunks") + chunks = _chunk_lines(summary_lines, plan.chunk_lines) + scored = await _score_chunks(call_llm, chunks, normalized, sub_questions, plan) + selected = _select_chunks(chunks, scored, plan) + snapshot_context = "ClusterSnapshot:\n" + "\n".join([chunk["text"] for chunk in selected]) + + context = _join_context( + [kb_summary, _format_runbooks(runbooks), snapshot_context, history_ctx if classify.get("follow_up") else ""] + ) + + if plan.use_tool and classify.get("needs_tool"): + if observer: + observer("tool", "suggesting tools") + tool_prompt = prompts.TOOL_PROMPT + "\nQuestion: " + normalized + tool_raw = await call_llm(prompts.TOOL_SYSTEM, tool_prompt, context=context, model=plan.fast_model, tag="tool") + tool_hint = _parse_json_block(tool_raw, fallback={}) + + if observer: + observer("subanswers", "drafting subanswers") + subanswers: list[str] = [] + for subq in sub_questions: + sub_prompt = prompts.SUBANSWER_PROMPT + "\nQuestion: " + subq + sub_answer = await call_llm(prompts.ANSWER_SYSTEM, sub_prompt, context=context, model=plan.model, tag="subanswer") + subanswers.append(sub_answer) + + if observer: + observer("synthesize", "synthesizing") + reply = await self._synthesize_answer(normalized, subanswers, context, classify, plan, call_llm) + + if plan.use_critic: + if observer: + observer("critic", "reviewing") + critic_prompt = prompts.CRITIC_PROMPT + "\nQuestion: " + normalized + "\nAnswer: " + reply + critic_raw = await call_llm(prompts.CRITIC_SYSTEM, critic_prompt, context=context, model=plan.model, tag="critic") + critic = _parse_json_block(critic_raw, fallback={}) + if critic.get("issues"): + revise_prompt = ( + prompts.REVISION_PROMPT + + "\nQuestion: " + + normalized + + "\nDraft: " + + reply + + "\nCritique: " + + json.dumps(critic) + ) + reply = await call_llm(prompts.REVISION_SYSTEM, revise_prompt, context=context, model=plan.model, tag="revise") + + if plan.use_gap: + if observer: + observer("gap", "checking gaps") + gap_prompt = prompts.EVIDENCE_GAP_PROMPT + "\nQuestion: " + normalized + "\nAnswer: " + reply + gap_raw = await call_llm(prompts.GAP_SYSTEM, gap_prompt, context=context, model=plan.fast_model, tag="gap") + gap = _parse_json_block(gap_raw, fallback={}) + note = str(gap.get("note") or "").strip() + if note: + reply = f"{reply}\n\n{note}" + + scores = await self._score_answer(normalized, reply, plan, call_llm) + claims = await self._extract_claims(normalized, reply, summary, call_llm) + except LLMLimitReached: + if not reply: + reply = "I started working on this but hit my reasoning limit. Ask again with 'Run limitless' for a deeper pass." + scores = _default_scores() + finally: + elapsed = round(time.monotonic() - started, 2) + log.info( + "atlasbot_answer", + extra={"extra": {"mode": mode, "seconds": elapsed, "llm_calls": call_count, "limit": call_cap, "limit_hit": limit_hit}}, + ) - if observer: - observer("synthesize", "synthesizing reply") - reply = await self._synthesize(question, best, base_context, classify, mode) - claims = await self._extract_claims(question, reply, summary) if conversation_id and claims: self._store_state(conversation_id, claims, summary, snapshot_used) - meta = { - "mode": mode, - "angles": angles, - "scores": scores.__dict__, - "classify": classify, - "candidates": len(candidates), - "claims": len(claims), - } + + meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started) return AnswerResult(reply, scores, meta) async def _answer_stock(self, question: str) -> AnswerResult: @@ -179,122 +264,73 @@ class AnswerEngine: reply = await self._llm.chat(messages, model=self._settings.ollama_model) return AnswerResult(reply, _default_scores(), {"mode": "stock"}) - async def _classify(self, question: str, context: str) -> dict[str, Any]: - prompt = prompts.CLASSIFY_PROMPT + "\nQuestion: " + question - messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context) - raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast) - data = _parse_json_block(raw, fallback={"needs_snapshot": True}) - if "answer_style" not in data: - data["answer_style"] = "direct" - if "follow_up_kind" not in data: - data["follow_up_kind"] = "other" - return data - - async def _angles(self, question: str, classify: dict[str, Any], mode: str) -> list[dict[str, Any]]: - max_angles = _angles_limit(self._settings, mode) - prompt = prompts.ANGLE_PROMPT.format(max_angles=max_angles) + "\nQuestion: " + question - messages = build_messages(prompts.CLUSTER_SYSTEM, prompt) - raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast) - angles = _parse_json_list(raw) - if not angles: - return [{"name": "primary", "question": question, "relevance": 100}] - if classify.get("answer_style") == "insightful": - if not any("implication" in (a.get("name") or "").lower() for a in angles): - angles.append({"name": "implications", "question": f"What are the implications of the data for: {question}", "relevance": 85}) - return angles[:max_angles] - - async def _candidates( + async def _synthesize_answer( self, question: str, - angles: list[dict[str, Any]], + subanswers: list[str], context: str, classify: dict[str, Any], - mode: str, - ) -> list[dict[str, Any]]: - limit = _candidates_limit(self._settings, mode) - selected = angles[:limit] - tasks = [] - model = _candidate_model(self._settings, mode) - for angle in selected: - angle_q = angle.get("question") or question - prompt = prompts.CANDIDATE_PROMPT + "\nQuestion: " + angle_q - if classify.get("answer_style"): - prompt += f"\nAnswerStyle: {classify.get('answer_style')}" - messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context) - tasks.append(self._llm.chat(messages, model=model)) - replies = await asyncio.gather(*tasks) - candidates = [] - for angle, reply in zip(selected, replies, strict=False): - candidates.append({"angle": angle, "reply": reply}) - return candidates - - async def _select_best(self, question: str, candidates: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], AnswerScores]: - if not candidates: - return ([], _default_scores()) - scored: list[tuple[dict[str, Any], AnswerScores]] = [] - for entry in candidates: - prompt = prompts.SCORE_PROMPT + "\nQuestion: " + question + "\nAnswer: " + entry["reply"] - messages = build_messages(prompts.CLUSTER_SYSTEM, prompt) - raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast) - data = _parse_json_block(raw, fallback={}) - scores = _scores_from_json(data) - scored.append((entry, scores)) - scored.sort(key=lambda item: (item[1].relevance, item[1].confidence), reverse=True) - best = [entry for entry, _scores in scored[:3]] - return best, scored[0][1] - - async def _synthesize( - self, - question: str, - best: list[dict[str, Any]], - context: str, - classify: dict[str, Any], - mode: str, + plan: ModePlan, + call_llm: Callable[..., Any], ) -> str: - if not best: - return "I do not have enough information to answer that yet." - parts = [] - for item in best: - parts.append(f"- {item['reply']}") - style = classify.get("answer_style") if isinstance(classify, dict) else None - intent = classify.get("intent") if isinstance(classify, dict) else None - ambiguity = classify.get("ambiguity") if isinstance(classify, dict) else None - style_line = f"AnswerStyle: {style}" if style else "AnswerStyle: default" - if intent: - style_line = f"{style_line}; Intent: {intent}" - if ambiguity is not None: - style_line = f"{style_line}; Ambiguity: {ambiguity}" - prompt = ( - prompts.SYNTHESIZE_PROMPT - + "\n" - + style_line + if not subanswers: + prompt = prompts.SYNTHESIZE_PROMPT + "\nQuestion: " + question + return await call_llm(prompts.SYNTHESIZE_SYSTEM, prompt, context=context, model=plan.model, tag="synth") + draft_prompts = [] + for idx in range(plan.drafts): + draft_prompts.append( + prompts.SYNTHESIZE_PROMPT + + "\nQuestion: " + + question + + "\nSubanswers:\n" + + "\n".join([f"- {item}" for item in subanswers]) + + f"\nDraftIndex: {idx + 1}" + ) + drafts: list[str] = [] + for prompt in draft_prompts: + drafts.append(await call_llm(prompts.SYNTHESIZE_SYSTEM, prompt, context=context, model=plan.model, tag="synth")) + if len(drafts) == 1: + return drafts[0] + select_prompt = ( + prompts.DRAFT_SELECT_PROMPT + "\nQuestion: " + question - + "\nCandidate answers:\n" - + "\n".join(parts) + + "\nDrafts:\n" + + "\n\n".join([f"Draft {idx + 1}: {text}" for idx, text in enumerate(drafts)]) ) - messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context) - model = _synthesis_model(self._settings, mode) - reply = await self._llm.chat(messages, model=model) - needs_refine = _needs_refine(reply, classify) - if not needs_refine: - return reply - refine_prompt = prompts.REFINE_PROMPT + "\nQuestion: " + question + "\nDraft: " + reply - refine_messages = build_messages(prompts.CLUSTER_SYSTEM, refine_prompt, context=context) - return await self._llm.chat(refine_messages, model=model) + select_raw = await call_llm(prompts.CRITIC_SYSTEM, select_prompt, context=context, model=plan.fast_model, tag="draft_select") + selection = _parse_json_block(select_raw, fallback={}) + idx = int(selection.get("best", 1)) - 1 + if 0 <= idx < len(drafts): + return drafts[idx] + return drafts[0] + + async def _score_answer( + self, + question: str, + reply: str, + plan: ModePlan, + call_llm: Callable[..., Any], + ) -> AnswerScores: + if not plan.use_scores: + return _default_scores() + prompt = prompts.SCORE_PROMPT + "\nQuestion: " + question + "\nAnswer: " + reply + raw = await call_llm(prompts.SCORE_SYSTEM, prompt, model=plan.fast_model, tag="score") + data = _parse_json_block(raw, fallback={}) + return _scores_from_json(data) async def _extract_claims( self, question: str, reply: str, summary: dict[str, Any], + call_llm: Callable[..., Any], ) -> list[ClaimItem]: if not reply or not summary: return [] summary_json = _json_excerpt(summary) prompt = prompts.CLAIM_MAP_PROMPT + "\nQuestion: " + question + "\nAnswer: " + reply - messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=f"SnapshotSummaryJson:{summary_json}") - raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast) + raw = await call_llm(prompts.CLAIM_SYSTEM, prompt, context=f"SnapshotSummaryJson:{summary_json}", model=self._settings.ollama_model_fast, tag="claim_map") data = _parse_json_block(raw, fallback={}) claims_raw = data.get("claims") if isinstance(data, dict) else None claims: list[ClaimItem] = [] @@ -324,10 +360,10 @@ class AnswerEngine: state: ConversationState, summary: dict[str, Any], classify: dict[str, Any], - mode: str, + plan: ModePlan, + call_llm: Callable[..., Any], ) -> str: - follow_kind = classify.get("follow_up_kind") if isinstance(classify, dict) else "other" - claim_ids = await self._select_claims(question, state.claims) + claim_ids = await self._select_claims(question, state.claims, plan, call_llm) selected = [claim for claim in state.claims if claim.id in claim_ids] if claim_ids else state.claims[:2] evidence_lines = [] for claim in selected: @@ -338,23 +374,23 @@ class AnswerEngine: delta_note = "" if ev.value_at_claim is not None and current is not None and current != ev.value_at_claim: delta_note = f" (now {current})" - evidence_lines.append(f"- {ev.path}: {ev.value_at_claim}{delta_note} {('- ' + ev.reason) if ev.reason else ''}") + evidence_lines.append(f"- {ev.path}: {ev.value_at_claim}{delta_note}") evidence_ctx = "\n".join(evidence_lines) - prompt = prompts.FOLLOWUP_EVIDENCE_PROMPT - if follow_kind in {"next_steps", "change"}: - prompt = prompts.FOLLOWUP_ACTION_PROMPT - prompt = prompt + "\nFollow-up: " + question + "\nEvidence:\n" + evidence_ctx - messages = build_messages(prompts.CLUSTER_SYSTEM, prompt) - model = _synthesis_model(self._settings, mode) - return await self._llm.chat(messages, model=model) + prompt = prompts.FOLLOWUP_PROMPT + "\nFollow-up: " + question + "\nEvidence:\n" + evidence_ctx + return await call_llm(prompts.FOLLOWUP_SYSTEM, prompt, model=plan.model, tag="followup") - async def _select_claims(self, question: str, claims: list[ClaimItem]) -> list[str]: + async def _select_claims( + self, + question: str, + claims: list[ClaimItem], + plan: ModePlan, + call_llm: Callable[..., Any], + ) -> list[str]: if not claims: return [] claims_brief = [{"id": claim.id, "claim": claim.claim} for claim in claims] prompt = prompts.SELECT_CLAIMS_PROMPT + "\nFollow-up: " + question + "\nClaims: " + json.dumps(claims_brief) - messages = build_messages(prompts.CLUSTER_SYSTEM, prompt) - raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast) + raw = await call_llm(prompts.FOLLOWUP_SYSTEM, prompt, model=plan.fast_model, tag="select_claims") data = _parse_json_block(raw, fallback={}) ids = data.get("claim_ids") if isinstance(data, dict) else [] if isinstance(ids, list): @@ -392,15 +428,191 @@ class AnswerEngine: self._state.pop(key, None) +def _build_meta( + mode: str, + call_count: int, + call_cap: int, + limit_hit: bool, + classify: dict[str, Any], + tool_hint: dict[str, Any] | None, + started: float, +) -> dict[str, Any]: + return { + "mode": mode, + "llm_calls": call_count, + "llm_limit": call_cap, + "llm_limit_hit": limit_hit, + "classify": classify, + "tool_hint": tool_hint, + "elapsed_sec": round(time.monotonic() - started, 2), + } + + +def _mode_plan(settings: Settings, mode: str) -> ModePlan: + if mode == "genius": + return ModePlan( + model=settings.ollama_model_genius, + fast_model=settings.ollama_model_fast, + max_subquestions=6, + chunk_lines=6, + chunk_top=10, + chunk_group=4, + use_tool=True, + use_critic=True, + use_gap=True, + use_scores=True, + drafts=2, + ) + if mode == "smart": + return ModePlan( + model=settings.ollama_model_smart, + fast_model=settings.ollama_model_fast, + max_subquestions=4, + chunk_lines=8, + chunk_top=8, + chunk_group=4, + use_tool=True, + use_critic=True, + use_gap=True, + use_scores=True, + drafts=1, + ) + return ModePlan( + model=settings.ollama_model_fast, + fast_model=settings.ollama_model_fast, + max_subquestions=2, + chunk_lines=12, + chunk_top=5, + chunk_group=5, + use_tool=False, + use_critic=False, + use_gap=False, + use_scores=False, + drafts=1, + ) + + +def _llm_call_limit(settings: Settings, mode: str) -> int: + if mode == "genius": + return settings.genius_llm_calls_max + if mode == "smart": + return settings.smart_llm_calls_max + return settings.fast_llm_calls_max + + +def _select_subquestions(parts: list[dict[str, Any]], fallback: str, limit: int) -> list[str]: + if not parts: + return [fallback] + ranked = [] + for entry in parts: + if not isinstance(entry, dict): + continue + question = str(entry.get("question") or "").strip() + if not question: + continue + priority = entry.get("priority") + try: + weight = float(priority) + except (TypeError, ValueError): + weight = 1.0 + ranked.append((weight, question)) + ranked.sort(key=lambda item: item[0]) + questions = [item[1] for item in ranked][:limit] + return questions or [fallback] + + +def _chunk_lines(lines: list[str], lines_per_chunk: int) -> list[dict[str, Any]]: + chunks: list[dict[str, Any]] = [] + if not lines: + return chunks + for idx in range(0, len(lines), lines_per_chunk): + chunk_lines = lines[idx : idx + lines_per_chunk] + text = "\n".join(chunk_lines) + summary = " | ".join(chunk_lines[:2]) + chunks.append({"id": f"c{idx//lines_per_chunk}", "text": text, "summary": summary}) + return chunks + + +async def _score_chunks( + call_llm: Callable[..., Any], + chunks: list[dict[str, Any]], + question: str, + sub_questions: list[str], + plan: ModePlan, +) -> dict[str, float]: + scores: dict[str, float] = {chunk["id"]: 0.0 for chunk in chunks} + if not chunks: + return scores + group: list[dict[str, Any]] = [] + for chunk in chunks: + group.append({"id": chunk["id"], "summary": chunk["summary"]}) + if len(group) >= plan.chunk_group: + scores.update(await _score_chunk_group(call_llm, group, question, sub_questions)) + group = [] + if group: + scores.update(await _score_chunk_group(call_llm, group, question, sub_questions)) + return scores + + +async def _score_chunk_group( + call_llm: Callable[..., Any], + group: list[dict[str, Any]], + question: str, + sub_questions: list[str], +) -> dict[str, float]: + prompt = ( + prompts.CHUNK_SCORE_PROMPT + + "\nQuestion: " + + question + + "\nSubQuestions: " + + json.dumps(sub_questions) + + "\nChunks: " + + json.dumps(group) + ) + raw = await call_llm(prompts.RETRIEVER_SYSTEM, prompt, model=None, tag="chunk_score") + data = _parse_json_list(raw) + scored: dict[str, float] = {} + for entry in data: + if not isinstance(entry, dict): + continue + cid = str(entry.get("id") or "").strip() + if not cid: + continue + try: + score = float(entry.get("score") or 0) + except (TypeError, ValueError): + score = 0.0 + scored[cid] = score + return scored + + +def _select_chunks( + chunks: list[dict[str, Any]], + scores: dict[str, float], + plan: ModePlan, +) -> list[dict[str, Any]]: + if not chunks: + return [] + ranked = sorted(chunks, key=lambda item: scores.get(item["id"], 0.0), reverse=True) + selected = ranked[: plan.chunk_top] + return selected + + +def _format_runbooks(runbooks: list[str]) -> str: + if not runbooks: + return "" + return "Relevant runbooks:\n" + "\n".join([f"- {item}" for item in runbooks]) + + def _join_context(parts: list[str]) -> str: - text = "\n".join([p for p in parts if p]) + text = "\n".join([part for part in parts if part]) return text.strip() def _format_history(history: list[dict[str, str]] | None) -> str: if not history: return "" - lines = ["Recent conversation:"] + lines = ["Recent conversation (non-authoritative):"] for entry in history[-4:]: if not isinstance(entry, dict): continue @@ -418,32 +630,11 @@ def _format_history(history: list[dict[str, str]] | None) -> str: return "\n".join(lines) -def _angles_limit(settings: Settings, mode: str) -> int: - if mode == "genius": - return settings.genius_max_angles - if mode == "quick": - return settings.fast_max_angles - return settings.smart_max_angles - - -def _candidates_limit(settings: Settings, mode: str) -> int: - if mode == "genius": - return settings.genius_max_candidates - if mode == "quick": - return settings.fast_max_candidates - return settings.smart_max_candidates - - -def _candidate_model(settings: Settings, mode: str) -> str: - if mode == "genius": - return settings.ollama_model_genius - return settings.ollama_model_smart - - -def _synthesis_model(settings: Settings, mode: str) -> str: - if mode == "genius": - return settings.ollama_model_genius - return settings.ollama_model_smart +def _summary_lines(snapshot: dict[str, Any] | None) -> list[str]: + text = summary_text(snapshot) + if not text: + return [] + return [line for line in text.splitlines() if line.strip()] def _parse_json_block(text: str, *, fallback: dict[str, Any]) -> dict[str, Any]: @@ -454,19 +645,6 @@ def _parse_json_block(text: str, *, fallback: dict[str, Any]) -> dict[str, Any]: return parse_json(raw, fallback=fallback) -def _needs_refine(reply: str, classify: dict[str, Any]) -> bool: - if not reply: - return False - style = classify.get("answer_style") if isinstance(classify, dict) else None - if style != "insightful": - return False - metric_markers = ["cpu", "ram", "pods", "connections", "%"] - lower = reply.lower() - metric_hits = sum(1 for m in metric_markers if m in lower) - sentence_count = reply.count(".") + reply.count("!") + reply.count("?") - return metric_hits >= 2 and sentence_count <= 2 - - def _parse_json_list(text: str) -> list[dict[str, Any]]: raw = text.strip() match = re.search(r"\[.*\]", raw, flags=re.S) @@ -498,10 +676,10 @@ def _default_scores() -> AnswerScores: def _resolve_path(data: Any, path: str) -> Any | None: cursor = data - for part in re.split(r"\\.(?![^\\[]*\\])", path): + for part in re.split(r"\.(?![^\[]*\])", path): if not part: continue - match = re.match(r"^(\\w+)(?:\\[(\\d+)\\])?$", part) + match = re.match(r"^(\w+)(?:\[(\d+)\])?$", part) if not match: return None key = match.group(1) @@ -532,100 +710,6 @@ def _snapshot_id(summary: dict[str, Any]) -> str | None: return None -def _summary_lines(snapshot: dict[str, Any] | None) -> list[str]: - text = summary_text(snapshot) - if not text: - return [] - return [line for line in text.splitlines() if line.strip()] - - -def _build_context( - summary_lines: list[str], - question: str, - classify: dict[str, Any], - *, - max_lines: int, -) -> str: - if not summary_lines: - return "" - lower = (question or "").lower() - prefixes: set[str] = set() - core_prefixes = { - "nodes", - "archs", - "roles", - "hardware", - "node_arch", - "node_os", - "pods", - "namespaces_top", - "namespace_pods", - "namespace_nodes", - "hottest", - "postgres", - "signals", - "profiles", - "watchlist", - "snapshot", - } - prefixes.update(core_prefixes) - - def _want(words: tuple[str, ...]) -> bool: - return any(word in lower for word in words) - - if _want(("cpu", "load", "ram", "memory", "io", "disk", "net", "network")): - prefixes.update( - { - "node_usage", - "node_load", - "node_load_summary", - "node_usage_top", - "root_disk", - "pvc_usage", - "namespace_cpu_top", - "namespace_mem_top", - "namespace_net_top", - "namespace_io_top", - } - ) - if _want(("namespace", "quota", "overcommit", "capacity", "active", "activity")): - prefixes.update( - { - "namespace_capacity", - "namespace_capacity_summary", - "namespace_cpu_top", - "namespace_mem_top", - "namespace_net_top", - "namespace_io_top", - "namespace_nodes_top", - "namespaces_top", - } - ) - if _want(("pod", "pending", "crash", "image", "pull", "fail")): - prefixes.update({"pod_issues", "pod_usage", "pod_events", "events", "event_summary", "pods_pending_oldest", "pods_pending_over_15m"}) - if _want(("restart", "restarts")): - prefixes.update({"restarts_1h_top", "restarts_1h_namespace_top", "pod_restarts"}) - if _want(("alert", "alerting", "incident", "error")): - prefixes.update({"signals", "events", "event_summary", "pod_issues", "watchlist"}) - if _want(("flux", "reconcile", "gitops")): - prefixes.update({"flux"}) - if _want(("longhorn", "volume", "pvc", "storage")): - prefixes.update({"longhorn", "pvc_usage", "root_disk"}) - if _want(("workload", "deployment", "stateful", "daemon", "schedule", "heavy")): - prefixes.update({"workloads", "workloads_by_namespace", "workload_health"}) - if classify.get("answer_style") == "insightful" or classify.get("question_type") == "open_ended": - prefixes.update({"signals", "profiles", "watchlist", "hottest"}) - - selected: list[str] = [] - for line in summary_lines: - prefix = line.split(":", 1)[0].strip().lower() - if prefix in prefixes or any(prefix.startswith(pfx) for pfx in prefixes): - selected.append(line) - if len(selected) >= max_lines: - break - return "\n".join(selected) - - def _json_excerpt(summary: dict[str, Any], max_chars: int = 12000) -> str: raw = json.dumps(summary, ensure_ascii=False) return raw[:max_chars] diff --git a/atlasbot/llm/prompts.py b/atlasbot/llm/prompts.py index 7007687..b79eeb0 100644 --- a/atlasbot/llm/prompts.py +++ b/atlasbot/llm/prompts.py @@ -1,97 +1,147 @@ CLUSTER_SYSTEM = ( "You are Atlas, the Titan Lab assistant for the Atlas cluster. " - "Use the provided context as your source of truth. " - "Context is authoritative; do not ignore it. " - "If Context is present, you must base numbers and facts on it. " - "If a fact or number is not present in the context, say you do not know. " - "Do not invent metrics or capacities. " - "If history conflicts with the snapshot, trust the snapshot. " - "If the question is about Atlas, respond in short paragraphs. " - "Avoid commands unless explicitly asked. " - "If information is missing, say so clearly and avoid guessing. " - "If the question is open-ended, provide grounded interpretation or implications, " - "not just a list of metrics. " - "Do not mention the context, snapshot, or knowledge base unless the user asks about sources. " + "Use provided context as authoritative. " + "If a fact is not in context, say you do not know. " + "Be conversational and grounded. " + "Avoid commands unless the user asks for them. " + "Do not mention the context or knowledge base unless asked." ) -CLASSIFY_PROMPT = ( - "Classify the user question. Return JSON with fields: " - "needs_snapshot (bool), needs_kb (bool), needs_metrics (bool), " - "needs_general (bool), intent (short string), ambiguity (0-1), " - "answer_style (direct|insightful), topic_summary (short string), " - "follow_up (bool), follow_up_kind (evidence|why|clarify|next_steps|change|other), " - "question_type (metric|diagnostic|planning|open_ended)." +NORMALIZE_SYSTEM = ( + "Normalize user questions for reasoning. " + "Return JSON only." ) -ANGLE_PROMPT = ( - "Generate up to {max_angles} possible angles to answer the question. " - "If the question is open-ended, include at least one angle that focuses on implications. " - "Return JSON list of objects with: name, question, relevance (0-100)." +NORMALIZE_PROMPT = ( + "Return JSON with fields: normalized (string), keywords (list), entities (list), " + "intent (short string), wants_metrics (bool), wants_opinion (bool)." ) -CANDIDATE_PROMPT = ( - "Answer this angle using the provided context. " - "Context facts override any prior or remembered statements. " - "Keep it concise, 2-4 sentences. " - "If the question is open-ended, include one grounded interpretation or implication. " - "Avoid dumping raw metrics unless asked; prefer what the numbers imply. " - "Do not mention the context or snapshot unless explicitly asked." +ROUTE_SYSTEM = ( + "Route the question to the best sources and answer style. " + "Return JSON only." ) -SCORE_PROMPT = ( - "Score the candidate response. Return JSON with fields: " - "confidence (0-100), relevance (0-100), satisfaction (0-100), " - "hallucination_risk (low|medium|high)." +ROUTE_PROMPT = ( + "Return JSON with fields: needs_snapshot (bool), needs_kb (bool), needs_tool (bool), " + "answer_style (direct|insightful), follow_up (bool), question_type (metric|diagnostic|planning|open_ended)." +) + +DECOMPOSE_SYSTEM = ( + "Break complex questions into smaller, answerable sub-questions. " + "Return JSON only." +) + +DECOMPOSE_PROMPT = ( + "Generate up to {max_parts} sub-questions. " + "Return JSON list of objects with: id, question, priority (1-5), kind (metric|analysis|context)." +) + +RETRIEVER_SYSTEM = ( + "Score relevance of chunk summaries to the question and sub-questions. " + "Return JSON list only." +) + +CHUNK_SCORE_PROMPT = ( + "Given chunk summaries, score relevance 0-100. " + "Return JSON list of objects with: id, score, reason (<=12 words)." +) + +TOOL_SYSTEM = ( + "Suggest a safe, read-only command that could refine the answer. " + "Return JSON only." +) + +TOOL_PROMPT = ( + "Return JSON with fields: command (string), rationale (string). " + "If no tool is useful, return empty strings." +) + +ANSWER_SYSTEM = ( + "Answer a focused sub-question using the provided context. " + "Be concise and grounded." +) + +SUBANSWER_PROMPT = ( + "Answer the sub-question using the context. " + "If context lacks the fact, say so." +) + +SYNTHESIZE_SYSTEM = ( + "Synthesize a final answer from sub-answers. " + "Keep it conversational and grounded." ) SYNTHESIZE_PROMPT = ( - "Synthesize a final response from the best candidates. " - "Use a natural, helpful tone with light reasoning. " - "Avoid lists unless the user asked for lists. " - "If AnswerStyle is insightful, add one grounded insight or mild hypothesis, " - "but mark uncertainty briefly. " - "If AnswerStyle is direct, keep it short and factual. " - "Do not include confidence scores or evaluation metadata." + "Write a final response to the user. " + "Use sub-answers as evidence, avoid raw metric dumps unless asked." ) -REFINE_PROMPT = ( - "Improve the answer if it reads like a raw metric dump or ignores the question's intent. " - "Keep it grounded in the context. If you cannot add insight, say so explicitly." +DRAFT_SELECT_PROMPT = ( + "Pick the best draft for accuracy, clarity, and helpfulness. " + "Return JSON with field: best (1-based index)." +) + +CRITIC_SYSTEM = ( + "Critique answers for unsupported claims or missing context. " + "Return JSON only." +) + +CRITIC_PROMPT = ( + "Return JSON with fields: issues (list), missing_data (list), risky_claims (list)." +) + +REVISION_SYSTEM = ( + "Revise the answer based on critique. " + "Keep the response grounded and concise." +) + +REVISION_PROMPT = ( + "Rewrite the answer using the critique. " + "Do not introduce new facts." +) + +GAP_SYSTEM = ( + "Identify missing data that would improve the answer. " + "Return JSON only." +) + +EVIDENCE_GAP_PROMPT = ( + "Return JSON with field: note (string). " + "If nothing is missing, return empty note." +) + +CLAIM_SYSTEM = ( + "Extract claim-evidence mappings from the answer. " + "Return JSON only." ) CLAIM_MAP_PROMPT = ( - "Extract a claim map from the answer. " - "Return JSON with fields: claims (list). " - "Each claim object: id (short string), claim (short sentence), " - "evidence (list of objects with path and reason). " - "Paths must point into the provided SnapshotSummary JSON using dot notation, " - "with list indexes in brackets, e.g. metrics.node_load[0].node. " - "Do not invent evidence; if no evidence exists, omit the claim." + "Return JSON with claims list; each claim: id, claim, evidence (list of {path, reason})." +) + +FOLLOWUP_SYSTEM = ( + "Answer follow-ups using prior claim evidence only. " + "Return JSON only when asked to select claims." +) + +FOLLOWUP_PROMPT = ( + "Answer the follow-up using provided evidence. Be conversational and concise." ) SELECT_CLAIMS_PROMPT = ( - "Pick which prior claim(s) the follow-up refers to. " - "Return JSON with fields: claim_ids (list of ids), follow_up_kind " - "(evidence|why|clarify|next_steps|change|other). " - "If none apply, return an empty list." + "Select relevant claim ids for the follow-up. " + "Return JSON with field: claim_ids (list)." ) -FOLLOWUP_EVIDENCE_PROMPT = ( - "Answer the follow-up using only the provided claims and evidence. " - "Be conversational, not bullet-heavy. " - "If evidence does not support a claim, say so plainly. " - "Do not add new claims." +SCORE_SYSTEM = ( + "Score response quality. Return JSON only." ) -FOLLOWUP_ACTION_PROMPT = ( - "Answer the follow-up using the provided claims and evidence. " - "You may suggest next steps or changes, but keep them tightly tied " - "to the evidence list. " - "Be conversational and concise." +SCORE_PROMPT = ( + "Return JSON with fields: confidence (0-100), relevance (0-100), satisfaction (0-100), hallucination_risk (low|medium|high)." ) STOCK_SYSTEM = ( - "You are Atlas, a helpful assistant. " - "Be concise and truthful. " - "If unsure, say so." + "You are Atlas, a helpful assistant. Be concise and truthful." ) diff --git a/tests/test_engine.py b/tests/test_engine.py index 49b1290..3d5fb43 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -7,13 +7,27 @@ from atlasbot.config import Settings class FakeLLM: - def __init__(self, replies: list[str]) -> None: - self._replies = replies + def __init__(self) -> None: self.calls: list[str] = [] async def chat(self, messages, *, model=None): - self.calls.append(model or "") - return self._replies.pop(0) + prompt = messages[-1]["content"] + self.calls.append(prompt) + if "normalized" in prompt and "keywords" in prompt: + return '{"normalized":"What is Atlas?","keywords":["atlas"]}' + if "needs_snapshot" in prompt: + return '{"needs_snapshot": true, "answer_style": "direct"}' + if "sub-questions" in prompt: + return '[{"id":"q1","question":"What is Atlas?","priority":1}]' + if "sub-question" in prompt: + return "Atlas has 22 nodes." + if "final response" in prompt: + return "Atlas has 22 nodes." + if "Score response quality" in prompt: + return '{"confidence":80,"relevance":90,"satisfaction":85,"hallucination_risk":"low"}' + if "claims list" in prompt: + return '{"claims": []}' + return "{}" def _settings() -> Settings: @@ -43,6 +57,8 @@ def _settings() -> Settings: ariadne_state_token="", snapshot_ttl_sec=30, thinking_interval_sec=30, + conversation_ttl_sec=300, + snapshot_pin_enabled=False, queue_enabled=False, nats_url="", nats_stream="", @@ -54,19 +70,15 @@ def _settings() -> Settings: fast_max_candidates=1, smart_max_candidates=1, genius_max_candidates=1, + fast_llm_calls_max=9, + smart_llm_calls_max=17, + genius_llm_calls_max=32, + llm_limit_multiplier=1.5, ) def test_engine_answer_basic(): - llm = FakeLLM( - [ - '{"needs_snapshot": true}', - '[{"name":"primary","question":"What is Atlas?","relevance":90}]', - "Based on the snapshot, Atlas has 22 nodes.", - '{"confidence":80,"relevance":90,"satisfaction":85,"hallucination_risk":"low"}', - "Atlas has 22 nodes and is healthy.", - ] - ) + llm = FakeLLM() settings = _settings() kb = KnowledgeBase("") snapshot = SnapshotProvider(settings)