diff --git a/services/comms/scripts/atlasbot/bot.py b/services/comms/scripts/atlasbot/bot.py index be256c0..5948634 100644 --- a/services/comms/scripts/atlasbot/bot.py +++ b/services/comms/scripts/atlasbot/bot.py @@ -11,8 +11,12 @@ from urllib import error, parse, request BASE = os.environ.get("MATRIX_BASE", "http://othrys-synapse-matrix-synapse:8008") AUTH_BASE = os.environ.get("AUTH_BASE", "http://matrix-authentication-service:8080") -USER = os.environ["BOT_USER"] -PASSWORD = os.environ["BOT_PASS"] +BOT_USER = os.environ["BOT_USER"] +BOT_PASS = os.environ["BOT_PASS"] +BOT_USER_QUICK = os.environ.get("BOT_USER_QUICK", "").strip() +BOT_PASS_QUICK = os.environ.get("BOT_PASS_QUICK", "").strip() +BOT_USER_SMART = os.environ.get("BOT_USER_SMART", "").strip() +BOT_PASS_SMART = os.environ.get("BOT_PASS_SMART", "").strip() ROOM_ALIAS = "#othrys:live.bstein.dev" OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/") @@ -31,7 +35,7 @@ VM_URL = os.environ.get("VM_URL", "http://victoria-metrics-single-server.monitor ARIADNE_STATE_URL = os.environ.get("ARIADNE_STATE_URL", "") ARIADNE_STATE_TOKEN = os.environ.get("ARIADNE_STATE_TOKEN", "") -BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{USER},atlas") +BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{BOT_USER},atlas") SERVER_NAME = os.environ.get("MATRIX_SERVER_NAME", "live.bstein.dev") MAX_KB_CHARS = int(os.environ.get("ATLASBOT_MAX_KB_CHARS", "2500")) @@ -393,6 +397,31 @@ def _detect_mode_from_body(body: str, *, default: str = "deep") -> str: return default +def _detect_mode( + content: dict[str, Any], + body: str, + *, + default: str = "deep", + account_user: str = "", +) -> str: + mode = _detect_mode_from_body(body, default=default) + mentions = content.get("m.mentions", {}) + user_ids = mentions.get("user_ids", []) + if isinstance(user_ids, list): + normalized = {normalize_user_id(uid).lower() for uid in user_ids if isinstance(uid, str)} + if BOT_USER_QUICK and normalize_user_id(BOT_USER_QUICK).lower() in normalized: + return "fast" + if BOT_USER_SMART and normalize_user_id(BOT_USER_SMART).lower() in normalized: + return "deep" + if BOT_USER and normalize_user_id(BOT_USER).lower() in normalized: + return "deep" + if account_user and BOT_USER_QUICK and normalize_user_id(account_user) == normalize_user_id(BOT_USER_QUICK): + return "fast" + if account_user and BOT_USER_SMART and normalize_user_id(account_user) == normalize_user_id(BOT_USER_SMART): + return "deep" + return mode + + def _model_for_mode(mode: str) -> str: if mode == "fast" and MODEL_FAST: return MODEL_FAST @@ -416,12 +445,12 @@ def req(method: str, path: str, token: str | None = None, body=None, timeout=60, raw = resp.read() return json.loads(raw.decode()) if raw else {} -def login() -> str: - login_user = normalize_user_id(USER) +def login(user: str, password: str) -> str: + login_user = normalize_user_id(user) payload = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": login_user}, - "password": PASSWORD, + "password": password, } res = req("POST", "/_matrix/client/v3/login", body=payload, base=AUTH_BASE) return res["access_token"] @@ -4820,7 +4849,7 @@ def open_ended_with_thinking( thread.join(timeout=1) return result["reply"] or "Model backend is busy. Try again in a moment." -def sync_loop(token: str, room_id: str): +def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str): since = None try: res = req("GET", "/_matrix/client/v3/sync?timeout=0", token, timeout=10) @@ -4861,7 +4890,7 @@ def sync_loop(token: str, room_id: str): if not body: continue sender = ev.get("sender", "") - if sender == f"@{USER}:live.bstein.dev": + if account_user and sender == normalize_user_id(account_user): continue mentioned = is_mentioned(content, body) @@ -4874,7 +4903,12 @@ def sync_loop(token: str, room_id: str): cleaned_body = _strip_bot_mention(body) lower_body = cleaned_body.lower() - mode = _detect_mode_from_body(body, default="deep" if is_dm else "deep") + mode = _detect_mode( + content, + body, + default=default_mode if default_mode in ("fast", "deep") else "deep", + account_user=account_user, + ) # Only do live cluster introspection in DMs. allow_tools = is_dm @@ -4951,26 +4985,65 @@ def sync_loop(token: str, room_id: str): history[hist_key].append(f"Atlas: {reply}") history[hist_key] = history[hist_key][-80:] -def login_with_retry(): +def login_with_retry(user: str, password: str): last_err = None for attempt in range(10): try: - return login() + return login(user, password) except Exception as exc: # noqa: BLE001 last_err = exc time.sleep(min(30, 2 ** attempt)) raise last_err +def _bot_accounts() -> list[dict[str, str]]: + accounts: list[dict[str, str]] = [] + + def add(user: str, password: str, mode: str): + if not user or not password: + return + accounts.append({"user": user, "password": password, "mode": mode}) + + add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "deep") + if BOT_USER_QUICK and BOT_PASS_QUICK: + add(BOT_USER_QUICK, BOT_PASS_QUICK, "fast") + if BOT_USER and BOT_PASS and all(acc["user"] != BOT_USER for acc in accounts): + add(BOT_USER, BOT_PASS, "deep") + + seen: set[str] = set() + unique: list[dict[str, str]] = [] + for acc in accounts: + uid = normalize_user_id(acc["user"]).lower() + if uid in seen: + continue + seen.add(uid) + unique.append(acc) + return unique + def main(): load_kb() _start_http_server() - token = login_with_retry() - try: - room_id = resolve_alias(token, ROOM_ALIAS) - join_room(token, room_id) - except Exception: - room_id = None - sync_loop(token, room_id) + accounts = _bot_accounts() + threads: list[threading.Thread] = [] + for acc in accounts: + token = login_with_retry(acc["user"], acc["password"]) + try: + room_id = resolve_alias(token, ROOM_ALIAS) + join_room(token, room_id) + except Exception: + room_id = None + thread = threading.Thread( + target=sync_loop, + args=(token, room_id), + kwargs={ + "account_user": acc["user"], + "default_mode": acc["mode"], + }, + daemon=True, + ) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() if __name__ == "__main__": main()