atlasbot: enforce mode time budgets and quick guardrails

This commit is contained in:
Brad Stein 2026-03-30 02:52:37 -03:00
parent c61f0a6847
commit 2cb601a614
4 changed files with 135 additions and 16 deletions

View File

@ -53,6 +53,9 @@ class Settings:
snapshot_ttl_sec: int snapshot_ttl_sec: int
thinking_interval_sec: int thinking_interval_sec: int
quick_time_budget_sec: float
smart_time_budget_sec: float
genius_time_budget_sec: float
conversation_ttl_sec: int conversation_ttl_sec: int
snapshot_pin_enabled: bool snapshot_pin_enabled: bool
@ -153,6 +156,9 @@ def load_settings() -> Settings:
ariadne_state_token=os.getenv("ARIADNE_STATE_TOKEN", ""), ariadne_state_token=os.getenv("ARIADNE_STATE_TOKEN", ""),
snapshot_ttl_sec=_env_int("ATLASBOT_SNAPSHOT_TTL_SEC", "30"), snapshot_ttl_sec=_env_int("ATLASBOT_SNAPSHOT_TTL_SEC", "30"),
thinking_interval_sec=_env_int("ATLASBOT_THINKING_INTERVAL_SEC", "30"), thinking_interval_sec=_env_int("ATLASBOT_THINKING_INTERVAL_SEC", "30"),
quick_time_budget_sec=_env_float("ATLASBOT_QUICK_TIME_BUDGET_SEC", "15"),
smart_time_budget_sec=_env_float("ATLASBOT_SMART_TIME_BUDGET_SEC", "45"),
genius_time_budget_sec=_env_float("ATLASBOT_GENIUS_TIME_BUDGET_SEC", "180"),
conversation_ttl_sec=_env_int("ATLASBOT_CONVERSATION_TTL_SEC", "900"), conversation_ttl_sec=_env_int("ATLASBOT_CONVERSATION_TTL_SEC", "900"),
snapshot_pin_enabled=_env_bool("ATLASBOT_SNAPSHOT_PIN_ENABLED", "false"), snapshot_pin_enabled=_env_bool("ATLASBOT_SNAPSHOT_PIN_ENABLED", "false"),
state_db_path=os.getenv("ATLASBOT_STATE_DB", "/data/atlasbot_state.db"), state_db_path=os.getenv("ATLASBOT_STATE_DB", "/data/atlasbot_state.db"),

View File

@ -2,7 +2,6 @@ import asyncio
import json import json
import logging import logging
import math import math
import asyncio
import re import re
import time import time
import difflib import difflib
@ -34,6 +33,10 @@ class LLMLimitReached(RuntimeError):
pass pass
class LLMTimeBudgetExceeded(RuntimeError):
pass
@dataclass @dataclass
class AnswerScores: class AnswerScores:
confidence: int confidence: int
@ -163,6 +166,7 @@ class AnswerEngine:
call_cap = math.ceil(call_limit * self._settings.llm_limit_multiplier) call_cap = math.ceil(call_limit * self._settings.llm_limit_multiplier)
call_count = 0 call_count = 0
limit_hit = False limit_hit = False
time_budget_hit = False
debug_tags = { debug_tags = {
"route", "route",
@ -177,6 +181,8 @@ class AnswerEngine:
"select_claims", "select_claims",
"evidence_fix", "evidence_fix",
} }
started = time.monotonic()
time_budget_sec = _mode_time_budget(self._settings, mode) if not limitless else 0.0
def _debug_log(name: str, payload: Any) -> None: def _debug_log(name: str, payload: Any) -> None:
if not self._settings.debug_pipeline: if not self._settings.debug_pipeline:
@ -184,13 +190,20 @@ class AnswerEngine:
log.info("atlasbot_debug", extra={"extra": {"name": name, "payload": payload}}) log.info("atlasbot_debug", extra={"extra": {"name": name, "payload": payload}})
async def call_llm(system: str, prompt: str, *, context: str | None = None, model: str | None = None, tag: str = "") -> str: async def call_llm(system: str, prompt: str, *, context: str | None = None, model: str | None = None, tag: str = "") -> str:
nonlocal call_count, limit_hit nonlocal call_count, limit_hit, time_budget_hit
if not limitless and call_count >= call_cap: if not limitless and call_count >= call_cap:
limit_hit = True limit_hit = True
raise LLMLimitReached("llm_limit") raise LLMLimitReached("llm_limit")
timeout_sec = None
if not limitless and time_budget_sec > 0:
time_left = time_budget_sec - (time.monotonic() - started)
if time_left <= 0:
time_budget_hit = True
raise LLMTimeBudgetExceeded("time_budget")
timeout_sec = min(self._settings.ollama_timeout_sec, time_left)
call_count += 1 call_count += 1
messages = build_messages(system, prompt, context=context) messages = build_messages(system, prompt, context=context)
response = await self._llm.chat(messages, model=model or plan.model) response = await self._llm.chat(messages, model=model or plan.model, timeout_sec=timeout_sec)
log.info( log.info(
"atlasbot_llm_call", "atlasbot_llm_call",
extra={"extra": {"mode": mode, "tag": tag, "call": call_count, "limit": call_cap}}, extra={"extra": {"mode": mode, "tag": tag, "call": call_count, "limit": call_cap}},
@ -221,7 +234,6 @@ class AnswerEngine:
metric_facts: list[str] = [] metric_facts: list[str] = []
facts_used: list[str] = [] facts_used: list[str] = []
started = time.monotonic()
reply = "" reply = ""
scores = _default_scores() scores = _default_scores()
claims: list[ClaimItem] = [] claims: list[ClaimItem] = []
@ -283,7 +295,17 @@ class AnswerEngine:
metric_facts = _merge_fact_lines([spine_line], metric_facts) metric_facts = _merge_fact_lines([spine_line], metric_facts)
if spine_answer and mode in {"fast", "quick"}: if spine_answer and mode in {"fast", "quick"}:
scores = _default_scores() scores = _default_scores()
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started) meta = _build_meta(
mode,
call_count,
call_cap,
limit_hit,
time_budget_hit,
time_budget_sec,
classify,
tool_hint,
started,
)
return AnswerResult(spine_answer, scores, meta) return AnswerResult(spine_answer, scores, meta)
cluster_terms = ( cluster_terms = (
"atlas", "atlas",
@ -353,7 +375,17 @@ class AnswerEngine:
observer("followup", "answering follow-up") observer("followup", "answering follow-up")
reply = await self._answer_followup(question, state, summary, classify, plan, call_llm) reply = await self._answer_followup(question, state, summary, classify, plan, call_llm)
scores = await self._score_answer(question, reply, plan, call_llm) scores = await self._score_answer(question, reply, plan, call_llm)
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started) meta = _build_meta(
mode,
call_count,
call_cap,
limit_hit,
time_budget_hit,
time_budget_sec,
classify,
tool_hint,
started,
)
return AnswerResult(reply, scores, meta) return AnswerResult(reply, scores, meta)
if observer: if observer:
@ -879,6 +911,25 @@ class AnswerEngine:
scores = await self._score_answer(normalized, reply, plan, call_llm) scores = await self._score_answer(normalized, reply, plan, call_llm)
claims = await self._extract_claims(normalized, reply, summary, facts_used, call_llm) claims = await self._extract_claims(normalized, reply, summary, facts_used, call_llm)
except LLMTimeBudgetExceeded:
time_budget_hit = True
if not reply:
budget = max(1, int(round(time_budget_sec))) if time_budget_sec > 0 else 0
if mode in {"quick", "fast"}:
budget_text = f"{budget}s" if budget else "its configured"
reply = (
f"Quick mode hit {budget_text} time budget before finishing. "
"Try atlas-smart for a deeper answer."
)
elif mode == "smart":
budget_text = f"{budget}s" if budget else "its configured"
reply = (
f"Smart mode hit {budget_text} time budget before finishing. "
"Try atlas-genius or ask a narrower follow-up."
)
else:
reply = "I ran out of time before I could finish this answer."
scores = _default_scores()
except LLMLimitReached: except LLMLimitReached:
if not reply: if not reply:
reply = "I started working on this but hit my reasoning limit. Ask again with 'Run limitless' for a deeper pass." reply = "I started working on this but hit my reasoning limit. Ask again with 'Run limitless' for a deeper pass."
@ -887,7 +938,17 @@ class AnswerEngine:
elapsed = round(time.monotonic() - started, 2) elapsed = round(time.monotonic() - started, 2)
log.info( log.info(
"atlasbot_answer", "atlasbot_answer",
extra={"extra": {"mode": mode, "seconds": elapsed, "llm_calls": call_count, "limit": call_cap, "limit_hit": limit_hit}}, extra={
"extra": {
"mode": mode,
"seconds": elapsed,
"llm_calls": call_count,
"limit": call_cap,
"limit_hit": limit_hit,
"time_budget_sec": time_budget_sec,
"time_budget_hit": time_budget_hit,
}
},
) )
if limit_hit and "run limitless" not in reply.lower(): if limit_hit and "run limitless" not in reply.lower():
@ -896,7 +957,17 @@ class AnswerEngine:
if conversation_id and claims: if conversation_id and claims:
self._store_state(conversation_id, claims, summary, snapshot_used, pin_snapshot) self._store_state(conversation_id, claims, summary, snapshot_used, pin_snapshot)
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started) meta = _build_meta(
mode,
call_count,
call_cap,
limit_hit,
time_budget_hit,
time_budget_sec,
classify,
tool_hint,
started,
)
return AnswerResult(reply, scores, meta) return AnswerResult(reply, scores, meta)
async def _answer_stock(self, question: str) -> AnswerResult: async def _answer_stock(self, question: str) -> AnswerResult:
@ -1188,6 +1259,8 @@ def _build_meta( # noqa: PLR0913
call_count: int, call_count: int,
call_cap: int, call_cap: int,
limit_hit: bool, limit_hit: bool,
time_budget_hit: bool,
time_budget_sec: float,
classify: dict[str, Any], classify: dict[str, Any],
tool_hint: dict[str, Any] | None, tool_hint: dict[str, Any] | None,
started: float, started: float,
@ -1197,6 +1270,8 @@ def _build_meta( # noqa: PLR0913
"llm_calls": call_count, "llm_calls": call_count,
"llm_limit": call_cap, "llm_limit": call_cap,
"llm_limit_hit": limit_hit, "llm_limit_hit": limit_hit,
"time_budget_sec": time_budget_sec,
"time_budget_hit": time_budget_hit,
"classify": classify, "classify": classify,
"tool_hint": tool_hint, "tool_hint": tool_hint,
"elapsed_sec": round(time.monotonic() - started, 2), "elapsed_sec": round(time.monotonic() - started, 2),
@ -1251,12 +1326,12 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan:
return ModePlan( return ModePlan(
model=settings.ollama_model_fast, model=settings.ollama_model_fast,
fast_model=settings.ollama_model_fast, fast_model=settings.ollama_model_fast,
max_subquestions=2, max_subquestions=1,
chunk_lines=12, chunk_lines=16,
chunk_top=5, chunk_top=3,
chunk_group=5, chunk_group=5,
kb_max_chars=1200, kb_max_chars=800,
kb_max_files=6, kb_max_files=4,
use_raw_snapshot=False, use_raw_snapshot=False,
parallelism=1, parallelism=1,
score_retries=1, score_retries=1,
@ -1279,6 +1354,14 @@ def _llm_call_limit(settings: Settings, mode: str) -> int:
return settings.fast_llm_calls_max return settings.fast_llm_calls_max
def _mode_time_budget(settings: Settings, mode: str) -> float:
if mode == "genius":
return max(0.0, settings.genius_time_budget_sec)
if mode == "smart":
return max(0.0, settings.smart_time_budget_sec)
return max(0.0, settings.quick_time_budget_sec)
def _select_subquestions(parts: list[dict[str, Any]], fallback: str, limit: int) -> list[str]: def _select_subquestions(parts: list[dict[str, Any]], fallback: str, limit: int) -> list[str]:
if not parts: if not parts:
return [fallback] return [fallback]

View File

@ -29,15 +29,22 @@ class LLMClient:
return base return base
return base + "/api/chat" return base + "/api/chat"
async def chat(self, messages: list[dict[str, str]], *, model: str | None = None) -> str: async def chat(
self,
messages: list[dict[str, str]],
*,
model: str | None = None,
timeout_sec: float | None = None,
) -> str:
payload = { payload = {
"model": model or self._settings.ollama_model, "model": model or self._settings.ollama_model,
"messages": messages, "messages": messages,
"stream": False, "stream": False,
} }
timeout = timeout_sec if timeout_sec is not None else self._timeout
for attempt in range(max(1, self._settings.ollama_retries + 1)): for attempt in range(max(1, self._settings.ollama_retries + 1)):
try: try:
async with httpx.AsyncClient(timeout=self._timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(self._endpoint(), json=payload, headers=self._headers) resp = await client.post(self._endpoint(), json=payload, headers=self._headers)
if resp.status_code == FALLBACK_STATUS_CODE and self._settings.ollama_fallback_model: if resp.status_code == FALLBACK_STATUS_CODE and self._settings.ollama_fallback_model:
payload["model"] = self._settings.ollama_fallback_model payload["model"] = self._settings.ollama_fallback_model

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
from dataclasses import replace
from atlasbot.engine.answerer import AnswerEngine from atlasbot.engine.answerer import AnswerEngine
from atlasbot.knowledge.loader import KnowledgeBase from atlasbot.knowledge.loader import KnowledgeBase
@ -10,7 +11,7 @@ class FakeLLM:
def __init__(self) -> None: def __init__(self) -> None:
self.calls: list[str] = [] self.calls: list[str] = []
async def chat(self, messages, *, model=None): async def chat(self, messages, *, model=None, timeout_sec=None):
prompt = messages[-1]["content"] prompt = messages[-1]["content"]
self.calls.append(prompt) self.calls.append(prompt)
if "normalized" in prompt and "keywords" in prompt: if "normalized" in prompt and "keywords" in prompt:
@ -30,6 +31,12 @@ class FakeLLM:
return "{}" return "{}"
class SlowFakeLLM(FakeLLM):
async def chat(self, messages, *, model=None, timeout_sec=None):
await asyncio.sleep(0.02)
return await super().chat(messages, model=model, timeout_sec=timeout_sec)
def _settings() -> Settings: def _settings() -> Settings:
return Settings( return Settings(
matrix_base="", matrix_base="",
@ -57,6 +64,9 @@ def _settings() -> Settings:
ariadne_state_token="", ariadne_state_token="",
snapshot_ttl_sec=30, snapshot_ttl_sec=30,
thinking_interval_sec=30, thinking_interval_sec=30,
quick_time_budget_sec=15.0,
smart_time_budget_sec=45.0,
genius_time_budget_sec=180.0,
conversation_ttl_sec=300, conversation_ttl_sec=300,
snapshot_pin_enabled=False, snapshot_pin_enabled=False,
queue_enabled=False, queue_enabled=False,
@ -74,6 +84,7 @@ def _settings() -> Settings:
smart_llm_calls_max=17, smart_llm_calls_max=17,
genius_llm_calls_max=32, genius_llm_calls_max=32,
llm_limit_multiplier=1.5, llm_limit_multiplier=1.5,
state_db_path="/tmp/atlasbot_test_state.db",
) )
@ -86,3 +97,15 @@ def test_engine_answer_basic():
result = asyncio.run(engine.answer("What is Atlas?", mode="quick")) result = asyncio.run(engine.answer("What is Atlas?", mode="quick"))
assert "Atlas has 22 nodes" in result.reply assert "Atlas has 22 nodes" in result.reply
def test_quick_mode_time_budget_guard():
llm = SlowFakeLLM()
settings = replace(_settings(), quick_time_budget_sec=0.01)
kb = KnowledgeBase("")
snapshot = SnapshotProvider(settings)
engine = AnswerEngine(settings, llm, kb, snapshot)
result = asyncio.run(engine.answer("What is Atlas?", mode="quick"))
assert "time budget" in result.reply
assert result.meta.get("time_budget_hit") is True