atlasbot: add conversation state
This commit is contained in:
parent
0e471ecc37
commit
0476edae6a
@ -20,6 +20,7 @@ class AnswerRequest(BaseModel):
|
||||
content: str | None = None
|
||||
mode: str | None = None
|
||||
history: list[dict[str, str]] | None = None
|
||||
conversation_id: str | None = None
|
||||
|
||||
|
||||
class AnswerResponse(BaseModel):
|
||||
@ -30,7 +31,7 @@ class Api:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
answer_handler: Callable[[str, str, list[dict[str, str]] | None], Awaitable[AnswerResult]],
|
||||
answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None], Awaitable[AnswerResult]],
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
self._answer_handler = answer_handler
|
||||
@ -53,7 +54,8 @@ class Api:
|
||||
if not question:
|
||||
raise HTTPException(status_code=400, detail="missing question")
|
||||
mode = (payload.mode or "quick").strip().lower()
|
||||
result = await self._answer_handler(question, mode, payload.history)
|
||||
conversation_id = payload.conversation_id
|
||||
result = await self._answer_handler(question, mode, payload.history, conversation_id)
|
||||
log.info(
|
||||
"answer",
|
||||
extra={
|
||||
|
||||
@ -53,6 +53,8 @@ class Settings:
|
||||
|
||||
snapshot_ttl_sec: int
|
||||
thinking_interval_sec: int
|
||||
conversation_ttl_sec: int
|
||||
snapshot_pin_enabled: bool
|
||||
|
||||
queue_enabled: bool
|
||||
nats_url: str
|
||||
@ -141,6 +143,8 @@ def load_settings() -> Settings:
|
||||
ariadne_state_token=os.getenv("ARIADNE_STATE_TOKEN", ""),
|
||||
snapshot_ttl_sec=_env_int("ATLASBOT_SNAPSHOT_TTL_SEC", "30"),
|
||||
thinking_interval_sec=_env_int("ATLASBOT_THINKING_INTERVAL_SEC", "30"),
|
||||
conversation_ttl_sec=_env_int("ATLASBOT_CONVERSATION_TTL_SEC", "900"),
|
||||
snapshot_pin_enabled=_env_bool("ATLASBOT_SNAPSHOT_PIN_ENABLED", "false"),
|
||||
queue_enabled=_env_bool("ATLASBOT_QUEUE_ENABLED", "false"),
|
||||
nats_url=os.getenv("ATLASBOT_NATS_URL", "nats://nats.nats.svc.cluster.local:4222"),
|
||||
nats_stream=os.getenv("ATLASBOT_NATS_STREAM", "atlasbot"),
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
@ -9,7 +10,7 @@ from atlasbot.config import Settings
|
||||
from atlasbot.knowledge.loader import KnowledgeBase
|
||||
from atlasbot.llm.client import LLMClient, build_messages, parse_json
|
||||
from atlasbot.llm import prompts
|
||||
from atlasbot.snapshot.builder import SnapshotProvider, summary_text
|
||||
from atlasbot.snapshot.builder import SnapshotProvider, build_summary, summary_text
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -29,6 +30,29 @@ class AnswerResult:
|
||||
meta: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvidenceItem:
|
||||
path: str
|
||||
reason: str
|
||||
value: Any | None = None
|
||||
value_at_claim: Any | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaimItem:
|
||||
id: str
|
||||
claim: str
|
||||
evidence: list[EvidenceItem]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationState:
|
||||
updated_at: float
|
||||
claims: list[ClaimItem]
|
||||
snapshot_id: str | None = None
|
||||
snapshot: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AnswerEngine:
|
||||
def __init__(
|
||||
self,
|
||||
@ -41,6 +65,7 @@ class AnswerEngine:
|
||||
self._llm = llm
|
||||
self._kb = kb
|
||||
self._snapshot = snapshot
|
||||
self._state: dict[str, ConversationState] = {}
|
||||
|
||||
async def answer(
|
||||
self,
|
||||
@ -49,6 +74,7 @@ class AnswerEngine:
|
||||
mode: str,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
observer: Callable[[str, str], None] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> AnswerResult:
|
||||
question = (question or "").strip()
|
||||
if not question:
|
||||
@ -56,10 +82,15 @@ class AnswerEngine:
|
||||
if mode == "stock":
|
||||
return await self._answer_stock(question)
|
||||
|
||||
state = self._get_state(conversation_id)
|
||||
snapshot = self._snapshot.get()
|
||||
snapshot_used = snapshot
|
||||
if self._settings.snapshot_pin_enabled and state and state.snapshot:
|
||||
snapshot_used = state.snapshot
|
||||
summary = build_summary(snapshot_used)
|
||||
kb_summary = self._kb.summary()
|
||||
runbooks = self._kb.runbook_titles(limit=4)
|
||||
snapshot_ctx = summary_text(snapshot)
|
||||
snapshot_ctx = summary_text(snapshot_used)
|
||||
history_ctx = _format_history(history)
|
||||
base_context = _join_context([
|
||||
kb_summary,
|
||||
@ -97,15 +128,30 @@ class AnswerEngine:
|
||||
"atlasbot_selection",
|
||||
extra={"extra": {"mode": mode, "selected": len(best), "scores": scores.__dict__}},
|
||||
)
|
||||
if classify.get("follow_up") and state and state.claims:
|
||||
if observer:
|
||||
observer("followup", "answering follow-up")
|
||||
reply = await self._answer_followup(question, state, summary, classify, mode)
|
||||
meta = {
|
||||
"mode": mode,
|
||||
"follow_up": True,
|
||||
"classify": classify,
|
||||
}
|
||||
return AnswerResult(reply, scores, meta)
|
||||
|
||||
if observer:
|
||||
observer("synthesize", "synthesizing reply")
|
||||
reply = await self._synthesize(question, best, base_context, classify, mode)
|
||||
claims = await self._extract_claims(question, reply, summary)
|
||||
if conversation_id and claims:
|
||||
self._store_state(conversation_id, claims, summary, snapshot_used)
|
||||
meta = {
|
||||
"mode": mode,
|
||||
"angles": angles,
|
||||
"scores": scores.__dict__,
|
||||
"classify": classify,
|
||||
"candidates": len(candidates),
|
||||
"claims": len(claims),
|
||||
}
|
||||
return AnswerResult(reply, scores, meta)
|
||||
|
||||
@ -121,6 +167,8 @@ class AnswerEngine:
|
||||
data = _parse_json_block(raw, fallback={"needs_snapshot": True})
|
||||
if "answer_style" not in data:
|
||||
data["answer_style"] = "direct"
|
||||
if "follow_up_kind" not in data:
|
||||
data["follow_up_kind"] = "other"
|
||||
return data
|
||||
|
||||
async def _angles(self, question: str, classify: dict[str, Any], mode: str) -> list[dict[str, Any]]:
|
||||
@ -216,6 +264,114 @@ class AnswerEngine:
|
||||
refine_messages = build_messages(prompts.CLUSTER_SYSTEM, refine_prompt, context=context)
|
||||
return await self._llm.chat(refine_messages, model=model)
|
||||
|
||||
async def _extract_claims(
|
||||
self,
|
||||
question: str,
|
||||
reply: str,
|
||||
summary: dict[str, Any],
|
||||
) -> list[ClaimItem]:
|
||||
if not reply or not summary:
|
||||
return []
|
||||
summary_json = _json_excerpt(summary)
|
||||
prompt = prompts.CLAIM_MAP_PROMPT + "\nQuestion: " + question + "\nAnswer: " + reply
|
||||
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=f"SnapshotSummaryJson:{summary_json}")
|
||||
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
|
||||
data = _parse_json_block(raw, fallback={})
|
||||
claims_raw = data.get("claims") if isinstance(data, dict) else None
|
||||
claims: list[ClaimItem] = []
|
||||
if isinstance(claims_raw, list):
|
||||
for entry in claims_raw:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
claim_text = str(entry.get("claim") or "").strip()
|
||||
claim_id = str(entry.get("id") or "").strip() or f"c{len(claims)+1}"
|
||||
evidence_items: list[EvidenceItem] = []
|
||||
for ev in entry.get("evidence") or []:
|
||||
if not isinstance(ev, dict):
|
||||
continue
|
||||
path = str(ev.get("path") or "").strip()
|
||||
if not path:
|
||||
continue
|
||||
reason = str(ev.get("reason") or "").strip()
|
||||
value = _resolve_path(summary, path)
|
||||
evidence_items.append(EvidenceItem(path=path, reason=reason, value=value, value_at_claim=value))
|
||||
if claim_text and evidence_items:
|
||||
claims.append(ClaimItem(id=claim_id, claim=claim_text, evidence=evidence_items))
|
||||
return claims
|
||||
|
||||
async def _answer_followup(
|
||||
self,
|
||||
question: str,
|
||||
state: ConversationState,
|
||||
summary: dict[str, Any],
|
||||
classify: dict[str, Any],
|
||||
mode: str,
|
||||
) -> str:
|
||||
follow_kind = classify.get("follow_up_kind") if isinstance(classify, dict) else "other"
|
||||
claim_ids = await self._select_claims(question, state.claims)
|
||||
selected = [claim for claim in state.claims if claim.id in claim_ids] if claim_ids else state.claims[:2]
|
||||
evidence_lines = []
|
||||
for claim in selected:
|
||||
evidence_lines.append(f"Claim: {claim.claim}")
|
||||
for ev in claim.evidence:
|
||||
current = _resolve_path(summary, ev.path)
|
||||
ev.value = current
|
||||
delta_note = ""
|
||||
if ev.value_at_claim is not None and current is not None and current != ev.value_at_claim:
|
||||
delta_note = f" (now {current})"
|
||||
evidence_lines.append(f"- {ev.path}: {ev.value_at_claim}{delta_note} {('- ' + ev.reason) if ev.reason else ''}")
|
||||
evidence_ctx = "\n".join(evidence_lines)
|
||||
prompt = prompts.FOLLOWUP_EVIDENCE_PROMPT
|
||||
if follow_kind in {"next_steps", "change"}:
|
||||
prompt = prompts.FOLLOWUP_ACTION_PROMPT
|
||||
prompt = prompt + "\nFollow-up: " + question + "\nEvidence:\n" + evidence_ctx
|
||||
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
|
||||
model = _synthesis_model(self._settings, mode)
|
||||
return await self._llm.chat(messages, model=model)
|
||||
|
||||
async def _select_claims(self, question: str, claims: list[ClaimItem]) -> list[str]:
|
||||
if not claims:
|
||||
return []
|
||||
claims_brief = [{"id": claim.id, "claim": claim.claim} for claim in claims]
|
||||
prompt = prompts.SELECT_CLAIMS_PROMPT + "\nFollow-up: " + question + "\nClaims: " + json.dumps(claims_brief)
|
||||
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
|
||||
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
|
||||
data = _parse_json_block(raw, fallback={})
|
||||
ids = data.get("claim_ids") if isinstance(data, dict) else []
|
||||
if isinstance(ids, list):
|
||||
return [str(item) for item in ids if item]
|
||||
return []
|
||||
|
||||
def _get_state(self, conversation_id: str | None) -> ConversationState | None:
|
||||
if not conversation_id:
|
||||
return None
|
||||
self._cleanup_state()
|
||||
return self._state.get(conversation_id)
|
||||
|
||||
def _store_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
claims: list[ClaimItem],
|
||||
summary: dict[str, Any],
|
||||
snapshot: dict[str, Any] | None,
|
||||
) -> None:
|
||||
snapshot_id = _snapshot_id(summary)
|
||||
pinned_snapshot = snapshot if self._settings.snapshot_pin_enabled else None
|
||||
self._state[conversation_id] = ConversationState(
|
||||
updated_at=time.monotonic(),
|
||||
claims=claims,
|
||||
snapshot_id=snapshot_id,
|
||||
snapshot=pinned_snapshot,
|
||||
)
|
||||
self._cleanup_state()
|
||||
|
||||
def _cleanup_state(self) -> None:
|
||||
ttl = max(60, self._settings.conversation_ttl_sec)
|
||||
now = time.monotonic()
|
||||
expired = [key for key, state in self._state.items() if now - state.updated_at > ttl]
|
||||
for key in expired:
|
||||
self._state.pop(key, None)
|
||||
|
||||
|
||||
def _join_context(parts: list[str]) -> str:
|
||||
text = "\n".join([p for p in parts if p])
|
||||
@ -227,12 +383,19 @@ def _format_history(history: list[dict[str, str]] | None) -> str:
|
||||
return ""
|
||||
lines = ["Recent conversation:"]
|
||||
for entry in history[-4:]:
|
||||
question = entry.get("q") if isinstance(entry, dict) else None
|
||||
answer = entry.get("a") if isinstance(entry, dict) else None
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
question = entry.get("q")
|
||||
answer = entry.get("a")
|
||||
role = entry.get("role")
|
||||
content = entry.get("content")
|
||||
if question:
|
||||
lines.append(f"Q: {question}")
|
||||
if answer:
|
||||
lines.append(f"A: {answer}")
|
||||
if role and content:
|
||||
prefix = "Q" if role == "user" else "A"
|
||||
lines.append(f"{prefix}: {content}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@ -312,3 +475,44 @@ def _coerce_int(value: Any, default: int) -> int:
|
||||
|
||||
def _default_scores() -> AnswerScores:
|
||||
return AnswerScores(confidence=60, relevance=60, satisfaction=60, hallucination_risk="medium")
|
||||
|
||||
|
||||
def _resolve_path(data: Any, path: str) -> Any | None:
|
||||
cursor = data
|
||||
for part in re.split(r"\\.(?![^\\[]*\\])", path):
|
||||
if not part:
|
||||
continue
|
||||
match = re.match(r"^(\\w+)(?:\\[(\\d+)\\])?$", part)
|
||||
if not match:
|
||||
return None
|
||||
key = match.group(1)
|
||||
index = match.group(2)
|
||||
if isinstance(cursor, dict):
|
||||
cursor = cursor.get(key)
|
||||
else:
|
||||
return None
|
||||
if index is not None:
|
||||
try:
|
||||
idx = int(index)
|
||||
if isinstance(cursor, list) and 0 <= idx < len(cursor):
|
||||
cursor = cursor[idx]
|
||||
else:
|
||||
return None
|
||||
except ValueError:
|
||||
return None
|
||||
return cursor
|
||||
|
||||
|
||||
def _snapshot_id(summary: dict[str, Any]) -> str | None:
|
||||
if not summary:
|
||||
return None
|
||||
for key in ("generated_at", "snapshot_ts", "snapshot_id"):
|
||||
value = summary.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _json_excerpt(summary: dict[str, Any], max_chars: int = 12000) -> str:
|
||||
raw = json.dumps(summary, ensure_ascii=False)
|
||||
return raw[:max_chars]
|
||||
|
||||
@ -16,7 +16,8 @@ CLASSIFY_PROMPT = (
|
||||
"needs_snapshot (bool), needs_kb (bool), needs_metrics (bool), "
|
||||
"needs_general (bool), intent (short string), ambiguity (0-1), "
|
||||
"answer_style (direct|insightful), topic_summary (short string), "
|
||||
"follow_up (bool), question_type (metric|diagnostic|planning|open_ended)."
|
||||
"follow_up (bool), follow_up_kind (evidence|why|clarify|next_steps|change|other), "
|
||||
"question_type (metric|diagnostic|planning|open_ended)."
|
||||
)
|
||||
|
||||
ANGLE_PROMPT = (
|
||||
@ -54,6 +55,37 @@ REFINE_PROMPT = (
|
||||
"Keep it grounded in the context. If you cannot add insight, say so explicitly."
|
||||
)
|
||||
|
||||
CLAIM_MAP_PROMPT = (
|
||||
"Extract a claim map from the answer. "
|
||||
"Return JSON with fields: claims (list). "
|
||||
"Each claim object: id (short string), claim (short sentence), "
|
||||
"evidence (list of objects with path and reason). "
|
||||
"Paths must point into the provided SnapshotSummary JSON using dot notation, "
|
||||
"with list indexes in brackets, e.g. metrics.node_load[0].node. "
|
||||
"Do not invent evidence; if no evidence exists, omit the claim."
|
||||
)
|
||||
|
||||
SELECT_CLAIMS_PROMPT = (
|
||||
"Pick which prior claim(s) the follow-up refers to. "
|
||||
"Return JSON with fields: claim_ids (list of ids), follow_up_kind "
|
||||
"(evidence|why|clarify|next_steps|change|other). "
|
||||
"If none apply, return an empty list."
|
||||
)
|
||||
|
||||
FOLLOWUP_EVIDENCE_PROMPT = (
|
||||
"Answer the follow-up using only the provided claims and evidence. "
|
||||
"Be conversational, not bullet-heavy. "
|
||||
"If evidence does not support a claim, say so plainly. "
|
||||
"Do not add new claims."
|
||||
)
|
||||
|
||||
FOLLOWUP_ACTION_PROMPT = (
|
||||
"Answer the follow-up using the provided claims and evidence. "
|
||||
"You may suggest next steps or changes, but keep them tightly tied "
|
||||
"to the evidence list. "
|
||||
"Be conversational and concise."
|
||||
)
|
||||
|
||||
STOCK_SYSTEM = (
|
||||
"You are Atlas, a helpful assistant. "
|
||||
"Be concise and truthful. "
|
||||
|
||||
@ -31,22 +31,24 @@ async def main() -> None:
|
||||
|
||||
async def handler(payload: dict[str, object]) -> dict[str, object]:
|
||||
history = payload.get("history") if isinstance(payload, dict) else None
|
||||
conversation_id = payload.get("conversation_id") if isinstance(payload, dict) else None
|
||||
result = await engine.answer(
|
||||
str(payload.get("question", "") or ""),
|
||||
mode=str(payload.get("mode", "quick") or "quick"),
|
||||
history=history if isinstance(history, list) else None,
|
||||
conversation_id=str(conversation_id) if isinstance(conversation_id, str) else None,
|
||||
)
|
||||
return {"reply": result.reply, "scores": result.scores.__dict__}
|
||||
|
||||
queue = QueueManager(settings, handler)
|
||||
await queue.start()
|
||||
|
||||
async def answer_handler(question: str, mode: str, history=None, observer=None) -> AnswerResult:
|
||||
async def answer_handler(question: str, mode: str, history=None, conversation_id=None, observer=None) -> AnswerResult:
|
||||
if settings.queue_enabled:
|
||||
payload = await queue.submit({"question": question, "mode": mode, "history": history or []})
|
||||
payload = await queue.submit({"question": question, "mode": mode, "history": history or [], "conversation_id": conversation_id})
|
||||
reply = payload.get("reply", "") if isinstance(payload, dict) else ""
|
||||
return AnswerResult(reply=reply or "", scores=result_scores(payload), meta={"mode": mode})
|
||||
return await engine.answer(question, mode=mode, history=history, observer=observer)
|
||||
return await engine.answer(question, mode=mode, history=history, observer=observer, conversation_id=conversation_id)
|
||||
|
||||
api = Api(settings, answer_handler)
|
||||
server = uvicorn.Server(uvicorn.Config(api.app, host="0.0.0.0", port=settings.http_port, log_level="info"))
|
||||
|
||||
@ -80,7 +80,10 @@ class MatrixBot:
|
||||
settings: Settings,
|
||||
bot: MatrixBotConfig,
|
||||
engine: AnswerEngine,
|
||||
answer_handler: Callable[[str, str, list[dict[str, str]] | None, Callable[[str, str], None] | None], Awaitable[AnswerResult]]
|
||||
answer_handler: Callable[
|
||||
[str, str, list[dict[str, str]] | None, str | None, Callable[[str, str], None] | None],
|
||||
Awaitable[AnswerResult],
|
||||
]
|
||||
| None = None,
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
@ -155,9 +158,11 @@ class MatrixBot:
|
||||
task = asyncio.create_task(heartbeat())
|
||||
started = time.monotonic()
|
||||
try:
|
||||
handler = self._answer_handler or (lambda q, m, h, obs: self._engine.answer(q, mode=m, history=h, observer=obs))
|
||||
handler = self._answer_handler or (
|
||||
lambda q, m, h, cid, obs: self._engine.answer(q, mode=m, history=h, observer=obs, conversation_id=cid)
|
||||
)
|
||||
history = self._history.get(room_id, [])
|
||||
result = await handler(question, mode, history, observer)
|
||||
result = await handler(question, mode, history, room_id, observer)
|
||||
elapsed = time.monotonic() - started
|
||||
await self._client.send_message(token, room_id, result.reply)
|
||||
log.info(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user