feat: allow test snapshot pin and enforce node whitelist

This commit is contained in:
Brad Stein 2026-02-02 16:52:33 -03:00
parent 400be36093
commit ea11460c92
3 changed files with 41 additions and 8 deletions

View File

@ -21,6 +21,7 @@ class AnswerRequest(BaseModel):
mode: str | None = None mode: str | None = None
history: list[dict[str, str]] | None = None history: list[dict[str, str]] | None = None
conversation_id: str | None = None conversation_id: str | None = None
snapshot_pin: bool | None = None
class AnswerResponse(BaseModel): class AnswerResponse(BaseModel):
@ -31,7 +32,7 @@ class Api:
def __init__( def __init__(
self, self,
settings: Settings, settings: Settings,
answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None], Awaitable[AnswerResult]], answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None, bool | None], Awaitable[AnswerResult]],
) -> None: ) -> None:
self._settings = settings self._settings = settings
self._answer_handler = answer_handler self._answer_handler = answer_handler
@ -55,7 +56,7 @@ class Api:
raise HTTPException(status_code=400, detail="missing question") raise HTTPException(status_code=400, detail="missing question")
mode = (payload.mode or "quick").strip().lower() mode = (payload.mode or "quick").strip().lower()
conversation_id = payload.conversation_id conversation_id = payload.conversation_id
result = await self._answer_handler(question, mode, payload.history, conversation_id) result = await self._answer_handler(question, mode, payload.history, conversation_id, payload.snapshot_pin)
log.info( log.info(
"answer", "answer",
extra={ extra={

View File

@ -97,6 +97,7 @@ class AnswerEngine:
history: list[dict[str, str]] | None = None, history: list[dict[str, str]] | None = None,
observer: Callable[[str, str], None] | None = None, observer: Callable[[str, str], None] | None = None,
conversation_id: str | None = None, conversation_id: str | None = None,
snapshot_pin: bool | None = None,
) -> AnswerResult: ) -> AnswerResult:
question = (question or "").strip() question = (question or "").strip()
if not question: if not question:
@ -148,9 +149,10 @@ class AnswerEngine:
return response return response
state = self._get_state(conversation_id) state = self._get_state(conversation_id)
pin_snapshot = bool(snapshot_pin) or self._settings.snapshot_pin_enabled
snapshot = self._snapshot.get() snapshot = self._snapshot.get()
snapshot_used = snapshot snapshot_used = snapshot
if self._settings.snapshot_pin_enabled and state and state.snapshot: if pin_snapshot and state and state.snapshot:
snapshot_used = state.snapshot snapshot_used = state.snapshot
summary = build_summary(snapshot_used) summary = build_summary(snapshot_used)
allowed_nodes = _allowed_nodes(summary) allowed_nodes = _allowed_nodes(summary)
@ -488,6 +490,11 @@ class AnswerEngine:
model=plan.model, model=plan.model,
tag="evidence_fix", tag="evidence_fix",
) )
if unknown_nodes or unknown_namespaces:
refreshed_nodes = _find_unknown_nodes(reply, allowed_nodes)
refreshed_namespaces = _find_unknown_namespaces(reply, allowed_namespaces)
if refreshed_nodes or refreshed_namespaces:
reply = _strip_unknown_entities(reply, refreshed_nodes, refreshed_namespaces)
if runbook_paths and resolved_runbook and _needs_runbook_reference(normalized, runbook_paths, reply): if runbook_paths and resolved_runbook and _needs_runbook_reference(normalized, runbook_paths, reply):
if observer: if observer:
observer("runbook_enforce", "enforcing runbook path") observer("runbook_enforce", "enforcing runbook path")
@ -856,7 +863,7 @@ class AnswerEngine:
) )
if conversation_id and claims: if conversation_id and claims:
self._store_state(conversation_id, claims, summary, snapshot_used) self._store_state(conversation_id, claims, summary, snapshot_used, pin_snapshot)
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started) meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started)
return AnswerResult(reply, scores, meta) return AnswerResult(reply, scores, meta)
@ -1086,9 +1093,10 @@ class AnswerEngine:
claims: list[ClaimItem], claims: list[ClaimItem],
summary: dict[str, Any], summary: dict[str, Any],
snapshot: dict[str, Any] | None, snapshot: dict[str, Any] | None,
pin_snapshot: bool,
) -> None: ) -> None:
snapshot_id = _snapshot_id(summary) snapshot_id = _snapshot_id(summary)
pinned_snapshot = snapshot if self._settings.snapshot_pin_enabled else None pinned_snapshot = snapshot if pin_snapshot else None
payload = { payload = {
"updated_at": time.monotonic(), "updated_at": time.monotonic(),
"claims": _claims_to_payload(claims), "claims": _claims_to_payload(claims),

View File

@ -32,23 +32,47 @@ async def main() -> None:
async def handler(payload: dict[str, object]) -> dict[str, object]: async def handler(payload: dict[str, object]) -> dict[str, object]:
history = payload.get("history") if isinstance(payload, dict) else None history = payload.get("history") if isinstance(payload, dict) else None
conversation_id = payload.get("conversation_id") if isinstance(payload, dict) else None conversation_id = payload.get("conversation_id") if isinstance(payload, dict) else None
snapshot_pin = payload.get("snapshot_pin") if isinstance(payload, dict) else None
result = await engine.answer( result = await engine.answer(
str(payload.get("question", "") or ""), str(payload.get("question", "") or ""),
mode=str(payload.get("mode", "quick") or "quick"), mode=str(payload.get("mode", "quick") or "quick"),
history=history if isinstance(history, list) else None, history=history if isinstance(history, list) else None,
conversation_id=str(conversation_id) if isinstance(conversation_id, str) else None, conversation_id=str(conversation_id) if isinstance(conversation_id, str) else None,
snapshot_pin=bool(snapshot_pin) if isinstance(snapshot_pin, bool) else None,
) )
return {"reply": result.reply, "scores": result.scores.__dict__} return {"reply": result.reply, "scores": result.scores.__dict__}
queue = QueueManager(settings, handler) queue = QueueManager(settings, handler)
await queue.start() await queue.start()
async def answer_handler(question: str, mode: str, history=None, conversation_id=None, observer=None) -> AnswerResult: async def answer_handler(
question: str,
mode: str,
history=None,
conversation_id=None,
snapshot_pin: bool | None = None,
observer=None,
) -> AnswerResult:
if settings.queue_enabled: if settings.queue_enabled:
payload = await queue.submit({"question": question, "mode": mode, "history": history or [], "conversation_id": conversation_id}) payload = await queue.submit(
{
"question": question,
"mode": mode,
"history": history or [],
"conversation_id": conversation_id,
"snapshot_pin": snapshot_pin,
}
)
reply = payload.get("reply", "") if isinstance(payload, dict) else "" reply = payload.get("reply", "") if isinstance(payload, dict) else ""
return AnswerResult(reply=reply or "", scores=result_scores(payload), meta={"mode": mode}) return AnswerResult(reply=reply or "", scores=result_scores(payload), meta={"mode": mode})
return await engine.answer(question, mode=mode, history=history, observer=observer, conversation_id=conversation_id) return await engine.answer(
question,
mode=mode,
history=history,
observer=observer,
conversation_id=conversation_id,
snapshot_pin=snapshot_pin,
)
api = Api(settings, answer_handler) api = Api(settings, answer_handler)
server = uvicorn.Server(uvicorn.Config(api.app, host="0.0.0.0", port=settings.http_port, log_level="info")) server = uvicorn.Server(uvicorn.Config(api.app, host="0.0.0.0", port=settings.http_port, log_level="info"))