atlasbot: support quick/smart Matrix accounts
This commit is contained in:
parent
a5d4c63cd3
commit
a851e184ca
@ -11,8 +11,12 @@ from urllib import error, parse, request
|
||||
|
||||
BASE = os.environ.get("MATRIX_BASE", "http://othrys-synapse-matrix-synapse:8008")
|
||||
AUTH_BASE = os.environ.get("AUTH_BASE", "http://matrix-authentication-service:8080")
|
||||
USER = os.environ["BOT_USER"]
|
||||
PASSWORD = os.environ["BOT_PASS"]
|
||||
BOT_USER = os.environ["BOT_USER"]
|
||||
BOT_PASS = os.environ["BOT_PASS"]
|
||||
BOT_USER_QUICK = os.environ.get("BOT_USER_QUICK", "").strip()
|
||||
BOT_PASS_QUICK = os.environ.get("BOT_PASS_QUICK", "").strip()
|
||||
BOT_USER_SMART = os.environ.get("BOT_USER_SMART", "").strip()
|
||||
BOT_PASS_SMART = os.environ.get("BOT_PASS_SMART", "").strip()
|
||||
ROOM_ALIAS = "#othrys:live.bstein.dev"
|
||||
|
||||
OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/")
|
||||
@ -31,7 +35,7 @@ VM_URL = os.environ.get("VM_URL", "http://victoria-metrics-single-server.monitor
|
||||
ARIADNE_STATE_URL = os.environ.get("ARIADNE_STATE_URL", "")
|
||||
ARIADNE_STATE_TOKEN = os.environ.get("ARIADNE_STATE_TOKEN", "")
|
||||
|
||||
BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{USER},atlas")
|
||||
BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{BOT_USER},atlas")
|
||||
SERVER_NAME = os.environ.get("MATRIX_SERVER_NAME", "live.bstein.dev")
|
||||
|
||||
MAX_KB_CHARS = int(os.environ.get("ATLASBOT_MAX_KB_CHARS", "2500"))
|
||||
@ -393,6 +397,31 @@ def _detect_mode_from_body(body: str, *, default: str = "deep") -> str:
|
||||
return default
|
||||
|
||||
|
||||
def _detect_mode(
|
||||
content: dict[str, Any],
|
||||
body: str,
|
||||
*,
|
||||
default: str = "deep",
|
||||
account_user: str = "",
|
||||
) -> str:
|
||||
mode = _detect_mode_from_body(body, default=default)
|
||||
mentions = content.get("m.mentions", {})
|
||||
user_ids = mentions.get("user_ids", [])
|
||||
if isinstance(user_ids, list):
|
||||
normalized = {normalize_user_id(uid).lower() for uid in user_ids if isinstance(uid, str)}
|
||||
if BOT_USER_QUICK and normalize_user_id(BOT_USER_QUICK).lower() in normalized:
|
||||
return "fast"
|
||||
if BOT_USER_SMART and normalize_user_id(BOT_USER_SMART).lower() in normalized:
|
||||
return "deep"
|
||||
if BOT_USER and normalize_user_id(BOT_USER).lower() in normalized:
|
||||
return "deep"
|
||||
if account_user and BOT_USER_QUICK and normalize_user_id(account_user) == normalize_user_id(BOT_USER_QUICK):
|
||||
return "fast"
|
||||
if account_user and BOT_USER_SMART and normalize_user_id(account_user) == normalize_user_id(BOT_USER_SMART):
|
||||
return "deep"
|
||||
return mode
|
||||
|
||||
|
||||
def _model_for_mode(mode: str) -> str:
|
||||
if mode == "fast" and MODEL_FAST:
|
||||
return MODEL_FAST
|
||||
@ -416,12 +445,12 @@ def req(method: str, path: str, token: str | None = None, body=None, timeout=60,
|
||||
raw = resp.read()
|
||||
return json.loads(raw.decode()) if raw else {}
|
||||
|
||||
def login() -> str:
|
||||
login_user = normalize_user_id(USER)
|
||||
def login(user: str, password: str) -> str:
|
||||
login_user = normalize_user_id(user)
|
||||
payload = {
|
||||
"type": "m.login.password",
|
||||
"identifier": {"type": "m.id.user", "user": login_user},
|
||||
"password": PASSWORD,
|
||||
"password": password,
|
||||
}
|
||||
res = req("POST", "/_matrix/client/v3/login", body=payload, base=AUTH_BASE)
|
||||
return res["access_token"]
|
||||
@ -4820,7 +4849,7 @@ def open_ended_with_thinking(
|
||||
thread.join(timeout=1)
|
||||
return result["reply"] or "Model backend is busy. Try again in a moment."
|
||||
|
||||
def sync_loop(token: str, room_id: str):
|
||||
def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str):
|
||||
since = None
|
||||
try:
|
||||
res = req("GET", "/_matrix/client/v3/sync?timeout=0", token, timeout=10)
|
||||
@ -4861,7 +4890,7 @@ def sync_loop(token: str, room_id: str):
|
||||
if not body:
|
||||
continue
|
||||
sender = ev.get("sender", "")
|
||||
if sender == f"@{USER}:live.bstein.dev":
|
||||
if account_user and sender == normalize_user_id(account_user):
|
||||
continue
|
||||
|
||||
mentioned = is_mentioned(content, body)
|
||||
@ -4874,7 +4903,12 @@ def sync_loop(token: str, room_id: str):
|
||||
|
||||
cleaned_body = _strip_bot_mention(body)
|
||||
lower_body = cleaned_body.lower()
|
||||
mode = _detect_mode_from_body(body, default="deep" if is_dm else "deep")
|
||||
mode = _detect_mode(
|
||||
content,
|
||||
body,
|
||||
default=default_mode if default_mode in ("fast", "deep") else "deep",
|
||||
account_user=account_user,
|
||||
)
|
||||
|
||||
# Only do live cluster introspection in DMs.
|
||||
allow_tools = is_dm
|
||||
@ -4951,26 +4985,65 @@ def sync_loop(token: str, room_id: str):
|
||||
history[hist_key].append(f"Atlas: {reply}")
|
||||
history[hist_key] = history[hist_key][-80:]
|
||||
|
||||
def login_with_retry():
|
||||
def login_with_retry(user: str, password: str):
|
||||
last_err = None
|
||||
for attempt in range(10):
|
||||
try:
|
||||
return login()
|
||||
return login(user, password)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_err = exc
|
||||
time.sleep(min(30, 2 ** attempt))
|
||||
raise last_err
|
||||
|
||||
def _bot_accounts() -> list[dict[str, str]]:
|
||||
accounts: list[dict[str, str]] = []
|
||||
|
||||
def add(user: str, password: str, mode: str):
|
||||
if not user or not password:
|
||||
return
|
||||
accounts.append({"user": user, "password": password, "mode": mode})
|
||||
|
||||
add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "deep")
|
||||
if BOT_USER_QUICK and BOT_PASS_QUICK:
|
||||
add(BOT_USER_QUICK, BOT_PASS_QUICK, "fast")
|
||||
if BOT_USER and BOT_PASS and all(acc["user"] != BOT_USER for acc in accounts):
|
||||
add(BOT_USER, BOT_PASS, "deep")
|
||||
|
||||
seen: set[str] = set()
|
||||
unique: list[dict[str, str]] = []
|
||||
for acc in accounts:
|
||||
uid = normalize_user_id(acc["user"]).lower()
|
||||
if uid in seen:
|
||||
continue
|
||||
seen.add(uid)
|
||||
unique.append(acc)
|
||||
return unique
|
||||
|
||||
def main():
|
||||
load_kb()
|
||||
_start_http_server()
|
||||
token = login_with_retry()
|
||||
try:
|
||||
room_id = resolve_alias(token, ROOM_ALIAS)
|
||||
join_room(token, room_id)
|
||||
except Exception:
|
||||
room_id = None
|
||||
sync_loop(token, room_id)
|
||||
accounts = _bot_accounts()
|
||||
threads: list[threading.Thread] = []
|
||||
for acc in accounts:
|
||||
token = login_with_retry(acc["user"], acc["password"])
|
||||
try:
|
||||
room_id = resolve_alias(token, ROOM_ALIAS)
|
||||
join_room(token, room_id)
|
||||
except Exception:
|
||||
room_id = None
|
||||
thread = threading.Thread(
|
||||
target=sync_loop,
|
||||
args=(token, room_id),
|
||||
kwargs={
|
||||
"account_user": acc["user"],
|
||||
"default_mode": acc["mode"],
|
||||
},
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user