feat: allow test snapshot pin and enforce node whitelist

This commit is contained in:
Brad Stein 2026-02-02 16:52:33 -03:00
parent 400be36093
commit ea11460c92
3 changed files with 41 additions and 8 deletions

View File

@ -21,6 +21,7 @@ class AnswerRequest(BaseModel):
mode: str | None = None
history: list[dict[str, str]] | None = None
conversation_id: str | None = None
snapshot_pin: bool | None = None
class AnswerResponse(BaseModel):
@ -31,7 +32,7 @@ class Api:
def __init__(
self,
settings: Settings,
answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None], Awaitable[AnswerResult]],
answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None, bool | None], Awaitable[AnswerResult]],
) -> None:
self._settings = settings
self._answer_handler = answer_handler
@ -55,7 +56,7 @@ class Api:
raise HTTPException(status_code=400, detail="missing question")
mode = (payload.mode or "quick").strip().lower()
conversation_id = payload.conversation_id
result = await self._answer_handler(question, mode, payload.history, conversation_id)
result = await self._answer_handler(question, mode, payload.history, conversation_id, payload.snapshot_pin)
log.info(
"answer",
extra={

View File

@ -97,6 +97,7 @@ class AnswerEngine:
history: list[dict[str, str]] | None = None,
observer: Callable[[str, str], None] | None = None,
conversation_id: str | None = None,
snapshot_pin: bool | None = None,
) -> AnswerResult:
question = (question or "").strip()
if not question:
@ -148,9 +149,10 @@ class AnswerEngine:
return response
state = self._get_state(conversation_id)
pin_snapshot = bool(snapshot_pin) or self._settings.snapshot_pin_enabled
snapshot = self._snapshot.get()
snapshot_used = snapshot
if self._settings.snapshot_pin_enabled and state and state.snapshot:
if pin_snapshot and state and state.snapshot:
snapshot_used = state.snapshot
summary = build_summary(snapshot_used)
allowed_nodes = _allowed_nodes(summary)
@ -488,6 +490,11 @@ class AnswerEngine:
model=plan.model,
tag="evidence_fix",
)
if unknown_nodes or unknown_namespaces:
refreshed_nodes = _find_unknown_nodes(reply, allowed_nodes)
refreshed_namespaces = _find_unknown_namespaces(reply, allowed_namespaces)
if refreshed_nodes or refreshed_namespaces:
reply = _strip_unknown_entities(reply, refreshed_nodes, refreshed_namespaces)
if runbook_paths and resolved_runbook and _needs_runbook_reference(normalized, runbook_paths, reply):
if observer:
observer("runbook_enforce", "enforcing runbook path")
@ -856,7 +863,7 @@ class AnswerEngine:
)
if conversation_id and claims:
self._store_state(conversation_id, claims, summary, snapshot_used)
self._store_state(conversation_id, claims, summary, snapshot_used, pin_snapshot)
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started)
return AnswerResult(reply, scores, meta)
@ -1086,9 +1093,10 @@ class AnswerEngine:
claims: list[ClaimItem],
summary: dict[str, Any],
snapshot: dict[str, Any] | None,
pin_snapshot: bool,
) -> None:
snapshot_id = _snapshot_id(summary)
pinned_snapshot = snapshot if self._settings.snapshot_pin_enabled else None
pinned_snapshot = snapshot if pin_snapshot else None
payload = {
"updated_at": time.monotonic(),
"claims": _claims_to_payload(claims),

View File

@ -32,23 +32,47 @@ async def main() -> None:
async def handler(payload: dict[str, object]) -> dict[str, object]:
history = payload.get("history") if isinstance(payload, dict) else None
conversation_id = payload.get("conversation_id") if isinstance(payload, dict) else None
snapshot_pin = payload.get("snapshot_pin") if isinstance(payload, dict) else None
result = await engine.answer(
str(payload.get("question", "") or ""),
mode=str(payload.get("mode", "quick") or "quick"),
history=history if isinstance(history, list) else None,
conversation_id=str(conversation_id) if isinstance(conversation_id, str) else None,
snapshot_pin=bool(snapshot_pin) if isinstance(snapshot_pin, bool) else None,
)
return {"reply": result.reply, "scores": result.scores.__dict__}
queue = QueueManager(settings, handler)
await queue.start()
async def answer_handler(question: str, mode: str, history=None, conversation_id=None, observer=None) -> AnswerResult:
async def answer_handler(
question: str,
mode: str,
history=None,
conversation_id=None,
snapshot_pin: bool | None = None,
observer=None,
) -> AnswerResult:
if settings.queue_enabled:
payload = await queue.submit({"question": question, "mode": mode, "history": history or [], "conversation_id": conversation_id})
payload = await queue.submit(
{
"question": question,
"mode": mode,
"history": history or [],
"conversation_id": conversation_id,
"snapshot_pin": snapshot_pin,
}
)
reply = payload.get("reply", "") if isinstance(payload, dict) else ""
return AnswerResult(reply=reply or "", scores=result_scores(payload), meta={"mode": mode})
return await engine.answer(question, mode=mode, history=history, observer=observer, conversation_id=conversation_id)
return await engine.answer(
question,
mode=mode,
history=history,
observer=observer,
conversation_id=conversation_id,
snapshot_pin=snapshot_pin,
)
api = Api(settings, answer_handler)
server = uvicorn.Server(uvicorn.Config(api.app, host="0.0.0.0", port=settings.http_port, log_level="info"))