From 0476edae6a33ee4f4ae6c92e1c753025a872c097 Mon Sep 17 00:00:00 2001 From: Brad Stein Date: Fri, 30 Jan 2026 16:41:04 -0300 Subject: [PATCH] atlasbot: add conversation state --- atlasbot/api/http.py | 6 +- atlasbot/config.py | 4 + atlasbot/engine/answerer.py | 212 +++++++++++++++++++++++++++++++++++- atlasbot/llm/prompts.py | 34 +++++- atlasbot/main.py | 8 +- atlasbot/matrix/bot.py | 11 +- 6 files changed, 262 insertions(+), 13 deletions(-) diff --git a/atlasbot/api/http.py b/atlasbot/api/http.py index 13a3522..c164db1 100644 --- a/atlasbot/api/http.py +++ b/atlasbot/api/http.py @@ -20,6 +20,7 @@ class AnswerRequest(BaseModel): content: str | None = None mode: str | None = None history: list[dict[str, str]] | None = None + conversation_id: str | None = None class AnswerResponse(BaseModel): @@ -30,7 +31,7 @@ class Api: def __init__( self, settings: Settings, - answer_handler: Callable[[str, str, list[dict[str, str]] | None], Awaitable[AnswerResult]], + answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None], Awaitable[AnswerResult]], ) -> None: self._settings = settings self._answer_handler = answer_handler @@ -53,7 +54,8 @@ class Api: if not question: raise HTTPException(status_code=400, detail="missing question") mode = (payload.mode or "quick").strip().lower() - result = await self._answer_handler(question, mode, payload.history) + conversation_id = payload.conversation_id + result = await self._answer_handler(question, mode, payload.history, conversation_id) log.info( "answer", extra={ diff --git a/atlasbot/config.py b/atlasbot/config.py index 03ffae7..c53bf23 100644 --- a/atlasbot/config.py +++ b/atlasbot/config.py @@ -53,6 +53,8 @@ class Settings: snapshot_ttl_sec: int thinking_interval_sec: int + conversation_ttl_sec: int + snapshot_pin_enabled: bool queue_enabled: bool nats_url: str @@ -141,6 +143,8 @@ def load_settings() -> Settings: ariadne_state_token=os.getenv("ARIADNE_STATE_TOKEN", ""), snapshot_ttl_sec=_env_int("ATLASBOT_SNAPSHOT_TTL_SEC", "30"), thinking_interval_sec=_env_int("ATLASBOT_THINKING_INTERVAL_SEC", "30"), + conversation_ttl_sec=_env_int("ATLASBOT_CONVERSATION_TTL_SEC", "900"), + snapshot_pin_enabled=_env_bool("ATLASBOT_SNAPSHOT_PIN_ENABLED", "false"), queue_enabled=_env_bool("ATLASBOT_QUEUE_ENABLED", "false"), nats_url=os.getenv("ATLASBOT_NATS_URL", "nats://nats.nats.svc.cluster.local:4222"), nats_stream=os.getenv("ATLASBOT_NATS_STREAM", "atlasbot"), diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index b0febae..e3b1621 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -1,4 +1,5 @@ import asyncio +import json import logging import re import time @@ -9,7 +10,7 @@ from atlasbot.config import Settings from atlasbot.knowledge.loader import KnowledgeBase from atlasbot.llm.client import LLMClient, build_messages, parse_json from atlasbot.llm import prompts -from atlasbot.snapshot.builder import SnapshotProvider, summary_text +from atlasbot.snapshot.builder import SnapshotProvider, build_summary, summary_text log = logging.getLogger(__name__) @@ -29,6 +30,29 @@ class AnswerResult: meta: dict[str, Any] +@dataclass +class EvidenceItem: + path: str + reason: str + value: Any | None = None + value_at_claim: Any | None = None + + +@dataclass +class ClaimItem: + id: str + claim: str + evidence: list[EvidenceItem] + + +@dataclass +class ConversationState: + updated_at: float + claims: list[ClaimItem] + snapshot_id: str | None = None + snapshot: dict[str, Any] | None = None + + class AnswerEngine: def __init__( self, @@ -41,6 +65,7 @@ class AnswerEngine: self._llm = llm self._kb = kb self._snapshot = snapshot + self._state: dict[str, ConversationState] = {} async def answer( self, @@ -49,6 +74,7 @@ class AnswerEngine: mode: str, history: list[dict[str, str]] | None = None, observer: Callable[[str, str], None] | None = None, + conversation_id: str | None = None, ) -> AnswerResult: question = (question or "").strip() if not question: @@ -56,10 +82,15 @@ class AnswerEngine: if mode == "stock": return await self._answer_stock(question) + state = self._get_state(conversation_id) snapshot = self._snapshot.get() + snapshot_used = snapshot + if self._settings.snapshot_pin_enabled and state and state.snapshot: + snapshot_used = state.snapshot + summary = build_summary(snapshot_used) kb_summary = self._kb.summary() runbooks = self._kb.runbook_titles(limit=4) - snapshot_ctx = summary_text(snapshot) + snapshot_ctx = summary_text(snapshot_used) history_ctx = _format_history(history) base_context = _join_context([ kb_summary, @@ -97,15 +128,30 @@ class AnswerEngine: "atlasbot_selection", extra={"extra": {"mode": mode, "selected": len(best), "scores": scores.__dict__}}, ) + if classify.get("follow_up") and state and state.claims: + if observer: + observer("followup", "answering follow-up") + reply = await self._answer_followup(question, state, summary, classify, mode) + meta = { + "mode": mode, + "follow_up": True, + "classify": classify, + } + return AnswerResult(reply, scores, meta) + if observer: observer("synthesize", "synthesizing reply") reply = await self._synthesize(question, best, base_context, classify, mode) + claims = await self._extract_claims(question, reply, summary) + if conversation_id and claims: + self._store_state(conversation_id, claims, summary, snapshot_used) meta = { "mode": mode, "angles": angles, "scores": scores.__dict__, "classify": classify, "candidates": len(candidates), + "claims": len(claims), } return AnswerResult(reply, scores, meta) @@ -121,6 +167,8 @@ class AnswerEngine: data = _parse_json_block(raw, fallback={"needs_snapshot": True}) if "answer_style" not in data: data["answer_style"] = "direct" + if "follow_up_kind" not in data: + data["follow_up_kind"] = "other" return data async def _angles(self, question: str, classify: dict[str, Any], mode: str) -> list[dict[str, Any]]: @@ -216,6 +264,114 @@ class AnswerEngine: refine_messages = build_messages(prompts.CLUSTER_SYSTEM, refine_prompt, context=context) return await self._llm.chat(refine_messages, model=model) + async def _extract_claims( + self, + question: str, + reply: str, + summary: dict[str, Any], + ) -> list[ClaimItem]: + if not reply or not summary: + return [] + summary_json = _json_excerpt(summary) + prompt = prompts.CLAIM_MAP_PROMPT + "\nQuestion: " + question + "\nAnswer: " + reply + messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=f"SnapshotSummaryJson:{summary_json}") + raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast) + data = _parse_json_block(raw, fallback={}) + claims_raw = data.get("claims") if isinstance(data, dict) else None + claims: list[ClaimItem] = [] + if isinstance(claims_raw, list): + for entry in claims_raw: + if not isinstance(entry, dict): + continue + claim_text = str(entry.get("claim") or "").strip() + claim_id = str(entry.get("id") or "").strip() or f"c{len(claims)+1}" + evidence_items: list[EvidenceItem] = [] + for ev in entry.get("evidence") or []: + if not isinstance(ev, dict): + continue + path = str(ev.get("path") or "").strip() + if not path: + continue + reason = str(ev.get("reason") or "").strip() + value = _resolve_path(summary, path) + evidence_items.append(EvidenceItem(path=path, reason=reason, value=value, value_at_claim=value)) + if claim_text and evidence_items: + claims.append(ClaimItem(id=claim_id, claim=claim_text, evidence=evidence_items)) + return claims + + async def _answer_followup( + self, + question: str, + state: ConversationState, + summary: dict[str, Any], + classify: dict[str, Any], + mode: str, + ) -> str: + follow_kind = classify.get("follow_up_kind") if isinstance(classify, dict) else "other" + claim_ids = await self._select_claims(question, state.claims) + selected = [claim for claim in state.claims if claim.id in claim_ids] if claim_ids else state.claims[:2] + evidence_lines = [] + for claim in selected: + evidence_lines.append(f"Claim: {claim.claim}") + for ev in claim.evidence: + current = _resolve_path(summary, ev.path) + ev.value = current + delta_note = "" + if ev.value_at_claim is not None and current is not None and current != ev.value_at_claim: + delta_note = f" (now {current})" + evidence_lines.append(f"- {ev.path}: {ev.value_at_claim}{delta_note} {('- ' + ev.reason) if ev.reason else ''}") + evidence_ctx = "\n".join(evidence_lines) + prompt = prompts.FOLLOWUP_EVIDENCE_PROMPT + if follow_kind in {"next_steps", "change"}: + prompt = prompts.FOLLOWUP_ACTION_PROMPT + prompt = prompt + "\nFollow-up: " + question + "\nEvidence:\n" + evidence_ctx + messages = build_messages(prompts.CLUSTER_SYSTEM, prompt) + model = _synthesis_model(self._settings, mode) + return await self._llm.chat(messages, model=model) + + async def _select_claims(self, question: str, claims: list[ClaimItem]) -> list[str]: + if not claims: + return [] + claims_brief = [{"id": claim.id, "claim": claim.claim} for claim in claims] + prompt = prompts.SELECT_CLAIMS_PROMPT + "\nFollow-up: " + question + "\nClaims: " + json.dumps(claims_brief) + messages = build_messages(prompts.CLUSTER_SYSTEM, prompt) + raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast) + data = _parse_json_block(raw, fallback={}) + ids = data.get("claim_ids") if isinstance(data, dict) else [] + if isinstance(ids, list): + return [str(item) for item in ids if item] + return [] + + def _get_state(self, conversation_id: str | None) -> ConversationState | None: + if not conversation_id: + return None + self._cleanup_state() + return self._state.get(conversation_id) + + def _store_state( + self, + conversation_id: str, + claims: list[ClaimItem], + summary: dict[str, Any], + snapshot: dict[str, Any] | None, + ) -> None: + snapshot_id = _snapshot_id(summary) + pinned_snapshot = snapshot if self._settings.snapshot_pin_enabled else None + self._state[conversation_id] = ConversationState( + updated_at=time.monotonic(), + claims=claims, + snapshot_id=snapshot_id, + snapshot=pinned_snapshot, + ) + self._cleanup_state() + + def _cleanup_state(self) -> None: + ttl = max(60, self._settings.conversation_ttl_sec) + now = time.monotonic() + expired = [key for key, state in self._state.items() if now - state.updated_at > ttl] + for key in expired: + self._state.pop(key, None) + def _join_context(parts: list[str]) -> str: text = "\n".join([p for p in parts if p]) @@ -227,12 +383,19 @@ def _format_history(history: list[dict[str, str]] | None) -> str: return "" lines = ["Recent conversation:"] for entry in history[-4:]: - question = entry.get("q") if isinstance(entry, dict) else None - answer = entry.get("a") if isinstance(entry, dict) else None + if not isinstance(entry, dict): + continue + question = entry.get("q") + answer = entry.get("a") + role = entry.get("role") + content = entry.get("content") if question: lines.append(f"Q: {question}") if answer: lines.append(f"A: {answer}") + if role and content: + prefix = "Q" if role == "user" else "A" + lines.append(f"{prefix}: {content}") return "\n".join(lines) @@ -312,3 +475,44 @@ def _coerce_int(value: Any, default: int) -> int: def _default_scores() -> AnswerScores: return AnswerScores(confidence=60, relevance=60, satisfaction=60, hallucination_risk="medium") + + +def _resolve_path(data: Any, path: str) -> Any | None: + cursor = data + for part in re.split(r"\\.(?![^\\[]*\\])", path): + if not part: + continue + match = re.match(r"^(\\w+)(?:\\[(\\d+)\\])?$", part) + if not match: + return None + key = match.group(1) + index = match.group(2) + if isinstance(cursor, dict): + cursor = cursor.get(key) + else: + return None + if index is not None: + try: + idx = int(index) + if isinstance(cursor, list) and 0 <= idx < len(cursor): + cursor = cursor[idx] + else: + return None + except ValueError: + return None + return cursor + + +def _snapshot_id(summary: dict[str, Any]) -> str | None: + if not summary: + return None + for key in ("generated_at", "snapshot_ts", "snapshot_id"): + value = summary.get(key) + if isinstance(value, str) and value: + return value + return None + + +def _json_excerpt(summary: dict[str, Any], max_chars: int = 12000) -> str: + raw = json.dumps(summary, ensure_ascii=False) + return raw[:max_chars] diff --git a/atlasbot/llm/prompts.py b/atlasbot/llm/prompts.py index e6d632a..68b8c7d 100644 --- a/atlasbot/llm/prompts.py +++ b/atlasbot/llm/prompts.py @@ -16,7 +16,8 @@ CLASSIFY_PROMPT = ( "needs_snapshot (bool), needs_kb (bool), needs_metrics (bool), " "needs_general (bool), intent (short string), ambiguity (0-1), " "answer_style (direct|insightful), topic_summary (short string), " - "follow_up (bool), question_type (metric|diagnostic|planning|open_ended)." + "follow_up (bool), follow_up_kind (evidence|why|clarify|next_steps|change|other), " + "question_type (metric|diagnostic|planning|open_ended)." ) ANGLE_PROMPT = ( @@ -54,6 +55,37 @@ REFINE_PROMPT = ( "Keep it grounded in the context. If you cannot add insight, say so explicitly." ) +CLAIM_MAP_PROMPT = ( + "Extract a claim map from the answer. " + "Return JSON with fields: claims (list). " + "Each claim object: id (short string), claim (short sentence), " + "evidence (list of objects with path and reason). " + "Paths must point into the provided SnapshotSummary JSON using dot notation, " + "with list indexes in brackets, e.g. metrics.node_load[0].node. " + "Do not invent evidence; if no evidence exists, omit the claim." +) + +SELECT_CLAIMS_PROMPT = ( + "Pick which prior claim(s) the follow-up refers to. " + "Return JSON with fields: claim_ids (list of ids), follow_up_kind " + "(evidence|why|clarify|next_steps|change|other). " + "If none apply, return an empty list." +) + +FOLLOWUP_EVIDENCE_PROMPT = ( + "Answer the follow-up using only the provided claims and evidence. " + "Be conversational, not bullet-heavy. " + "If evidence does not support a claim, say so plainly. " + "Do not add new claims." +) + +FOLLOWUP_ACTION_PROMPT = ( + "Answer the follow-up using the provided claims and evidence. " + "You may suggest next steps or changes, but keep them tightly tied " + "to the evidence list. " + "Be conversational and concise." +) + STOCK_SYSTEM = ( "You are Atlas, a helpful assistant. " "Be concise and truthful. " diff --git a/atlasbot/main.py b/atlasbot/main.py index 57d3ca3..9316356 100644 --- a/atlasbot/main.py +++ b/atlasbot/main.py @@ -31,22 +31,24 @@ async def main() -> None: async def handler(payload: dict[str, object]) -> dict[str, object]: history = payload.get("history") if isinstance(payload, dict) else None + conversation_id = payload.get("conversation_id") if isinstance(payload, dict) else None result = await engine.answer( str(payload.get("question", "") or ""), mode=str(payload.get("mode", "quick") or "quick"), history=history if isinstance(history, list) else None, + conversation_id=str(conversation_id) if isinstance(conversation_id, str) else None, ) return {"reply": result.reply, "scores": result.scores.__dict__} queue = QueueManager(settings, handler) await queue.start() - async def answer_handler(question: str, mode: str, history=None, observer=None) -> AnswerResult: + async def answer_handler(question: str, mode: str, history=None, conversation_id=None, observer=None) -> AnswerResult: if settings.queue_enabled: - payload = await queue.submit({"question": question, "mode": mode, "history": history or []}) + payload = await queue.submit({"question": question, "mode": mode, "history": history or [], "conversation_id": conversation_id}) reply = payload.get("reply", "") if isinstance(payload, dict) else "" return AnswerResult(reply=reply or "", scores=result_scores(payload), meta={"mode": mode}) - return await engine.answer(question, mode=mode, history=history, observer=observer) + return await engine.answer(question, mode=mode, history=history, observer=observer, conversation_id=conversation_id) api = Api(settings, answer_handler) server = uvicorn.Server(uvicorn.Config(api.app, host="0.0.0.0", port=settings.http_port, log_level="info")) diff --git a/atlasbot/matrix/bot.py b/atlasbot/matrix/bot.py index ee49118..07147cb 100644 --- a/atlasbot/matrix/bot.py +++ b/atlasbot/matrix/bot.py @@ -80,7 +80,10 @@ class MatrixBot: settings: Settings, bot: MatrixBotConfig, engine: AnswerEngine, - answer_handler: Callable[[str, str, list[dict[str, str]] | None, Callable[[str, str], None] | None], Awaitable[AnswerResult]] + answer_handler: Callable[ + [str, str, list[dict[str, str]] | None, str | None, Callable[[str, str], None] | None], + Awaitable[AnswerResult], + ] | None = None, ) -> None: self._settings = settings @@ -155,9 +158,11 @@ class MatrixBot: task = asyncio.create_task(heartbeat()) started = time.monotonic() try: - handler = self._answer_handler or (lambda q, m, h, obs: self._engine.answer(q, mode=m, history=h, observer=obs)) + handler = self._answer_handler or ( + lambda q, m, h, cid, obs: self._engine.answer(q, mode=m, history=h, observer=obs, conversation_id=cid) + ) history = self._history.get(room_id, []) - result = await handler(question, mode, history, observer) + result = await handler(question, mode, history, room_id, observer) elapsed = time.monotonic() - started await self._client.send_message(token, room_id, result.reply) log.info(