diff --git a/atlasbot/api/http.py b/atlasbot/api/http.py index c164db1..49de446 100644 --- a/atlasbot/api/http.py +++ b/atlasbot/api/http.py @@ -21,6 +21,7 @@ class AnswerRequest(BaseModel): mode: str | None = None history: list[dict[str, str]] | None = None conversation_id: str | None = None + snapshot_pin: bool | None = None class AnswerResponse(BaseModel): @@ -31,7 +32,7 @@ class Api: def __init__( self, settings: Settings, - answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None], Awaitable[AnswerResult]], + answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None, bool | None], Awaitable[AnswerResult]], ) -> None: self._settings = settings self._answer_handler = answer_handler @@ -55,7 +56,7 @@ class Api: raise HTTPException(status_code=400, detail="missing question") mode = (payload.mode or "quick").strip().lower() conversation_id = payload.conversation_id - result = await self._answer_handler(question, mode, payload.history, conversation_id) + result = await self._answer_handler(question, mode, payload.history, conversation_id, payload.snapshot_pin) log.info( "answer", extra={ diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index 3fecd82..8062aa2 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -97,6 +97,7 @@ class AnswerEngine: history: list[dict[str, str]] | None = None, observer: Callable[[str, str], None] | None = None, conversation_id: str | None = None, + snapshot_pin: bool | None = None, ) -> AnswerResult: question = (question or "").strip() if not question: @@ -148,9 +149,10 @@ class AnswerEngine: return response state = self._get_state(conversation_id) + pin_snapshot = bool(snapshot_pin) or self._settings.snapshot_pin_enabled snapshot = self._snapshot.get() snapshot_used = snapshot - if self._settings.snapshot_pin_enabled and state and state.snapshot: + if pin_snapshot and state and state.snapshot: snapshot_used = state.snapshot summary = build_summary(snapshot_used) allowed_nodes = _allowed_nodes(summary) @@ -488,6 +490,11 @@ class AnswerEngine: model=plan.model, tag="evidence_fix", ) + if unknown_nodes or unknown_namespaces: + refreshed_nodes = _find_unknown_nodes(reply, allowed_nodes) + refreshed_namespaces = _find_unknown_namespaces(reply, allowed_namespaces) + if refreshed_nodes or refreshed_namespaces: + reply = _strip_unknown_entities(reply, refreshed_nodes, refreshed_namespaces) if runbook_paths and resolved_runbook and _needs_runbook_reference(normalized, runbook_paths, reply): if observer: observer("runbook_enforce", "enforcing runbook path") @@ -856,7 +863,7 @@ class AnswerEngine: ) if conversation_id and claims: - self._store_state(conversation_id, claims, summary, snapshot_used) + self._store_state(conversation_id, claims, summary, snapshot_used, pin_snapshot) meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started) return AnswerResult(reply, scores, meta) @@ -1086,9 +1093,10 @@ class AnswerEngine: claims: list[ClaimItem], summary: dict[str, Any], snapshot: dict[str, Any] | None, + pin_snapshot: bool, ) -> None: snapshot_id = _snapshot_id(summary) - pinned_snapshot = snapshot if self._settings.snapshot_pin_enabled else None + pinned_snapshot = snapshot if pin_snapshot else None payload = { "updated_at": time.monotonic(), "claims": _claims_to_payload(claims), diff --git a/atlasbot/main.py b/atlasbot/main.py index 9316356..c05961e 100644 --- a/atlasbot/main.py +++ b/atlasbot/main.py @@ -32,23 +32,47 @@ async def main() -> None: async def handler(payload: dict[str, object]) -> dict[str, object]: history = payload.get("history") if isinstance(payload, dict) else None conversation_id = payload.get("conversation_id") if isinstance(payload, dict) else None + snapshot_pin = payload.get("snapshot_pin") if isinstance(payload, dict) else None result = await engine.answer( str(payload.get("question", "") or ""), mode=str(payload.get("mode", "quick") or "quick"), history=history if isinstance(history, list) else None, conversation_id=str(conversation_id) if isinstance(conversation_id, str) else None, + snapshot_pin=bool(snapshot_pin) if isinstance(snapshot_pin, bool) else None, ) return {"reply": result.reply, "scores": result.scores.__dict__} queue = QueueManager(settings, handler) await queue.start() - async def answer_handler(question: str, mode: str, history=None, conversation_id=None, observer=None) -> AnswerResult: + async def answer_handler( + question: str, + mode: str, + history=None, + conversation_id=None, + snapshot_pin: bool | None = None, + observer=None, + ) -> AnswerResult: if settings.queue_enabled: - payload = await queue.submit({"question": question, "mode": mode, "history": history or [], "conversation_id": conversation_id}) + payload = await queue.submit( + { + "question": question, + "mode": mode, + "history": history or [], + "conversation_id": conversation_id, + "snapshot_pin": snapshot_pin, + } + ) reply = payload.get("reply", "") if isinstance(payload, dict) else "" return AnswerResult(reply=reply or "", scores=result_scores(payload), meta={"mode": mode}) - return await engine.answer(question, mode=mode, history=history, observer=observer, conversation_id=conversation_id) + return await engine.answer( + question, + mode=mode, + history=history, + observer=observer, + conversation_id=conversation_id, + snapshot_pin=snapshot_pin, + ) api = Api(settings, answer_handler) server = uvicorn.Server(uvicorn.Config(api.app, host="0.0.0.0", port=settings.http_port, log_level="info"))