From 40ce60241242f2beb3b961fbe7b911ce9875b235 Mon Sep 17 00:00:00 2001 From: Brad Stein Date: Mon, 30 Mar 2026 16:51:23 -0300 Subject: [PATCH] atlasbot: wire quick smart genius modes --- services/comms/scripts/atlasbot/bot.py | 177 ++++++++++++++---- .../scripts/tests/test_atlasbot_modes.py | 107 +++++++++++ 2 files changed, 248 insertions(+), 36 deletions(-) create mode 100644 services/comms/scripts/tests/test_atlasbot_modes.py diff --git a/services/comms/scripts/atlasbot/bot.py b/services/comms/scripts/atlasbot/bot.py index 59486346..84d771fe 100644 --- a/services/comms/scripts/atlasbot/bot.py +++ b/services/comms/scripts/atlasbot/bot.py @@ -17,12 +17,15 @@ BOT_USER_QUICK = os.environ.get("BOT_USER_QUICK", "").strip() BOT_PASS_QUICK = os.environ.get("BOT_PASS_QUICK", "").strip() BOT_USER_SMART = os.environ.get("BOT_USER_SMART", "").strip() BOT_PASS_SMART = os.environ.get("BOT_PASS_SMART", "").strip() +BOT_USER_GENIUS = os.environ.get("BOT_USER_GENIUS", "").strip() +BOT_PASS_GENIUS = os.environ.get("BOT_PASS_GENIUS", "").strip() ROOM_ALIAS = "#othrys:live.bstein.dev" OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/") MODEL = os.environ.get("OLLAMA_MODEL", "qwen2.5:14b-instruct") MODEL_FAST = os.environ.get("ATLASBOT_MODEL_FAST", "") -MODEL_DEEP = os.environ.get("ATLASBOT_MODEL_DEEP", "") +MODEL_SMART = os.environ.get("ATLASBOT_MODEL_SMART", os.environ.get("ATLASBOT_MODEL_DEEP", "")).strip() +MODEL_GENIUS = os.environ.get("ATLASBOT_MODEL_GENIUS", MODEL_SMART).strip() FALLBACK_MODEL = os.environ.get("OLLAMA_FALLBACK_MODEL", "") API_KEY = os.environ.get("CHAT_API_KEY", "") OLLAMA_TIMEOUT_SEC = float(os.environ.get("OLLAMA_TIMEOUT_SEC", "480")) @@ -43,6 +46,9 @@ MAX_TOOL_CHARS = int(os.environ.get("ATLASBOT_MAX_TOOL_CHARS", "2500")) MAX_FACTS_CHARS = int(os.environ.get("ATLASBOT_MAX_FACTS_CHARS", "8000")) MAX_CONTEXT_CHARS = int(os.environ.get("ATLASBOT_MAX_CONTEXT_CHARS", "12000")) THINKING_INTERVAL_SEC = int(os.environ.get("ATLASBOT_THINKING_INTERVAL_SEC", "120")) +QUICK_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_QUICK_TIME_BUDGET_SEC", "15")) +SMART_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_SMART_TIME_BUDGET_SEC", "45")) +GENIUS_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_GENIUS_TIME_BUDGET_SEC", "180")) OLLAMA_RETRIES = int(os.environ.get("ATLASBOT_OLLAMA_RETRIES", "2")) OLLAMA_SERIALIZE = os.environ.get("ATLASBOT_OLLAMA_SERIALIZE", "true").lower() != "false" @@ -384,16 +390,20 @@ def _strip_bot_mention(text: str) -> str: return cleaned or text.strip() -def _detect_mode_from_body(body: str, *, default: str = "deep") -> str: +def _detect_mode_from_body(body: str, *, default: str = "smart") -> str: lower = normalize_query(body or "") if "atlas_quick" in lower or "atlas-quick" in lower: return "fast" if "atlas_smart" in lower or "atlas-smart" in lower: - return "deep" + return "smart" + if "atlas_genius" in lower or "atlas-genius" in lower: + return "genius" if lower.startswith("quick ") or lower.startswith("fast "): return "fast" - if lower.startswith("smart ") or lower.startswith("deep "): - return "deep" + if lower.startswith("smart "): + return "smart" + if lower.startswith("genius ") or lower.startswith("deep "): + return "genius" return default @@ -401,7 +411,7 @@ def _detect_mode( content: dict[str, Any], body: str, *, - default: str = "deep", + default: str = "smart", account_user: str = "", ) -> str: mode = _detect_mode_from_body(body, default=default) @@ -412,24 +422,72 @@ def _detect_mode( if BOT_USER_QUICK and normalize_user_id(BOT_USER_QUICK).lower() in normalized: return "fast" if BOT_USER_SMART and normalize_user_id(BOT_USER_SMART).lower() in normalized: - return "deep" + return "smart" + if BOT_USER_GENIUS and normalize_user_id(BOT_USER_GENIUS).lower() in normalized: + return "genius" if BOT_USER and normalize_user_id(BOT_USER).lower() in normalized: - return "deep" + return "smart" if account_user and BOT_USER_QUICK and normalize_user_id(account_user) == normalize_user_id(BOT_USER_QUICK): return "fast" if account_user and BOT_USER_SMART and normalize_user_id(account_user) == normalize_user_id(BOT_USER_SMART): - return "deep" + return "smart" + if account_user and BOT_USER_GENIUS and normalize_user_id(account_user) == normalize_user_id(BOT_USER_GENIUS): + return "genius" return mode def _model_for_mode(mode: str) -> str: if mode == "fast" and MODEL_FAST: return MODEL_FAST - if mode == "deep" and MODEL_DEEP: - return MODEL_DEEP + if mode == "smart" and MODEL_SMART: + return MODEL_SMART + if mode == "genius" and MODEL_GENIUS: + return MODEL_GENIUS + if mode == "deep" and MODEL_SMART: + return MODEL_SMART return MODEL +def _normalize_mode(mode: str) -> str: + normalized = (mode or "").strip().lower() + if normalized in {"quick", "fast"}: + return "fast" + if normalized in {"smart"}: + return "smart" + if normalized in {"genius", "deep"}: + return "genius" + return "smart" + + +def _mode_time_budget_sec(mode: str) -> float: + normalized = _normalize_mode(mode) + if normalized == "fast": + return max(1.0, QUICK_TIME_BUDGET_SEC) + if normalized == "smart": + return max(1.0, SMART_TIME_BUDGET_SEC) + if normalized == "genius": + return max(1.0, GENIUS_TIME_BUDGET_SEC) + return max(1.0, SMART_TIME_BUDGET_SEC) + + +def _mode_ollama_timeout_sec(mode: str) -> float: + normalized = _normalize_mode(mode) + budget = _mode_time_budget_sec(normalized) + if normalized == "fast": + return max(6.0, min(budget - 2.0, OLLAMA_TIMEOUT_SEC)) + if normalized == "smart": + return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC)) + if normalized == "genius": + return max(20.0, min(budget - 10.0, OLLAMA_TIMEOUT_SEC)) + return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC)) + + +def _mode_heartbeat_sec(mode: str) -> int: + normalized = _normalize_mode(mode) + budget = _mode_time_budget_sec(normalized) + return max(5, min(THINKING_INTERVAL_SEC, int(max(5.0, budget / 3.0)))) + + # Matrix HTTP helper. def req(method: str, path: str, token: str | None = None, body=None, timeout=60, base: str | None = None): url = (base or BASE) + path @@ -3842,9 +3900,12 @@ def _open_ended_multi( def _open_ended_total_steps(mode: str) -> int: - if mode == "fast": + normalized = _normalize_mode(mode) + if normalized == "fast": return 2 - return 9 + if normalized == "smart": + return 3 + return 4 def _fast_fact_lines( @@ -4179,6 +4240,7 @@ def _open_ended_fast_single( use_history=False, system_override=_open_ended_system(), model=model, + timeout=_mode_ollama_timeout_sec("fast"), ) if not _has_body_lines(reply): reply = _ollama_call( @@ -4188,6 +4250,7 @@ def _open_ended_fast_single( use_history=False, system_override=_open_ended_system(), model=model, + timeout=_mode_ollama_timeout_sec("fast"), ) fallback = _fallback_fact_answer(prompt, context) if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)): @@ -4248,16 +4311,53 @@ def _open_ended_deep( fact_lines: list[str], fact_meta: dict[str, dict[str, Any]], history_lines: list[str], + mode: str, state: ThoughtState | None = None, ) -> str: - return _open_ended_multi( - prompt, - fact_pack=fact_pack, - fact_lines=fact_lines, - fact_meta=fact_meta, - history_lines=history_lines, - state=state, + normalized = _normalize_mode(mode) + model = _model_for_mode(normalized) + subjective = _is_subjective_query(prompt) + primary_tags = _primary_tags_for_prompt(prompt) + focus_tags = _preferred_tags_for_prompt(prompt) + if not focus_tags and subjective: + focus_tags = set(_ALLOWED_INSIGHT_TAGS) + avoid_tags = _history_focus_tags(history_lines) if (subjective or _is_followup_query(prompt)) else set() + limit = 12 if normalized == "smart" else 18 + selected_lines = _fast_fact_lines( + fact_lines, + fact_meta, + focus_tags=focus_tags, + avoid_tags=avoid_tags, + primary_tags=primary_tags, + limit=limit, ) + selected_meta = _fact_pack_meta(selected_lines) + selected_pack = _fact_pack_text(selected_lines, selected_meta) + if _needs_full_fact_pack(prompt) or not selected_lines or normalized == "genius": + selected_pack = fact_pack + fallback = _fallback_fact_answer(prompt, selected_pack) + if not subjective and fallback: + if state: + state.update("done", step=_open_ended_total_steps(normalized)) + return _ensure_scores(fallback) + if state: + state.update("drafting", step=1, note="synthesizing") + reply = _ollama_call( + ("atlasbot_deep", "atlasbot_deep"), + prompt, + context=_append_history_context(selected_pack, history_lines), + use_history=False, + system_override=_open_ended_system(), + model=model, + timeout=_mode_ollama_timeout_sec(normalized), + ) + if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)): + reply = fallback + if not _has_body_lines(reply): + reply = "I don't have enough data in the current snapshot to answer that." + if state: + state.update("done", step=_open_ended_total_steps(normalized)) + return _ensure_scores(reply) def open_ended_answer( @@ -4285,7 +4385,8 @@ def open_ended_answer( return _ensure_scores("I don't have enough data to answer that.") fact_meta = _fact_pack_meta(lines) fact_pack = _fact_pack_text(lines, fact_meta) - if mode == "fast": + normalized = _normalize_mode(mode) + if normalized == "fast": return _open_ended_fast( prompt, fact_pack=fact_pack, @@ -4300,6 +4401,7 @@ def open_ended_answer( fact_lines=lines, fact_meta=fact_meta, history_lines=history_lines, + mode=normalized, state=state, ) @@ -4321,6 +4423,7 @@ def _non_cluster_reply(prompt: str, *, history_lines: list[str], mode: str) -> s use_history=False, system_override=system, model=model, + timeout=_mode_ollama_timeout_sec(mode), ) reply = re.sub(r"\bconfidence\s*:\s*(high|medium|low)\b\.?\s*", "", reply, flags=re.IGNORECASE).strip() return _ensure_scores(reply) @@ -4372,13 +4475,7 @@ class _AtlasbotHandler(BaseHTTPRequestHandler): self._write_json(400, {"error": "missing_prompt"}) return cleaned = _strip_bot_mention(prompt) - mode = str(payload.get("mode") or "deep").lower() - if mode in ("quick", "fast"): - mode = "fast" - elif mode in ("smart", "deep"): - mode = "deep" - else: - mode = "deep" + mode = _normalize_mode(str(payload.get("mode") or "smart")) snapshot = _snapshot_state() inventory = _snapshot_inventory(snapshot) or node_inventory_live() workloads = _snapshot_workloads(snapshot) @@ -4669,6 +4766,7 @@ def _ollama_call( use_history: bool = True, system_override: str | None = None, model: str | None = None, + timeout: float | None = None, ) -> str: system = system_override or ( "System: You are Atlas, the Titan lab assistant for Atlas/Othrys. " @@ -4702,6 +4800,7 @@ def _ollama_call( messages.append({"role": "user", "content": prompt}) model_name = model or MODEL + request_timeout = timeout if timeout is not None else OLLAMA_TIMEOUT_SEC payload = {"model": model_name, "messages": messages, "stream": False} headers = {"Content-Type": "application/json"} if API_KEY: @@ -4712,13 +4811,13 @@ def _ollama_call( lock.acquire() try: try: - with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp: + with request.urlopen(r, timeout=request_timeout) as resp: data = json.loads(resp.read().decode()) except error.HTTPError as exc: if exc.code == 404 and FALLBACK_MODEL and FALLBACK_MODEL != payload["model"]: payload["model"] = FALLBACK_MODEL r = request.Request(endpoint, data=json.dumps(payload).encode(), headers=headers) - with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp: + with request.urlopen(r, timeout=request_timeout) as resp: data = json.loads(resp.read().decode()) else: raise @@ -4743,6 +4842,7 @@ def ollama_reply( fallback: str = "", use_history: bool = True, model: str | None = None, + timeout: float | None = None, ) -> str: last_error = None for attempt in range(max(1, OLLAMA_RETRIES + 1)): @@ -4753,6 +4853,7 @@ def ollama_reply( context=context, use_history=use_history, model=model, + timeout=timeout, ) except Exception as exc: # noqa: BLE001 last_error = exc @@ -4773,6 +4874,7 @@ def ollama_reply_with_thinking( fallback: str, use_history: bool = True, model: str | None = None, + timeout: float | None = None, ) -> str: result: dict[str, str] = {"reply": ""} done = threading.Event() @@ -4785,6 +4887,7 @@ def ollama_reply_with_thinking( fallback=fallback, use_history=use_history, model=model, + timeout=timeout, ) done.set() @@ -4841,7 +4944,7 @@ def open_ended_with_thinking( thread.start() if not done.wait(2.0): send_msg(token, room, "Thinking…") - heartbeat = max(10, THINKING_INTERVAL_SEC) + heartbeat = _mode_heartbeat_sec(mode) next_heartbeat = time.monotonic() + heartbeat while not done.wait(max(0, next_heartbeat - time.monotonic())): send_msg(token, room, state.status_line()) @@ -4906,7 +5009,7 @@ def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str) mode = _detect_mode( content, body, - default=default_mode if default_mode in ("fast", "deep") else "deep", + default=_normalize_mode(default_mode), account_user=account_user, ) @@ -4972,14 +5075,14 @@ def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str) snapshot=snapshot, workloads=workloads, history_lines=history[hist_key], - mode=mode if mode in ("fast", "deep") else "deep", + mode=_normalize_mode(mode), allow_tools=allow_tools, ) else: reply = _non_cluster_reply( cleaned_body, history_lines=history[hist_key], - mode=mode if mode in ("fast", "deep") else "deep", + mode=_normalize_mode(mode), ) send_msg(token, rid, reply) history[hist_key].append(f"Atlas: {reply}") @@ -5003,11 +5106,13 @@ def _bot_accounts() -> list[dict[str, str]]: return accounts.append({"user": user, "password": password, "mode": mode}) - add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "deep") + add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "smart") if BOT_USER_QUICK and BOT_PASS_QUICK: add(BOT_USER_QUICK, BOT_PASS_QUICK, "fast") + if BOT_USER_GENIUS and BOT_PASS_GENIUS: + add(BOT_USER_GENIUS, BOT_PASS_GENIUS, "genius") if BOT_USER and BOT_PASS and all(acc["user"] != BOT_USER for acc in accounts): - add(BOT_USER, BOT_PASS, "deep") + add(BOT_USER, BOT_PASS, "smart") seen: set[str] = set() unique: list[dict[str, str]] = [] diff --git a/services/comms/scripts/tests/test_atlasbot_modes.py b/services/comms/scripts/tests/test_atlasbot_modes.py new file mode 100644 index 00000000..9a1a8037 --- /dev/null +++ b/services/comms/scripts/tests/test_atlasbot_modes.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import importlib.util +import os +from pathlib import Path +from unittest import TestCase, mock + + +BOT_PATH = Path(__file__).resolve().parents[1] / "atlasbot" / "bot.py" + + +def load_bot_module(): + env = { + "BOT_USER": "atlas-smart", + "BOT_PASS": "smart-pass", + "BOT_USER_QUICK": "atlas-quick", + "BOT_PASS_QUICK": "quick-pass", + "BOT_USER_SMART": "atlas-smart", + "BOT_PASS_SMART": "smart-pass", + "BOT_USER_GENIUS": "atlas-genius", + "BOT_PASS_GENIUS": "genius-pass", + "OLLAMA_URL": "http://ollama.invalid", + "OLLAMA_MODEL": "base-model", + "ATLASBOT_MODEL_FAST": "fast-model", + "ATLASBOT_MODEL_SMART": "smart-model", + "ATLASBOT_MODEL_GENIUS": "genius-model", + "ATLASBOT_QUICK_TIME_BUDGET_SEC": "15", + "ATLASBOT_SMART_TIME_BUDGET_SEC": "45", + "ATLASBOT_GENIUS_TIME_BUDGET_SEC": "180", + "KB_DIR": "", + "VM_URL": "http://vm.invalid", + "ARIADNE_STATE_URL": "", + "ARIADNE_STATE_TOKEN": "", + } + with mock.patch.dict(os.environ, env, clear=False): + spec = importlib.util.spec_from_file_location("atlasbot_bot", BOT_PATH) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +class AtlasbotModeTests(TestCase): + def setUp(self): + self.bot = load_bot_module() + + def test_bot_accounts_include_genius_mode(self): + accounts = self.bot._bot_accounts() + by_user = {account["user"]: account["mode"] for account in accounts} + + self.assertEqual(by_user["atlas-quick"], "fast") + self.assertEqual(by_user["atlas-smart"], "smart") + self.assertEqual(by_user["atlas-genius"], "genius") + + def test_objective_cluster_question_uses_fact_pack_without_llm(self): + fact_lines = [ + "hottest_cpu: longhorn-system (6.69)", + "hottest_ram: longhorn-system (36.05 GB)", + ] + + with ( + mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines), + mock.patch.object(self.bot, "_ollama_call", side_effect=AssertionError("LLM should not be called")), + ): + reply = self.bot.open_ended_answer( + "what is the hottest cpu node in titan lab currently?", + inventory=[], + snapshot=None, + workloads=[], + history_lines=[], + mode="smart", + allow_tools=True, + ) + + self.assertIn("longhorn-system", reply) + self.assertIn("Confidence:", reply) + + def test_subjective_genius_answer_uses_genius_model(self): + fact_lines = [ + "hottest_cpu: longhorn-system (6.69)", + "worker_nodes: titan-01, titan-02, titan-03", + ] + captured: dict[str, object] = {} + + def fake_ollama_call(hist_key, prompt, *, context, use_history=True, system_override=None, model=None, timeout=None): + captured["model"] = model + captured["timeout"] = timeout + captured["context"] = context + return "The worker spread stands out because Titan keeps meaningful capacity on the same cluster. Confidence: high" + + with ( + mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines), + mock.patch.object(self.bot, "_ollama_call", side_effect=fake_ollama_call), + ): + reply = self.bot.open_ended_answer( + "what stands out about titan lab?", + inventory=[], + snapshot=None, + workloads=[], + history_lines=[], + mode="genius", + allow_tools=True, + ) + + self.assertIn("The worker spread stands out", reply) + self.assertEqual(captured["model"], "genius-model") + self.assertLessEqual(float(captured["timeout"]), 180.0)