ariadne/ariadne/db/storage.py

364 lines
12 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
import json
from typing import Any, Iterable
from .database import Database
REQUIRED_TASKS = (
"keycloak_user",
"keycloak_password",
"keycloak_groups",
"mailu_app_password",
"mailu_sync",
"nextcloud_mail_sync",
"wger_account",
"firefly_account",
"vaultwarden_invite",
)
@dataclass(frozen=True)
class AccessRequest:
request_code: str
username: str
contact_email: str
status: str
email_verified_at: datetime | None
initial_password: str | None
initial_password_revealed_at: datetime | None
provision_attempted_at: datetime | None
approval_flags: list[str]
approval_note: str | None
denial_note: str | None
@dataclass(frozen=True)
class TaskRunRecord:
request_code: str | None
task: str
status: str
detail: str | None
started_at: datetime
finished_at: datetime | None
duration_ms: int | None
@dataclass(frozen=True)
class ScheduleState:
task_name: str
cron_expr: str
last_started_at: datetime | None
last_finished_at: datetime | None
last_status: str | None
last_error: str | None
last_duration_ms: int | None
next_run_at: datetime | None
class Storage:
def __init__(self, db: Database, portal_db: Database | None = None) -> None:
self._db = db
self._portal_db = portal_db or db
def ensure_task_rows(self, request_code: str, tasks: Iterable[str]) -> None:
tasks_list = list(tasks)
if not tasks_list:
return
self._db.execute(
"""
INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at)
SELECT %s, task, 'pending', NULL, NOW()
FROM UNNEST(%s::text[]) AS task
ON CONFLICT (request_code, task) DO NOTHING
""",
(request_code, tasks_list),
)
def update_task(self, request_code: str, task: str, status: str, detail: str | None) -> None:
self._db.execute(
"""
INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at)
VALUES (%s, %s, %s, %s, NOW())
ON CONFLICT (request_code, task)
DO UPDATE SET status = EXCLUDED.status, detail = EXCLUDED.detail, updated_at = NOW()
""",
(request_code, task, status, detail),
)
def task_statuses(self, request_code: str) -> dict[str, str]:
rows = self._db.fetchall(
"SELECT task, status FROM access_request_tasks WHERE request_code = %s",
(request_code,),
)
output: dict[str, str] = {}
for row in rows:
task = row.get("task")
status = row.get("status")
if isinstance(task, str) and isinstance(status, str):
output[task] = status
return output
def tasks_complete(self, request_code: str, tasks: Iterable[str]) -> bool:
statuses = self.task_statuses(request_code)
for task in tasks:
if statuses.get(task) != "ok":
return False
return True
def fetch_access_request(self, request_code: str) -> AccessRequest | None:
row = self._portal_db.fetchone(
"""
SELECT request_code, username, contact_email, status, email_verified_at,
initial_password, initial_password_revealed_at, provision_attempted_at,
approval_flags, approval_note, denial_note
FROM access_requests
WHERE request_code = %s
""",
(request_code,),
)
if not row:
return None
return self._row_to_request(row)
def find_access_request_by_username(self, username: str) -> AccessRequest | None:
row = self._portal_db.fetchone(
"""
SELECT request_code, username, contact_email, status, email_verified_at,
initial_password, initial_password_revealed_at, provision_attempted_at,
approval_flags, approval_note, denial_note
FROM access_requests
WHERE username = %s
ORDER BY created_at DESC
LIMIT 1
""",
(username,),
)
if not row:
return None
return self._row_to_request(row)
def list_pending_requests(self) -> list[dict[str, Any]]:
return self._portal_db.fetchall(
"""
SELECT request_code, username, contact_email, note, status, created_at
FROM access_requests
WHERE status = 'pending'
ORDER BY created_at ASC
LIMIT 200
"""
)
def list_provision_candidates(self) -> list[AccessRequest]:
rows = self._portal_db.fetchall(
"""
SELECT request_code, username, contact_email, status, email_verified_at,
initial_password, initial_password_revealed_at, provision_attempted_at,
approval_flags, approval_note, denial_note
FROM access_requests
WHERE status IN ('approved', 'accounts_building')
ORDER BY created_at ASC
LIMIT 200
"""
)
return [self._row_to_request(row) for row in rows]
def update_status(self, request_code: str, status: str) -> None:
self._portal_db.execute(
"UPDATE access_requests SET status = %s WHERE request_code = %s",
(status, request_code),
)
def mark_provision_attempted(self, request_code: str) -> None:
self._portal_db.execute(
"UPDATE access_requests SET provision_attempted_at = NOW() WHERE request_code = %s",
(request_code,),
)
def set_initial_password(self, request_code: str, password: str) -> None:
self._portal_db.execute(
"""
UPDATE access_requests
SET initial_password = %s
WHERE request_code = %s AND initial_password IS NULL
""",
(password, request_code),
)
def mark_welcome_sent(self, request_code: str) -> None:
self._portal_db.execute(
"""
UPDATE access_requests
SET welcome_email_sent_at = NOW()
WHERE request_code = %s AND welcome_email_sent_at IS NULL
""",
(request_code,),
)
def update_approval(self, request_code: str, status: str, decided_by: str, flags: list[str], note: str | None) -> None:
self._portal_db.execute(
"""
UPDATE access_requests
SET status = %s,
decided_at = NOW(),
decided_by = %s,
approval_flags = %s,
approval_note = %s,
denial_note = CASE WHEN %s = 'denied' THEN %s ELSE denial_note END
WHERE request_code = %s
""",
(status, decided_by or None, flags or None, note, status, note, request_code),
)
def record_task_run(self, record: TaskRunRecord) -> None:
self._db.execute(
"""
INSERT INTO ariadne_task_runs
(request_code, task, status, detail, started_at, finished_at, duration_ms)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(
record.request_code,
record.task,
record.status,
record.detail,
record.started_at,
record.finished_at,
record.duration_ms,
),
)
def update_schedule_state(self, state: ScheduleState) -> None:
self._db.execute(
"""
INSERT INTO ariadne_schedule_state
(task_name, cron_expr, last_started_at, last_finished_at, last_status,
last_error, last_duration_ms, next_run_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, NOW())
ON CONFLICT (task_name) DO UPDATE
SET cron_expr = EXCLUDED.cron_expr,
last_started_at = EXCLUDED.last_started_at,
last_finished_at = EXCLUDED.last_finished_at,
last_status = EXCLUDED.last_status,
last_error = EXCLUDED.last_error,
last_duration_ms = EXCLUDED.last_duration_ms,
next_run_at = EXCLUDED.next_run_at,
updated_at = NOW()
""",
(
state.task_name,
state.cron_expr,
state.last_started_at,
state.last_finished_at,
state.last_status,
state.last_error,
state.last_duration_ms,
state.next_run_at,
),
)
def record_event(self, event_type: str, detail: dict[str, Any] | str | None) -> None:
payload = detail
if isinstance(detail, dict):
payload = json.dumps(detail, ensure_ascii=True)
self._db.execute(
"INSERT INTO ariadne_events (event_type, detail) VALUES (%s, %s)",
(event_type, payload),
)
def list_events(self, limit: int = 200, event_type: str | None = None) -> list[dict[str, Any]]:
limit = max(1, min(int(limit or 200), 500))
if event_type:
return self._db.fetchall(
"""
SELECT id, event_type, detail, created_at
FROM ariadne_events
WHERE event_type = %s
ORDER BY created_at DESC
LIMIT %s
""",
(event_type, limit),
)
return self._db.fetchall(
"""
SELECT id, event_type, detail, created_at
FROM ariadne_events
ORDER BY created_at DESC
LIMIT %s
""",
(limit,),
)
def list_task_runs(
self,
limit: int = 200,
request_code: str | None = None,
task: str | None = None,
) -> list[dict[str, Any]]:
limit = max(1, min(int(limit or 200), 500))
if request_code and task:
return self._db.fetchall(
"""
SELECT id, request_code, task, status, detail, started_at, finished_at, duration_ms
FROM ariadne_task_runs
WHERE request_code = %s AND task = %s
ORDER BY started_at DESC
LIMIT %s
""",
(request_code, task, limit),
)
if request_code:
return self._db.fetchall(
"""
SELECT id, request_code, task, status, detail, started_at, finished_at, duration_ms
FROM ariadne_task_runs
WHERE request_code = %s
ORDER BY started_at DESC
LIMIT %s
""",
(request_code, limit),
)
if task:
return self._db.fetchall(
"""
SELECT id, request_code, task, status, detail, started_at, finished_at, duration_ms
FROM ariadne_task_runs
WHERE task = %s
ORDER BY started_at DESC
LIMIT %s
""",
(task, limit),
)
return self._db.fetchall(
"""
SELECT id, request_code, task, status, detail, started_at, finished_at, duration_ms
FROM ariadne_task_runs
ORDER BY started_at DESC
LIMIT %s
""",
(limit,),
)
@staticmethod
def _row_to_request(row: dict[str, Any]) -> AccessRequest:
flags = row.get("approval_flags")
flags_list: list[str] = []
if isinstance(flags, list):
flags_list = [str(item) for item in flags if item]
return AccessRequest(
request_code=str(row.get("request_code") or ""),
username=str(row.get("username") or ""),
contact_email=str(row.get("contact_email") or ""),
status=str(row.get("status") or ""),
email_verified_at=row.get("email_verified_at"),
initial_password=row.get("initial_password"),
initial_password_revealed_at=row.get("initial_password_revealed_at"),
provision_attempted_at=row.get("provision_attempted_at"),
approval_flags=flags_list,
approval_note=row.get("approval_note"),
denial_note=row.get("denial_note"),
)