From 0d8a2c5531b27362df1c2d549125999138e2ff78 Mon Sep 17 00:00:00 2001 From: Brad Stein Date: Mon, 30 Mar 2026 16:51:23 -0300 Subject: [PATCH] atlasbot: wire quick smart genius modes --- services/comms/scripts/atlasbot/bot.py | 273 +++++++++++++++--- .../scripts/tests/test_atlasbot_modes.py | 107 +++++++ 2 files changed, 333 insertions(+), 47 deletions(-) create mode 100644 services/comms/scripts/tests/test_atlasbot_modes.py diff --git a/services/comms/scripts/atlasbot/bot.py b/services/comms/scripts/atlasbot/bot.py index be256c0e..c2311a1f 100644 --- a/services/comms/scripts/atlasbot/bot.py +++ b/services/comms/scripts/atlasbot/bot.py @@ -11,14 +11,23 @@ from urllib import error, parse, request BASE = os.environ.get("MATRIX_BASE", "http://othrys-synapse-matrix-synapse:8008") AUTH_BASE = os.environ.get("AUTH_BASE", "http://matrix-authentication-service:8080") -USER = os.environ["BOT_USER"] -PASSWORD = os.environ["BOT_PASS"] +BOT_USER = os.environ["BOT_USER"] +BOT_PASS = os.environ["BOT_PASS"] +BOT_USER_QUICK = os.environ.get("BOT_USER_QUICK", "").strip() +BOT_PASS_QUICK = os.environ.get("BOT_PASS_QUICK", "").strip() +BOT_USER_SMART = os.environ.get("BOT_USER_SMART", "").strip() +BOT_PASS_SMART = os.environ.get("BOT_PASS_SMART", "").strip() +BOT_USER_GENIUS = os.environ.get("BOT_USER_GENIUS", "").strip() +BOT_PASS_GENIUS = os.environ.get("BOT_PASS_GENIUS", "").strip() +USER = BOT_USER +PASSWORD = BOT_PASS ROOM_ALIAS = "#othrys:live.bstein.dev" OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/") MODEL = os.environ.get("OLLAMA_MODEL", "qwen2.5:14b-instruct") MODEL_FAST = os.environ.get("ATLASBOT_MODEL_FAST", "") -MODEL_DEEP = os.environ.get("ATLASBOT_MODEL_DEEP", "") +MODEL_SMART = os.environ.get("ATLASBOT_MODEL_SMART", os.environ.get("ATLASBOT_MODEL_DEEP", "")).strip() +MODEL_GENIUS = os.environ.get("ATLASBOT_MODEL_GENIUS", MODEL_SMART).strip() FALLBACK_MODEL = os.environ.get("OLLAMA_FALLBACK_MODEL", "") API_KEY = os.environ.get("CHAT_API_KEY", "") OLLAMA_TIMEOUT_SEC = float(os.environ.get("OLLAMA_TIMEOUT_SEC", "480")) @@ -31,7 +40,7 @@ VM_URL = os.environ.get("VM_URL", "http://victoria-metrics-single-server.monitor ARIADNE_STATE_URL = os.environ.get("ARIADNE_STATE_URL", "") ARIADNE_STATE_TOKEN = os.environ.get("ARIADNE_STATE_TOKEN", "") -BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{USER},atlas") +BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{BOT_USER},atlas") SERVER_NAME = os.environ.get("MATRIX_SERVER_NAME", "live.bstein.dev") MAX_KB_CHARS = int(os.environ.get("ATLASBOT_MAX_KB_CHARS", "2500")) @@ -39,6 +48,9 @@ MAX_TOOL_CHARS = int(os.environ.get("ATLASBOT_MAX_TOOL_CHARS", "2500")) MAX_FACTS_CHARS = int(os.environ.get("ATLASBOT_MAX_FACTS_CHARS", "8000")) MAX_CONTEXT_CHARS = int(os.environ.get("ATLASBOT_MAX_CONTEXT_CHARS", "12000")) THINKING_INTERVAL_SEC = int(os.environ.get("ATLASBOT_THINKING_INTERVAL_SEC", "120")) +QUICK_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_QUICK_TIME_BUDGET_SEC", "15")) +SMART_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_SMART_TIME_BUDGET_SEC", "45")) +GENIUS_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_GENIUS_TIME_BUDGET_SEC", "180")) OLLAMA_RETRIES = int(os.environ.get("ATLASBOT_OLLAMA_RETRIES", "2")) OLLAMA_SERIALIZE = os.environ.get("ATLASBOT_OLLAMA_SERIALIZE", "true").lower() != "false" @@ -380,27 +392,103 @@ def _strip_bot_mention(text: str) -> str: return cleaned or text.strip() -def _detect_mode_from_body(body: str, *, default: str = "deep") -> str: +def _detect_mode_from_body(body: str, *, default: str = "smart") -> str: lower = normalize_query(body or "") if "atlas_quick" in lower or "atlas-quick" in lower: return "fast" if "atlas_smart" in lower or "atlas-smart" in lower: - return "deep" + return "smart" + if "atlas_genius" in lower or "atlas-genius" in lower: + return "genius" if lower.startswith("quick ") or lower.startswith("fast "): return "fast" - if lower.startswith("smart ") or lower.startswith("deep "): - return "deep" + if lower.startswith("smart "): + return "smart" + if lower.startswith("genius ") or lower.startswith("deep "): + return "genius" return default +def _detect_mode( + content: dict[str, Any], + body: str, + *, + default: str = "smart", + account_user: str = "", +) -> str: + mode = _detect_mode_from_body(body, default=default) + mentions = content.get("m.mentions", {}) + user_ids = mentions.get("user_ids", []) + if isinstance(user_ids, list): + normalized = {normalize_user_id(uid).lower() for uid in user_ids if isinstance(uid, str)} + if BOT_USER_QUICK and normalize_user_id(BOT_USER_QUICK).lower() in normalized: + return "fast" + if BOT_USER_SMART and normalize_user_id(BOT_USER_SMART).lower() in normalized: + return "smart" + if BOT_USER_GENIUS and normalize_user_id(BOT_USER_GENIUS).lower() in normalized: + return "genius" + if BOT_USER and normalize_user_id(BOT_USER).lower() in normalized: + return "smart" + if account_user and BOT_USER_QUICK and normalize_user_id(account_user) == normalize_user_id(BOT_USER_QUICK): + return "fast" + if account_user and BOT_USER_SMART and normalize_user_id(account_user) == normalize_user_id(BOT_USER_SMART): + return "smart" + if account_user and BOT_USER_GENIUS and normalize_user_id(account_user) == normalize_user_id(BOT_USER_GENIUS): + return "genius" + return mode + def _model_for_mode(mode: str) -> str: if mode == "fast" and MODEL_FAST: return MODEL_FAST - if mode == "deep" and MODEL_DEEP: - return MODEL_DEEP + if mode == "smart" and MODEL_SMART: + return MODEL_SMART + if mode == "genius" and MODEL_GENIUS: + return MODEL_GENIUS + if mode == "deep" and MODEL_SMART: + return MODEL_SMART return MODEL +def _normalize_mode(mode: str) -> str: + normalized = (mode or "").strip().lower() + if normalized in {"quick", "fast"}: + return "fast" + if normalized in {"smart"}: + return "smart" + if normalized in {"genius", "deep"}: + return "genius" + return "smart" + + +def _mode_time_budget_sec(mode: str) -> float: + normalized = _normalize_mode(mode) + if normalized == "fast": + return max(1.0, QUICK_TIME_BUDGET_SEC) + if normalized == "smart": + return max(1.0, SMART_TIME_BUDGET_SEC) + if normalized == "genius": + return max(1.0, GENIUS_TIME_BUDGET_SEC) + return max(1.0, SMART_TIME_BUDGET_SEC) + + +def _mode_ollama_timeout_sec(mode: str) -> float: + normalized = _normalize_mode(mode) + budget = _mode_time_budget_sec(normalized) + if normalized == "fast": + return max(6.0, min(budget - 2.0, OLLAMA_TIMEOUT_SEC)) + if normalized == "smart": + return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC)) + if normalized == "genius": + return max(20.0, min(budget - 10.0, OLLAMA_TIMEOUT_SEC)) + return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC)) + + +def _mode_heartbeat_sec(mode: str) -> int: + normalized = _normalize_mode(mode) + budget = _mode_time_budget_sec(normalized) + return max(5, min(THINKING_INTERVAL_SEC, int(max(5.0, budget / 3.0)))) + + # Matrix HTTP helper. def req(method: str, path: str, token: str | None = None, body=None, timeout=60, base: str | None = None): url = (base or BASE) + path @@ -416,12 +504,12 @@ def req(method: str, path: str, token: str | None = None, body=None, timeout=60, raw = resp.read() return json.loads(raw.decode()) if raw else {} -def login() -> str: - login_user = normalize_user_id(USER) +def login(user: str, password: str) -> str: + login_user = normalize_user_id(user) payload = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": login_user}, - "password": PASSWORD, + "password": password, } res = req("POST", "/_matrix/client/v3/login", body=payload, base=AUTH_BASE) return res["access_token"] @@ -3813,9 +3901,12 @@ def _open_ended_multi( def _open_ended_total_steps(mode: str) -> int: - if mode == "fast": + normalized = _normalize_mode(mode) + if normalized == "fast": return 2 - return 9 + if normalized == "smart": + return 3 + return 4 def _fast_fact_lines( @@ -4150,6 +4241,7 @@ def _open_ended_fast_single( use_history=False, system_override=_open_ended_system(), model=model, + timeout=_mode_ollama_timeout_sec("fast"), ) if not _has_body_lines(reply): reply = _ollama_call( @@ -4159,6 +4251,7 @@ def _open_ended_fast_single( use_history=False, system_override=_open_ended_system(), model=model, + timeout=_mode_ollama_timeout_sec("fast"), ) fallback = _fallback_fact_answer(prompt, context) if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)): @@ -4219,16 +4312,53 @@ def _open_ended_deep( fact_lines: list[str], fact_meta: dict[str, dict[str, Any]], history_lines: list[str], + mode: str, state: ThoughtState | None = None, ) -> str: - return _open_ended_multi( - prompt, - fact_pack=fact_pack, - fact_lines=fact_lines, - fact_meta=fact_meta, - history_lines=history_lines, - state=state, + normalized = _normalize_mode(mode) + model = _model_for_mode(normalized) + subjective = _is_subjective_query(prompt) + primary_tags = _primary_tags_for_prompt(prompt) + focus_tags = _preferred_tags_for_prompt(prompt) + if not focus_tags and subjective: + focus_tags = set(_ALLOWED_INSIGHT_TAGS) + avoid_tags = _history_focus_tags(history_lines) if (subjective or _is_followup_query(prompt)) else set() + limit = 12 if normalized == "smart" else 18 + selected_lines = _fast_fact_lines( + fact_lines, + fact_meta, + focus_tags=focus_tags, + avoid_tags=avoid_tags, + primary_tags=primary_tags, + limit=limit, ) + selected_meta = _fact_pack_meta(selected_lines) + selected_pack = _fact_pack_text(selected_lines, selected_meta) + if _needs_full_fact_pack(prompt) or not selected_lines or normalized == "genius": + selected_pack = fact_pack + fallback = _fallback_fact_answer(prompt, selected_pack) + if not subjective and fallback: + if state: + state.update("done", step=_open_ended_total_steps(normalized)) + return _ensure_scores(fallback) + if state: + state.update("drafting", step=1, note="synthesizing") + reply = _ollama_call( + ("atlasbot_deep", "atlasbot_deep"), + prompt, + context=_append_history_context(selected_pack, history_lines), + use_history=False, + system_override=_open_ended_system(), + model=model, + timeout=_mode_ollama_timeout_sec(normalized), + ) + if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)): + reply = fallback + if not _has_body_lines(reply): + reply = "I don't have enough data in the current snapshot to answer that." + if state: + state.update("done", step=_open_ended_total_steps(normalized)) + return _ensure_scores(reply) def open_ended_answer( @@ -4256,7 +4386,8 @@ def open_ended_answer( return _ensure_scores("I don't have enough data to answer that.") fact_meta = _fact_pack_meta(lines) fact_pack = _fact_pack_text(lines, fact_meta) - if mode == "fast": + normalized = _normalize_mode(mode) + if normalized == "fast": return _open_ended_fast( prompt, fact_pack=fact_pack, @@ -4271,6 +4402,7 @@ def open_ended_answer( fact_lines=lines, fact_meta=fact_meta, history_lines=history_lines, + mode=normalized, state=state, ) @@ -4292,6 +4424,7 @@ def _non_cluster_reply(prompt: str, *, history_lines: list[str], mode: str) -> s use_history=False, system_override=system, model=model, + timeout=_mode_ollama_timeout_sec(mode), ) reply = re.sub(r"\bconfidence\s*:\s*(high|medium|low)\b\.?\s*", "", reply, flags=re.IGNORECASE).strip() return _ensure_scores(reply) @@ -4343,13 +4476,7 @@ class _AtlasbotHandler(BaseHTTPRequestHandler): self._write_json(400, {"error": "missing_prompt"}) return cleaned = _strip_bot_mention(prompt) - mode = str(payload.get("mode") or "deep").lower() - if mode in ("quick", "fast"): - mode = "fast" - elif mode in ("smart", "deep"): - mode = "deep" - else: - mode = "deep" + mode = _normalize_mode(str(payload.get("mode") or "smart")) snapshot = _snapshot_state() inventory = _snapshot_inventory(snapshot) or node_inventory_live() workloads = _snapshot_workloads(snapshot) @@ -4640,6 +4767,7 @@ def _ollama_call( use_history: bool = True, system_override: str | None = None, model: str | None = None, + timeout: float | None = None, ) -> str: system = system_override or ( "System: You are Atlas, the Titan lab assistant for Atlas/Othrys. " @@ -4673,6 +4801,7 @@ def _ollama_call( messages.append({"role": "user", "content": prompt}) model_name = model or MODEL + request_timeout = timeout if timeout is not None else OLLAMA_TIMEOUT_SEC payload = {"model": model_name, "messages": messages, "stream": False} headers = {"Content-Type": "application/json"} if API_KEY: @@ -4683,13 +4812,13 @@ def _ollama_call( lock.acquire() try: try: - with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp: + with request.urlopen(r, timeout=request_timeout) as resp: data = json.loads(resp.read().decode()) except error.HTTPError as exc: if exc.code == 404 and FALLBACK_MODEL and FALLBACK_MODEL != payload["model"]: payload["model"] = FALLBACK_MODEL r = request.Request(endpoint, data=json.dumps(payload).encode(), headers=headers) - with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp: + with request.urlopen(r, timeout=request_timeout) as resp: data = json.loads(resp.read().decode()) else: raise @@ -4714,6 +4843,7 @@ def ollama_reply( fallback: str = "", use_history: bool = True, model: str | None = None, + timeout: float | None = None, ) -> str: last_error = None for attempt in range(max(1, OLLAMA_RETRIES + 1)): @@ -4724,6 +4854,7 @@ def ollama_reply( context=context, use_history=use_history, model=model, + timeout=timeout, ) except Exception as exc: # noqa: BLE001 last_error = exc @@ -4744,6 +4875,7 @@ def ollama_reply_with_thinking( fallback: str, use_history: bool = True, model: str | None = None, + timeout: float | None = None, ) -> str: result: dict[str, str] = {"reply": ""} done = threading.Event() @@ -4756,6 +4888,7 @@ def ollama_reply_with_thinking( fallback=fallback, use_history=use_history, model=model, + timeout=timeout, ) done.set() @@ -4812,7 +4945,7 @@ def open_ended_with_thinking( thread.start() if not done.wait(2.0): send_msg(token, room, "Thinking…") - heartbeat = max(10, THINKING_INTERVAL_SEC) + heartbeat = _mode_heartbeat_sec(mode) next_heartbeat = time.monotonic() + heartbeat while not done.wait(max(0, next_heartbeat - time.monotonic())): send_msg(token, room, state.status_line()) @@ -4820,7 +4953,7 @@ def open_ended_with_thinking( thread.join(timeout=1) return result["reply"] or "Model backend is busy. Try again in a moment." -def sync_loop(token: str, room_id: str): +def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str): since = None try: res = req("GET", "/_matrix/client/v3/sync?timeout=0", token, timeout=10) @@ -4861,7 +4994,7 @@ def sync_loop(token: str, room_id: str): if not body: continue sender = ev.get("sender", "") - if sender == f"@{USER}:live.bstein.dev": + if account_user and sender == normalize_user_id(account_user): continue mentioned = is_mentioned(content, body) @@ -4874,7 +5007,12 @@ def sync_loop(token: str, room_id: str): cleaned_body = _strip_bot_mention(body) lower_body = cleaned_body.lower() - mode = _detect_mode_from_body(body, default="deep" if is_dm else "deep") + mode = _detect_mode( + content, + body, + default=_normalize_mode(default_mode), + account_user=account_user, + ) # Only do live cluster introspection in DMs. allow_tools = is_dm @@ -4938,39 +5076,80 @@ def sync_loop(token: str, room_id: str): snapshot=snapshot, workloads=workloads, history_lines=history[hist_key], - mode=mode if mode in ("fast", "deep") else "deep", + mode=_normalize_mode(mode), allow_tools=allow_tools, ) else: reply = _non_cluster_reply( cleaned_body, history_lines=history[hist_key], - mode=mode if mode in ("fast", "deep") else "deep", + mode=_normalize_mode(mode), ) send_msg(token, rid, reply) history[hist_key].append(f"Atlas: {reply}") history[hist_key] = history[hist_key][-80:] -def login_with_retry(): +def login_with_retry(user: str, password: str): last_err = None for attempt in range(10): try: - return login() + return login(user, password) except Exception as exc: # noqa: BLE001 last_err = exc time.sleep(min(30, 2 ** attempt)) raise last_err +def _bot_accounts() -> list[dict[str, str]]: + accounts: list[dict[str, str]] = [] + + def add(user: str, password: str, mode: str): + if not user or not password: + return + accounts.append({"user": user, "password": password, "mode": mode}) + + add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "smart") + if BOT_USER_QUICK and BOT_PASS_QUICK: + add(BOT_USER_QUICK, BOT_PASS_QUICK, "fast") + if BOT_USER_GENIUS and BOT_PASS_GENIUS: + add(BOT_USER_GENIUS, BOT_PASS_GENIUS, "genius") + if BOT_USER and BOT_PASS and all(acc["user"] != BOT_USER for acc in accounts): + add(BOT_USER, BOT_PASS, "smart") + + seen: set[str] = set() + unique: list[dict[str, str]] = [] + for acc in accounts: + uid = normalize_user_id(acc["user"]).lower() + if uid in seen: + continue + seen.add(uid) + unique.append(acc) + return unique + def main(): load_kb() _start_http_server() - token = login_with_retry() - try: - room_id = resolve_alias(token, ROOM_ALIAS) - join_room(token, room_id) - except Exception: - room_id = None - sync_loop(token, room_id) + accounts = _bot_accounts() + threads: list[threading.Thread] = [] + for acc in accounts: + token = login_with_retry(acc["user"], acc["password"]) + try: + room_id = resolve_alias(token, ROOM_ALIAS) + join_room(token, room_id) + except Exception: + room_id = None + thread = threading.Thread( + target=sync_loop, + args=(token, room_id), + kwargs={ + "account_user": acc["user"], + "default_mode": acc["mode"], + }, + daemon=True, + ) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() if __name__ == "__main__": main() diff --git a/services/comms/scripts/tests/test_atlasbot_modes.py b/services/comms/scripts/tests/test_atlasbot_modes.py new file mode 100644 index 00000000..9a1a8037 --- /dev/null +++ b/services/comms/scripts/tests/test_atlasbot_modes.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import importlib.util +import os +from pathlib import Path +from unittest import TestCase, mock + + +BOT_PATH = Path(__file__).resolve().parents[1] / "atlasbot" / "bot.py" + + +def load_bot_module(): + env = { + "BOT_USER": "atlas-smart", + "BOT_PASS": "smart-pass", + "BOT_USER_QUICK": "atlas-quick", + "BOT_PASS_QUICK": "quick-pass", + "BOT_USER_SMART": "atlas-smart", + "BOT_PASS_SMART": "smart-pass", + "BOT_USER_GENIUS": "atlas-genius", + "BOT_PASS_GENIUS": "genius-pass", + "OLLAMA_URL": "http://ollama.invalid", + "OLLAMA_MODEL": "base-model", + "ATLASBOT_MODEL_FAST": "fast-model", + "ATLASBOT_MODEL_SMART": "smart-model", + "ATLASBOT_MODEL_GENIUS": "genius-model", + "ATLASBOT_QUICK_TIME_BUDGET_SEC": "15", + "ATLASBOT_SMART_TIME_BUDGET_SEC": "45", + "ATLASBOT_GENIUS_TIME_BUDGET_SEC": "180", + "KB_DIR": "", + "VM_URL": "http://vm.invalid", + "ARIADNE_STATE_URL": "", + "ARIADNE_STATE_TOKEN": "", + } + with mock.patch.dict(os.environ, env, clear=False): + spec = importlib.util.spec_from_file_location("atlasbot_bot", BOT_PATH) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +class AtlasbotModeTests(TestCase): + def setUp(self): + self.bot = load_bot_module() + + def test_bot_accounts_include_genius_mode(self): + accounts = self.bot._bot_accounts() + by_user = {account["user"]: account["mode"] for account in accounts} + + self.assertEqual(by_user["atlas-quick"], "fast") + self.assertEqual(by_user["atlas-smart"], "smart") + self.assertEqual(by_user["atlas-genius"], "genius") + + def test_objective_cluster_question_uses_fact_pack_without_llm(self): + fact_lines = [ + "hottest_cpu: longhorn-system (6.69)", + "hottest_ram: longhorn-system (36.05 GB)", + ] + + with ( + mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines), + mock.patch.object(self.bot, "_ollama_call", side_effect=AssertionError("LLM should not be called")), + ): + reply = self.bot.open_ended_answer( + "what is the hottest cpu node in titan lab currently?", + inventory=[], + snapshot=None, + workloads=[], + history_lines=[], + mode="smart", + allow_tools=True, + ) + + self.assertIn("longhorn-system", reply) + self.assertIn("Confidence:", reply) + + def test_subjective_genius_answer_uses_genius_model(self): + fact_lines = [ + "hottest_cpu: longhorn-system (6.69)", + "worker_nodes: titan-01, titan-02, titan-03", + ] + captured: dict[str, object] = {} + + def fake_ollama_call(hist_key, prompt, *, context, use_history=True, system_override=None, model=None, timeout=None): + captured["model"] = model + captured["timeout"] = timeout + captured["context"] = context + return "The worker spread stands out because Titan keeps meaningful capacity on the same cluster. Confidence: high" + + with ( + mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines), + mock.patch.object(self.bot, "_ollama_call", side_effect=fake_ollama_call), + ): + reply = self.bot.open_ended_answer( + "what stands out about titan lab?", + inventory=[], + snapshot=None, + workloads=[], + history_lines=[], + mode="genius", + allow_tools=True, + ) + + self.assertIn("The worker spread stands out", reply) + self.assertEqual(captured["model"], "genius-model") + self.assertLessEqual(float(captured["timeout"]), 180.0)