276 lines
11 KiB
Python

import asyncio
import logging
import time
from typing import Any
from urllib.parse import quote
import httpx
from atlasbot.config import MatrixBotConfig, Settings
from collections.abc import Awaitable, Callable
from atlasbot.engine.answerer import AnswerEngine, AnswerResult
log = logging.getLogger(__name__)
class MatrixClient:
"""Wrap the Matrix client endpoints used by the bot runtime."""
def __init__(self, settings: Settings, bot: MatrixBotConfig) -> None:
self._settings = settings
self._bot = bot
async def login(self) -> str:
"""Exchange bot credentials for a Matrix access token."""
payload = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": self._bot.username},
"password": self._bot.password,
}
url = f"{self._settings.auth_base}/_matrix/client/v3/login"
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.post(url, json=payload)
resp.raise_for_status()
data = resp.json()
return data.get("access_token", "")
async def resolve_room(self, token: str) -> str:
"""Resolve the configured room alias into a room id."""
alias = quote(self._settings.room_alias, safe="")
url = f"{self._settings.matrix_base}/_matrix/client/v3/directory/room/{alias}"
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=15.0) as client:
try:
resp = await client.get(url, headers=headers)
resp.raise_for_status()
data = resp.json()
except httpx.HTTPError as exc:
log.warning(
"matrix resolve_room failed",
extra={"extra": {"error": str(exc), "alias": self._settings.room_alias}},
)
return ""
return data.get("room_id", "")
async def join_room(self, token: str, room_id: str) -> None:
"""Join the target room if the bot is not already present."""
url = f"{self._settings.matrix_base}/_matrix/client/v3/rooms/{room_id}/join"
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=15.0) as client:
await client.post(url, headers=headers)
async def send_message(self, token: str, room_id: str, text: str) -> None:
"""Send a plain text message to the Matrix room."""
url = f"{self._settings.matrix_base}/_matrix/client/v3/rooms/{room_id}/send/m.room.message"
headers = {"Authorization": f"Bearer {token}"}
payload = {"msgtype": "m.text", "body": text}
async with httpx.AsyncClient(timeout=15.0) as client:
await client.post(url, json=payload, headers=headers)
async def sync(self, token: str, since: str | None) -> dict[str, Any]:
"""Fetch the incremental Matrix sync payload."""
base = f"{self._settings.matrix_base}/_matrix/client/v3/sync"
params = {"timeout": 30000}
if since:
params["since"] = since
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=40.0) as client:
resp = await client.get(base, headers=headers, params=params)
resp.raise_for_status()
return resp.json()
class MatrixBot:
"""Drive Matrix conversation handling and heartbeat replies."""
def __init__(
self,
settings: Settings,
bot: MatrixBotConfig,
engine: AnswerEngine,
answer_handler: Callable[
[str, str, list[dict[str, str]] | None, str | None, Callable[[str, str], None] | None],
Awaitable[AnswerResult],
]
| None = None,
) -> None:
self._settings = settings
self._bot = bot
self._engine = engine
self._client = MatrixClient(settings, bot)
self._answer_handler = answer_handler
self._history: dict[str, list[dict[str, str]]] = {}
async def run(self) -> None:
"""Continuously bootstrap, sync, and answer Matrix events."""
while True:
try:
token = await self._client.login()
room_id = await self._client.resolve_room(token)
if room_id:
await self._client.join_room(token, room_id)
await self._sync_loop(token)
except Exception as exc:
log.warning("matrix bootstrap failed", extra={"extra": {"error": str(exc)}})
await asyncio.sleep(10)
async def _sync_loop(self, token: str) -> None:
since = None
while True:
try:
payload = await self._client.sync(token, since)
since = payload.get("next_batch")
await self._handle_sync(token, payload)
except Exception as exc:
log.warning("matrix sync failed", extra={"extra": {"error": str(exc)}})
await asyncio.sleep(5)
async def _handle_sync(self, token: str, payload: dict[str, Any]) -> None:
rooms = payload.get("rooms") or {}
joins = rooms.get("join") or {}
for room_id, room_data in joins.items():
events = (room_data.get("timeline") or {}).get("events") or []
for event in events:
if not isinstance(event, dict):
continue
if event.get("type") != "m.room.message":
continue
content = event.get("content") or {}
body = content.get("body") or ""
sender = event.get("sender") or ""
if sender.endswith(f"/{self._bot.username}") or sender == self._bot.username:
continue
mode, question = _extract_mode(body, self._bot.mentions, self._bot.mode)
if not question:
continue
await self._client.send_message(token, room_id, "Thinking…")
await self._answer_with_heartbeat(token, room_id, question, mode)
async def _answer_with_heartbeat(self, token: str, room_id: str, question: str, mode: str) -> None:
latest = {"stage": "", "note": ""}
stop = asyncio.Event()
def observer(stage: str, note: str) -> None:
latest["stage"] = stage
latest["note"] = note
async def heartbeat() -> None:
last_note = ""
last_sent = 0.0
while not stop.is_set():
await asyncio.sleep(self._settings.thinking_interval_sec)
if stop.is_set():
break
note = (latest.get("note") or latest.get("stage") or "thinking").strip()
if not note:
note = "thinking"
snippet = note[:64]
now = time.monotonic()
should_send = False
if snippet and snippet != last_note:
should_send = True
elif now - last_sent >= max(60.0, self._settings.thinking_interval_sec * 2):
should_send = True
if should_send:
msg = f"Still thinking — {snippet}"
await self._client.send_message(token, room_id, msg)
last_note = snippet
last_sent = now
task = asyncio.create_task(heartbeat())
started = time.monotonic()
try:
handler = self._answer_handler or (
lambda q, m, h, cid, obs: self._engine.answer(q, mode=m, history=h, observer=obs, conversation_id=cid)
)
history = self._history.get(room_id, [])
timeout_sec = _mode_timeout_sec(self._settings, mode)
if timeout_sec > 0:
result = await asyncio.wait_for(
handler(question, mode, history, room_id, observer),
timeout=timeout_sec + 1.0,
)
else:
result = await handler(question, mode, history, room_id, observer)
elapsed = time.monotonic() - started
await self._client.send_message(token, room_id, result.reply)
log.info(
"matrix_answer",
extra={
"extra": {
"mode": mode,
"seconds": round(elapsed, 2),
"scores": result.scores.__dict__,
}
},
)
history.append({"q": question, "a": result.reply})
self._history[room_id] = history[-4:]
except asyncio.TimeoutError:
timeout_sec = max(1, int(round(_mode_timeout_sec(self._settings, mode))))
if mode in {"quick", "fast"}:
msg = (
f"Quick mode hit {timeout_sec}s time budget before finishing. "
"Try atlas-smart for a deeper answer."
)
elif mode == "smart":
msg = (
f"Smart mode hit {timeout_sec}s time budget before finishing. "
"Try atlas-genius or ask a narrower follow-up."
)
else:
msg = "I ran out of time before I could finish this answer."
await self._client.send_message(token, room_id, msg)
log.warning(
"matrix_answer_timeout",
extra={"extra": {"mode": mode, "seconds": timeout_sec}},
)
except Exception as exc:
log.warning(
"matrix_answer_failed",
extra={"extra": {"mode": mode, "error": str(exc)}},
)
await self._client.send_message(
token,
room_id,
"I hit an internal error while answering. Please retry, or switch to atlas-smart.",
)
finally:
stop.set()
task.cancel()
def _extract_mode(body: str, mentions: tuple[str, ...], default_mode: str) -> tuple[str, str]:
lower = body.lower()
for mention in mentions:
if mention and mention.lower() in lower:
mode = default_mode or "quick"
if not default_mode:
if "atlas-smart" in lower or "smart" in lower:
mode = "smart"
if "atlas-genius" in lower or "genius" in lower:
mode = "genius"
if "atlas-quick" in lower or "quick" in lower:
mode = "quick"
cleaned = body
for tag in mentions:
cleaned = cleaned.replace(tag, "")
cleaned = cleaned.replace(tag.capitalize(), "")
return mode, cleaned.strip()
return ("", "")
def _mode_timeout_sec(settings: Settings, mode: str) -> float:
if mode == "genius":
return settings.genius_time_budget_sec
if mode == "smart":
return settings.smart_time_budget_sec
return settings.quick_time_budget_sec