atlasbot: enforce mode time budgets and quick guardrails
This commit is contained in:
parent
c61f0a6847
commit
2cb601a614
@ -53,6 +53,9 @@ class Settings:
|
||||
|
||||
snapshot_ttl_sec: int
|
||||
thinking_interval_sec: int
|
||||
quick_time_budget_sec: float
|
||||
smart_time_budget_sec: float
|
||||
genius_time_budget_sec: float
|
||||
conversation_ttl_sec: int
|
||||
snapshot_pin_enabled: bool
|
||||
|
||||
@ -153,6 +156,9 @@ def load_settings() -> Settings:
|
||||
ariadne_state_token=os.getenv("ARIADNE_STATE_TOKEN", ""),
|
||||
snapshot_ttl_sec=_env_int("ATLASBOT_SNAPSHOT_TTL_SEC", "30"),
|
||||
thinking_interval_sec=_env_int("ATLASBOT_THINKING_INTERVAL_SEC", "30"),
|
||||
quick_time_budget_sec=_env_float("ATLASBOT_QUICK_TIME_BUDGET_SEC", "15"),
|
||||
smart_time_budget_sec=_env_float("ATLASBOT_SMART_TIME_BUDGET_SEC", "45"),
|
||||
genius_time_budget_sec=_env_float("ATLASBOT_GENIUS_TIME_BUDGET_SEC", "180"),
|
||||
conversation_ttl_sec=_env_int("ATLASBOT_CONVERSATION_TTL_SEC", "900"),
|
||||
snapshot_pin_enabled=_env_bool("ATLASBOT_SNAPSHOT_PIN_ENABLED", "false"),
|
||||
state_db_path=os.getenv("ATLASBOT_STATE_DB", "/data/atlasbot_state.db"),
|
||||
|
||||
@ -2,7 +2,6 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import difflib
|
||||
@ -34,6 +33,10 @@ class LLMLimitReached(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class LLMTimeBudgetExceeded(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnswerScores:
|
||||
confidence: int
|
||||
@ -163,6 +166,7 @@ class AnswerEngine:
|
||||
call_cap = math.ceil(call_limit * self._settings.llm_limit_multiplier)
|
||||
call_count = 0
|
||||
limit_hit = False
|
||||
time_budget_hit = False
|
||||
|
||||
debug_tags = {
|
||||
"route",
|
||||
@ -177,6 +181,8 @@ class AnswerEngine:
|
||||
"select_claims",
|
||||
"evidence_fix",
|
||||
}
|
||||
started = time.monotonic()
|
||||
time_budget_sec = _mode_time_budget(self._settings, mode) if not limitless else 0.0
|
||||
|
||||
def _debug_log(name: str, payload: Any) -> None:
|
||||
if not self._settings.debug_pipeline:
|
||||
@ -184,13 +190,20 @@ class AnswerEngine:
|
||||
log.info("atlasbot_debug", extra={"extra": {"name": name, "payload": payload}})
|
||||
|
||||
async def call_llm(system: str, prompt: str, *, context: str | None = None, model: str | None = None, tag: str = "") -> str:
|
||||
nonlocal call_count, limit_hit
|
||||
nonlocal call_count, limit_hit, time_budget_hit
|
||||
if not limitless and call_count >= call_cap:
|
||||
limit_hit = True
|
||||
raise LLMLimitReached("llm_limit")
|
||||
timeout_sec = None
|
||||
if not limitless and time_budget_sec > 0:
|
||||
time_left = time_budget_sec - (time.monotonic() - started)
|
||||
if time_left <= 0:
|
||||
time_budget_hit = True
|
||||
raise LLMTimeBudgetExceeded("time_budget")
|
||||
timeout_sec = min(self._settings.ollama_timeout_sec, time_left)
|
||||
call_count += 1
|
||||
messages = build_messages(system, prompt, context=context)
|
||||
response = await self._llm.chat(messages, model=model or plan.model)
|
||||
response = await self._llm.chat(messages, model=model or plan.model, timeout_sec=timeout_sec)
|
||||
log.info(
|
||||
"atlasbot_llm_call",
|
||||
extra={"extra": {"mode": mode, "tag": tag, "call": call_count, "limit": call_cap}},
|
||||
@ -221,7 +234,6 @@ class AnswerEngine:
|
||||
metric_facts: list[str] = []
|
||||
facts_used: list[str] = []
|
||||
|
||||
started = time.monotonic()
|
||||
reply = ""
|
||||
scores = _default_scores()
|
||||
claims: list[ClaimItem] = []
|
||||
@ -283,7 +295,17 @@ class AnswerEngine:
|
||||
metric_facts = _merge_fact_lines([spine_line], metric_facts)
|
||||
if spine_answer and mode in {"fast", "quick"}:
|
||||
scores = _default_scores()
|
||||
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started)
|
||||
meta = _build_meta(
|
||||
mode,
|
||||
call_count,
|
||||
call_cap,
|
||||
limit_hit,
|
||||
time_budget_hit,
|
||||
time_budget_sec,
|
||||
classify,
|
||||
tool_hint,
|
||||
started,
|
||||
)
|
||||
return AnswerResult(spine_answer, scores, meta)
|
||||
cluster_terms = (
|
||||
"atlas",
|
||||
@ -353,7 +375,17 @@ class AnswerEngine:
|
||||
observer("followup", "answering follow-up")
|
||||
reply = await self._answer_followup(question, state, summary, classify, plan, call_llm)
|
||||
scores = await self._score_answer(question, reply, plan, call_llm)
|
||||
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started)
|
||||
meta = _build_meta(
|
||||
mode,
|
||||
call_count,
|
||||
call_cap,
|
||||
limit_hit,
|
||||
time_budget_hit,
|
||||
time_budget_sec,
|
||||
classify,
|
||||
tool_hint,
|
||||
started,
|
||||
)
|
||||
return AnswerResult(reply, scores, meta)
|
||||
|
||||
if observer:
|
||||
@ -879,6 +911,25 @@ class AnswerEngine:
|
||||
|
||||
scores = await self._score_answer(normalized, reply, plan, call_llm)
|
||||
claims = await self._extract_claims(normalized, reply, summary, facts_used, call_llm)
|
||||
except LLMTimeBudgetExceeded:
|
||||
time_budget_hit = True
|
||||
if not reply:
|
||||
budget = max(1, int(round(time_budget_sec))) if time_budget_sec > 0 else 0
|
||||
if mode in {"quick", "fast"}:
|
||||
budget_text = f"{budget}s" if budget else "its configured"
|
||||
reply = (
|
||||
f"Quick mode hit {budget_text} time budget before finishing. "
|
||||
"Try atlas-smart for a deeper answer."
|
||||
)
|
||||
elif mode == "smart":
|
||||
budget_text = f"{budget}s" if budget else "its configured"
|
||||
reply = (
|
||||
f"Smart mode hit {budget_text} time budget before finishing. "
|
||||
"Try atlas-genius or ask a narrower follow-up."
|
||||
)
|
||||
else:
|
||||
reply = "I ran out of time before I could finish this answer."
|
||||
scores = _default_scores()
|
||||
except LLMLimitReached:
|
||||
if not reply:
|
||||
reply = "I started working on this but hit my reasoning limit. Ask again with 'Run limitless' for a deeper pass."
|
||||
@ -887,7 +938,17 @@ class AnswerEngine:
|
||||
elapsed = round(time.monotonic() - started, 2)
|
||||
log.info(
|
||||
"atlasbot_answer",
|
||||
extra={"extra": {"mode": mode, "seconds": elapsed, "llm_calls": call_count, "limit": call_cap, "limit_hit": limit_hit}},
|
||||
extra={
|
||||
"extra": {
|
||||
"mode": mode,
|
||||
"seconds": elapsed,
|
||||
"llm_calls": call_count,
|
||||
"limit": call_cap,
|
||||
"limit_hit": limit_hit,
|
||||
"time_budget_sec": time_budget_sec,
|
||||
"time_budget_hit": time_budget_hit,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if limit_hit and "run limitless" not in reply.lower():
|
||||
@ -896,7 +957,17 @@ class AnswerEngine:
|
||||
if conversation_id and claims:
|
||||
self._store_state(conversation_id, claims, summary, snapshot_used, pin_snapshot)
|
||||
|
||||
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started)
|
||||
meta = _build_meta(
|
||||
mode,
|
||||
call_count,
|
||||
call_cap,
|
||||
limit_hit,
|
||||
time_budget_hit,
|
||||
time_budget_sec,
|
||||
classify,
|
||||
tool_hint,
|
||||
started,
|
||||
)
|
||||
return AnswerResult(reply, scores, meta)
|
||||
|
||||
async def _answer_stock(self, question: str) -> AnswerResult:
|
||||
@ -1188,6 +1259,8 @@ def _build_meta( # noqa: PLR0913
|
||||
call_count: int,
|
||||
call_cap: int,
|
||||
limit_hit: bool,
|
||||
time_budget_hit: bool,
|
||||
time_budget_sec: float,
|
||||
classify: dict[str, Any],
|
||||
tool_hint: dict[str, Any] | None,
|
||||
started: float,
|
||||
@ -1197,6 +1270,8 @@ def _build_meta( # noqa: PLR0913
|
||||
"llm_calls": call_count,
|
||||
"llm_limit": call_cap,
|
||||
"llm_limit_hit": limit_hit,
|
||||
"time_budget_sec": time_budget_sec,
|
||||
"time_budget_hit": time_budget_hit,
|
||||
"classify": classify,
|
||||
"tool_hint": tool_hint,
|
||||
"elapsed_sec": round(time.monotonic() - started, 2),
|
||||
@ -1251,12 +1326,12 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
|
||||
return ModePlan(
|
||||
model=settings.ollama_model_fast,
|
||||
fast_model=settings.ollama_model_fast,
|
||||
max_subquestions=2,
|
||||
chunk_lines=12,
|
||||
chunk_top=5,
|
||||
max_subquestions=1,
|
||||
chunk_lines=16,
|
||||
chunk_top=3,
|
||||
chunk_group=5,
|
||||
kb_max_chars=1200,
|
||||
kb_max_files=6,
|
||||
kb_max_chars=800,
|
||||
kb_max_files=4,
|
||||
use_raw_snapshot=False,
|
||||
parallelism=1,
|
||||
score_retries=1,
|
||||
@ -1279,6 +1354,14 @@ def _llm_call_limit(settings: Settings, mode: str) -> int:
|
||||
return settings.fast_llm_calls_max
|
||||
|
||||
|
||||
def _mode_time_budget(settings: Settings, mode: str) -> float:
|
||||
if mode == "genius":
|
||||
return max(0.0, settings.genius_time_budget_sec)
|
||||
if mode == "smart":
|
||||
return max(0.0, settings.smart_time_budget_sec)
|
||||
return max(0.0, settings.quick_time_budget_sec)
|
||||
|
||||
|
||||
def _select_subquestions(parts: list[dict[str, Any]], fallback: str, limit: int) -> list[str]:
|
||||
if not parts:
|
||||
return [fallback]
|
||||
|
||||
@ -29,15 +29,22 @@ class LLMClient:
|
||||
return base
|
||||
return base + "/api/chat"
|
||||
|
||||
async def chat(self, messages: list[dict[str, str]], *, model: str | None = None) -> str:
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
*,
|
||||
model: str | None = None,
|
||||
timeout_sec: float | None = None,
|
||||
) -> str:
|
||||
payload = {
|
||||
"model": model or self._settings.ollama_model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
}
|
||||
timeout = timeout_sec if timeout_sec is not None else self._timeout
|
||||
for attempt in range(max(1, self._settings.ollama_retries + 1)):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
resp = await client.post(self._endpoint(), json=payload, headers=self._headers)
|
||||
if resp.status_code == FALLBACK_STATUS_CODE and self._settings.ollama_fallback_model:
|
||||
payload["model"] = self._settings.ollama_fallback_model
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from dataclasses import replace
|
||||
|
||||
from atlasbot.engine.answerer import AnswerEngine
|
||||
from atlasbot.knowledge.loader import KnowledgeBase
|
||||
@ -10,7 +11,7 @@ class FakeLLM:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[str] = []
|
||||
|
||||
async def chat(self, messages, *, model=None):
|
||||
async def chat(self, messages, *, model=None, timeout_sec=None):
|
||||
prompt = messages[-1]["content"]
|
||||
self.calls.append(prompt)
|
||||
if "normalized" in prompt and "keywords" in prompt:
|
||||
@ -30,6 +31,12 @@ class FakeLLM:
|
||||
return "{}"
|
||||
|
||||
|
||||
class SlowFakeLLM(FakeLLM):
|
||||
async def chat(self, messages, *, model=None, timeout_sec=None):
|
||||
await asyncio.sleep(0.02)
|
||||
return await super().chat(messages, model=model, timeout_sec=timeout_sec)
|
||||
|
||||
|
||||
def _settings() -> Settings:
|
||||
return Settings(
|
||||
matrix_base="",
|
||||
@ -57,6 +64,9 @@ def _settings() -> Settings:
|
||||
ariadne_state_token="",
|
||||
snapshot_ttl_sec=30,
|
||||
thinking_interval_sec=30,
|
||||
quick_time_budget_sec=15.0,
|
||||
smart_time_budget_sec=45.0,
|
||||
genius_time_budget_sec=180.0,
|
||||
conversation_ttl_sec=300,
|
||||
snapshot_pin_enabled=False,
|
||||
queue_enabled=False,
|
||||
@ -74,6 +84,7 @@ def _settings() -> Settings:
|
||||
smart_llm_calls_max=17,
|
||||
genius_llm_calls_max=32,
|
||||
llm_limit_multiplier=1.5,
|
||||
state_db_path="/tmp/atlasbot_test_state.db",
|
||||
)
|
||||
|
||||
|
||||
@ -86,3 +97,15 @@ def test_engine_answer_basic():
|
||||
|
||||
result = asyncio.run(engine.answer("What is Atlas?", mode="quick"))
|
||||
assert "Atlas has 22 nodes" in result.reply
|
||||
|
||||
|
||||
def test_quick_mode_time_budget_guard():
|
||||
llm = SlowFakeLLM()
|
||||
settings = replace(_settings(), quick_time_budget_sec=0.01)
|
||||
kb = KnowledgeBase("")
|
||||
snapshot = SnapshotProvider(settings)
|
||||
engine = AnswerEngine(settings, llm, kb, snapshot)
|
||||
|
||||
result = asyncio.run(engine.answer("What is Atlas?", mode="quick"))
|
||||
assert "time budget" in result.reply
|
||||
assert result.meta.get("time_budget_hit") is True
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user