364 lines
12 KiB
Python
364 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
import json
|
|
from typing import Any, Iterable
|
|
|
|
from .database import Database
|
|
|
|
|
|
REQUIRED_TASKS = (
|
|
"keycloak_user",
|
|
"keycloak_password",
|
|
"keycloak_groups",
|
|
"mailu_app_password",
|
|
"mailu_sync",
|
|
"nextcloud_mail_sync",
|
|
"wger_account",
|
|
"firefly_account",
|
|
"vaultwarden_invite",
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AccessRequest:
|
|
request_code: str
|
|
username: str
|
|
contact_email: str
|
|
status: str
|
|
email_verified_at: datetime | None
|
|
initial_password: str | None
|
|
initial_password_revealed_at: datetime | None
|
|
provision_attempted_at: datetime | None
|
|
approval_flags: list[str]
|
|
approval_note: str | None
|
|
denial_note: str | None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TaskRunRecord:
|
|
request_code: str | None
|
|
task: str
|
|
status: str
|
|
detail: str | None
|
|
started_at: datetime
|
|
finished_at: datetime | None
|
|
duration_ms: int | None
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ScheduleState:
|
|
task_name: str
|
|
cron_expr: str
|
|
last_started_at: datetime | None
|
|
last_finished_at: datetime | None
|
|
last_status: str | None
|
|
last_error: str | None
|
|
last_duration_ms: int | None
|
|
next_run_at: datetime | None
|
|
|
|
|
|
class Storage:
|
|
def __init__(self, db: Database, portal_db: Database | None = None) -> None:
|
|
self._db = db
|
|
self._portal_db = portal_db or db
|
|
|
|
def ensure_task_rows(self, request_code: str, tasks: Iterable[str]) -> None:
|
|
tasks_list = list(tasks)
|
|
if not tasks_list:
|
|
return
|
|
self._db.execute(
|
|
"""
|
|
INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at)
|
|
SELECT %s, task, 'pending', NULL, NOW()
|
|
FROM UNNEST(%s::text[]) AS task
|
|
ON CONFLICT (request_code, task) DO NOTHING
|
|
""",
|
|
(request_code, tasks_list),
|
|
)
|
|
|
|
def update_task(self, request_code: str, task: str, status: str, detail: str | None) -> None:
|
|
self._db.execute(
|
|
"""
|
|
INSERT INTO access_request_tasks (request_code, task, status, detail, updated_at)
|
|
VALUES (%s, %s, %s, %s, NOW())
|
|
ON CONFLICT (request_code, task)
|
|
DO UPDATE SET status = EXCLUDED.status, detail = EXCLUDED.detail, updated_at = NOW()
|
|
""",
|
|
(request_code, task, status, detail),
|
|
)
|
|
|
|
def task_statuses(self, request_code: str) -> dict[str, str]:
|
|
rows = self._db.fetchall(
|
|
"SELECT task, status FROM access_request_tasks WHERE request_code = %s",
|
|
(request_code,),
|
|
)
|
|
output: dict[str, str] = {}
|
|
for row in rows:
|
|
task = row.get("task")
|
|
status = row.get("status")
|
|
if isinstance(task, str) and isinstance(status, str):
|
|
output[task] = status
|
|
return output
|
|
|
|
def tasks_complete(self, request_code: str, tasks: Iterable[str]) -> bool:
|
|
statuses = self.task_statuses(request_code)
|
|
for task in tasks:
|
|
if statuses.get(task) != "ok":
|
|
return False
|
|
return True
|
|
|
|
def fetch_access_request(self, request_code: str) -> AccessRequest | None:
|
|
row = self._portal_db.fetchone(
|
|
"""
|
|
SELECT request_code, username, contact_email, status, email_verified_at,
|
|
initial_password, initial_password_revealed_at, provision_attempted_at,
|
|
approval_flags, approval_note, denial_note
|
|
FROM access_requests
|
|
WHERE request_code = %s
|
|
""",
|
|
(request_code,),
|
|
)
|
|
if not row:
|
|
return None
|
|
return self._row_to_request(row)
|
|
|
|
def find_access_request_by_username(self, username: str) -> AccessRequest | None:
|
|
row = self._portal_db.fetchone(
|
|
"""
|
|
SELECT request_code, username, contact_email, status, email_verified_at,
|
|
initial_password, initial_password_revealed_at, provision_attempted_at,
|
|
approval_flags, approval_note, denial_note
|
|
FROM access_requests
|
|
WHERE username = %s
|
|
ORDER BY created_at DESC
|
|
LIMIT 1
|
|
""",
|
|
(username,),
|
|
)
|
|
if not row:
|
|
return None
|
|
return self._row_to_request(row)
|
|
|
|
def list_pending_requests(self) -> list[dict[str, Any]]:
|
|
return self._portal_db.fetchall(
|
|
"""
|
|
SELECT request_code, username, contact_email, note, status, created_at
|
|
FROM access_requests
|
|
WHERE status = 'pending'
|
|
ORDER BY created_at ASC
|
|
LIMIT 200
|
|
"""
|
|
)
|
|
|
|
def list_provision_candidates(self) -> list[AccessRequest]:
|
|
rows = self._portal_db.fetchall(
|
|
"""
|
|
SELECT request_code, username, contact_email, status, email_verified_at,
|
|
initial_password, initial_password_revealed_at, provision_attempted_at,
|
|
approval_flags, approval_note, denial_note
|
|
FROM access_requests
|
|
WHERE status IN ('approved', 'accounts_building')
|
|
ORDER BY created_at ASC
|
|
LIMIT 200
|
|
"""
|
|
)
|
|
return [self._row_to_request(row) for row in rows]
|
|
|
|
def update_status(self, request_code: str, status: str) -> None:
|
|
self._portal_db.execute(
|
|
"UPDATE access_requests SET status = %s WHERE request_code = %s",
|
|
(status, request_code),
|
|
)
|
|
|
|
def mark_provision_attempted(self, request_code: str) -> None:
|
|
self._portal_db.execute(
|
|
"UPDATE access_requests SET provision_attempted_at = NOW() WHERE request_code = %s",
|
|
(request_code,),
|
|
)
|
|
|
|
def set_initial_password(self, request_code: str, password: str) -> None:
|
|
self._portal_db.execute(
|
|
"""
|
|
UPDATE access_requests
|
|
SET initial_password = %s
|
|
WHERE request_code = %s AND initial_password IS NULL
|
|
""",
|
|
(password, request_code),
|
|
)
|
|
|
|
def mark_welcome_sent(self, request_code: str) -> None:
|
|
self._portal_db.execute(
|
|
"""
|
|
UPDATE access_requests
|
|
SET welcome_email_sent_at = NOW()
|
|
WHERE request_code = %s AND welcome_email_sent_at IS NULL
|
|
""",
|
|
(request_code,),
|
|
)
|
|
|
|
def update_approval(self, request_code: str, status: str, decided_by: str, flags: list[str], note: str | None) -> None:
|
|
self._portal_db.execute(
|
|
"""
|
|
UPDATE access_requests
|
|
SET status = %s,
|
|
decided_at = NOW(),
|
|
decided_by = %s,
|
|
approval_flags = %s,
|
|
approval_note = %s,
|
|
denial_note = CASE WHEN %s = 'denied' THEN %s ELSE denial_note END
|
|
WHERE request_code = %s
|
|
""",
|
|
(status, decided_by or None, flags or None, note, status, note, request_code),
|
|
)
|
|
|
|
def record_task_run(self, record: TaskRunRecord) -> None:
|
|
self._db.execute(
|
|
"""
|
|
INSERT INTO ariadne_task_runs
|
|
(request_code, task, status, detail, started_at, finished_at, duration_ms)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
""",
|
|
(
|
|
record.request_code,
|
|
record.task,
|
|
record.status,
|
|
record.detail,
|
|
record.started_at,
|
|
record.finished_at,
|
|
record.duration_ms,
|
|
),
|
|
)
|
|
|
|
def update_schedule_state(self, state: ScheduleState) -> None:
|
|
self._db.execute(
|
|
"""
|
|
INSERT INTO ariadne_schedule_state
|
|
(task_name, cron_expr, last_started_at, last_finished_at, last_status,
|
|
last_error, last_duration_ms, next_run_at, updated_at)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, NOW())
|
|
ON CONFLICT (task_name) DO UPDATE
|
|
SET cron_expr = EXCLUDED.cron_expr,
|
|
last_started_at = EXCLUDED.last_started_at,
|
|
last_finished_at = EXCLUDED.last_finished_at,
|
|
last_status = EXCLUDED.last_status,
|
|
last_error = EXCLUDED.last_error,
|
|
last_duration_ms = EXCLUDED.last_duration_ms,
|
|
next_run_at = EXCLUDED.next_run_at,
|
|
updated_at = NOW()
|
|
""",
|
|
(
|
|
state.task_name,
|
|
state.cron_expr,
|
|
state.last_started_at,
|
|
state.last_finished_at,
|
|
state.last_status,
|
|
state.last_error,
|
|
state.last_duration_ms,
|
|
state.next_run_at,
|
|
),
|
|
)
|
|
|
|
def record_event(self, event_type: str, detail: dict[str, Any] | str | None) -> None:
|
|
payload = detail
|
|
if isinstance(detail, dict):
|
|
payload = json.dumps(detail, ensure_ascii=True)
|
|
self._db.execute(
|
|
"INSERT INTO ariadne_events (event_type, detail) VALUES (%s, %s)",
|
|
(event_type, payload),
|
|
)
|
|
|
|
def list_events(self, limit: int = 200, event_type: str | None = None) -> list[dict[str, Any]]:
|
|
limit = max(1, min(int(limit or 200), 500))
|
|
if event_type:
|
|
return self._db.fetchall(
|
|
"""
|
|
SELECT id, event_type, detail, created_at
|
|
FROM ariadne_events
|
|
WHERE event_type = %s
|
|
ORDER BY created_at DESC
|
|
LIMIT %s
|
|
""",
|
|
(event_type, limit),
|
|
)
|
|
return self._db.fetchall(
|
|
"""
|
|
SELECT id, event_type, detail, created_at
|
|
FROM ariadne_events
|
|
ORDER BY created_at DESC
|
|
LIMIT %s
|
|
""",
|
|
(limit,),
|
|
)
|
|
|
|
def list_task_runs(
|
|
self,
|
|
limit: int = 200,
|
|
request_code: str | None = None,
|
|
task: str | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
limit = max(1, min(int(limit or 200), 500))
|
|
if request_code and task:
|
|
return self._db.fetchall(
|
|
"""
|
|
SELECT id, request_code, task, status, detail, started_at, finished_at, duration_ms
|
|
FROM ariadne_task_runs
|
|
WHERE request_code = %s AND task = %s
|
|
ORDER BY started_at DESC
|
|
LIMIT %s
|
|
""",
|
|
(request_code, task, limit),
|
|
)
|
|
if request_code:
|
|
return self._db.fetchall(
|
|
"""
|
|
SELECT id, request_code, task, status, detail, started_at, finished_at, duration_ms
|
|
FROM ariadne_task_runs
|
|
WHERE request_code = %s
|
|
ORDER BY started_at DESC
|
|
LIMIT %s
|
|
""",
|
|
(request_code, limit),
|
|
)
|
|
if task:
|
|
return self._db.fetchall(
|
|
"""
|
|
SELECT id, request_code, task, status, detail, started_at, finished_at, duration_ms
|
|
FROM ariadne_task_runs
|
|
WHERE task = %s
|
|
ORDER BY started_at DESC
|
|
LIMIT %s
|
|
""",
|
|
(task, limit),
|
|
)
|
|
return self._db.fetchall(
|
|
"""
|
|
SELECT id, request_code, task, status, detail, started_at, finished_at, duration_ms
|
|
FROM ariadne_task_runs
|
|
ORDER BY started_at DESC
|
|
LIMIT %s
|
|
""",
|
|
(limit,),
|
|
)
|
|
|
|
@staticmethod
|
|
def _row_to_request(row: dict[str, Any]) -> AccessRequest:
|
|
flags = row.get("approval_flags")
|
|
flags_list: list[str] = []
|
|
if isinstance(flags, list):
|
|
flags_list = [str(item) for item in flags if item]
|
|
return AccessRequest(
|
|
request_code=str(row.get("request_code") or ""),
|
|
username=str(row.get("username") or ""),
|
|
contact_email=str(row.get("contact_email") or ""),
|
|
status=str(row.get("status") or ""),
|
|
email_verified_at=row.get("email_verified_at"),
|
|
initial_password=row.get("initial_password"),
|
|
initial_password_revealed_at=row.get("initial_password_revealed_at"),
|
|
provision_attempted_at=row.get("provision_attempted_at"),
|
|
approval_flags=flags_list,
|
|
approval_note=row.get("approval_note"),
|
|
denial_note=row.get("denial_note"),
|
|
)
|