atlasbot: support quick/smart Matrix accounts
This commit is contained in:
parent
a5d4c63cd3
commit
a851e184ca
@ -11,8 +11,12 @@ from urllib import error, parse, request
|
|||||||
|
|
||||||
BASE = os.environ.get("MATRIX_BASE", "http://othrys-synapse-matrix-synapse:8008")
|
BASE = os.environ.get("MATRIX_BASE", "http://othrys-synapse-matrix-synapse:8008")
|
||||||
AUTH_BASE = os.environ.get("AUTH_BASE", "http://matrix-authentication-service:8080")
|
AUTH_BASE = os.environ.get("AUTH_BASE", "http://matrix-authentication-service:8080")
|
||||||
USER = os.environ["BOT_USER"]
|
BOT_USER = os.environ["BOT_USER"]
|
||||||
PASSWORD = os.environ["BOT_PASS"]
|
BOT_PASS = os.environ["BOT_PASS"]
|
||||||
|
BOT_USER_QUICK = os.environ.get("BOT_USER_QUICK", "").strip()
|
||||||
|
BOT_PASS_QUICK = os.environ.get("BOT_PASS_QUICK", "").strip()
|
||||||
|
BOT_USER_SMART = os.environ.get("BOT_USER_SMART", "").strip()
|
||||||
|
BOT_PASS_SMART = os.environ.get("BOT_PASS_SMART", "").strip()
|
||||||
ROOM_ALIAS = "#othrys:live.bstein.dev"
|
ROOM_ALIAS = "#othrys:live.bstein.dev"
|
||||||
|
|
||||||
OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/")
|
OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/")
|
||||||
@ -31,7 +35,7 @@ VM_URL = os.environ.get("VM_URL", "http://victoria-metrics-single-server.monitor
|
|||||||
ARIADNE_STATE_URL = os.environ.get("ARIADNE_STATE_URL", "")
|
ARIADNE_STATE_URL = os.environ.get("ARIADNE_STATE_URL", "")
|
||||||
ARIADNE_STATE_TOKEN = os.environ.get("ARIADNE_STATE_TOKEN", "")
|
ARIADNE_STATE_TOKEN = os.environ.get("ARIADNE_STATE_TOKEN", "")
|
||||||
|
|
||||||
BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{USER},atlas")
|
BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{BOT_USER},atlas")
|
||||||
SERVER_NAME = os.environ.get("MATRIX_SERVER_NAME", "live.bstein.dev")
|
SERVER_NAME = os.environ.get("MATRIX_SERVER_NAME", "live.bstein.dev")
|
||||||
|
|
||||||
MAX_KB_CHARS = int(os.environ.get("ATLASBOT_MAX_KB_CHARS", "2500"))
|
MAX_KB_CHARS = int(os.environ.get("ATLASBOT_MAX_KB_CHARS", "2500"))
|
||||||
@ -393,6 +397,31 @@ def _detect_mode_from_body(body: str, *, default: str = "deep") -> str:
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_mode(
|
||||||
|
content: dict[str, Any],
|
||||||
|
body: str,
|
||||||
|
*,
|
||||||
|
default: str = "deep",
|
||||||
|
account_user: str = "",
|
||||||
|
) -> str:
|
||||||
|
mode = _detect_mode_from_body(body, default=default)
|
||||||
|
mentions = content.get("m.mentions", {})
|
||||||
|
user_ids = mentions.get("user_ids", [])
|
||||||
|
if isinstance(user_ids, list):
|
||||||
|
normalized = {normalize_user_id(uid).lower() for uid in user_ids if isinstance(uid, str)}
|
||||||
|
if BOT_USER_QUICK and normalize_user_id(BOT_USER_QUICK).lower() in normalized:
|
||||||
|
return "fast"
|
||||||
|
if BOT_USER_SMART and normalize_user_id(BOT_USER_SMART).lower() in normalized:
|
||||||
|
return "deep"
|
||||||
|
if BOT_USER and normalize_user_id(BOT_USER).lower() in normalized:
|
||||||
|
return "deep"
|
||||||
|
if account_user and BOT_USER_QUICK and normalize_user_id(account_user) == normalize_user_id(BOT_USER_QUICK):
|
||||||
|
return "fast"
|
||||||
|
if account_user and BOT_USER_SMART and normalize_user_id(account_user) == normalize_user_id(BOT_USER_SMART):
|
||||||
|
return "deep"
|
||||||
|
return mode
|
||||||
|
|
||||||
|
|
||||||
def _model_for_mode(mode: str) -> str:
|
def _model_for_mode(mode: str) -> str:
|
||||||
if mode == "fast" and MODEL_FAST:
|
if mode == "fast" and MODEL_FAST:
|
||||||
return MODEL_FAST
|
return MODEL_FAST
|
||||||
@ -416,12 +445,12 @@ def req(method: str, path: str, token: str | None = None, body=None, timeout=60,
|
|||||||
raw = resp.read()
|
raw = resp.read()
|
||||||
return json.loads(raw.decode()) if raw else {}
|
return json.loads(raw.decode()) if raw else {}
|
||||||
|
|
||||||
def login() -> str:
|
def login(user: str, password: str) -> str:
|
||||||
login_user = normalize_user_id(USER)
|
login_user = normalize_user_id(user)
|
||||||
payload = {
|
payload = {
|
||||||
"type": "m.login.password",
|
"type": "m.login.password",
|
||||||
"identifier": {"type": "m.id.user", "user": login_user},
|
"identifier": {"type": "m.id.user", "user": login_user},
|
||||||
"password": PASSWORD,
|
"password": password,
|
||||||
}
|
}
|
||||||
res = req("POST", "/_matrix/client/v3/login", body=payload, base=AUTH_BASE)
|
res = req("POST", "/_matrix/client/v3/login", body=payload, base=AUTH_BASE)
|
||||||
return res["access_token"]
|
return res["access_token"]
|
||||||
@ -4820,7 +4849,7 @@ def open_ended_with_thinking(
|
|||||||
thread.join(timeout=1)
|
thread.join(timeout=1)
|
||||||
return result["reply"] or "Model backend is busy. Try again in a moment."
|
return result["reply"] or "Model backend is busy. Try again in a moment."
|
||||||
|
|
||||||
def sync_loop(token: str, room_id: str):
|
def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str):
|
||||||
since = None
|
since = None
|
||||||
try:
|
try:
|
||||||
res = req("GET", "/_matrix/client/v3/sync?timeout=0", token, timeout=10)
|
res = req("GET", "/_matrix/client/v3/sync?timeout=0", token, timeout=10)
|
||||||
@ -4861,7 +4890,7 @@ def sync_loop(token: str, room_id: str):
|
|||||||
if not body:
|
if not body:
|
||||||
continue
|
continue
|
||||||
sender = ev.get("sender", "")
|
sender = ev.get("sender", "")
|
||||||
if sender == f"@{USER}:live.bstein.dev":
|
if account_user and sender == normalize_user_id(account_user):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
mentioned = is_mentioned(content, body)
|
mentioned = is_mentioned(content, body)
|
||||||
@ -4874,7 +4903,12 @@ def sync_loop(token: str, room_id: str):
|
|||||||
|
|
||||||
cleaned_body = _strip_bot_mention(body)
|
cleaned_body = _strip_bot_mention(body)
|
||||||
lower_body = cleaned_body.lower()
|
lower_body = cleaned_body.lower()
|
||||||
mode = _detect_mode_from_body(body, default="deep" if is_dm else "deep")
|
mode = _detect_mode(
|
||||||
|
content,
|
||||||
|
body,
|
||||||
|
default=default_mode if default_mode in ("fast", "deep") else "deep",
|
||||||
|
account_user=account_user,
|
||||||
|
)
|
||||||
|
|
||||||
# Only do live cluster introspection in DMs.
|
# Only do live cluster introspection in DMs.
|
||||||
allow_tools = is_dm
|
allow_tools = is_dm
|
||||||
@ -4951,26 +4985,65 @@ def sync_loop(token: str, room_id: str):
|
|||||||
history[hist_key].append(f"Atlas: {reply}")
|
history[hist_key].append(f"Atlas: {reply}")
|
||||||
history[hist_key] = history[hist_key][-80:]
|
history[hist_key] = history[hist_key][-80:]
|
||||||
|
|
||||||
def login_with_retry():
|
def login_with_retry(user: str, password: str):
|
||||||
last_err = None
|
last_err = None
|
||||||
for attempt in range(10):
|
for attempt in range(10):
|
||||||
try:
|
try:
|
||||||
return login()
|
return login(user, password)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
last_err = exc
|
last_err = exc
|
||||||
time.sleep(min(30, 2 ** attempt))
|
time.sleep(min(30, 2 ** attempt))
|
||||||
raise last_err
|
raise last_err
|
||||||
|
|
||||||
|
def _bot_accounts() -> list[dict[str, str]]:
|
||||||
|
accounts: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
def add(user: str, password: str, mode: str):
|
||||||
|
if not user or not password:
|
||||||
|
return
|
||||||
|
accounts.append({"user": user, "password": password, "mode": mode})
|
||||||
|
|
||||||
|
add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "deep")
|
||||||
|
if BOT_USER_QUICK and BOT_PASS_QUICK:
|
||||||
|
add(BOT_USER_QUICK, BOT_PASS_QUICK, "fast")
|
||||||
|
if BOT_USER and BOT_PASS and all(acc["user"] != BOT_USER for acc in accounts):
|
||||||
|
add(BOT_USER, BOT_PASS, "deep")
|
||||||
|
|
||||||
|
seen: set[str] = set()
|
||||||
|
unique: list[dict[str, str]] = []
|
||||||
|
for acc in accounts:
|
||||||
|
uid = normalize_user_id(acc["user"]).lower()
|
||||||
|
if uid in seen:
|
||||||
|
continue
|
||||||
|
seen.add(uid)
|
||||||
|
unique.append(acc)
|
||||||
|
return unique
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
load_kb()
|
load_kb()
|
||||||
_start_http_server()
|
_start_http_server()
|
||||||
token = login_with_retry()
|
accounts = _bot_accounts()
|
||||||
try:
|
threads: list[threading.Thread] = []
|
||||||
room_id = resolve_alias(token, ROOM_ALIAS)
|
for acc in accounts:
|
||||||
join_room(token, room_id)
|
token = login_with_retry(acc["user"], acc["password"])
|
||||||
except Exception:
|
try:
|
||||||
room_id = None
|
room_id = resolve_alias(token, ROOM_ALIAS)
|
||||||
sync_loop(token, room_id)
|
join_room(token, room_id)
|
||||||
|
except Exception:
|
||||||
|
room_id = None
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=sync_loop,
|
||||||
|
args=(token, room_id),
|
||||||
|
kwargs={
|
||||||
|
"account_user": acc["user"],
|
||||||
|
"default_mode": acc["mode"],
|
||||||
|
},
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
threads.append(thread)
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user