91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any, Awaitable, Callable
|
|
|
|
from nats.aio.client import Client as NATS
|
|
from nats.js.errors import NotFoundError
|
|
|
|
from atlasbot.config import Settings
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class QueueManager:
|
|
def __init__(self, settings: Settings, handler: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> None:
|
|
self._settings = settings
|
|
self._handler = handler
|
|
self._nc: NATS | None = None
|
|
self._js = None
|
|
self._worker_task: asyncio.Task | None = None
|
|
|
|
async def start(self) -> None:
|
|
if not self._settings.queue_enabled:
|
|
return
|
|
self._nc = NATS()
|
|
await self._nc.connect(self._settings.nats_url)
|
|
self._js = self._nc.jetstream()
|
|
await self._ensure_stream()
|
|
self._worker_task = asyncio.create_task(self._worker_loop())
|
|
|
|
async def stop(self) -> None:
|
|
if self._worker_task:
|
|
self._worker_task.cancel()
|
|
if self._nc:
|
|
await self._nc.drain()
|
|
|
|
async def submit(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
if not self._settings.queue_enabled:
|
|
return await self._handler(payload)
|
|
if not self._nc or not self._js:
|
|
raise RuntimeError("queue not initialized")
|
|
reply = self._nc.new_inbox()
|
|
sub = await self._nc.subscribe(reply)
|
|
envelope = {"reply": reply, "payload": payload}
|
|
await self._js.publish(self._settings.nats_subject, json.dumps(envelope).encode())
|
|
msg = await sub.next_msg(timeout=300)
|
|
await sub.unsubscribe()
|
|
return json.loads(msg.data.decode())
|
|
|
|
async def _ensure_stream(self) -> None:
|
|
assert self._js is not None
|
|
try:
|
|
await self._js.stream_info(self._settings.nats_stream)
|
|
except NotFoundError:
|
|
await self._js.add_stream(
|
|
name=self._settings.nats_stream,
|
|
subjects=[self._settings.nats_subject],
|
|
retention="workqueue",
|
|
max_msgs=10000,
|
|
max_bytes=50 * 1024 * 1024,
|
|
)
|
|
|
|
async def _worker_loop(self) -> None:
|
|
assert self._js is not None
|
|
sub = await self._js.pull_subscribe(self._settings.nats_subject, durable="atlasbot-worker")
|
|
while True:
|
|
try:
|
|
msgs = await sub.fetch(1, timeout=1)
|
|
except Exception:
|
|
await asyncio.sleep(0.2)
|
|
continue
|
|
for msg in msgs:
|
|
await self._handle_message(msg)
|
|
|
|
async def _handle_message(self, msg) -> None:
|
|
try:
|
|
envelope = json.loads(msg.data.decode())
|
|
except Exception:
|
|
await msg.ack()
|
|
return
|
|
payload = envelope.get("payload", envelope)
|
|
reply = envelope.get("reply") or msg.reply
|
|
try:
|
|
result = await self._handler(payload)
|
|
if reply and self._nc:
|
|
await self._nc.publish(reply, json.dumps(result).encode())
|
|
except Exception as exc:
|
|
log.warning("queue handler failed", extra={"extra": {"error": str(exc)}})
|
|
finally:
|
|
await msg.ack()
|