276 lines
11 KiB
Python
276 lines
11 KiB
Python
import asyncio
|
|
import logging
|
|
import time
|
|
from typing import Any
|
|
from urllib.parse import quote
|
|
|
|
import httpx
|
|
|
|
from atlasbot.config import MatrixBotConfig, Settings
|
|
from collections.abc import Awaitable, Callable
|
|
|
|
from atlasbot.engine.answerer import AnswerEngine, AnswerResult
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class MatrixClient:
|
|
"""Wrap the Matrix client endpoints used by the bot runtime."""
|
|
|
|
def __init__(self, settings: Settings, bot: MatrixBotConfig) -> None:
|
|
self._settings = settings
|
|
self._bot = bot
|
|
|
|
async def login(self) -> str:
|
|
"""Exchange bot credentials for a Matrix access token."""
|
|
|
|
payload = {
|
|
"type": "m.login.password",
|
|
"identifier": {"type": "m.id.user", "user": self._bot.username},
|
|
"password": self._bot.password,
|
|
}
|
|
url = f"{self._settings.auth_base}/_matrix/client/v3/login"
|
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
resp = await client.post(url, json=payload)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
return data.get("access_token", "")
|
|
|
|
async def resolve_room(self, token: str) -> str:
|
|
"""Resolve the configured room alias into a room id."""
|
|
|
|
alias = quote(self._settings.room_alias, safe="")
|
|
url = f"{self._settings.matrix_base}/_matrix/client/v3/directory/room/{alias}"
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
try:
|
|
resp = await client.get(url, headers=headers)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
except httpx.HTTPError as exc:
|
|
log.warning(
|
|
"matrix resolve_room failed",
|
|
extra={"extra": {"error": str(exc), "alias": self._settings.room_alias}},
|
|
)
|
|
return ""
|
|
return data.get("room_id", "")
|
|
|
|
async def join_room(self, token: str, room_id: str) -> None:
|
|
"""Join the target room if the bot is not already present."""
|
|
|
|
url = f"{self._settings.matrix_base}/_matrix/client/v3/rooms/{room_id}/join"
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
await client.post(url, headers=headers)
|
|
|
|
async def send_message(self, token: str, room_id: str, text: str) -> None:
|
|
"""Send a plain text message to the Matrix room."""
|
|
|
|
url = f"{self._settings.matrix_base}/_matrix/client/v3/rooms/{room_id}/send/m.room.message"
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
payload = {"msgtype": "m.text", "body": text}
|
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
await client.post(url, json=payload, headers=headers)
|
|
|
|
async def sync(self, token: str, since: str | None) -> dict[str, Any]:
|
|
"""Fetch the incremental Matrix sync payload."""
|
|
|
|
base = f"{self._settings.matrix_base}/_matrix/client/v3/sync"
|
|
params = {"timeout": 30000}
|
|
if since:
|
|
params["since"] = since
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
async with httpx.AsyncClient(timeout=40.0) as client:
|
|
resp = await client.get(base, headers=headers, params=params)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
class MatrixBot:
|
|
"""Drive Matrix conversation handling and heartbeat replies."""
|
|
|
|
def __init__(
|
|
self,
|
|
settings: Settings,
|
|
bot: MatrixBotConfig,
|
|
engine: AnswerEngine,
|
|
answer_handler: Callable[
|
|
[str, str, list[dict[str, str]] | None, str | None, Callable[[str, str], None] | None],
|
|
Awaitable[AnswerResult],
|
|
]
|
|
| None = None,
|
|
) -> None:
|
|
self._settings = settings
|
|
self._bot = bot
|
|
self._engine = engine
|
|
self._client = MatrixClient(settings, bot)
|
|
self._answer_handler = answer_handler
|
|
self._history: dict[str, list[dict[str, str]]] = {}
|
|
|
|
async def run(self) -> None:
|
|
"""Continuously bootstrap, sync, and answer Matrix events."""
|
|
|
|
while True:
|
|
try:
|
|
token = await self._client.login()
|
|
room_id = await self._client.resolve_room(token)
|
|
if room_id:
|
|
await self._client.join_room(token, room_id)
|
|
await self._sync_loop(token)
|
|
except Exception as exc:
|
|
log.warning("matrix bootstrap failed", extra={"extra": {"error": str(exc)}})
|
|
await asyncio.sleep(10)
|
|
|
|
async def _sync_loop(self, token: str) -> None:
|
|
since = None
|
|
while True:
|
|
try:
|
|
payload = await self._client.sync(token, since)
|
|
since = payload.get("next_batch")
|
|
await self._handle_sync(token, payload)
|
|
except Exception as exc:
|
|
log.warning("matrix sync failed", extra={"extra": {"error": str(exc)}})
|
|
await asyncio.sleep(5)
|
|
|
|
async def _handle_sync(self, token: str, payload: dict[str, Any]) -> None:
|
|
rooms = payload.get("rooms") or {}
|
|
joins = rooms.get("join") or {}
|
|
for room_id, room_data in joins.items():
|
|
events = (room_data.get("timeline") or {}).get("events") or []
|
|
for event in events:
|
|
if not isinstance(event, dict):
|
|
continue
|
|
if event.get("type") != "m.room.message":
|
|
continue
|
|
content = event.get("content") or {}
|
|
body = content.get("body") or ""
|
|
sender = event.get("sender") or ""
|
|
if sender.endswith(f"/{self._bot.username}") or sender == self._bot.username:
|
|
continue
|
|
mode, question = _extract_mode(body, self._bot.mentions, self._bot.mode)
|
|
if not question:
|
|
continue
|
|
await self._client.send_message(token, room_id, "Thinking…")
|
|
await self._answer_with_heartbeat(token, room_id, question, mode)
|
|
|
|
async def _answer_with_heartbeat(self, token: str, room_id: str, question: str, mode: str) -> None:
|
|
latest = {"stage": "", "note": ""}
|
|
stop = asyncio.Event()
|
|
|
|
def observer(stage: str, note: str) -> None:
|
|
latest["stage"] = stage
|
|
latest["note"] = note
|
|
|
|
async def heartbeat() -> None:
|
|
last_note = ""
|
|
last_sent = 0.0
|
|
while not stop.is_set():
|
|
await asyncio.sleep(self._settings.thinking_interval_sec)
|
|
if stop.is_set():
|
|
break
|
|
note = (latest.get("note") or latest.get("stage") or "thinking").strip()
|
|
if not note:
|
|
note = "thinking"
|
|
snippet = note[:64]
|
|
now = time.monotonic()
|
|
should_send = False
|
|
if snippet and snippet != last_note:
|
|
should_send = True
|
|
elif now - last_sent >= max(60.0, self._settings.thinking_interval_sec * 2):
|
|
should_send = True
|
|
if should_send:
|
|
msg = f"Still thinking — {snippet}…"
|
|
await self._client.send_message(token, room_id, msg)
|
|
last_note = snippet
|
|
last_sent = now
|
|
|
|
task = asyncio.create_task(heartbeat())
|
|
started = time.monotonic()
|
|
try:
|
|
handler = self._answer_handler or (
|
|
lambda q, m, h, cid, obs: self._engine.answer(q, mode=m, history=h, observer=obs, conversation_id=cid)
|
|
)
|
|
history = self._history.get(room_id, [])
|
|
timeout_sec = _mode_timeout_sec(self._settings, mode)
|
|
if timeout_sec > 0:
|
|
result = await asyncio.wait_for(
|
|
handler(question, mode, history, room_id, observer),
|
|
timeout=timeout_sec + 1.0,
|
|
)
|
|
else:
|
|
result = await handler(question, mode, history, room_id, observer)
|
|
elapsed = time.monotonic() - started
|
|
await self._client.send_message(token, room_id, result.reply)
|
|
log.info(
|
|
"matrix_answer",
|
|
extra={
|
|
"extra": {
|
|
"mode": mode,
|
|
"seconds": round(elapsed, 2),
|
|
"scores": result.scores.__dict__,
|
|
}
|
|
},
|
|
)
|
|
history.append({"q": question, "a": result.reply})
|
|
self._history[room_id] = history[-4:]
|
|
except asyncio.TimeoutError:
|
|
timeout_sec = max(1, int(round(_mode_timeout_sec(self._settings, mode))))
|
|
if mode in {"quick", "fast"}:
|
|
msg = (
|
|
f"Quick mode hit {timeout_sec}s time budget before finishing. "
|
|
"Try atlas-smart for a deeper answer."
|
|
)
|
|
elif mode == "smart":
|
|
msg = (
|
|
f"Smart mode hit {timeout_sec}s time budget before finishing. "
|
|
"Try atlas-genius or ask a narrower follow-up."
|
|
)
|
|
else:
|
|
msg = "I ran out of time before I could finish this answer."
|
|
await self._client.send_message(token, room_id, msg)
|
|
log.warning(
|
|
"matrix_answer_timeout",
|
|
extra={"extra": {"mode": mode, "seconds": timeout_sec}},
|
|
)
|
|
except Exception as exc:
|
|
log.warning(
|
|
"matrix_answer_failed",
|
|
extra={"extra": {"mode": mode, "error": str(exc)}},
|
|
)
|
|
await self._client.send_message(
|
|
token,
|
|
room_id,
|
|
"I hit an internal error while answering. Please retry, or switch to atlas-smart.",
|
|
)
|
|
finally:
|
|
stop.set()
|
|
task.cancel()
|
|
|
|
|
|
def _extract_mode(body: str, mentions: tuple[str, ...], default_mode: str) -> tuple[str, str]:
|
|
lower = body.lower()
|
|
for mention in mentions:
|
|
if mention and mention.lower() in lower:
|
|
mode = default_mode or "quick"
|
|
if not default_mode:
|
|
if "atlas-smart" in lower or "smart" in lower:
|
|
mode = "smart"
|
|
if "atlas-genius" in lower or "genius" in lower:
|
|
mode = "genius"
|
|
if "atlas-quick" in lower or "quick" in lower:
|
|
mode = "quick"
|
|
cleaned = body
|
|
for tag in mentions:
|
|
cleaned = cleaned.replace(tag, "")
|
|
cleaned = cleaned.replace(tag.capitalize(), "")
|
|
return mode, cleaned.strip()
|
|
return ("", "")
|
|
|
|
|
|
def _mode_timeout_sec(settings: Settings, mode: str) -> float:
|
|
if mode == "genius":
|
|
return settings.genius_time_budget_sec
|
|
if mode == "smart":
|
|
return settings.smart_time_budget_sec
|
|
return settings.quick_time_budget_sec
|