atlasbot: support quick/smart Matrix accounts

This commit is contained in:
Brad Stein 2026-01-30 10:20:50 -03:00
parent a5d4c63cd3
commit a851e184ca

View File

@ -11,8 +11,12 @@ from urllib import error, parse, request
BASE = os.environ.get("MATRIX_BASE", "http://othrys-synapse-matrix-synapse:8008") BASE = os.environ.get("MATRIX_BASE", "http://othrys-synapse-matrix-synapse:8008")
AUTH_BASE = os.environ.get("AUTH_BASE", "http://matrix-authentication-service:8080") AUTH_BASE = os.environ.get("AUTH_BASE", "http://matrix-authentication-service:8080")
USER = os.environ["BOT_USER"] BOT_USER = os.environ["BOT_USER"]
PASSWORD = os.environ["BOT_PASS"] BOT_PASS = os.environ["BOT_PASS"]
BOT_USER_QUICK = os.environ.get("BOT_USER_QUICK", "").strip()
BOT_PASS_QUICK = os.environ.get("BOT_PASS_QUICK", "").strip()
BOT_USER_SMART = os.environ.get("BOT_USER_SMART", "").strip()
BOT_PASS_SMART = os.environ.get("BOT_PASS_SMART", "").strip()
ROOM_ALIAS = "#othrys:live.bstein.dev" ROOM_ALIAS = "#othrys:live.bstein.dev"
OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/") OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/")
@ -31,7 +35,7 @@ VM_URL = os.environ.get("VM_URL", "http://victoria-metrics-single-server.monitor
ARIADNE_STATE_URL = os.environ.get("ARIADNE_STATE_URL", "") ARIADNE_STATE_URL = os.environ.get("ARIADNE_STATE_URL", "")
ARIADNE_STATE_TOKEN = os.environ.get("ARIADNE_STATE_TOKEN", "") ARIADNE_STATE_TOKEN = os.environ.get("ARIADNE_STATE_TOKEN", "")
BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{USER},atlas") BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{BOT_USER},atlas")
SERVER_NAME = os.environ.get("MATRIX_SERVER_NAME", "live.bstein.dev") SERVER_NAME = os.environ.get("MATRIX_SERVER_NAME", "live.bstein.dev")
MAX_KB_CHARS = int(os.environ.get("ATLASBOT_MAX_KB_CHARS", "2500")) MAX_KB_CHARS = int(os.environ.get("ATLASBOT_MAX_KB_CHARS", "2500"))
@ -393,6 +397,31 @@ def _detect_mode_from_body(body: str, *, default: str = "deep") -> str:
return default return default
def _detect_mode(
content: dict[str, Any],
body: str,
*,
default: str = "deep",
account_user: str = "",
) -> str:
mode = _detect_mode_from_body(body, default=default)
mentions = content.get("m.mentions", {})
user_ids = mentions.get("user_ids", [])
if isinstance(user_ids, list):
normalized = {normalize_user_id(uid).lower() for uid in user_ids if isinstance(uid, str)}
if BOT_USER_QUICK and normalize_user_id(BOT_USER_QUICK).lower() in normalized:
return "fast"
if BOT_USER_SMART and normalize_user_id(BOT_USER_SMART).lower() in normalized:
return "deep"
if BOT_USER and normalize_user_id(BOT_USER).lower() in normalized:
return "deep"
if account_user and BOT_USER_QUICK and normalize_user_id(account_user) == normalize_user_id(BOT_USER_QUICK):
return "fast"
if account_user and BOT_USER_SMART and normalize_user_id(account_user) == normalize_user_id(BOT_USER_SMART):
return "deep"
return mode
def _model_for_mode(mode: str) -> str: def _model_for_mode(mode: str) -> str:
if mode == "fast" and MODEL_FAST: if mode == "fast" and MODEL_FAST:
return MODEL_FAST return MODEL_FAST
@ -416,12 +445,12 @@ def req(method: str, path: str, token: str | None = None, body=None, timeout=60,
raw = resp.read() raw = resp.read()
return json.loads(raw.decode()) if raw else {} return json.loads(raw.decode()) if raw else {}
def login() -> str: def login(user: str, password: str) -> str:
login_user = normalize_user_id(USER) login_user = normalize_user_id(user)
payload = { payload = {
"type": "m.login.password", "type": "m.login.password",
"identifier": {"type": "m.id.user", "user": login_user}, "identifier": {"type": "m.id.user", "user": login_user},
"password": PASSWORD, "password": password,
} }
res = req("POST", "/_matrix/client/v3/login", body=payload, base=AUTH_BASE) res = req("POST", "/_matrix/client/v3/login", body=payload, base=AUTH_BASE)
return res["access_token"] return res["access_token"]
@ -4820,7 +4849,7 @@ def open_ended_with_thinking(
thread.join(timeout=1) thread.join(timeout=1)
return result["reply"] or "Model backend is busy. Try again in a moment." return result["reply"] or "Model backend is busy. Try again in a moment."
def sync_loop(token: str, room_id: str): def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str):
since = None since = None
try: try:
res = req("GET", "/_matrix/client/v3/sync?timeout=0", token, timeout=10) res = req("GET", "/_matrix/client/v3/sync?timeout=0", token, timeout=10)
@ -4861,7 +4890,7 @@ def sync_loop(token: str, room_id: str):
if not body: if not body:
continue continue
sender = ev.get("sender", "") sender = ev.get("sender", "")
if sender == f"@{USER}:live.bstein.dev": if account_user and sender == normalize_user_id(account_user):
continue continue
mentioned = is_mentioned(content, body) mentioned = is_mentioned(content, body)
@ -4874,7 +4903,12 @@ def sync_loop(token: str, room_id: str):
cleaned_body = _strip_bot_mention(body) cleaned_body = _strip_bot_mention(body)
lower_body = cleaned_body.lower() lower_body = cleaned_body.lower()
mode = _detect_mode_from_body(body, default="deep" if is_dm else "deep") mode = _detect_mode(
content,
body,
default=default_mode if default_mode in ("fast", "deep") else "deep",
account_user=account_user,
)
# Only do live cluster introspection in DMs. # Only do live cluster introspection in DMs.
allow_tools = is_dm allow_tools = is_dm
@ -4951,26 +4985,65 @@ def sync_loop(token: str, room_id: str):
history[hist_key].append(f"Atlas: {reply}") history[hist_key].append(f"Atlas: {reply}")
history[hist_key] = history[hist_key][-80:] history[hist_key] = history[hist_key][-80:]
def login_with_retry(): def login_with_retry(user: str, password: str):
last_err = None last_err = None
for attempt in range(10): for attempt in range(10):
try: try:
return login() return login(user, password)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
last_err = exc last_err = exc
time.sleep(min(30, 2 ** attempt)) time.sleep(min(30, 2 ** attempt))
raise last_err raise last_err
def _bot_accounts() -> list[dict[str, str]]:
accounts: list[dict[str, str]] = []
def add(user: str, password: str, mode: str):
if not user or not password:
return
accounts.append({"user": user, "password": password, "mode": mode})
add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "deep")
if BOT_USER_QUICK and BOT_PASS_QUICK:
add(BOT_USER_QUICK, BOT_PASS_QUICK, "fast")
if BOT_USER and BOT_PASS and all(acc["user"] != BOT_USER for acc in accounts):
add(BOT_USER, BOT_PASS, "deep")
seen: set[str] = set()
unique: list[dict[str, str]] = []
for acc in accounts:
uid = normalize_user_id(acc["user"]).lower()
if uid in seen:
continue
seen.add(uid)
unique.append(acc)
return unique
def main(): def main():
load_kb() load_kb()
_start_http_server() _start_http_server()
token = login_with_retry() accounts = _bot_accounts()
try: threads: list[threading.Thread] = []
room_id = resolve_alias(token, ROOM_ALIAS) for acc in accounts:
join_room(token, room_id) token = login_with_retry(acc["user"], acc["password"])
except Exception: try:
room_id = None room_id = resolve_alias(token, ROOM_ALIAS)
sync_loop(token, room_id) join_room(token, room_id)
except Exception:
room_id = None
thread = threading.Thread(
target=sync_loop,
args=(token, room_id),
kwargs={
"account_user": acc["user"],
"default_mode": acc["mode"],
},
daemon=True,
)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
if __name__ == "__main__": if __name__ == "__main__":
main() main()