91 lines
3.1 KiB
Python

import asyncio
import json
import logging
from typing import Any, Awaitable, Callable
from nats.aio.client import Client as NATS
from nats.js.errors import NotFoundError
from atlasbot.config import Settings
log = logging.getLogger(__name__)
class QueueManager:
def __init__(self, settings: Settings, handler: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> None:
self._settings = settings
self._handler = handler
self._nc: NATS | None = None
self._js = None
self._worker_task: asyncio.Task | None = None
async def start(self) -> None:
if not self._settings.queue_enabled:
return
self._nc = NATS()
await self._nc.connect(self._settings.nats_url)
self._js = self._nc.jetstream()
await self._ensure_stream()
self._worker_task = asyncio.create_task(self._worker_loop())
async def stop(self) -> None:
if self._worker_task:
self._worker_task.cancel()
if self._nc:
await self._nc.drain()
async def submit(self, payload: dict[str, Any]) -> dict[str, Any]:
if not self._settings.queue_enabled:
return await self._handler(payload)
if not self._nc or not self._js:
raise RuntimeError("queue not initialized")
reply = self._nc.new_inbox()
sub = await self._nc.subscribe(reply)
envelope = {"reply": reply, "payload": payload}
await self._js.publish(self._settings.nats_subject, json.dumps(envelope).encode())
msg = await sub.next_msg(timeout=300)
await sub.unsubscribe()
return json.loads(msg.data.decode())
async def _ensure_stream(self) -> None:
assert self._js is not None
try:
await self._js.stream_info(self._settings.nats_stream)
except NotFoundError:
await self._js.add_stream(
name=self._settings.nats_stream,
subjects=[self._settings.nats_subject],
retention="workqueue",
max_msgs=10000,
max_bytes=50 * 1024 * 1024,
)
async def _worker_loop(self) -> None:
assert self._js is not None
sub = await self._js.pull_subscribe(self._settings.nats_subject, durable="atlasbot-worker")
while True:
try:
msgs = await sub.fetch(1, timeout=1)
except Exception:
await asyncio.sleep(0.2)
continue
for msg in msgs:
await self._handle_message(msg)
async def _handle_message(self, msg) -> None:
try:
envelope = json.loads(msg.data.decode())
except Exception:
await msg.ack()
return
payload = envelope.get("payload", envelope)
reply = envelope.get("reply") or msg.reply
try:
result = await self._handler(payload)
if reply and self._nc:
await self._nc.publish(reply, json.dumps(result).encode())
except Exception as exc:
log.warning("queue handler failed", extra={"extra": {"error": str(exc)}})
finally:
await msg.ack()