From 2cb601a6143bb34ac366d3000e3f795c08560447 Mon Sep 17 00:00:00 2001 From: Brad Stein Date: Mon, 30 Mar 2026 02:52:37 -0300 Subject: [PATCH] atlasbot: enforce mode time budgets and quick guardrails --- atlasbot/config.py | 6 ++ atlasbot/engine/answerer.py | 109 +++++++++++++++++++++++++++++++----- atlasbot/llm/client.py | 11 +++- tests/test_engine.py | 25 ++++++++- 4 files changed, 135 insertions(+), 16 deletions(-) diff --git a/atlasbot/config.py b/atlasbot/config.py index b6c7a3a..bc0d321 100644 --- a/atlasbot/config.py +++ b/atlasbot/config.py @@ -53,6 +53,9 @@ class Settings: snapshot_ttl_sec: int thinking_interval_sec: int + quick_time_budget_sec: float + smart_time_budget_sec: float + genius_time_budget_sec: float conversation_ttl_sec: int snapshot_pin_enabled: bool @@ -153,6 +156,9 @@ def load_settings() -> Settings: ariadne_state_token=os.getenv("ARIADNE_STATE_TOKEN", ""), snapshot_ttl_sec=_env_int("ATLASBOT_SNAPSHOT_TTL_SEC", "30"), thinking_interval_sec=_env_int("ATLASBOT_THINKING_INTERVAL_SEC", "30"), + quick_time_budget_sec=_env_float("ATLASBOT_QUICK_TIME_BUDGET_SEC", "15"), + smart_time_budget_sec=_env_float("ATLASBOT_SMART_TIME_BUDGET_SEC", "45"), + genius_time_budget_sec=_env_float("ATLASBOT_GENIUS_TIME_BUDGET_SEC", "180"), conversation_ttl_sec=_env_int("ATLASBOT_CONVERSATION_TTL_SEC", "900"), snapshot_pin_enabled=_env_bool("ATLASBOT_SNAPSHOT_PIN_ENABLED", "false"), state_db_path=os.getenv("ATLASBOT_STATE_DB", "/data/atlasbot_state.db"), diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index f2bfb9a..5b904e8 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -2,7 +2,6 @@ import asyncio import json import logging import math -import asyncio import re import time import difflib @@ -34,6 +33,10 @@ class LLMLimitReached(RuntimeError): pass +class LLMTimeBudgetExceeded(RuntimeError): + pass + + @dataclass class AnswerScores: confidence: int @@ -163,6 +166,7 @@ class AnswerEngine: call_cap = math.ceil(call_limit * self._settings.llm_limit_multiplier) call_count = 0 limit_hit = False + time_budget_hit = False debug_tags = { "route", @@ -177,6 +181,8 @@ class AnswerEngine: "select_claims", "evidence_fix", } + started = time.monotonic() + time_budget_sec = _mode_time_budget(self._settings, mode) if not limitless else 0.0 def _debug_log(name: str, payload: Any) -> None: if not self._settings.debug_pipeline: @@ -184,13 +190,20 @@ class AnswerEngine: log.info("atlasbot_debug", extra={"extra": {"name": name, "payload": payload}}) async def call_llm(system: str, prompt: str, *, context: str | None = None, model: str | None = None, tag: str = "") -> str: - nonlocal call_count, limit_hit + nonlocal call_count, limit_hit, time_budget_hit if not limitless and call_count >= call_cap: limit_hit = True raise LLMLimitReached("llm_limit") + timeout_sec = None + if not limitless and time_budget_sec > 0: + time_left = time_budget_sec - (time.monotonic() - started) + if time_left <= 0: + time_budget_hit = True + raise LLMTimeBudgetExceeded("time_budget") + timeout_sec = min(self._settings.ollama_timeout_sec, time_left) call_count += 1 messages = build_messages(system, prompt, context=context) - response = await self._llm.chat(messages, model=model or plan.model) + response = await self._llm.chat(messages, model=model or plan.model, timeout_sec=timeout_sec) log.info( "atlasbot_llm_call", extra={"extra": {"mode": mode, "tag": tag, "call": call_count, "limit": call_cap}}, @@ -221,7 +234,6 @@ class AnswerEngine: metric_facts: list[str] = [] facts_used: list[str] = [] - started = time.monotonic() reply = "" scores = _default_scores() claims: list[ClaimItem] = [] @@ -283,7 +295,17 @@ class AnswerEngine: metric_facts = _merge_fact_lines([spine_line], metric_facts) if spine_answer and mode in {"fast", "quick"}: scores = _default_scores() - meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started) + meta = _build_meta( + mode, + call_count, + call_cap, + limit_hit, + time_budget_hit, + time_budget_sec, + classify, + tool_hint, + started, + ) return AnswerResult(spine_answer, scores, meta) cluster_terms = ( "atlas", @@ -353,7 +375,17 @@ class AnswerEngine: observer("followup", "answering follow-up") reply = await self._answer_followup(question, state, summary, classify, plan, call_llm) scores = await self._score_answer(question, reply, plan, call_llm) - meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started) + meta = _build_meta( + mode, + call_count, + call_cap, + limit_hit, + time_budget_hit, + time_budget_sec, + classify, + tool_hint, + started, + ) return AnswerResult(reply, scores, meta) if observer: @@ -879,6 +911,25 @@ class AnswerEngine: scores = await self._score_answer(normalized, reply, plan, call_llm) claims = await self._extract_claims(normalized, reply, summary, facts_used, call_llm) + except LLMTimeBudgetExceeded: + time_budget_hit = True + if not reply: + budget = max(1, int(round(time_budget_sec))) if time_budget_sec > 0 else 0 + if mode in {"quick", "fast"}: + budget_text = f"{budget}s" if budget else "its configured" + reply = ( + f"Quick mode hit {budget_text} time budget before finishing. " + "Try atlas-smart for a deeper answer." + ) + elif mode == "smart": + budget_text = f"{budget}s" if budget else "its configured" + reply = ( + f"Smart mode hit {budget_text} time budget before finishing. " + "Try atlas-genius or ask a narrower follow-up." + ) + else: + reply = "I ran out of time before I could finish this answer." + scores = _default_scores() except LLMLimitReached: if not reply: reply = "I started working on this but hit my reasoning limit. Ask again with 'Run limitless' for a deeper pass." @@ -887,7 +938,17 @@ class AnswerEngine: elapsed = round(time.monotonic() - started, 2) log.info( "atlasbot_answer", - extra={"extra": {"mode": mode, "seconds": elapsed, "llm_calls": call_count, "limit": call_cap, "limit_hit": limit_hit}}, + extra={ + "extra": { + "mode": mode, + "seconds": elapsed, + "llm_calls": call_count, + "limit": call_cap, + "limit_hit": limit_hit, + "time_budget_sec": time_budget_sec, + "time_budget_hit": time_budget_hit, + } + }, ) if limit_hit and "run limitless" not in reply.lower(): @@ -896,7 +957,17 @@ class AnswerEngine: if conversation_id and claims: self._store_state(conversation_id, claims, summary, snapshot_used, pin_snapshot) - meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started) + meta = _build_meta( + mode, + call_count, + call_cap, + limit_hit, + time_budget_hit, + time_budget_sec, + classify, + tool_hint, + started, + ) return AnswerResult(reply, scores, meta) async def _answer_stock(self, question: str) -> AnswerResult: @@ -1188,6 +1259,8 @@ def _build_meta( # noqa: PLR0913 call_count: int, call_cap: int, limit_hit: bool, + time_budget_hit: bool, + time_budget_sec: float, classify: dict[str, Any], tool_hint: dict[str, Any] | None, started: float, @@ -1197,6 +1270,8 @@ def _build_meta( # noqa: PLR0913 "llm_calls": call_count, "llm_limit": call_cap, "llm_limit_hit": limit_hit, + "time_budget_sec": time_budget_sec, + "time_budget_hit": time_budget_hit, "classify": classify, "tool_hint": tool_hint, "elapsed_sec": round(time.monotonic() - started, 2), @@ -1251,12 +1326,12 @@ def _mode_plan(settings: Settings, mode: str) -> ModePlan: return ModePlan( model=settings.ollama_model_fast, fast_model=settings.ollama_model_fast, - max_subquestions=2, - chunk_lines=12, - chunk_top=5, + max_subquestions=1, + chunk_lines=16, + chunk_top=3, chunk_group=5, - kb_max_chars=1200, - kb_max_files=6, + kb_max_chars=800, + kb_max_files=4, use_raw_snapshot=False, parallelism=1, score_retries=1, @@ -1279,6 +1354,14 @@ def _llm_call_limit(settings: Settings, mode: str) -> int: return settings.fast_llm_calls_max +def _mode_time_budget(settings: Settings, mode: str) -> float: + if mode == "genius": + return max(0.0, settings.genius_time_budget_sec) + if mode == "smart": + return max(0.0, settings.smart_time_budget_sec) + return max(0.0, settings.quick_time_budget_sec) + + def _select_subquestions(parts: list[dict[str, Any]], fallback: str, limit: int) -> list[str]: if not parts: return [fallback] diff --git a/atlasbot/llm/client.py b/atlasbot/llm/client.py index 091ba48..3baf595 100644 --- a/atlasbot/llm/client.py +++ b/atlasbot/llm/client.py @@ -29,15 +29,22 @@ class LLMClient: return base return base + "/api/chat" - async def chat(self, messages: list[dict[str, str]], *, model: str | None = None) -> str: + async def chat( + self, + messages: list[dict[str, str]], + *, + model: str | None = None, + timeout_sec: float | None = None, + ) -> str: payload = { "model": model or self._settings.ollama_model, "messages": messages, "stream": False, } + timeout = timeout_sec if timeout_sec is not None else self._timeout for attempt in range(max(1, self._settings.ollama_retries + 1)): try: - async with httpx.AsyncClient(timeout=self._timeout) as client: + async with httpx.AsyncClient(timeout=timeout) as client: resp = await client.post(self._endpoint(), json=payload, headers=self._headers) if resp.status_code == FALLBACK_STATUS_CODE and self._settings.ollama_fallback_model: payload["model"] = self._settings.ollama_fallback_model diff --git a/tests/test_engine.py b/tests/test_engine.py index 3d5fb43..34768bc 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,4 +1,5 @@ import asyncio +from dataclasses import replace from atlasbot.engine.answerer import AnswerEngine from atlasbot.knowledge.loader import KnowledgeBase @@ -10,7 +11,7 @@ class FakeLLM: def __init__(self) -> None: self.calls: list[str] = [] - async def chat(self, messages, *, model=None): + async def chat(self, messages, *, model=None, timeout_sec=None): prompt = messages[-1]["content"] self.calls.append(prompt) if "normalized" in prompt and "keywords" in prompt: @@ -30,6 +31,12 @@ class FakeLLM: return "{}" +class SlowFakeLLM(FakeLLM): + async def chat(self, messages, *, model=None, timeout_sec=None): + await asyncio.sleep(0.02) + return await super().chat(messages, model=model, timeout_sec=timeout_sec) + + def _settings() -> Settings: return Settings( matrix_base="", @@ -57,6 +64,9 @@ def _settings() -> Settings: ariadne_state_token="", snapshot_ttl_sec=30, thinking_interval_sec=30, + quick_time_budget_sec=15.0, + smart_time_budget_sec=45.0, + genius_time_budget_sec=180.0, conversation_ttl_sec=300, snapshot_pin_enabled=False, queue_enabled=False, @@ -74,6 +84,7 @@ def _settings() -> Settings: smart_llm_calls_max=17, genius_llm_calls_max=32, llm_limit_multiplier=1.5, + state_db_path="/tmp/atlasbot_test_state.db", ) @@ -86,3 +97,15 @@ def test_engine_answer_basic(): result = asyncio.run(engine.answer("What is Atlas?", mode="quick")) assert "Atlas has 22 nodes" in result.reply + + +def test_quick_mode_time_budget_guard(): + llm = SlowFakeLLM() + settings = replace(_settings(), quick_time_budget_sec=0.01) + kb = KnowledgeBase("") + snapshot = SnapshotProvider(settings) + engine = AnswerEngine(settings, llm, kb, snapshot) + + result = asyncio.run(engine.answer("What is Atlas?", mode="quick")) + assert "time budget" in result.reply + assert result.meta.get("time_budget_hit") is True