atlasbot: wire quick smart genius modes

This commit is contained in:
Brad Stein 2026-03-30 16:51:23 -03:00
parent f04f032721
commit a1e90f4600
2 changed files with 248 additions and 36 deletions

View File

@ -17,12 +17,15 @@ BOT_USER_QUICK = os.environ.get("BOT_USER_QUICK", "").strip()
BOT_PASS_QUICK = os.environ.get("BOT_PASS_QUICK", "").strip()
BOT_USER_SMART = os.environ.get("BOT_USER_SMART", "").strip()
BOT_PASS_SMART = os.environ.get("BOT_PASS_SMART", "").strip()
BOT_USER_GENIUS = os.environ.get("BOT_USER_GENIUS", "").strip()
BOT_PASS_GENIUS = os.environ.get("BOT_PASS_GENIUS", "").strip()
ROOM_ALIAS = "#othrys:live.bstein.dev"
OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/")
MODEL = os.environ.get("OLLAMA_MODEL", "qwen2.5:14b-instruct")
MODEL_FAST = os.environ.get("ATLASBOT_MODEL_FAST", "")
MODEL_DEEP = os.environ.get("ATLASBOT_MODEL_DEEP", "")
MODEL_SMART = os.environ.get("ATLASBOT_MODEL_SMART", os.environ.get("ATLASBOT_MODEL_DEEP", "")).strip()
MODEL_GENIUS = os.environ.get("ATLASBOT_MODEL_GENIUS", MODEL_SMART).strip()
FALLBACK_MODEL = os.environ.get("OLLAMA_FALLBACK_MODEL", "")
API_KEY = os.environ.get("CHAT_API_KEY", "")
OLLAMA_TIMEOUT_SEC = float(os.environ.get("OLLAMA_TIMEOUT_SEC", "480"))
@ -43,6 +46,9 @@ MAX_TOOL_CHARS = int(os.environ.get("ATLASBOT_MAX_TOOL_CHARS", "2500"))
MAX_FACTS_CHARS = int(os.environ.get("ATLASBOT_MAX_FACTS_CHARS", "8000"))
MAX_CONTEXT_CHARS = int(os.environ.get("ATLASBOT_MAX_CONTEXT_CHARS", "12000"))
THINKING_INTERVAL_SEC = int(os.environ.get("ATLASBOT_THINKING_INTERVAL_SEC", "120"))
QUICK_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_QUICK_TIME_BUDGET_SEC", "15"))
SMART_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_SMART_TIME_BUDGET_SEC", "45"))
GENIUS_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_GENIUS_TIME_BUDGET_SEC", "180"))
OLLAMA_RETRIES = int(os.environ.get("ATLASBOT_OLLAMA_RETRIES", "2"))
OLLAMA_SERIALIZE = os.environ.get("ATLASBOT_OLLAMA_SERIALIZE", "true").lower() != "false"
@ -384,16 +390,20 @@ def _strip_bot_mention(text: str) -> str:
return cleaned or text.strip()
def _detect_mode_from_body(body: str, *, default: str = "deep") -> str:
def _detect_mode_from_body(body: str, *, default: str = "smart") -> str:
lower = normalize_query(body or "")
if "atlas_quick" in lower or "atlas-quick" in lower:
return "fast"
if "atlas_smart" in lower or "atlas-smart" in lower:
return "deep"
return "smart"
if "atlas_genius" in lower or "atlas-genius" in lower:
return "genius"
if lower.startswith("quick ") or lower.startswith("fast "):
return "fast"
if lower.startswith("smart ") or lower.startswith("deep "):
return "deep"
if lower.startswith("smart "):
return "smart"
if lower.startswith("genius ") or lower.startswith("deep "):
return "genius"
return default
@ -401,7 +411,7 @@ def _detect_mode(
content: dict[str, Any],
body: str,
*,
default: str = "deep",
default: str = "smart",
account_user: str = "",
) -> str:
mode = _detect_mode_from_body(body, default=default)
@ -412,24 +422,72 @@ def _detect_mode(
if BOT_USER_QUICK and normalize_user_id(BOT_USER_QUICK).lower() in normalized:
return "fast"
if BOT_USER_SMART and normalize_user_id(BOT_USER_SMART).lower() in normalized:
return "deep"
return "smart"
if BOT_USER_GENIUS and normalize_user_id(BOT_USER_GENIUS).lower() in normalized:
return "genius"
if BOT_USER and normalize_user_id(BOT_USER).lower() in normalized:
return "deep"
return "smart"
if account_user and BOT_USER_QUICK and normalize_user_id(account_user) == normalize_user_id(BOT_USER_QUICK):
return "fast"
if account_user and BOT_USER_SMART and normalize_user_id(account_user) == normalize_user_id(BOT_USER_SMART):
return "deep"
return "smart"
if account_user and BOT_USER_GENIUS and normalize_user_id(account_user) == normalize_user_id(BOT_USER_GENIUS):
return "genius"
return mode
def _model_for_mode(mode: str) -> str:
if mode == "fast" and MODEL_FAST:
return MODEL_FAST
if mode == "deep" and MODEL_DEEP:
return MODEL_DEEP
if mode == "smart" and MODEL_SMART:
return MODEL_SMART
if mode == "genius" and MODEL_GENIUS:
return MODEL_GENIUS
if mode == "deep" and MODEL_SMART:
return MODEL_SMART
return MODEL
def _normalize_mode(mode: str) -> str:
normalized = (mode or "").strip().lower()
if normalized in {"quick", "fast"}:
return "fast"
if normalized in {"smart"}:
return "smart"
if normalized in {"genius", "deep"}:
return "genius"
return "smart"
def _mode_time_budget_sec(mode: str) -> float:
normalized = _normalize_mode(mode)
if normalized == "fast":
return max(1.0, QUICK_TIME_BUDGET_SEC)
if normalized == "smart":
return max(1.0, SMART_TIME_BUDGET_SEC)
if normalized == "genius":
return max(1.0, GENIUS_TIME_BUDGET_SEC)
return max(1.0, SMART_TIME_BUDGET_SEC)
def _mode_ollama_timeout_sec(mode: str) -> float:
normalized = _normalize_mode(mode)
budget = _mode_time_budget_sec(normalized)
if normalized == "fast":
return max(6.0, min(budget - 2.0, OLLAMA_TIMEOUT_SEC))
if normalized == "smart":
return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC))
if normalized == "genius":
return max(20.0, min(budget - 10.0, OLLAMA_TIMEOUT_SEC))
return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC))
def _mode_heartbeat_sec(mode: str) -> int:
normalized = _normalize_mode(mode)
budget = _mode_time_budget_sec(normalized)
return max(5, min(THINKING_INTERVAL_SEC, int(max(5.0, budget / 3.0))))
# Matrix HTTP helper.
def req(method: str, path: str, token: str | None = None, body=None, timeout=60, base: str | None = None):
url = (base or BASE) + path
@ -3842,9 +3900,12 @@ def _open_ended_multi(
def _open_ended_total_steps(mode: str) -> int:
if mode == "fast":
normalized = _normalize_mode(mode)
if normalized == "fast":
return 2
return 9
if normalized == "smart":
return 3
return 4
def _fast_fact_lines(
@ -4179,6 +4240,7 @@ def _open_ended_fast_single(
use_history=False,
system_override=_open_ended_system(),
model=model,
timeout=_mode_ollama_timeout_sec("fast"),
)
if not _has_body_lines(reply):
reply = _ollama_call(
@ -4188,6 +4250,7 @@ def _open_ended_fast_single(
use_history=False,
system_override=_open_ended_system(),
model=model,
timeout=_mode_ollama_timeout_sec("fast"),
)
fallback = _fallback_fact_answer(prompt, context)
if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)):
@ -4248,16 +4311,53 @@ def _open_ended_deep(
fact_lines: list[str],
fact_meta: dict[str, dict[str, Any]],
history_lines: list[str],
mode: str,
state: ThoughtState | None = None,
) -> str:
return _open_ended_multi(
prompt,
fact_pack=fact_pack,
fact_lines=fact_lines,
fact_meta=fact_meta,
history_lines=history_lines,
state=state,
normalized = _normalize_mode(mode)
model = _model_for_mode(normalized)
subjective = _is_subjective_query(prompt)
primary_tags = _primary_tags_for_prompt(prompt)
focus_tags = _preferred_tags_for_prompt(prompt)
if not focus_tags and subjective:
focus_tags = set(_ALLOWED_INSIGHT_TAGS)
avoid_tags = _history_focus_tags(history_lines) if (subjective or _is_followup_query(prompt)) else set()
limit = 12 if normalized == "smart" else 18
selected_lines = _fast_fact_lines(
fact_lines,
fact_meta,
focus_tags=focus_tags,
avoid_tags=avoid_tags,
primary_tags=primary_tags,
limit=limit,
)
selected_meta = _fact_pack_meta(selected_lines)
selected_pack = _fact_pack_text(selected_lines, selected_meta)
if _needs_full_fact_pack(prompt) or not selected_lines or normalized == "genius":
selected_pack = fact_pack
fallback = _fallback_fact_answer(prompt, selected_pack)
if not subjective and fallback:
if state:
state.update("done", step=_open_ended_total_steps(normalized))
return _ensure_scores(fallback)
if state:
state.update("drafting", step=1, note="synthesizing")
reply = _ollama_call(
("atlasbot_deep", "atlasbot_deep"),
prompt,
context=_append_history_context(selected_pack, history_lines),
use_history=False,
system_override=_open_ended_system(),
model=model,
timeout=_mode_ollama_timeout_sec(normalized),
)
if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)):
reply = fallback
if not _has_body_lines(reply):
reply = "I don't have enough data in the current snapshot to answer that."
if state:
state.update("done", step=_open_ended_total_steps(normalized))
return _ensure_scores(reply)
def open_ended_answer(
@ -4285,7 +4385,8 @@ def open_ended_answer(
return _ensure_scores("I don't have enough data to answer that.")
fact_meta = _fact_pack_meta(lines)
fact_pack = _fact_pack_text(lines, fact_meta)
if mode == "fast":
normalized = _normalize_mode(mode)
if normalized == "fast":
return _open_ended_fast(
prompt,
fact_pack=fact_pack,
@ -4300,6 +4401,7 @@ def open_ended_answer(
fact_lines=lines,
fact_meta=fact_meta,
history_lines=history_lines,
mode=normalized,
state=state,
)
@ -4321,6 +4423,7 @@ def _non_cluster_reply(prompt: str, *, history_lines: list[str], mode: str) -> s
use_history=False,
system_override=system,
model=model,
timeout=_mode_ollama_timeout_sec(mode),
)
reply = re.sub(r"\bconfidence\s*:\s*(high|medium|low)\b\.?\s*", "", reply, flags=re.IGNORECASE).strip()
return _ensure_scores(reply)
@ -4372,13 +4475,7 @@ class _AtlasbotHandler(BaseHTTPRequestHandler):
self._write_json(400, {"error": "missing_prompt"})
return
cleaned = _strip_bot_mention(prompt)
mode = str(payload.get("mode") or "deep").lower()
if mode in ("quick", "fast"):
mode = "fast"
elif mode in ("smart", "deep"):
mode = "deep"
else:
mode = "deep"
mode = _normalize_mode(str(payload.get("mode") or "smart"))
snapshot = _snapshot_state()
inventory = _snapshot_inventory(snapshot) or node_inventory_live()
workloads = _snapshot_workloads(snapshot)
@ -4669,6 +4766,7 @@ def _ollama_call(
use_history: bool = True,
system_override: str | None = None,
model: str | None = None,
timeout: float | None = None,
) -> str:
system = system_override or (
"System: You are Atlas, the Titan lab assistant for Atlas/Othrys. "
@ -4702,6 +4800,7 @@ def _ollama_call(
messages.append({"role": "user", "content": prompt})
model_name = model or MODEL
request_timeout = timeout if timeout is not None else OLLAMA_TIMEOUT_SEC
payload = {"model": model_name, "messages": messages, "stream": False}
headers = {"Content-Type": "application/json"}
if API_KEY:
@ -4712,13 +4811,13 @@ def _ollama_call(
lock.acquire()
try:
try:
with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp:
with request.urlopen(r, timeout=request_timeout) as resp:
data = json.loads(resp.read().decode())
except error.HTTPError as exc:
if exc.code == 404 and FALLBACK_MODEL and FALLBACK_MODEL != payload["model"]:
payload["model"] = FALLBACK_MODEL
r = request.Request(endpoint, data=json.dumps(payload).encode(), headers=headers)
with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp:
with request.urlopen(r, timeout=request_timeout) as resp:
data = json.loads(resp.read().decode())
else:
raise
@ -4743,6 +4842,7 @@ def ollama_reply(
fallback: str = "",
use_history: bool = True,
model: str | None = None,
timeout: float | None = None,
) -> str:
last_error = None
for attempt in range(max(1, OLLAMA_RETRIES + 1)):
@ -4753,6 +4853,7 @@ def ollama_reply(
context=context,
use_history=use_history,
model=model,
timeout=timeout,
)
except Exception as exc: # noqa: BLE001
last_error = exc
@ -4773,6 +4874,7 @@ def ollama_reply_with_thinking(
fallback: str,
use_history: bool = True,
model: str | None = None,
timeout: float | None = None,
) -> str:
result: dict[str, str] = {"reply": ""}
done = threading.Event()
@ -4785,6 +4887,7 @@ def ollama_reply_with_thinking(
fallback=fallback,
use_history=use_history,
model=model,
timeout=timeout,
)
done.set()
@ -4841,7 +4944,7 @@ def open_ended_with_thinking(
thread.start()
if not done.wait(2.0):
send_msg(token, room, "Thinking…")
heartbeat = max(10, THINKING_INTERVAL_SEC)
heartbeat = _mode_heartbeat_sec(mode)
next_heartbeat = time.monotonic() + heartbeat
while not done.wait(max(0, next_heartbeat - time.monotonic())):
send_msg(token, room, state.status_line())
@ -4906,7 +5009,7 @@ def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str)
mode = _detect_mode(
content,
body,
default=default_mode if default_mode in ("fast", "deep") else "deep",
default=_normalize_mode(default_mode),
account_user=account_user,
)
@ -4972,14 +5075,14 @@ def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str)
snapshot=snapshot,
workloads=workloads,
history_lines=history[hist_key],
mode=mode if mode in ("fast", "deep") else "deep",
mode=_normalize_mode(mode),
allow_tools=allow_tools,
)
else:
reply = _non_cluster_reply(
cleaned_body,
history_lines=history[hist_key],
mode=mode if mode in ("fast", "deep") else "deep",
mode=_normalize_mode(mode),
)
send_msg(token, rid, reply)
history[hist_key].append(f"Atlas: {reply}")
@ -5003,11 +5106,13 @@ def _bot_accounts() -> list[dict[str, str]]:
return
accounts.append({"user": user, "password": password, "mode": mode})
add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "deep")
add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "smart")
if BOT_USER_QUICK and BOT_PASS_QUICK:
add(BOT_USER_QUICK, BOT_PASS_QUICK, "fast")
if BOT_USER_GENIUS and BOT_PASS_GENIUS:
add(BOT_USER_GENIUS, BOT_PASS_GENIUS, "genius")
if BOT_USER and BOT_PASS and all(acc["user"] != BOT_USER for acc in accounts):
add(BOT_USER, BOT_PASS, "deep")
add(BOT_USER, BOT_PASS, "smart")
seen: set[str] = set()
unique: list[dict[str, str]] = []

View File

@ -0,0 +1,107 @@
from __future__ import annotations
import importlib.util
import os
from pathlib import Path
from unittest import TestCase, mock
BOT_PATH = Path(__file__).resolve().parents[1] / "atlasbot" / "bot.py"
def load_bot_module():
env = {
"BOT_USER": "atlas-smart",
"BOT_PASS": "smart-pass",
"BOT_USER_QUICK": "atlas-quick",
"BOT_PASS_QUICK": "quick-pass",
"BOT_USER_SMART": "atlas-smart",
"BOT_PASS_SMART": "smart-pass",
"BOT_USER_GENIUS": "atlas-genius",
"BOT_PASS_GENIUS": "genius-pass",
"OLLAMA_URL": "http://ollama.invalid",
"OLLAMA_MODEL": "base-model",
"ATLASBOT_MODEL_FAST": "fast-model",
"ATLASBOT_MODEL_SMART": "smart-model",
"ATLASBOT_MODEL_GENIUS": "genius-model",
"ATLASBOT_QUICK_TIME_BUDGET_SEC": "15",
"ATLASBOT_SMART_TIME_BUDGET_SEC": "45",
"ATLASBOT_GENIUS_TIME_BUDGET_SEC": "180",
"KB_DIR": "",
"VM_URL": "http://vm.invalid",
"ARIADNE_STATE_URL": "",
"ARIADNE_STATE_TOKEN": "",
}
with mock.patch.dict(os.environ, env, clear=False):
spec = importlib.util.spec_from_file_location("atlasbot_bot", BOT_PATH)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
class AtlasbotModeTests(TestCase):
def setUp(self):
self.bot = load_bot_module()
def test_bot_accounts_include_genius_mode(self):
accounts = self.bot._bot_accounts()
by_user = {account["user"]: account["mode"] for account in accounts}
self.assertEqual(by_user["atlas-quick"], "fast")
self.assertEqual(by_user["atlas-smart"], "smart")
self.assertEqual(by_user["atlas-genius"], "genius")
def test_objective_cluster_question_uses_fact_pack_without_llm(self):
fact_lines = [
"hottest_cpu: longhorn-system (6.69)",
"hottest_ram: longhorn-system (36.05 GB)",
]
with (
mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines),
mock.patch.object(self.bot, "_ollama_call", side_effect=AssertionError("LLM should not be called")),
):
reply = self.bot.open_ended_answer(
"what is the hottest cpu node in titan lab currently?",
inventory=[],
snapshot=None,
workloads=[],
history_lines=[],
mode="smart",
allow_tools=True,
)
self.assertIn("longhorn-system", reply)
self.assertIn("Confidence:", reply)
def test_subjective_genius_answer_uses_genius_model(self):
fact_lines = [
"hottest_cpu: longhorn-system (6.69)",
"worker_nodes: titan-01, titan-02, titan-03",
]
captured: dict[str, object] = {}
def fake_ollama_call(hist_key, prompt, *, context, use_history=True, system_override=None, model=None, timeout=None):
captured["model"] = model
captured["timeout"] = timeout
captured["context"] = context
return "The worker spread stands out because Titan keeps meaningful capacity on the same cluster. Confidence: high"
with (
mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines),
mock.patch.object(self.bot, "_ollama_call", side_effect=fake_ollama_call),
):
reply = self.bot.open_ended_answer(
"what stands out about titan lab?",
inventory=[],
snapshot=None,
workloads=[],
history_lines=[],
mode="genius",
allow_tools=True,
)
self.assertIn("The worker spread stands out", reply)
self.assertEqual(captured["model"], "genius-model")
self.assertLessEqual(float(captured["timeout"]), 180.0)