atlasbot: wire quick smart genius modes
This commit is contained in:
parent
a3a00cfa9d
commit
0d8a2c5531
@ -11,14 +11,23 @@ from urllib import error, parse, request
|
||||
|
||||
BASE = os.environ.get("MATRIX_BASE", "http://othrys-synapse-matrix-synapse:8008")
|
||||
AUTH_BASE = os.environ.get("AUTH_BASE", "http://matrix-authentication-service:8080")
|
||||
USER = os.environ["BOT_USER"]
|
||||
PASSWORD = os.environ["BOT_PASS"]
|
||||
BOT_USER = os.environ["BOT_USER"]
|
||||
BOT_PASS = os.environ["BOT_PASS"]
|
||||
BOT_USER_QUICK = os.environ.get("BOT_USER_QUICK", "").strip()
|
||||
BOT_PASS_QUICK = os.environ.get("BOT_PASS_QUICK", "").strip()
|
||||
BOT_USER_SMART = os.environ.get("BOT_USER_SMART", "").strip()
|
||||
BOT_PASS_SMART = os.environ.get("BOT_PASS_SMART", "").strip()
|
||||
BOT_USER_GENIUS = os.environ.get("BOT_USER_GENIUS", "").strip()
|
||||
BOT_PASS_GENIUS = os.environ.get("BOT_PASS_GENIUS", "").strip()
|
||||
USER = BOT_USER
|
||||
PASSWORD = BOT_PASS
|
||||
ROOM_ALIAS = "#othrys:live.bstein.dev"
|
||||
|
||||
OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/")
|
||||
MODEL = os.environ.get("OLLAMA_MODEL", "qwen2.5:14b-instruct")
|
||||
MODEL_FAST = os.environ.get("ATLASBOT_MODEL_FAST", "")
|
||||
MODEL_DEEP = os.environ.get("ATLASBOT_MODEL_DEEP", "")
|
||||
MODEL_SMART = os.environ.get("ATLASBOT_MODEL_SMART", os.environ.get("ATLASBOT_MODEL_DEEP", "")).strip()
|
||||
MODEL_GENIUS = os.environ.get("ATLASBOT_MODEL_GENIUS", MODEL_SMART).strip()
|
||||
FALLBACK_MODEL = os.environ.get("OLLAMA_FALLBACK_MODEL", "")
|
||||
API_KEY = os.environ.get("CHAT_API_KEY", "")
|
||||
OLLAMA_TIMEOUT_SEC = float(os.environ.get("OLLAMA_TIMEOUT_SEC", "480"))
|
||||
@ -31,7 +40,7 @@ VM_URL = os.environ.get("VM_URL", "http://victoria-metrics-single-server.monitor
|
||||
ARIADNE_STATE_URL = os.environ.get("ARIADNE_STATE_URL", "")
|
||||
ARIADNE_STATE_TOKEN = os.environ.get("ARIADNE_STATE_TOKEN", "")
|
||||
|
||||
BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{USER},atlas")
|
||||
BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{BOT_USER},atlas")
|
||||
SERVER_NAME = os.environ.get("MATRIX_SERVER_NAME", "live.bstein.dev")
|
||||
|
||||
MAX_KB_CHARS = int(os.environ.get("ATLASBOT_MAX_KB_CHARS", "2500"))
|
||||
@ -39,6 +48,9 @@ MAX_TOOL_CHARS = int(os.environ.get("ATLASBOT_MAX_TOOL_CHARS", "2500"))
|
||||
MAX_FACTS_CHARS = int(os.environ.get("ATLASBOT_MAX_FACTS_CHARS", "8000"))
|
||||
MAX_CONTEXT_CHARS = int(os.environ.get("ATLASBOT_MAX_CONTEXT_CHARS", "12000"))
|
||||
THINKING_INTERVAL_SEC = int(os.environ.get("ATLASBOT_THINKING_INTERVAL_SEC", "120"))
|
||||
QUICK_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_QUICK_TIME_BUDGET_SEC", "15"))
|
||||
SMART_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_SMART_TIME_BUDGET_SEC", "45"))
|
||||
GENIUS_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_GENIUS_TIME_BUDGET_SEC", "180"))
|
||||
OLLAMA_RETRIES = int(os.environ.get("ATLASBOT_OLLAMA_RETRIES", "2"))
|
||||
OLLAMA_SERIALIZE = os.environ.get("ATLASBOT_OLLAMA_SERIALIZE", "true").lower() != "false"
|
||||
|
||||
@ -380,27 +392,103 @@ def _strip_bot_mention(text: str) -> str:
|
||||
return cleaned or text.strip()
|
||||
|
||||
|
||||
def _detect_mode_from_body(body: str, *, default: str = "deep") -> str:
|
||||
def _detect_mode_from_body(body: str, *, default: str = "smart") -> str:
|
||||
lower = normalize_query(body or "")
|
||||
if "atlas_quick" in lower or "atlas-quick" in lower:
|
||||
return "fast"
|
||||
if "atlas_smart" in lower or "atlas-smart" in lower:
|
||||
return "deep"
|
||||
return "smart"
|
||||
if "atlas_genius" in lower or "atlas-genius" in lower:
|
||||
return "genius"
|
||||
if lower.startswith("quick ") or lower.startswith("fast "):
|
||||
return "fast"
|
||||
if lower.startswith("smart ") or lower.startswith("deep "):
|
||||
return "deep"
|
||||
if lower.startswith("smart "):
|
||||
return "smart"
|
||||
if lower.startswith("genius ") or lower.startswith("deep "):
|
||||
return "genius"
|
||||
return default
|
||||
|
||||
def _detect_mode(
|
||||
content: dict[str, Any],
|
||||
body: str,
|
||||
*,
|
||||
default: str = "smart",
|
||||
account_user: str = "",
|
||||
) -> str:
|
||||
mode = _detect_mode_from_body(body, default=default)
|
||||
mentions = content.get("m.mentions", {})
|
||||
user_ids = mentions.get("user_ids", [])
|
||||
if isinstance(user_ids, list):
|
||||
normalized = {normalize_user_id(uid).lower() for uid in user_ids if isinstance(uid, str)}
|
||||
if BOT_USER_QUICK and normalize_user_id(BOT_USER_QUICK).lower() in normalized:
|
||||
return "fast"
|
||||
if BOT_USER_SMART and normalize_user_id(BOT_USER_SMART).lower() in normalized:
|
||||
return "smart"
|
||||
if BOT_USER_GENIUS and normalize_user_id(BOT_USER_GENIUS).lower() in normalized:
|
||||
return "genius"
|
||||
if BOT_USER and normalize_user_id(BOT_USER).lower() in normalized:
|
||||
return "smart"
|
||||
if account_user and BOT_USER_QUICK and normalize_user_id(account_user) == normalize_user_id(BOT_USER_QUICK):
|
||||
return "fast"
|
||||
if account_user and BOT_USER_SMART and normalize_user_id(account_user) == normalize_user_id(BOT_USER_SMART):
|
||||
return "smart"
|
||||
if account_user and BOT_USER_GENIUS and normalize_user_id(account_user) == normalize_user_id(BOT_USER_GENIUS):
|
||||
return "genius"
|
||||
return mode
|
||||
|
||||
|
||||
def _model_for_mode(mode: str) -> str:
|
||||
if mode == "fast" and MODEL_FAST:
|
||||
return MODEL_FAST
|
||||
if mode == "deep" and MODEL_DEEP:
|
||||
return MODEL_DEEP
|
||||
if mode == "smart" and MODEL_SMART:
|
||||
return MODEL_SMART
|
||||
if mode == "genius" and MODEL_GENIUS:
|
||||
return MODEL_GENIUS
|
||||
if mode == "deep" and MODEL_SMART:
|
||||
return MODEL_SMART
|
||||
return MODEL
|
||||
|
||||
|
||||
def _normalize_mode(mode: str) -> str:
|
||||
normalized = (mode or "").strip().lower()
|
||||
if normalized in {"quick", "fast"}:
|
||||
return "fast"
|
||||
if normalized in {"smart"}:
|
||||
return "smart"
|
||||
if normalized in {"genius", "deep"}:
|
||||
return "genius"
|
||||
return "smart"
|
||||
|
||||
|
||||
def _mode_time_budget_sec(mode: str) -> float:
|
||||
normalized = _normalize_mode(mode)
|
||||
if normalized == "fast":
|
||||
return max(1.0, QUICK_TIME_BUDGET_SEC)
|
||||
if normalized == "smart":
|
||||
return max(1.0, SMART_TIME_BUDGET_SEC)
|
||||
if normalized == "genius":
|
||||
return max(1.0, GENIUS_TIME_BUDGET_SEC)
|
||||
return max(1.0, SMART_TIME_BUDGET_SEC)
|
||||
|
||||
|
||||
def _mode_ollama_timeout_sec(mode: str) -> float:
|
||||
normalized = _normalize_mode(mode)
|
||||
budget = _mode_time_budget_sec(normalized)
|
||||
if normalized == "fast":
|
||||
return max(6.0, min(budget - 2.0, OLLAMA_TIMEOUT_SEC))
|
||||
if normalized == "smart":
|
||||
return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC))
|
||||
if normalized == "genius":
|
||||
return max(20.0, min(budget - 10.0, OLLAMA_TIMEOUT_SEC))
|
||||
return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC))
|
||||
|
||||
|
||||
def _mode_heartbeat_sec(mode: str) -> int:
|
||||
normalized = _normalize_mode(mode)
|
||||
budget = _mode_time_budget_sec(normalized)
|
||||
return max(5, min(THINKING_INTERVAL_SEC, int(max(5.0, budget / 3.0))))
|
||||
|
||||
|
||||
# Matrix HTTP helper.
|
||||
def req(method: str, path: str, token: str | None = None, body=None, timeout=60, base: str | None = None):
|
||||
url = (base or BASE) + path
|
||||
@ -416,12 +504,12 @@ def req(method: str, path: str, token: str | None = None, body=None, timeout=60,
|
||||
raw = resp.read()
|
||||
return json.loads(raw.decode()) if raw else {}
|
||||
|
||||
def login() -> str:
|
||||
login_user = normalize_user_id(USER)
|
||||
def login(user: str, password: str) -> str:
|
||||
login_user = normalize_user_id(user)
|
||||
payload = {
|
||||
"type": "m.login.password",
|
||||
"identifier": {"type": "m.id.user", "user": login_user},
|
||||
"password": PASSWORD,
|
||||
"password": password,
|
||||
}
|
||||
res = req("POST", "/_matrix/client/v3/login", body=payload, base=AUTH_BASE)
|
||||
return res["access_token"]
|
||||
@ -3813,9 +3901,12 @@ def _open_ended_multi(
|
||||
|
||||
|
||||
def _open_ended_total_steps(mode: str) -> int:
|
||||
if mode == "fast":
|
||||
normalized = _normalize_mode(mode)
|
||||
if normalized == "fast":
|
||||
return 2
|
||||
return 9
|
||||
if normalized == "smart":
|
||||
return 3
|
||||
return 4
|
||||
|
||||
|
||||
def _fast_fact_lines(
|
||||
@ -4150,6 +4241,7 @@ def _open_ended_fast_single(
|
||||
use_history=False,
|
||||
system_override=_open_ended_system(),
|
||||
model=model,
|
||||
timeout=_mode_ollama_timeout_sec("fast"),
|
||||
)
|
||||
if not _has_body_lines(reply):
|
||||
reply = _ollama_call(
|
||||
@ -4159,6 +4251,7 @@ def _open_ended_fast_single(
|
||||
use_history=False,
|
||||
system_override=_open_ended_system(),
|
||||
model=model,
|
||||
timeout=_mode_ollama_timeout_sec("fast"),
|
||||
)
|
||||
fallback = _fallback_fact_answer(prompt, context)
|
||||
if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)):
|
||||
@ -4219,16 +4312,53 @@ def _open_ended_deep(
|
||||
fact_lines: list[str],
|
||||
fact_meta: dict[str, dict[str, Any]],
|
||||
history_lines: list[str],
|
||||
mode: str,
|
||||
state: ThoughtState | None = None,
|
||||
) -> str:
|
||||
return _open_ended_multi(
|
||||
prompt,
|
||||
fact_pack=fact_pack,
|
||||
fact_lines=fact_lines,
|
||||
fact_meta=fact_meta,
|
||||
history_lines=history_lines,
|
||||
state=state,
|
||||
normalized = _normalize_mode(mode)
|
||||
model = _model_for_mode(normalized)
|
||||
subjective = _is_subjective_query(prompt)
|
||||
primary_tags = _primary_tags_for_prompt(prompt)
|
||||
focus_tags = _preferred_tags_for_prompt(prompt)
|
||||
if not focus_tags and subjective:
|
||||
focus_tags = set(_ALLOWED_INSIGHT_TAGS)
|
||||
avoid_tags = _history_focus_tags(history_lines) if (subjective or _is_followup_query(prompt)) else set()
|
||||
limit = 12 if normalized == "smart" else 18
|
||||
selected_lines = _fast_fact_lines(
|
||||
fact_lines,
|
||||
fact_meta,
|
||||
focus_tags=focus_tags,
|
||||
avoid_tags=avoid_tags,
|
||||
primary_tags=primary_tags,
|
||||
limit=limit,
|
||||
)
|
||||
selected_meta = _fact_pack_meta(selected_lines)
|
||||
selected_pack = _fact_pack_text(selected_lines, selected_meta)
|
||||
if _needs_full_fact_pack(prompt) or not selected_lines or normalized == "genius":
|
||||
selected_pack = fact_pack
|
||||
fallback = _fallback_fact_answer(prompt, selected_pack)
|
||||
if not subjective and fallback:
|
||||
if state:
|
||||
state.update("done", step=_open_ended_total_steps(normalized))
|
||||
return _ensure_scores(fallback)
|
||||
if state:
|
||||
state.update("drafting", step=1, note="synthesizing")
|
||||
reply = _ollama_call(
|
||||
("atlasbot_deep", "atlasbot_deep"),
|
||||
prompt,
|
||||
context=_append_history_context(selected_pack, history_lines),
|
||||
use_history=False,
|
||||
system_override=_open_ended_system(),
|
||||
model=model,
|
||||
timeout=_mode_ollama_timeout_sec(normalized),
|
||||
)
|
||||
if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)):
|
||||
reply = fallback
|
||||
if not _has_body_lines(reply):
|
||||
reply = "I don't have enough data in the current snapshot to answer that."
|
||||
if state:
|
||||
state.update("done", step=_open_ended_total_steps(normalized))
|
||||
return _ensure_scores(reply)
|
||||
|
||||
|
||||
def open_ended_answer(
|
||||
@ -4256,7 +4386,8 @@ def open_ended_answer(
|
||||
return _ensure_scores("I don't have enough data to answer that.")
|
||||
fact_meta = _fact_pack_meta(lines)
|
||||
fact_pack = _fact_pack_text(lines, fact_meta)
|
||||
if mode == "fast":
|
||||
normalized = _normalize_mode(mode)
|
||||
if normalized == "fast":
|
||||
return _open_ended_fast(
|
||||
prompt,
|
||||
fact_pack=fact_pack,
|
||||
@ -4271,6 +4402,7 @@ def open_ended_answer(
|
||||
fact_lines=lines,
|
||||
fact_meta=fact_meta,
|
||||
history_lines=history_lines,
|
||||
mode=normalized,
|
||||
state=state,
|
||||
)
|
||||
|
||||
@ -4292,6 +4424,7 @@ def _non_cluster_reply(prompt: str, *, history_lines: list[str], mode: str) -> s
|
||||
use_history=False,
|
||||
system_override=system,
|
||||
model=model,
|
||||
timeout=_mode_ollama_timeout_sec(mode),
|
||||
)
|
||||
reply = re.sub(r"\bconfidence\s*:\s*(high|medium|low)\b\.?\s*", "", reply, flags=re.IGNORECASE).strip()
|
||||
return _ensure_scores(reply)
|
||||
@ -4343,13 +4476,7 @@ class _AtlasbotHandler(BaseHTTPRequestHandler):
|
||||
self._write_json(400, {"error": "missing_prompt"})
|
||||
return
|
||||
cleaned = _strip_bot_mention(prompt)
|
||||
mode = str(payload.get("mode") or "deep").lower()
|
||||
if mode in ("quick", "fast"):
|
||||
mode = "fast"
|
||||
elif mode in ("smart", "deep"):
|
||||
mode = "deep"
|
||||
else:
|
||||
mode = "deep"
|
||||
mode = _normalize_mode(str(payload.get("mode") or "smart"))
|
||||
snapshot = _snapshot_state()
|
||||
inventory = _snapshot_inventory(snapshot) or node_inventory_live()
|
||||
workloads = _snapshot_workloads(snapshot)
|
||||
@ -4640,6 +4767,7 @@ def _ollama_call(
|
||||
use_history: bool = True,
|
||||
system_override: str | None = None,
|
||||
model: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> str:
|
||||
system = system_override or (
|
||||
"System: You are Atlas, the Titan lab assistant for Atlas/Othrys. "
|
||||
@ -4673,6 +4801,7 @@ def _ollama_call(
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
model_name = model or MODEL
|
||||
request_timeout = timeout if timeout is not None else OLLAMA_TIMEOUT_SEC
|
||||
payload = {"model": model_name, "messages": messages, "stream": False}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if API_KEY:
|
||||
@ -4683,13 +4812,13 @@ def _ollama_call(
|
||||
lock.acquire()
|
||||
try:
|
||||
try:
|
||||
with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp:
|
||||
with request.urlopen(r, timeout=request_timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
except error.HTTPError as exc:
|
||||
if exc.code == 404 and FALLBACK_MODEL and FALLBACK_MODEL != payload["model"]:
|
||||
payload["model"] = FALLBACK_MODEL
|
||||
r = request.Request(endpoint, data=json.dumps(payload).encode(), headers=headers)
|
||||
with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp:
|
||||
with request.urlopen(r, timeout=request_timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
else:
|
||||
raise
|
||||
@ -4714,6 +4843,7 @@ def ollama_reply(
|
||||
fallback: str = "",
|
||||
use_history: bool = True,
|
||||
model: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> str:
|
||||
last_error = None
|
||||
for attempt in range(max(1, OLLAMA_RETRIES + 1)):
|
||||
@ -4724,6 +4854,7 @@ def ollama_reply(
|
||||
context=context,
|
||||
use_history=use_history,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_error = exc
|
||||
@ -4744,6 +4875,7 @@ def ollama_reply_with_thinking(
|
||||
fallback: str,
|
||||
use_history: bool = True,
|
||||
model: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> str:
|
||||
result: dict[str, str] = {"reply": ""}
|
||||
done = threading.Event()
|
||||
@ -4756,6 +4888,7 @@ def ollama_reply_with_thinking(
|
||||
fallback=fallback,
|
||||
use_history=use_history,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
)
|
||||
done.set()
|
||||
|
||||
@ -4812,7 +4945,7 @@ def open_ended_with_thinking(
|
||||
thread.start()
|
||||
if not done.wait(2.0):
|
||||
send_msg(token, room, "Thinking…")
|
||||
heartbeat = max(10, THINKING_INTERVAL_SEC)
|
||||
heartbeat = _mode_heartbeat_sec(mode)
|
||||
next_heartbeat = time.monotonic() + heartbeat
|
||||
while not done.wait(max(0, next_heartbeat - time.monotonic())):
|
||||
send_msg(token, room, state.status_line())
|
||||
@ -4820,7 +4953,7 @@ def open_ended_with_thinking(
|
||||
thread.join(timeout=1)
|
||||
return result["reply"] or "Model backend is busy. Try again in a moment."
|
||||
|
||||
def sync_loop(token: str, room_id: str):
|
||||
def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str):
|
||||
since = None
|
||||
try:
|
||||
res = req("GET", "/_matrix/client/v3/sync?timeout=0", token, timeout=10)
|
||||
@ -4861,7 +4994,7 @@ def sync_loop(token: str, room_id: str):
|
||||
if not body:
|
||||
continue
|
||||
sender = ev.get("sender", "")
|
||||
if sender == f"@{USER}:live.bstein.dev":
|
||||
if account_user and sender == normalize_user_id(account_user):
|
||||
continue
|
||||
|
||||
mentioned = is_mentioned(content, body)
|
||||
@ -4874,7 +5007,12 @@ def sync_loop(token: str, room_id: str):
|
||||
|
||||
cleaned_body = _strip_bot_mention(body)
|
||||
lower_body = cleaned_body.lower()
|
||||
mode = _detect_mode_from_body(body, default="deep" if is_dm else "deep")
|
||||
mode = _detect_mode(
|
||||
content,
|
||||
body,
|
||||
default=_normalize_mode(default_mode),
|
||||
account_user=account_user,
|
||||
)
|
||||
|
||||
# Only do live cluster introspection in DMs.
|
||||
allow_tools = is_dm
|
||||
@ -4938,39 +5076,80 @@ def sync_loop(token: str, room_id: str):
|
||||
snapshot=snapshot,
|
||||
workloads=workloads,
|
||||
history_lines=history[hist_key],
|
||||
mode=mode if mode in ("fast", "deep") else "deep",
|
||||
mode=_normalize_mode(mode),
|
||||
allow_tools=allow_tools,
|
||||
)
|
||||
else:
|
||||
reply = _non_cluster_reply(
|
||||
cleaned_body,
|
||||
history_lines=history[hist_key],
|
||||
mode=mode if mode in ("fast", "deep") else "deep",
|
||||
mode=_normalize_mode(mode),
|
||||
)
|
||||
send_msg(token, rid, reply)
|
||||
history[hist_key].append(f"Atlas: {reply}")
|
||||
history[hist_key] = history[hist_key][-80:]
|
||||
|
||||
def login_with_retry():
|
||||
def login_with_retry(user: str, password: str):
|
||||
last_err = None
|
||||
for attempt in range(10):
|
||||
try:
|
||||
return login()
|
||||
return login(user, password)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_err = exc
|
||||
time.sleep(min(30, 2 ** attempt))
|
||||
raise last_err
|
||||
|
||||
def _bot_accounts() -> list[dict[str, str]]:
|
||||
accounts: list[dict[str, str]] = []
|
||||
|
||||
def add(user: str, password: str, mode: str):
|
||||
if not user or not password:
|
||||
return
|
||||
accounts.append({"user": user, "password": password, "mode": mode})
|
||||
|
||||
add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "smart")
|
||||
if BOT_USER_QUICK and BOT_PASS_QUICK:
|
||||
add(BOT_USER_QUICK, BOT_PASS_QUICK, "fast")
|
||||
if BOT_USER_GENIUS and BOT_PASS_GENIUS:
|
||||
add(BOT_USER_GENIUS, BOT_PASS_GENIUS, "genius")
|
||||
if BOT_USER and BOT_PASS and all(acc["user"] != BOT_USER for acc in accounts):
|
||||
add(BOT_USER, BOT_PASS, "smart")
|
||||
|
||||
seen: set[str] = set()
|
||||
unique: list[dict[str, str]] = []
|
||||
for acc in accounts:
|
||||
uid = normalize_user_id(acc["user"]).lower()
|
||||
if uid in seen:
|
||||
continue
|
||||
seen.add(uid)
|
||||
unique.append(acc)
|
||||
return unique
|
||||
|
||||
def main():
|
||||
load_kb()
|
||||
_start_http_server()
|
||||
token = login_with_retry()
|
||||
try:
|
||||
room_id = resolve_alias(token, ROOM_ALIAS)
|
||||
join_room(token, room_id)
|
||||
except Exception:
|
||||
room_id = None
|
||||
sync_loop(token, room_id)
|
||||
accounts = _bot_accounts()
|
||||
threads: list[threading.Thread] = []
|
||||
for acc in accounts:
|
||||
token = login_with_retry(acc["user"], acc["password"])
|
||||
try:
|
||||
room_id = resolve_alias(token, ROOM_ALIAS)
|
||||
join_room(token, room_id)
|
||||
except Exception:
|
||||
room_id = None
|
||||
thread = threading.Thread(
|
||||
target=sync_loop,
|
||||
args=(token, room_id),
|
||||
kwargs={
|
||||
"account_user": acc["user"],
|
||||
"default_mode": acc["mode"],
|
||||
},
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
107
services/comms/scripts/tests/test_atlasbot_modes.py
Normal file
107
services/comms/scripts/tests/test_atlasbot_modes.py
Normal file
@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest import TestCase, mock
|
||||
|
||||
|
||||
BOT_PATH = Path(__file__).resolve().parents[1] / "atlasbot" / "bot.py"
|
||||
|
||||
|
||||
def load_bot_module():
|
||||
env = {
|
||||
"BOT_USER": "atlas-smart",
|
||||
"BOT_PASS": "smart-pass",
|
||||
"BOT_USER_QUICK": "atlas-quick",
|
||||
"BOT_PASS_QUICK": "quick-pass",
|
||||
"BOT_USER_SMART": "atlas-smart",
|
||||
"BOT_PASS_SMART": "smart-pass",
|
||||
"BOT_USER_GENIUS": "atlas-genius",
|
||||
"BOT_PASS_GENIUS": "genius-pass",
|
||||
"OLLAMA_URL": "http://ollama.invalid",
|
||||
"OLLAMA_MODEL": "base-model",
|
||||
"ATLASBOT_MODEL_FAST": "fast-model",
|
||||
"ATLASBOT_MODEL_SMART": "smart-model",
|
||||
"ATLASBOT_MODEL_GENIUS": "genius-model",
|
||||
"ATLASBOT_QUICK_TIME_BUDGET_SEC": "15",
|
||||
"ATLASBOT_SMART_TIME_BUDGET_SEC": "45",
|
||||
"ATLASBOT_GENIUS_TIME_BUDGET_SEC": "180",
|
||||
"KB_DIR": "",
|
||||
"VM_URL": "http://vm.invalid",
|
||||
"ARIADNE_STATE_URL": "",
|
||||
"ARIADNE_STATE_TOKEN": "",
|
||||
}
|
||||
with mock.patch.dict(os.environ, env, clear=False):
|
||||
spec = importlib.util.spec_from_file_location("atlasbot_bot", BOT_PATH)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
class AtlasbotModeTests(TestCase):
|
||||
def setUp(self):
|
||||
self.bot = load_bot_module()
|
||||
|
||||
def test_bot_accounts_include_genius_mode(self):
|
||||
accounts = self.bot._bot_accounts()
|
||||
by_user = {account["user"]: account["mode"] for account in accounts}
|
||||
|
||||
self.assertEqual(by_user["atlas-quick"], "fast")
|
||||
self.assertEqual(by_user["atlas-smart"], "smart")
|
||||
self.assertEqual(by_user["atlas-genius"], "genius")
|
||||
|
||||
def test_objective_cluster_question_uses_fact_pack_without_llm(self):
|
||||
fact_lines = [
|
||||
"hottest_cpu: longhorn-system (6.69)",
|
||||
"hottest_ram: longhorn-system (36.05 GB)",
|
||||
]
|
||||
|
||||
with (
|
||||
mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines),
|
||||
mock.patch.object(self.bot, "_ollama_call", side_effect=AssertionError("LLM should not be called")),
|
||||
):
|
||||
reply = self.bot.open_ended_answer(
|
||||
"what is the hottest cpu node in titan lab currently?",
|
||||
inventory=[],
|
||||
snapshot=None,
|
||||
workloads=[],
|
||||
history_lines=[],
|
||||
mode="smart",
|
||||
allow_tools=True,
|
||||
)
|
||||
|
||||
self.assertIn("longhorn-system", reply)
|
||||
self.assertIn("Confidence:", reply)
|
||||
|
||||
def test_subjective_genius_answer_uses_genius_model(self):
|
||||
fact_lines = [
|
||||
"hottest_cpu: longhorn-system (6.69)",
|
||||
"worker_nodes: titan-01, titan-02, titan-03",
|
||||
]
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_ollama_call(hist_key, prompt, *, context, use_history=True, system_override=None, model=None, timeout=None):
|
||||
captured["model"] = model
|
||||
captured["timeout"] = timeout
|
||||
captured["context"] = context
|
||||
return "The worker spread stands out because Titan keeps meaningful capacity on the same cluster. Confidence: high"
|
||||
|
||||
with (
|
||||
mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines),
|
||||
mock.patch.object(self.bot, "_ollama_call", side_effect=fake_ollama_call),
|
||||
):
|
||||
reply = self.bot.open_ended_answer(
|
||||
"what stands out about titan lab?",
|
||||
inventory=[],
|
||||
snapshot=None,
|
||||
workloads=[],
|
||||
history_lines=[],
|
||||
mode="genius",
|
||||
allow_tools=True,
|
||||
)
|
||||
|
||||
self.assertIn("The worker spread stands out", reply)
|
||||
self.assertEqual(captured["model"], "genius-model")
|
||||
self.assertLessEqual(float(captured["timeout"]), 180.0)
|
||||
Loading…
x
Reference in New Issue
Block a user