feat: allow test snapshot pin and enforce node whitelist
This commit is contained in:
parent
400be36093
commit
ea11460c92
@ -21,6 +21,7 @@ class AnswerRequest(BaseModel):
|
||||
mode: str | None = None
|
||||
history: list[dict[str, str]] | None = None
|
||||
conversation_id: str | None = None
|
||||
snapshot_pin: bool | None = None
|
||||
|
||||
|
||||
class AnswerResponse(BaseModel):
|
||||
@ -31,7 +32,7 @@ class Api:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None], Awaitable[AnswerResult]],
|
||||
answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None, bool | None], Awaitable[AnswerResult]],
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
self._answer_handler = answer_handler
|
||||
@ -55,7 +56,7 @@ class Api:
|
||||
raise HTTPException(status_code=400, detail="missing question")
|
||||
mode = (payload.mode or "quick").strip().lower()
|
||||
conversation_id = payload.conversation_id
|
||||
result = await self._answer_handler(question, mode, payload.history, conversation_id)
|
||||
result = await self._answer_handler(question, mode, payload.history, conversation_id, payload.snapshot_pin)
|
||||
log.info(
|
||||
"answer",
|
||||
extra={
|
||||
|
||||
@ -97,6 +97,7 @@ class AnswerEngine:
|
||||
history: list[dict[str, str]] | None = None,
|
||||
observer: Callable[[str, str], None] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
snapshot_pin: bool | None = None,
|
||||
) -> AnswerResult:
|
||||
question = (question or "").strip()
|
||||
if not question:
|
||||
@ -148,9 +149,10 @@ class AnswerEngine:
|
||||
return response
|
||||
|
||||
state = self._get_state(conversation_id)
|
||||
pin_snapshot = bool(snapshot_pin) or self._settings.snapshot_pin_enabled
|
||||
snapshot = self._snapshot.get()
|
||||
snapshot_used = snapshot
|
||||
if self._settings.snapshot_pin_enabled and state and state.snapshot:
|
||||
if pin_snapshot and state and state.snapshot:
|
||||
snapshot_used = state.snapshot
|
||||
summary = build_summary(snapshot_used)
|
||||
allowed_nodes = _allowed_nodes(summary)
|
||||
@ -488,6 +490,11 @@ class AnswerEngine:
|
||||
model=plan.model,
|
||||
tag="evidence_fix",
|
||||
)
|
||||
if unknown_nodes or unknown_namespaces:
|
||||
refreshed_nodes = _find_unknown_nodes(reply, allowed_nodes)
|
||||
refreshed_namespaces = _find_unknown_namespaces(reply, allowed_namespaces)
|
||||
if refreshed_nodes or refreshed_namespaces:
|
||||
reply = _strip_unknown_entities(reply, refreshed_nodes, refreshed_namespaces)
|
||||
if runbook_paths and resolved_runbook and _needs_runbook_reference(normalized, runbook_paths, reply):
|
||||
if observer:
|
||||
observer("runbook_enforce", "enforcing runbook path")
|
||||
@ -856,7 +863,7 @@ class AnswerEngine:
|
||||
)
|
||||
|
||||
if conversation_id and claims:
|
||||
self._store_state(conversation_id, claims, summary, snapshot_used)
|
||||
self._store_state(conversation_id, claims, summary, snapshot_used, pin_snapshot)
|
||||
|
||||
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started)
|
||||
return AnswerResult(reply, scores, meta)
|
||||
@ -1086,9 +1093,10 @@ class AnswerEngine:
|
||||
claims: list[ClaimItem],
|
||||
summary: dict[str, Any],
|
||||
snapshot: dict[str, Any] | None,
|
||||
pin_snapshot: bool,
|
||||
) -> None:
|
||||
snapshot_id = _snapshot_id(summary)
|
||||
pinned_snapshot = snapshot if self._settings.snapshot_pin_enabled else None
|
||||
pinned_snapshot = snapshot if pin_snapshot else None
|
||||
payload = {
|
||||
"updated_at": time.monotonic(),
|
||||
"claims": _claims_to_payload(claims),
|
||||
|
||||
@ -32,23 +32,47 @@ async def main() -> None:
|
||||
async def handler(payload: dict[str, object]) -> dict[str, object]:
|
||||
history = payload.get("history") if isinstance(payload, dict) else None
|
||||
conversation_id = payload.get("conversation_id") if isinstance(payload, dict) else None
|
||||
snapshot_pin = payload.get("snapshot_pin") if isinstance(payload, dict) else None
|
||||
result = await engine.answer(
|
||||
str(payload.get("question", "") or ""),
|
||||
mode=str(payload.get("mode", "quick") or "quick"),
|
||||
history=history if isinstance(history, list) else None,
|
||||
conversation_id=str(conversation_id) if isinstance(conversation_id, str) else None,
|
||||
snapshot_pin=bool(snapshot_pin) if isinstance(snapshot_pin, bool) else None,
|
||||
)
|
||||
return {"reply": result.reply, "scores": result.scores.__dict__}
|
||||
|
||||
queue = QueueManager(settings, handler)
|
||||
await queue.start()
|
||||
|
||||
async def answer_handler(question: str, mode: str, history=None, conversation_id=None, observer=None) -> AnswerResult:
|
||||
async def answer_handler(
|
||||
question: str,
|
||||
mode: str,
|
||||
history=None,
|
||||
conversation_id=None,
|
||||
snapshot_pin: bool | None = None,
|
||||
observer=None,
|
||||
) -> AnswerResult:
|
||||
if settings.queue_enabled:
|
||||
payload = await queue.submit({"question": question, "mode": mode, "history": history or [], "conversation_id": conversation_id})
|
||||
payload = await queue.submit(
|
||||
{
|
||||
"question": question,
|
||||
"mode": mode,
|
||||
"history": history or [],
|
||||
"conversation_id": conversation_id,
|
||||
"snapshot_pin": snapshot_pin,
|
||||
}
|
||||
)
|
||||
reply = payload.get("reply", "") if isinstance(payload, dict) else ""
|
||||
return AnswerResult(reply=reply or "", scores=result_scores(payload), meta={"mode": mode})
|
||||
return await engine.answer(question, mode=mode, history=history, observer=observer, conversation_id=conversation_id)
|
||||
return await engine.answer(
|
||||
question,
|
||||
mode=mode,
|
||||
history=history,
|
||||
observer=observer,
|
||||
conversation_id=conversation_id,
|
||||
snapshot_pin=snapshot_pin,
|
||||
)
|
||||
|
||||
api = Api(settings, answer_handler)
|
||||
server = uvicorn.Server(uvicorn.Config(api.app, host="0.0.0.0", port=settings.http_port, log_level="info"))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user