atlasbot: wire quick smart genius modes

This commit is contained in:
Brad Stein 2026-03-30 16:51:23 -03:00
parent a3a00cfa9d
commit 0d8a2c5531
2 changed files with 333 additions and 47 deletions

View File

@ -11,14 +11,23 @@ from urllib import error, parse, request
BASE = os.environ.get("MATRIX_BASE", "http://othrys-synapse-matrix-synapse:8008")
AUTH_BASE = os.environ.get("AUTH_BASE", "http://matrix-authentication-service:8080")
USER = os.environ["BOT_USER"]
PASSWORD = os.environ["BOT_PASS"]
BOT_USER = os.environ["BOT_USER"]
BOT_PASS = os.environ["BOT_PASS"]
BOT_USER_QUICK = os.environ.get("BOT_USER_QUICK", "").strip()
BOT_PASS_QUICK = os.environ.get("BOT_PASS_QUICK", "").strip()
BOT_USER_SMART = os.environ.get("BOT_USER_SMART", "").strip()
BOT_PASS_SMART = os.environ.get("BOT_PASS_SMART", "").strip()
BOT_USER_GENIUS = os.environ.get("BOT_USER_GENIUS", "").strip()
BOT_PASS_GENIUS = os.environ.get("BOT_PASS_GENIUS", "").strip()
USER = BOT_USER
PASSWORD = BOT_PASS
ROOM_ALIAS = "#othrys:live.bstein.dev"
OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/")
MODEL = os.environ.get("OLLAMA_MODEL", "qwen2.5:14b-instruct")
MODEL_FAST = os.environ.get("ATLASBOT_MODEL_FAST", "")
MODEL_DEEP = os.environ.get("ATLASBOT_MODEL_DEEP", "")
MODEL_SMART = os.environ.get("ATLASBOT_MODEL_SMART", os.environ.get("ATLASBOT_MODEL_DEEP", "")).strip()
MODEL_GENIUS = os.environ.get("ATLASBOT_MODEL_GENIUS", MODEL_SMART).strip()
FALLBACK_MODEL = os.environ.get("OLLAMA_FALLBACK_MODEL", "")
API_KEY = os.environ.get("CHAT_API_KEY", "")
OLLAMA_TIMEOUT_SEC = float(os.environ.get("OLLAMA_TIMEOUT_SEC", "480"))
@ -31,7 +40,7 @@ VM_URL = os.environ.get("VM_URL", "http://victoria-metrics-single-server.monitor
ARIADNE_STATE_URL = os.environ.get("ARIADNE_STATE_URL", "")
ARIADNE_STATE_TOKEN = os.environ.get("ARIADNE_STATE_TOKEN", "")
BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{USER},atlas")
BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{BOT_USER},atlas")
SERVER_NAME = os.environ.get("MATRIX_SERVER_NAME", "live.bstein.dev")
MAX_KB_CHARS = int(os.environ.get("ATLASBOT_MAX_KB_CHARS", "2500"))
@ -39,6 +48,9 @@ MAX_TOOL_CHARS = int(os.environ.get("ATLASBOT_MAX_TOOL_CHARS", "2500"))
MAX_FACTS_CHARS = int(os.environ.get("ATLASBOT_MAX_FACTS_CHARS", "8000"))
MAX_CONTEXT_CHARS = int(os.environ.get("ATLASBOT_MAX_CONTEXT_CHARS", "12000"))
THINKING_INTERVAL_SEC = int(os.environ.get("ATLASBOT_THINKING_INTERVAL_SEC", "120"))
QUICK_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_QUICK_TIME_BUDGET_SEC", "15"))
SMART_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_SMART_TIME_BUDGET_SEC", "45"))
GENIUS_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_GENIUS_TIME_BUDGET_SEC", "180"))
OLLAMA_RETRIES = int(os.environ.get("ATLASBOT_OLLAMA_RETRIES", "2"))
OLLAMA_SERIALIZE = os.environ.get("ATLASBOT_OLLAMA_SERIALIZE", "true").lower() != "false"
@ -380,27 +392,103 @@ def _strip_bot_mention(text: str) -> str:
return cleaned or text.strip()
def _detect_mode_from_body(body: str, *, default: str = "deep") -> str:
def _detect_mode_from_body(body: str, *, default: str = "smart") -> str:
lower = normalize_query(body or "")
if "atlas_quick" in lower or "atlas-quick" in lower:
return "fast"
if "atlas_smart" in lower or "atlas-smart" in lower:
return "deep"
return "smart"
if "atlas_genius" in lower or "atlas-genius" in lower:
return "genius"
if lower.startswith("quick ") or lower.startswith("fast "):
return "fast"
if lower.startswith("smart ") or lower.startswith("deep "):
return "deep"
if lower.startswith("smart "):
return "smart"
if lower.startswith("genius ") or lower.startswith("deep "):
return "genius"
return default
def _detect_mode(
content: dict[str, Any],
body: str,
*,
default: str = "smart",
account_user: str = "",
) -> str:
mode = _detect_mode_from_body(body, default=default)
mentions = content.get("m.mentions", {})
user_ids = mentions.get("user_ids", [])
if isinstance(user_ids, list):
normalized = {normalize_user_id(uid).lower() for uid in user_ids if isinstance(uid, str)}
if BOT_USER_QUICK and normalize_user_id(BOT_USER_QUICK).lower() in normalized:
return "fast"
if BOT_USER_SMART and normalize_user_id(BOT_USER_SMART).lower() in normalized:
return "smart"
if BOT_USER_GENIUS and normalize_user_id(BOT_USER_GENIUS).lower() in normalized:
return "genius"
if BOT_USER and normalize_user_id(BOT_USER).lower() in normalized:
return "smart"
if account_user and BOT_USER_QUICK and normalize_user_id(account_user) == normalize_user_id(BOT_USER_QUICK):
return "fast"
if account_user and BOT_USER_SMART and normalize_user_id(account_user) == normalize_user_id(BOT_USER_SMART):
return "smart"
if account_user and BOT_USER_GENIUS and normalize_user_id(account_user) == normalize_user_id(BOT_USER_GENIUS):
return "genius"
return mode
def _model_for_mode(mode: str) -> str:
if mode == "fast" and MODEL_FAST:
return MODEL_FAST
if mode == "deep" and MODEL_DEEP:
return MODEL_DEEP
if mode == "smart" and MODEL_SMART:
return MODEL_SMART
if mode == "genius" and MODEL_GENIUS:
return MODEL_GENIUS
if mode == "deep" and MODEL_SMART:
return MODEL_SMART
return MODEL
def _normalize_mode(mode: str) -> str:
normalized = (mode or "").strip().lower()
if normalized in {"quick", "fast"}:
return "fast"
if normalized in {"smart"}:
return "smart"
if normalized in {"genius", "deep"}:
return "genius"
return "smart"
def _mode_time_budget_sec(mode: str) -> float:
normalized = _normalize_mode(mode)
if normalized == "fast":
return max(1.0, QUICK_TIME_BUDGET_SEC)
if normalized == "smart":
return max(1.0, SMART_TIME_BUDGET_SEC)
if normalized == "genius":
return max(1.0, GENIUS_TIME_BUDGET_SEC)
return max(1.0, SMART_TIME_BUDGET_SEC)
def _mode_ollama_timeout_sec(mode: str) -> float:
normalized = _normalize_mode(mode)
budget = _mode_time_budget_sec(normalized)
if normalized == "fast":
return max(6.0, min(budget - 2.0, OLLAMA_TIMEOUT_SEC))
if normalized == "smart":
return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC))
if normalized == "genius":
return max(20.0, min(budget - 10.0, OLLAMA_TIMEOUT_SEC))
return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC))
def _mode_heartbeat_sec(mode: str) -> int:
normalized = _normalize_mode(mode)
budget = _mode_time_budget_sec(normalized)
return max(5, min(THINKING_INTERVAL_SEC, int(max(5.0, budget / 3.0))))
# Matrix HTTP helper.
def req(method: str, path: str, token: str | None = None, body=None, timeout=60, base: str | None = None):
url = (base or BASE) + path
@ -416,12 +504,12 @@ def req(method: str, path: str, token: str | None = None, body=None, timeout=60,
raw = resp.read()
return json.loads(raw.decode()) if raw else {}
def login() -> str:
login_user = normalize_user_id(USER)
def login(user: str, password: str) -> str:
login_user = normalize_user_id(user)
payload = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": login_user},
"password": PASSWORD,
"password": password,
}
res = req("POST", "/_matrix/client/v3/login", body=payload, base=AUTH_BASE)
return res["access_token"]
@ -3813,9 +3901,12 @@ def _open_ended_multi(
def _open_ended_total_steps(mode: str) -> int:
if mode == "fast":
normalized = _normalize_mode(mode)
if normalized == "fast":
return 2
return 9
if normalized == "smart":
return 3
return 4
def _fast_fact_lines(
@ -4150,6 +4241,7 @@ def _open_ended_fast_single(
use_history=False,
system_override=_open_ended_system(),
model=model,
timeout=_mode_ollama_timeout_sec("fast"),
)
if not _has_body_lines(reply):
reply = _ollama_call(
@ -4159,6 +4251,7 @@ def _open_ended_fast_single(
use_history=False,
system_override=_open_ended_system(),
model=model,
timeout=_mode_ollama_timeout_sec("fast"),
)
fallback = _fallback_fact_answer(prompt, context)
if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)):
@ -4219,16 +4312,53 @@ def _open_ended_deep(
fact_lines: list[str],
fact_meta: dict[str, dict[str, Any]],
history_lines: list[str],
mode: str,
state: ThoughtState | None = None,
) -> str:
return _open_ended_multi(
prompt,
fact_pack=fact_pack,
fact_lines=fact_lines,
fact_meta=fact_meta,
history_lines=history_lines,
state=state,
normalized = _normalize_mode(mode)
model = _model_for_mode(normalized)
subjective = _is_subjective_query(prompt)
primary_tags = _primary_tags_for_prompt(prompt)
focus_tags = _preferred_tags_for_prompt(prompt)
if not focus_tags and subjective:
focus_tags = set(_ALLOWED_INSIGHT_TAGS)
avoid_tags = _history_focus_tags(history_lines) if (subjective or _is_followup_query(prompt)) else set()
limit = 12 if normalized == "smart" else 18
selected_lines = _fast_fact_lines(
fact_lines,
fact_meta,
focus_tags=focus_tags,
avoid_tags=avoid_tags,
primary_tags=primary_tags,
limit=limit,
)
selected_meta = _fact_pack_meta(selected_lines)
selected_pack = _fact_pack_text(selected_lines, selected_meta)
if _needs_full_fact_pack(prompt) or not selected_lines or normalized == "genius":
selected_pack = fact_pack
fallback = _fallback_fact_answer(prompt, selected_pack)
if not subjective and fallback:
if state:
state.update("done", step=_open_ended_total_steps(normalized))
return _ensure_scores(fallback)
if state:
state.update("drafting", step=1, note="synthesizing")
reply = _ollama_call(
("atlasbot_deep", "atlasbot_deep"),
prompt,
context=_append_history_context(selected_pack, history_lines),
use_history=False,
system_override=_open_ended_system(),
model=model,
timeout=_mode_ollama_timeout_sec(normalized),
)
if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)):
reply = fallback
if not _has_body_lines(reply):
reply = "I don't have enough data in the current snapshot to answer that."
if state:
state.update("done", step=_open_ended_total_steps(normalized))
return _ensure_scores(reply)
def open_ended_answer(
@ -4256,7 +4386,8 @@ def open_ended_answer(
return _ensure_scores("I don't have enough data to answer that.")
fact_meta = _fact_pack_meta(lines)
fact_pack = _fact_pack_text(lines, fact_meta)
if mode == "fast":
normalized = _normalize_mode(mode)
if normalized == "fast":
return _open_ended_fast(
prompt,
fact_pack=fact_pack,
@ -4271,6 +4402,7 @@ def open_ended_answer(
fact_lines=lines,
fact_meta=fact_meta,
history_lines=history_lines,
mode=normalized,
state=state,
)
@ -4292,6 +4424,7 @@ def _non_cluster_reply(prompt: str, *, history_lines: list[str], mode: str) -> s
use_history=False,
system_override=system,
model=model,
timeout=_mode_ollama_timeout_sec(mode),
)
reply = re.sub(r"\bconfidence\s*:\s*(high|medium|low)\b\.?\s*", "", reply, flags=re.IGNORECASE).strip()
return _ensure_scores(reply)
@ -4343,13 +4476,7 @@ class _AtlasbotHandler(BaseHTTPRequestHandler):
self._write_json(400, {"error": "missing_prompt"})
return
cleaned = _strip_bot_mention(prompt)
mode = str(payload.get("mode") or "deep").lower()
if mode in ("quick", "fast"):
mode = "fast"
elif mode in ("smart", "deep"):
mode = "deep"
else:
mode = "deep"
mode = _normalize_mode(str(payload.get("mode") or "smart"))
snapshot = _snapshot_state()
inventory = _snapshot_inventory(snapshot) or node_inventory_live()
workloads = _snapshot_workloads(snapshot)
@ -4640,6 +4767,7 @@ def _ollama_call(
use_history: bool = True,
system_override: str | None = None,
model: str | None = None,
timeout: float | None = None,
) -> str:
system = system_override or (
"System: You are Atlas, the Titan lab assistant for Atlas/Othrys. "
@ -4673,6 +4801,7 @@ def _ollama_call(
messages.append({"role": "user", "content": prompt})
model_name = model or MODEL
request_timeout = timeout if timeout is not None else OLLAMA_TIMEOUT_SEC
payload = {"model": model_name, "messages": messages, "stream": False}
headers = {"Content-Type": "application/json"}
if API_KEY:
@ -4683,13 +4812,13 @@ def _ollama_call(
lock.acquire()
try:
try:
with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp:
with request.urlopen(r, timeout=request_timeout) as resp:
data = json.loads(resp.read().decode())
except error.HTTPError as exc:
if exc.code == 404 and FALLBACK_MODEL and FALLBACK_MODEL != payload["model"]:
payload["model"] = FALLBACK_MODEL
r = request.Request(endpoint, data=json.dumps(payload).encode(), headers=headers)
with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp:
with request.urlopen(r, timeout=request_timeout) as resp:
data = json.loads(resp.read().decode())
else:
raise
@ -4714,6 +4843,7 @@ def ollama_reply(
fallback: str = "",
use_history: bool = True,
model: str | None = None,
timeout: float | None = None,
) -> str:
last_error = None
for attempt in range(max(1, OLLAMA_RETRIES + 1)):
@ -4724,6 +4854,7 @@ def ollama_reply(
context=context,
use_history=use_history,
model=model,
timeout=timeout,
)
except Exception as exc: # noqa: BLE001
last_error = exc
@ -4744,6 +4875,7 @@ def ollama_reply_with_thinking(
fallback: str,
use_history: bool = True,
model: str | None = None,
timeout: float | None = None,
) -> str:
result: dict[str, str] = {"reply": ""}
done = threading.Event()
@ -4756,6 +4888,7 @@ def ollama_reply_with_thinking(
fallback=fallback,
use_history=use_history,
model=model,
timeout=timeout,
)
done.set()
@ -4812,7 +4945,7 @@ def open_ended_with_thinking(
thread.start()
if not done.wait(2.0):
send_msg(token, room, "Thinking…")
heartbeat = max(10, THINKING_INTERVAL_SEC)
heartbeat = _mode_heartbeat_sec(mode)
next_heartbeat = time.monotonic() + heartbeat
while not done.wait(max(0, next_heartbeat - time.monotonic())):
send_msg(token, room, state.status_line())
@ -4820,7 +4953,7 @@ def open_ended_with_thinking(
thread.join(timeout=1)
return result["reply"] or "Model backend is busy. Try again in a moment."
def sync_loop(token: str, room_id: str):
def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str):
since = None
try:
res = req("GET", "/_matrix/client/v3/sync?timeout=0", token, timeout=10)
@ -4861,7 +4994,7 @@ def sync_loop(token: str, room_id: str):
if not body:
continue
sender = ev.get("sender", "")
if sender == f"@{USER}:live.bstein.dev":
if account_user and sender == normalize_user_id(account_user):
continue
mentioned = is_mentioned(content, body)
@ -4874,7 +5007,12 @@ def sync_loop(token: str, room_id: str):
cleaned_body = _strip_bot_mention(body)
lower_body = cleaned_body.lower()
mode = _detect_mode_from_body(body, default="deep" if is_dm else "deep")
mode = _detect_mode(
content,
body,
default=_normalize_mode(default_mode),
account_user=account_user,
)
# Only do live cluster introspection in DMs.
allow_tools = is_dm
@ -4938,39 +5076,80 @@ def sync_loop(token: str, room_id: str):
snapshot=snapshot,
workloads=workloads,
history_lines=history[hist_key],
mode=mode if mode in ("fast", "deep") else "deep",
mode=_normalize_mode(mode),
allow_tools=allow_tools,
)
else:
reply = _non_cluster_reply(
cleaned_body,
history_lines=history[hist_key],
mode=mode if mode in ("fast", "deep") else "deep",
mode=_normalize_mode(mode),
)
send_msg(token, rid, reply)
history[hist_key].append(f"Atlas: {reply}")
history[hist_key] = history[hist_key][-80:]
def login_with_retry():
def login_with_retry(user: str, password: str):
last_err = None
for attempt in range(10):
try:
return login()
return login(user, password)
except Exception as exc: # noqa: BLE001
last_err = exc
time.sleep(min(30, 2 ** attempt))
raise last_err
def _bot_accounts() -> list[dict[str, str]]:
accounts: list[dict[str, str]] = []
def add(user: str, password: str, mode: str):
if not user or not password:
return
accounts.append({"user": user, "password": password, "mode": mode})
add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "smart")
if BOT_USER_QUICK and BOT_PASS_QUICK:
add(BOT_USER_QUICK, BOT_PASS_QUICK, "fast")
if BOT_USER_GENIUS and BOT_PASS_GENIUS:
add(BOT_USER_GENIUS, BOT_PASS_GENIUS, "genius")
if BOT_USER and BOT_PASS and all(acc["user"] != BOT_USER for acc in accounts):
add(BOT_USER, BOT_PASS, "smart")
seen: set[str] = set()
unique: list[dict[str, str]] = []
for acc in accounts:
uid = normalize_user_id(acc["user"]).lower()
if uid in seen:
continue
seen.add(uid)
unique.append(acc)
return unique
def main():
load_kb()
_start_http_server()
token = login_with_retry()
try:
room_id = resolve_alias(token, ROOM_ALIAS)
join_room(token, room_id)
except Exception:
room_id = None
sync_loop(token, room_id)
accounts = _bot_accounts()
threads: list[threading.Thread] = []
for acc in accounts:
token = login_with_retry(acc["user"], acc["password"])
try:
room_id = resolve_alias(token, ROOM_ALIAS)
join_room(token, room_id)
except Exception:
room_id = None
thread = threading.Thread(
target=sync_loop,
args=(token, room_id),
kwargs={
"account_user": acc["user"],
"default_mode": acc["mode"],
},
daemon=True,
)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,107 @@
from __future__ import annotations
import importlib.util
import os
from pathlib import Path
from unittest import TestCase, mock
BOT_PATH = Path(__file__).resolve().parents[1] / "atlasbot" / "bot.py"
def load_bot_module():
env = {
"BOT_USER": "atlas-smart",
"BOT_PASS": "smart-pass",
"BOT_USER_QUICK": "atlas-quick",
"BOT_PASS_QUICK": "quick-pass",
"BOT_USER_SMART": "atlas-smart",
"BOT_PASS_SMART": "smart-pass",
"BOT_USER_GENIUS": "atlas-genius",
"BOT_PASS_GENIUS": "genius-pass",
"OLLAMA_URL": "http://ollama.invalid",
"OLLAMA_MODEL": "base-model",
"ATLASBOT_MODEL_FAST": "fast-model",
"ATLASBOT_MODEL_SMART": "smart-model",
"ATLASBOT_MODEL_GENIUS": "genius-model",
"ATLASBOT_QUICK_TIME_BUDGET_SEC": "15",
"ATLASBOT_SMART_TIME_BUDGET_SEC": "45",
"ATLASBOT_GENIUS_TIME_BUDGET_SEC": "180",
"KB_DIR": "",
"VM_URL": "http://vm.invalid",
"ARIADNE_STATE_URL": "",
"ARIADNE_STATE_TOKEN": "",
}
with mock.patch.dict(os.environ, env, clear=False):
spec = importlib.util.spec_from_file_location("atlasbot_bot", BOT_PATH)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
class AtlasbotModeTests(TestCase):
def setUp(self):
self.bot = load_bot_module()
def test_bot_accounts_include_genius_mode(self):
accounts = self.bot._bot_accounts()
by_user = {account["user"]: account["mode"] for account in accounts}
self.assertEqual(by_user["atlas-quick"], "fast")
self.assertEqual(by_user["atlas-smart"], "smart")
self.assertEqual(by_user["atlas-genius"], "genius")
def test_objective_cluster_question_uses_fact_pack_without_llm(self):
fact_lines = [
"hottest_cpu: longhorn-system (6.69)",
"hottest_ram: longhorn-system (36.05 GB)",
]
with (
mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines),
mock.patch.object(self.bot, "_ollama_call", side_effect=AssertionError("LLM should not be called")),
):
reply = self.bot.open_ended_answer(
"what is the hottest cpu node in titan lab currently?",
inventory=[],
snapshot=None,
workloads=[],
history_lines=[],
mode="smart",
allow_tools=True,
)
self.assertIn("longhorn-system", reply)
self.assertIn("Confidence:", reply)
def test_subjective_genius_answer_uses_genius_model(self):
fact_lines = [
"hottest_cpu: longhorn-system (6.69)",
"worker_nodes: titan-01, titan-02, titan-03",
]
captured: dict[str, object] = {}
def fake_ollama_call(hist_key, prompt, *, context, use_history=True, system_override=None, model=None, timeout=None):
captured["model"] = model
captured["timeout"] = timeout
captured["context"] = context
return "The worker spread stands out because Titan keeps meaningful capacity on the same cluster. Confidence: high"
with (
mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines),
mock.patch.object(self.bot, "_ollama_call", side_effect=fake_ollama_call),
):
reply = self.bot.open_ended_answer(
"what stands out about titan lab?",
inventory=[],
snapshot=None,
workloads=[],
history_lines=[],
mode="genius",
allow_tools=True,
)
self.assertIn("The worker spread stands out", reply)
self.assertEqual(captured["model"], "genius-model")
self.assertLessEqual(float(captured["timeout"]), 180.0)