atlasbot: add conversation state

This commit is contained in:
Brad Stein 2026-01-30 16:41:04 -03:00
parent 0e471ecc37
commit 0476edae6a
6 changed files with 262 additions and 13 deletions

View File

@ -20,6 +20,7 @@ class AnswerRequest(BaseModel):
content: str | None = None
mode: str | None = None
history: list[dict[str, str]] | None = None
conversation_id: str | None = None
class AnswerResponse(BaseModel):
@ -30,7 +31,7 @@ class Api:
def __init__(
self,
settings: Settings,
answer_handler: Callable[[str, str, list[dict[str, str]] | None], Awaitable[AnswerResult]],
answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None], Awaitable[AnswerResult]],
) -> None:
self._settings = settings
self._answer_handler = answer_handler
@ -53,7 +54,8 @@ class Api:
if not question:
raise HTTPException(status_code=400, detail="missing question")
mode = (payload.mode or "quick").strip().lower()
result = await self._answer_handler(question, mode, payload.history)
conversation_id = payload.conversation_id
result = await self._answer_handler(question, mode, payload.history, conversation_id)
log.info(
"answer",
extra={

View File

@ -53,6 +53,8 @@ class Settings:
snapshot_ttl_sec: int
thinking_interval_sec: int
conversation_ttl_sec: int
snapshot_pin_enabled: bool
queue_enabled: bool
nats_url: str
@ -141,6 +143,8 @@ def load_settings() -> Settings:
ariadne_state_token=os.getenv("ARIADNE_STATE_TOKEN", ""),
snapshot_ttl_sec=_env_int("ATLASBOT_SNAPSHOT_TTL_SEC", "30"),
thinking_interval_sec=_env_int("ATLASBOT_THINKING_INTERVAL_SEC", "30"),
conversation_ttl_sec=_env_int("ATLASBOT_CONVERSATION_TTL_SEC", "900"),
snapshot_pin_enabled=_env_bool("ATLASBOT_SNAPSHOT_PIN_ENABLED", "false"),
queue_enabled=_env_bool("ATLASBOT_QUEUE_ENABLED", "false"),
nats_url=os.getenv("ATLASBOT_NATS_URL", "nats://nats.nats.svc.cluster.local:4222"),
nats_stream=os.getenv("ATLASBOT_NATS_STREAM", "atlasbot"),

View File

@ -1,4 +1,5 @@
import asyncio
import json
import logging
import re
import time
@ -9,7 +10,7 @@ from atlasbot.config import Settings
from atlasbot.knowledge.loader import KnowledgeBase
from atlasbot.llm.client import LLMClient, build_messages, parse_json
from atlasbot.llm import prompts
from atlasbot.snapshot.builder import SnapshotProvider, summary_text
from atlasbot.snapshot.builder import SnapshotProvider, build_summary, summary_text
log = logging.getLogger(__name__)
@ -29,6 +30,29 @@ class AnswerResult:
meta: dict[str, Any]
@dataclass
class EvidenceItem:
path: str
reason: str
value: Any | None = None
value_at_claim: Any | None = None
@dataclass
class ClaimItem:
id: str
claim: str
evidence: list[EvidenceItem]
@dataclass
class ConversationState:
updated_at: float
claims: list[ClaimItem]
snapshot_id: str | None = None
snapshot: dict[str, Any] | None = None
class AnswerEngine:
def __init__(
self,
@ -41,6 +65,7 @@ class AnswerEngine:
self._llm = llm
self._kb = kb
self._snapshot = snapshot
self._state: dict[str, ConversationState] = {}
async def answer(
self,
@ -49,6 +74,7 @@ class AnswerEngine:
mode: str,
history: list[dict[str, str]] | None = None,
observer: Callable[[str, str], None] | None = None,
conversation_id: str | None = None,
) -> AnswerResult:
question = (question or "").strip()
if not question:
@ -56,10 +82,15 @@ class AnswerEngine:
if mode == "stock":
return await self._answer_stock(question)
state = self._get_state(conversation_id)
snapshot = self._snapshot.get()
snapshot_used = snapshot
if self._settings.snapshot_pin_enabled and state and state.snapshot:
snapshot_used = state.snapshot
summary = build_summary(snapshot_used)
kb_summary = self._kb.summary()
runbooks = self._kb.runbook_titles(limit=4)
snapshot_ctx = summary_text(snapshot)
snapshot_ctx = summary_text(snapshot_used)
history_ctx = _format_history(history)
base_context = _join_context([
kb_summary,
@ -97,15 +128,30 @@ class AnswerEngine:
"atlasbot_selection",
extra={"extra": {"mode": mode, "selected": len(best), "scores": scores.__dict__}},
)
if classify.get("follow_up") and state and state.claims:
if observer:
observer("followup", "answering follow-up")
reply = await self._answer_followup(question, state, summary, classify, mode)
meta = {
"mode": mode,
"follow_up": True,
"classify": classify,
}
return AnswerResult(reply, scores, meta)
if observer:
observer("synthesize", "synthesizing reply")
reply = await self._synthesize(question, best, base_context, classify, mode)
claims = await self._extract_claims(question, reply, summary)
if conversation_id and claims:
self._store_state(conversation_id, claims, summary, snapshot_used)
meta = {
"mode": mode,
"angles": angles,
"scores": scores.__dict__,
"classify": classify,
"candidates": len(candidates),
"claims": len(claims),
}
return AnswerResult(reply, scores, meta)
@ -121,6 +167,8 @@ class AnswerEngine:
data = _parse_json_block(raw, fallback={"needs_snapshot": True})
if "answer_style" not in data:
data["answer_style"] = "direct"
if "follow_up_kind" not in data:
data["follow_up_kind"] = "other"
return data
async def _angles(self, question: str, classify: dict[str, Any], mode: str) -> list[dict[str, Any]]:
@ -216,6 +264,114 @@ class AnswerEngine:
refine_messages = build_messages(prompts.CLUSTER_SYSTEM, refine_prompt, context=context)
return await self._llm.chat(refine_messages, model=model)
async def _extract_claims(
self,
question: str,
reply: str,
summary: dict[str, Any],
) -> list[ClaimItem]:
if not reply or not summary:
return []
summary_json = _json_excerpt(summary)
prompt = prompts.CLAIM_MAP_PROMPT + "\nQuestion: " + question + "\nAnswer: " + reply
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=f"SnapshotSummaryJson:{summary_json}")
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
data = _parse_json_block(raw, fallback={})
claims_raw = data.get("claims") if isinstance(data, dict) else None
claims: list[ClaimItem] = []
if isinstance(claims_raw, list):
for entry in claims_raw:
if not isinstance(entry, dict):
continue
claim_text = str(entry.get("claim") or "").strip()
claim_id = str(entry.get("id") or "").strip() or f"c{len(claims)+1}"
evidence_items: list[EvidenceItem] = []
for ev in entry.get("evidence") or []:
if not isinstance(ev, dict):
continue
path = str(ev.get("path") or "").strip()
if not path:
continue
reason = str(ev.get("reason") or "").strip()
value = _resolve_path(summary, path)
evidence_items.append(EvidenceItem(path=path, reason=reason, value=value, value_at_claim=value))
if claim_text and evidence_items:
claims.append(ClaimItem(id=claim_id, claim=claim_text, evidence=evidence_items))
return claims
async def _answer_followup(
self,
question: str,
state: ConversationState,
summary: dict[str, Any],
classify: dict[str, Any],
mode: str,
) -> str:
follow_kind = classify.get("follow_up_kind") if isinstance(classify, dict) else "other"
claim_ids = await self._select_claims(question, state.claims)
selected = [claim for claim in state.claims if claim.id in claim_ids] if claim_ids else state.claims[:2]
evidence_lines = []
for claim in selected:
evidence_lines.append(f"Claim: {claim.claim}")
for ev in claim.evidence:
current = _resolve_path(summary, ev.path)
ev.value = current
delta_note = ""
if ev.value_at_claim is not None and current is not None and current != ev.value_at_claim:
delta_note = f" (now {current})"
evidence_lines.append(f"- {ev.path}: {ev.value_at_claim}{delta_note} {('- ' + ev.reason) if ev.reason else ''}")
evidence_ctx = "\n".join(evidence_lines)
prompt = prompts.FOLLOWUP_EVIDENCE_PROMPT
if follow_kind in {"next_steps", "change"}:
prompt = prompts.FOLLOWUP_ACTION_PROMPT
prompt = prompt + "\nFollow-up: " + question + "\nEvidence:\n" + evidence_ctx
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
model = _synthesis_model(self._settings, mode)
return await self._llm.chat(messages, model=model)
async def _select_claims(self, question: str, claims: list[ClaimItem]) -> list[str]:
if not claims:
return []
claims_brief = [{"id": claim.id, "claim": claim.claim} for claim in claims]
prompt = prompts.SELECT_CLAIMS_PROMPT + "\nFollow-up: " + question + "\nClaims: " + json.dumps(claims_brief)
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
data = _parse_json_block(raw, fallback={})
ids = data.get("claim_ids") if isinstance(data, dict) else []
if isinstance(ids, list):
return [str(item) for item in ids if item]
return []
def _get_state(self, conversation_id: str | None) -> ConversationState | None:
if not conversation_id:
return None
self._cleanup_state()
return self._state.get(conversation_id)
def _store_state(
self,
conversation_id: str,
claims: list[ClaimItem],
summary: dict[str, Any],
snapshot: dict[str, Any] | None,
) -> None:
snapshot_id = _snapshot_id(summary)
pinned_snapshot = snapshot if self._settings.snapshot_pin_enabled else None
self._state[conversation_id] = ConversationState(
updated_at=time.monotonic(),
claims=claims,
snapshot_id=snapshot_id,
snapshot=pinned_snapshot,
)
self._cleanup_state()
def _cleanup_state(self) -> None:
ttl = max(60, self._settings.conversation_ttl_sec)
now = time.monotonic()
expired = [key for key, state in self._state.items() if now - state.updated_at > ttl]
for key in expired:
self._state.pop(key, None)
def _join_context(parts: list[str]) -> str:
text = "\n".join([p for p in parts if p])
@ -227,12 +383,19 @@ def _format_history(history: list[dict[str, str]] | None) -> str:
return ""
lines = ["Recent conversation:"]
for entry in history[-4:]:
question = entry.get("q") if isinstance(entry, dict) else None
answer = entry.get("a") if isinstance(entry, dict) else None
if not isinstance(entry, dict):
continue
question = entry.get("q")
answer = entry.get("a")
role = entry.get("role")
content = entry.get("content")
if question:
lines.append(f"Q: {question}")
if answer:
lines.append(f"A: {answer}")
if role and content:
prefix = "Q" if role == "user" else "A"
lines.append(f"{prefix}: {content}")
return "\n".join(lines)
@ -312,3 +475,44 @@ def _coerce_int(value: Any, default: int) -> int:
def _default_scores() -> AnswerScores:
return AnswerScores(confidence=60, relevance=60, satisfaction=60, hallucination_risk="medium")
def _resolve_path(data: Any, path: str) -> Any | None:
cursor = data
for part in re.split(r"\\.(?![^\\[]*\\])", path):
if not part:
continue
match = re.match(r"^(\\w+)(?:\\[(\\d+)\\])?$", part)
if not match:
return None
key = match.group(1)
index = match.group(2)
if isinstance(cursor, dict):
cursor = cursor.get(key)
else:
return None
if index is not None:
try:
idx = int(index)
if isinstance(cursor, list) and 0 <= idx < len(cursor):
cursor = cursor[idx]
else:
return None
except ValueError:
return None
return cursor
def _snapshot_id(summary: dict[str, Any]) -> str | None:
if not summary:
return None
for key in ("generated_at", "snapshot_ts", "snapshot_id"):
value = summary.get(key)
if isinstance(value, str) and value:
return value
return None
def _json_excerpt(summary: dict[str, Any], max_chars: int = 12000) -> str:
raw = json.dumps(summary, ensure_ascii=False)
return raw[:max_chars]

View File

@ -16,7 +16,8 @@ CLASSIFY_PROMPT = (
"needs_snapshot (bool), needs_kb (bool), needs_metrics (bool), "
"needs_general (bool), intent (short string), ambiguity (0-1), "
"answer_style (direct|insightful), topic_summary (short string), "
"follow_up (bool), question_type (metric|diagnostic|planning|open_ended)."
"follow_up (bool), follow_up_kind (evidence|why|clarify|next_steps|change|other), "
"question_type (metric|diagnostic|planning|open_ended)."
)
ANGLE_PROMPT = (
@ -54,6 +55,37 @@ REFINE_PROMPT = (
"Keep it grounded in the context. If you cannot add insight, say so explicitly."
)
CLAIM_MAP_PROMPT = (
"Extract a claim map from the answer. "
"Return JSON with fields: claims (list). "
"Each claim object: id (short string), claim (short sentence), "
"evidence (list of objects with path and reason). "
"Paths must point into the provided SnapshotSummary JSON using dot notation, "
"with list indexes in brackets, e.g. metrics.node_load[0].node. "
"Do not invent evidence; if no evidence exists, omit the claim."
)
SELECT_CLAIMS_PROMPT = (
"Pick which prior claim(s) the follow-up refers to. "
"Return JSON with fields: claim_ids (list of ids), follow_up_kind "
"(evidence|why|clarify|next_steps|change|other). "
"If none apply, return an empty list."
)
FOLLOWUP_EVIDENCE_PROMPT = (
"Answer the follow-up using only the provided claims and evidence. "
"Be conversational, not bullet-heavy. "
"If evidence does not support a claim, say so plainly. "
"Do not add new claims."
)
FOLLOWUP_ACTION_PROMPT = (
"Answer the follow-up using the provided claims and evidence. "
"You may suggest next steps or changes, but keep them tightly tied "
"to the evidence list. "
"Be conversational and concise."
)
STOCK_SYSTEM = (
"You are Atlas, a helpful assistant. "
"Be concise and truthful. "

View File

@ -31,22 +31,24 @@ async def main() -> None:
async def handler(payload: dict[str, object]) -> dict[str, object]:
history = payload.get("history") if isinstance(payload, dict) else None
conversation_id = payload.get("conversation_id") if isinstance(payload, dict) else None
result = await engine.answer(
str(payload.get("question", "") or ""),
mode=str(payload.get("mode", "quick") or "quick"),
history=history if isinstance(history, list) else None,
conversation_id=str(conversation_id) if isinstance(conversation_id, str) else None,
)
return {"reply": result.reply, "scores": result.scores.__dict__}
queue = QueueManager(settings, handler)
await queue.start()
async def answer_handler(question: str, mode: str, history=None, observer=None) -> AnswerResult:
async def answer_handler(question: str, mode: str, history=None, conversation_id=None, observer=None) -> AnswerResult:
if settings.queue_enabled:
payload = await queue.submit({"question": question, "mode": mode, "history": history or []})
payload = await queue.submit({"question": question, "mode": mode, "history": history or [], "conversation_id": conversation_id})
reply = payload.get("reply", "") if isinstance(payload, dict) else ""
return AnswerResult(reply=reply or "", scores=result_scores(payload), meta={"mode": mode})
return await engine.answer(question, mode=mode, history=history, observer=observer)
return await engine.answer(question, mode=mode, history=history, observer=observer, conversation_id=conversation_id)
api = Api(settings, answer_handler)
server = uvicorn.Server(uvicorn.Config(api.app, host="0.0.0.0", port=settings.http_port, log_level="info"))

View File

@ -80,7 +80,10 @@ class MatrixBot:
settings: Settings,
bot: MatrixBotConfig,
engine: AnswerEngine,
answer_handler: Callable[[str, str, list[dict[str, str]] | None, Callable[[str, str], None] | None], Awaitable[AnswerResult]]
answer_handler: Callable[
[str, str, list[dict[str, str]] | None, str | None, Callable[[str, str], None] | None],
Awaitable[AnswerResult],
]
| None = None,
) -> None:
self._settings = settings
@ -155,9 +158,11 @@ class MatrixBot:
task = asyncio.create_task(heartbeat())
started = time.monotonic()
try:
handler = self._answer_handler or (lambda q, m, h, obs: self._engine.answer(q, mode=m, history=h, observer=obs))
handler = self._answer_handler or (
lambda q, m, h, cid, obs: self._engine.answer(q, mode=m, history=h, observer=obs, conversation_id=cid)
)
history = self._history.get(room_id, [])
result = await handler(question, mode, history, observer)
result = await handler(question, mode, history, room_id, observer)
elapsed = time.monotonic() - started
await self._client.send_message(token, room_id, result.reply)
log.info(