atlasbot: enforce mode time budgets and quick guardrails

This commit is contained in:
Brad Stein 2026-03-30 02:52:37 -03:00
parent c61f0a6847
commit 2cb601a614
4 changed files with 135 additions and 16 deletions

View File

@ -53,6 +53,9 @@ class Settings:
snapshot_ttl_sec: int
thinking_interval_sec: int
quick_time_budget_sec: float
smart_time_budget_sec: float
genius_time_budget_sec: float
conversation_ttl_sec: int
snapshot_pin_enabled: bool
@ -153,6 +156,9 @@ def load_settings() -> Settings:
ariadne_state_token=os.getenv("ARIADNE_STATE_TOKEN", ""),
snapshot_ttl_sec=_env_int("ATLASBOT_SNAPSHOT_TTL_SEC", "30"),
thinking_interval_sec=_env_int("ATLASBOT_THINKING_INTERVAL_SEC", "30"),
quick_time_budget_sec=_env_float("ATLASBOT_QUICK_TIME_BUDGET_SEC", "15"),
smart_time_budget_sec=_env_float("ATLASBOT_SMART_TIME_BUDGET_SEC", "45"),
genius_time_budget_sec=_env_float("ATLASBOT_GENIUS_TIME_BUDGET_SEC", "180"),
conversation_ttl_sec=_env_int("ATLASBOT_CONVERSATION_TTL_SEC", "900"),
snapshot_pin_enabled=_env_bool("ATLASBOT_SNAPSHOT_PIN_ENABLED", "false"),
state_db_path=os.getenv("ATLASBOT_STATE_DB", "/data/atlasbot_state.db"),

View File

@ -2,7 +2,6 @@ import asyncio
import json
import logging
import math
import asyncio
import re
import time
import difflib
@ -34,6 +33,10 @@ class LLMLimitReached(RuntimeError):
pass
class LLMTimeBudgetExceeded(RuntimeError):
pass
@dataclass
class AnswerScores:
confidence: int
@ -163,6 +166,7 @@ class AnswerEngine:
call_cap = math.ceil(call_limit * self._settings.llm_limit_multiplier)
call_count = 0
limit_hit = False
time_budget_hit = False
debug_tags = {
"route",
@ -177,6 +181,8 @@ class AnswerEngine:
"select_claims",
"evidence_fix",
}
started = time.monotonic()
time_budget_sec = _mode_time_budget(self._settings, mode) if not limitless else 0.0
def _debug_log(name: str, payload: Any) -> None:
if not self._settings.debug_pipeline:
@ -184,13 +190,20 @@ class AnswerEngine:
log.info("atlasbot_debug", extra={"extra": {"name": name, "payload": payload}})
async def call_llm(system: str, prompt: str, *, context: str | None = None, model: str | None = None, tag: str = "") -> str:
nonlocal call_count, limit_hit
nonlocal call_count, limit_hit, time_budget_hit
if not limitless and call_count >= call_cap:
limit_hit = True
raise LLMLimitReached("llm_limit")
timeout_sec = None
if not limitless and time_budget_sec > 0:
time_left = time_budget_sec - (time.monotonic() - started)
if time_left <= 0:
time_budget_hit = True
raise LLMTimeBudgetExceeded("time_budget")
timeout_sec = min(self._settings.ollama_timeout_sec, time_left)
call_count += 1
messages = build_messages(system, prompt, context=context)
response = await self._llm.chat(messages, model=model or plan.model)
response = await self._llm.chat(messages, model=model or plan.model, timeout_sec=timeout_sec)
log.info(
"atlasbot_llm_call",
extra={"extra": {"mode": mode, "tag": tag, "call": call_count, "limit": call_cap}},
@ -221,7 +234,6 @@ class AnswerEngine:
metric_facts: list[str] = []
facts_used: list[str] = []
started = time.monotonic()
reply = ""
scores = _default_scores()
claims: list[ClaimItem] = []
@ -283,7 +295,17 @@ class AnswerEngine:
metric_facts = _merge_fact_lines([spine_line], metric_facts)
if spine_answer and mode in {"fast", "quick"}:
scores = _default_scores()
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started)
meta = _build_meta(
mode,
call_count,
call_cap,
limit_hit,
time_budget_hit,
time_budget_sec,
classify,
tool_hint,
started,
)
return AnswerResult(spine_answer, scores, meta)
cluster_terms = (
"atlas",
@ -353,7 +375,17 @@ class AnswerEngine:
observer("followup", "answering follow-up")
reply = await self._answer_followup(question, state, summary, classify, plan, call_llm)
scores = await self._score_answer(question, reply, plan, call_llm)
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started)
meta = _build_meta(
mode,
call_count,
call_cap,
limit_hit,
time_budget_hit,
time_budget_sec,
classify,
tool_hint,
started,
)
return AnswerResult(reply, scores, meta)
if observer:
@ -879,6 +911,25 @@ class AnswerEngine:
scores = await self._score_answer(normalized, reply, plan, call_llm)
claims = await self._extract_claims(normalized, reply, summary, facts_used, call_llm)
except LLMTimeBudgetExceeded:
time_budget_hit = True
if not reply:
budget = max(1, int(round(time_budget_sec))) if time_budget_sec > 0 else 0
if mode in {"quick", "fast"}:
budget_text = f"{budget}s" if budget else "its configured"
reply = (
f"Quick mode hit {budget_text} time budget before finishing. "
"Try atlas-smart for a deeper answer."
)
elif mode == "smart":
budget_text = f"{budget}s" if budget else "its configured"
reply = (
f"Smart mode hit {budget_text} time budget before finishing. "
"Try atlas-genius or ask a narrower follow-up."
)
else:
reply = "I ran out of time before I could finish this answer."
scores = _default_scores()
except LLMLimitReached:
if not reply:
reply = "I started working on this but hit my reasoning limit. Ask again with 'Run limitless' for a deeper pass."
@ -887,7 +938,17 @@ class AnswerEngine:
elapsed = round(time.monotonic() - started, 2)
log.info(
"atlasbot_answer",
extra={"extra": {"mode": mode, "seconds": elapsed, "llm_calls": call_count, "limit": call_cap, "limit_hit": limit_hit}},
extra={
"extra": {
"mode": mode,
"seconds": elapsed,
"llm_calls": call_count,
"limit": call_cap,
"limit_hit": limit_hit,
"time_budget_sec": time_budget_sec,
"time_budget_hit": time_budget_hit,
}
},
)
if limit_hit and "run limitless" not in reply.lower():
@ -896,7 +957,17 @@ class AnswerEngine:
if conversation_id and claims:
self._store_state(conversation_id, claims, summary, snapshot_used, pin_snapshot)
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started)
meta = _build_meta(
mode,
call_count,
call_cap,
limit_hit,
time_budget_hit,
time_budget_sec,
classify,
tool_hint,
started,
)
return AnswerResult(reply, scores, meta)
async def _answer_stock(self, question: str) -> AnswerResult:
@ -1188,6 +1259,8 @@ def _build_meta( # noqa: PLR0913
call_count: int,
call_cap: int,
limit_hit: bool,
time_budget_hit: bool,
time_budget_sec: float,
classify: dict[str, Any],
tool_hint: dict[str, Any] | None,
started: float,
@ -1197,6 +1270,8 @@ def _build_meta( # noqa: PLR0913
"llm_calls": call_count,
"llm_limit": call_cap,
"llm_limit_hit": limit_hit,
"time_budget_sec": time_budget_sec,
"time_budget_hit": time_budget_hit,
"classify": classify,
"tool_hint": tool_hint,
"elapsed_sec": round(time.monotonic() - started, 2),
@ -1251,12 +1326,12 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
return ModePlan(
model=settings.ollama_model_fast,
fast_model=settings.ollama_model_fast,
max_subquestions=2,
chunk_lines=12,
chunk_top=5,
max_subquestions=1,
chunk_lines=16,
chunk_top=3,
chunk_group=5,
kb_max_chars=1200,
kb_max_files=6,
kb_max_chars=800,
kb_max_files=4,
use_raw_snapshot=False,
parallelism=1,
score_retries=1,
@ -1279,6 +1354,14 @@ def _llm_call_limit(settings: Settings, mode: str) -> int:
return settings.fast_llm_calls_max
def _mode_time_budget(settings: Settings, mode: str) -> float:
if mode == "genius":
return max(0.0, settings.genius_time_budget_sec)
if mode == "smart":
return max(0.0, settings.smart_time_budget_sec)
return max(0.0, settings.quick_time_budget_sec)
def _select_subquestions(parts: list[dict[str, Any]], fallback: str, limit: int) -> list[str]:
if not parts:
return [fallback]

View File

@ -29,15 +29,22 @@ class LLMClient:
return base
return base + "/api/chat"
async def chat(self, messages: list[dict[str, str]], *, model: str | None = None) -> str:
async def chat(
self,
messages: list[dict[str, str]],
*,
model: str | None = None,
timeout_sec: float | None = None,
) -> str:
payload = {
"model": model or self._settings.ollama_model,
"messages": messages,
"stream": False,
}
timeout = timeout_sec if timeout_sec is not None else self._timeout
for attempt in range(max(1, self._settings.ollama_retries + 1)):
try:
async with httpx.AsyncClient(timeout=self._timeout) as client:
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(self._endpoint(), json=payload, headers=self._headers)
if resp.status_code == FALLBACK_STATUS_CODE and self._settings.ollama_fallback_model:
payload["model"] = self._settings.ollama_fallback_model

View File

@ -1,4 +1,5 @@
import asyncio
from dataclasses import replace
from atlasbot.engine.answerer import AnswerEngine
from atlasbot.knowledge.loader import KnowledgeBase
@ -10,7 +11,7 @@ class FakeLLM:
def __init__(self) -> None:
self.calls: list[str] = []
async def chat(self, messages, *, model=None):
async def chat(self, messages, *, model=None, timeout_sec=None):
prompt = messages[-1]["content"]
self.calls.append(prompt)
if "normalized" in prompt and "keywords" in prompt:
@ -30,6 +31,12 @@ class FakeLLM:
return "{}"
class SlowFakeLLM(FakeLLM):
async def chat(self, messages, *, model=None, timeout_sec=None):
await asyncio.sleep(0.02)
return await super().chat(messages, model=model, timeout_sec=timeout_sec)
def _settings() -> Settings:
return Settings(
matrix_base="",
@ -57,6 +64,9 @@ def _settings() -> Settings:
ariadne_state_token="",
snapshot_ttl_sec=30,
thinking_interval_sec=30,
quick_time_budget_sec=15.0,
smart_time_budget_sec=45.0,
genius_time_budget_sec=180.0,
conversation_ttl_sec=300,
snapshot_pin_enabled=False,
queue_enabled=False,
@ -74,6 +84,7 @@ def _settings() -> Settings:
smart_llm_calls_max=17,
genius_llm_calls_max=32,
llm_limit_multiplier=1.5,
state_db_path="/tmp/atlasbot_test_state.db",
)
@ -86,3 +97,15 @@ def test_engine_answer_basic():
result = asyncio.run(engine.answer("What is Atlas?", mode="quick"))
assert "Atlas has 22 nodes" in result.reply
def test_quick_mode_time_budget_guard():
llm = SlowFakeLLM()
settings = replace(_settings(), quick_time_budget_sec=0.01)
kb = KnowledgeBase("")
snapshot = SnapshotProvider(settings)
engine = AnswerEngine(settings, llm, kb, snapshot)
result = asyncio.run(engine.answer("What is Atlas?", mode="quick"))
assert "time budget" in result.reply
assert result.meta.get("time_budget_hit") is True