183 lines
6.6 KiB
Python
183 lines
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import asdict, dataclass, is_dataclass
|
|
from datetime import datetime, timezone
|
|
import json
|
|
import threading
|
|
import time
|
|
from typing import Any, Callable
|
|
|
|
from croniter import croniter
|
|
|
|
from ..db.storage import ScheduleState, Storage, TaskRunRecord
|
|
from ..metrics.metrics import record_schedule_state, record_task_run
|
|
from ..utils.logging import get_logger, task_context
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CronTask:
|
|
name: str
|
|
cron_expr: str
|
|
runner: Callable[[], None]
|
|
|
|
|
|
class CronScheduler:
|
|
"""Run named cron tasks while recording schedule state and outcomes."""
|
|
|
|
def __init__(self, storage: Storage, tick_sec: float = 5.0) -> None:
|
|
self._storage = storage
|
|
self._tick_sec = tick_sec
|
|
self._tasks: dict[str, CronTask] = {}
|
|
self._next_run: dict[str, datetime] = {}
|
|
self._running: set[str] = set()
|
|
self._lock = threading.Lock()
|
|
self._stop_event = threading.Event()
|
|
self._thread: threading.Thread | None = None
|
|
self._logger = get_logger(__name__)
|
|
|
|
def add_task(self, name: str, cron_expr: str, runner: Callable[[], None]) -> None:
|
|
task = CronTask(name=name, cron_expr=cron_expr, runner=runner)
|
|
self._tasks[name] = task
|
|
self._next_run[name] = self._compute_next(cron_expr, datetime.now(timezone.utc))
|
|
|
|
def start(self) -> None:
|
|
if self._thread and self._thread.is_alive():
|
|
return
|
|
self._stop_event.clear()
|
|
self._thread = threading.Thread(target=self._run_loop, name="ariadne-scheduler", daemon=True)
|
|
self._thread.start()
|
|
|
|
def stop(self) -> None:
|
|
self._stop_event.set()
|
|
if self._thread:
|
|
self._thread.join(timeout=5)
|
|
|
|
def _compute_next(self, cron_expr: str, base: datetime) -> datetime:
|
|
itr = croniter(cron_expr, base)
|
|
next_time = itr.get_next(datetime)
|
|
if next_time.tzinfo is None:
|
|
return next_time.replace(tzinfo=timezone.utc)
|
|
return next_time
|
|
|
|
def _run_loop(self) -> None:
|
|
while not self._stop_event.is_set():
|
|
now = datetime.now(timezone.utc)
|
|
for name, task in list(self._tasks.items()):
|
|
next_run = self._next_run.get(name)
|
|
if next_run and now >= next_run:
|
|
with self._lock:
|
|
if name in self._running:
|
|
continue
|
|
self._running.add(name)
|
|
self._next_run[name] = self._compute_next(task.cron_expr, now)
|
|
threading.Thread(
|
|
target=self._execute_task,
|
|
args=(task,),
|
|
name=f"ariadne-scheduler-{name}",
|
|
daemon=True,
|
|
).start()
|
|
record_schedule_state(
|
|
name,
|
|
None,
|
|
None,
|
|
self._next_run.get(name).timestamp() if self._next_run.get(name) else None,
|
|
None,
|
|
)
|
|
time.sleep(self._tick_sec)
|
|
|
|
def _execute_task(self, task: CronTask) -> None:
|
|
started = datetime.now(timezone.utc)
|
|
status = "ok"
|
|
detail = None
|
|
result_detail = ""
|
|
result_payload: Any | None = None
|
|
with task_context(task.name):
|
|
self._logger.info(
|
|
"schedule task started",
|
|
extra={"event": "schedule_start", "task": task.name},
|
|
)
|
|
try:
|
|
result = task.runner()
|
|
result_detail, result_payload = self._format_result(result)
|
|
except Exception as exc:
|
|
status = "error"
|
|
detail = str(exc).strip() or "task failed"
|
|
finished = datetime.now(timezone.utc)
|
|
duration_sec = (finished - started).total_seconds()
|
|
detail_value = detail or result_detail or ""
|
|
record_task_run(task.name, status, duration_sec)
|
|
self._logger.info(
|
|
"schedule task finished",
|
|
extra={
|
|
"event": "schedule_finish",
|
|
"task": task.name,
|
|
"status": status,
|
|
"duration_sec": round(duration_sec, 3),
|
|
"detail": detail_value,
|
|
"result": result_payload if result_payload is not None else "",
|
|
},
|
|
)
|
|
try:
|
|
event_detail: dict[str, Any] = {
|
|
"task": task.name,
|
|
"status": status,
|
|
"duration_sec": round(duration_sec, 3),
|
|
"detail": detail_value,
|
|
"next_run_at": (
|
|
self._next_run.get(task.name).isoformat() if self._next_run.get(task.name) else ""
|
|
),
|
|
}
|
|
if result_payload not in (None, ""):
|
|
event_detail["result"] = result_payload
|
|
self._storage.record_event(
|
|
"schedule_task",
|
|
event_detail,
|
|
)
|
|
except Exception:
|
|
pass
|
|
record_schedule_state(
|
|
task.name,
|
|
started.timestamp(),
|
|
finished.timestamp() if status == "ok" else None,
|
|
self._next_run.get(task.name).timestamp() if self._next_run.get(task.name) else None,
|
|
status == "ok",
|
|
)
|
|
try:
|
|
self._storage.record_task_run(
|
|
TaskRunRecord(
|
|
request_code=None,
|
|
task=task.name,
|
|
status=status,
|
|
detail=detail_value or None,
|
|
started_at=started,
|
|
finished_at=finished,
|
|
duration_ms=int(duration_sec * 1000),
|
|
)
|
|
)
|
|
self._storage.update_schedule_state(
|
|
ScheduleState(
|
|
task_name=task.name,
|
|
cron_expr=task.cron_expr,
|
|
last_started_at=started,
|
|
last_finished_at=finished,
|
|
last_status=status,
|
|
last_error=detail,
|
|
last_duration_ms=int(duration_sec * 1000),
|
|
next_run_at=self._next_run.get(task.name),
|
|
)
|
|
)
|
|
except Exception:
|
|
pass
|
|
with self._lock:
|
|
self._running.discard(task.name)
|
|
|
|
@staticmethod
|
|
def _format_result(result: Any) -> tuple[str, Any | None]:
|
|
if result is None:
|
|
return "", None
|
|
if is_dataclass(result):
|
|
result = asdict(result)
|
|
if isinstance(result, dict):
|
|
return json.dumps(result, ensure_ascii=True), result
|
|
return str(result), result
|