portal: make provisioning retries safe
This commit is contained in:
parent
71678a4819
commit
b41a2a2b1d
@ -40,12 +40,14 @@ def ensure_schema() -> None:
|
||||
decided_at TIMESTAMPTZ,
|
||||
decided_by TEXT,
|
||||
initial_password TEXT,
|
||||
initial_password_revealed_at TIMESTAMPTZ
|
||||
initial_password_revealed_at TIMESTAMPTZ,
|
||||
provision_attempted_at TIMESTAMPTZ
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS initial_password TEXT")
|
||||
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS initial_password_revealed_at TIMESTAMPTZ")
|
||||
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS provision_attempted_at TIMESTAMPTZ")
|
||||
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS email_verification_token_hash TEXT")
|
||||
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS email_verification_sent_at TIMESTAMPTZ")
|
||||
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS email_verified_at TIMESTAMPTZ")
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
import httpx
|
||||
@ -30,6 +32,11 @@ class ProvisionResult:
|
||||
status: str
|
||||
|
||||
|
||||
def _advisory_lock_id(request_code: str) -> int:
|
||||
digest = hashlib.sha256(request_code.encode("utf-8")).digest()
|
||||
return int.from_bytes(digest[:8], "big", signed=True)
|
||||
|
||||
|
||||
def _upsert_task(conn, request_code: str, task: str, status: str, detail: str | None = None) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
@ -77,6 +84,15 @@ def provision_access_request(request_code: str) -> ProvisionResult:
|
||||
required_tasks = list(REQUIRED_PROVISION_TASKS)
|
||||
|
||||
with connect() as conn:
|
||||
lock_id = _advisory_lock_id(request_code)
|
||||
lock_row = conn.execute(
|
||||
"SELECT pg_try_advisory_lock(%s) AS locked",
|
||||
(lock_id,),
|
||||
).fetchone()
|
||||
if not lock_row or not lock_row.get("locked"):
|
||||
return ProvisionResult(ok=False, status="accounts_building")
|
||||
|
||||
try:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT username,
|
||||
@ -84,7 +100,8 @@ def provision_access_request(request_code: str) -> ProvisionResult:
|
||||
email_verified_at,
|
||||
status,
|
||||
initial_password,
|
||||
initial_password_revealed_at
|
||||
initial_password_revealed_at,
|
||||
provision_attempted_at
|
||||
FROM access_requests
|
||||
WHERE request_code = %s
|
||||
""",
|
||||
@ -99,10 +116,24 @@ def provision_access_request(request_code: str) -> ProvisionResult:
|
||||
status = str(row.get("status") or "")
|
||||
initial_password = row.get("initial_password")
|
||||
revealed_at = row.get("initial_password_revealed_at")
|
||||
attempted_at = row.get("provision_attempted_at")
|
||||
|
||||
if status not in {"accounts_building", "awaiting_onboarding", "ready"}:
|
||||
return ProvisionResult(ok=False, status=status or "unknown")
|
||||
|
||||
if status == "accounts_building":
|
||||
now = datetime.now(timezone.utc)
|
||||
if isinstance(attempted_at, datetime):
|
||||
if attempted_at.tzinfo is None:
|
||||
attempted_at = attempted_at.replace(tzinfo=timezone.utc)
|
||||
age_sec = (now - attempted_at).total_seconds()
|
||||
if age_sec < settings.ACCESS_REQUEST_PROVISION_RETRY_COOLDOWN_SEC:
|
||||
return ProvisionResult(ok=False, status="accounts_building")
|
||||
conn.execute(
|
||||
"UPDATE access_requests SET provision_attempted_at = NOW() WHERE request_code = %s",
|
||||
(request_code,),
|
||||
)
|
||||
|
||||
user_id = ""
|
||||
mailu_email = f"{username}@{settings.MAILU_DOMAIN}"
|
||||
|
||||
@ -126,6 +157,7 @@ def provision_access_request(request_code: str) -> ProvisionResult:
|
||||
user_id = str((user or {}).get("id") or "")
|
||||
if not user_id:
|
||||
raise RuntimeError("user id missing")
|
||||
|
||||
try:
|
||||
full = admin_client().get_user(user_id)
|
||||
attrs = full.get("attributes") or {}
|
||||
@ -146,43 +178,51 @@ def provision_access_request(request_code: str) -> ProvisionResult:
|
||||
mailu_email = f"{username}@{settings.MAILU_DOMAIN}"
|
||||
admin_client().set_user_attribute(username, MAILU_EMAIL_ATTR, mailu_email)
|
||||
except Exception:
|
||||
# Non-fatal: Mailu sync will fall back to username@domain.
|
||||
mailu_email = f"{username}@{settings.MAILU_DOMAIN}"
|
||||
|
||||
_upsert_task(conn, request_code, "keycloak_user", "ok", None)
|
||||
except Exception:
|
||||
_upsert_task(conn, request_code, "keycloak_user", "error", "failed to ensure user")
|
||||
|
||||
# Task: set initial temporary password and store it for "show once" onboarding
|
||||
# Task: set initial temporary password and store it for "show once" onboarding.
|
||||
try:
|
||||
if user_id:
|
||||
password_value = ""
|
||||
if not user_id:
|
||||
raise RuntimeError("missing user id")
|
||||
|
||||
should_reset = status == "accounts_building" and revealed_at is None
|
||||
password_value: str | None = None
|
||||
|
||||
if should_reset:
|
||||
if isinstance(initial_password, str) and initial_password:
|
||||
password_value = initial_password
|
||||
elif initial_password is None and revealed_at is None:
|
||||
elif initial_password is None:
|
||||
password_value = random_password(20)
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE access_requests
|
||||
SET initial_password = %s, initial_password_revealed_at = NULL
|
||||
SET initial_password = %s
|
||||
WHERE request_code = %s AND initial_password IS NULL
|
||||
""",
|
||||
(password_value, request_code),
|
||||
)
|
||||
initial_password = password_value
|
||||
elif isinstance(initial_password, str) and initial_password and revealed_at is None:
|
||||
password_value = initial_password
|
||||
|
||||
if password_value:
|
||||
admin_client().reset_password(user_id, password_value, temporary=True)
|
||||
|
||||
if isinstance(initial_password, str) and initial_password:
|
||||
_upsert_task(conn, request_code, "keycloak_password", "ok", None)
|
||||
elif revealed_at is not None:
|
||||
_upsert_task(conn, request_code, "keycloak_password", "ok", "initial password already revealed")
|
||||
else:
|
||||
raise RuntimeError("missing user id")
|
||||
raise RuntimeError("initial password missing")
|
||||
except Exception:
|
||||
_upsert_task(conn, request_code, "keycloak_password", "error", "failed to set password")
|
||||
|
||||
# Task: group membership (default dev)
|
||||
try:
|
||||
if user_id:
|
||||
if not user_id:
|
||||
raise RuntimeError("missing user id")
|
||||
groups = settings.DEFAULT_USER_GROUPS or ["dev"]
|
||||
for group_name in groups:
|
||||
gid = admin_client().get_group_id(group_name)
|
||||
@ -190,14 +230,13 @@ def provision_access_request(request_code: str) -> ProvisionResult:
|
||||
raise RuntimeError("group missing")
|
||||
admin_client().add_user_to_group(user_id, gid)
|
||||
_upsert_task(conn, request_code, "keycloak_groups", "ok", None)
|
||||
else:
|
||||
raise RuntimeError("missing user id")
|
||||
except Exception:
|
||||
_upsert_task(conn, request_code, "keycloak_groups", "error", "failed to add groups")
|
||||
|
||||
# Task: ensure mailu_app_password attribute exists
|
||||
try:
|
||||
if user_id:
|
||||
if not user_id:
|
||||
raise RuntimeError("missing user id")
|
||||
full = admin_client().get_user(user_id)
|
||||
attrs = full.get("attributes") or {}
|
||||
existing = None
|
||||
@ -210,8 +249,6 @@ def provision_access_request(request_code: str) -> ProvisionResult:
|
||||
if not existing:
|
||||
admin_client().set_user_attribute(username, MAILU_APP_PASSWORD_ATTR, random_password())
|
||||
_upsert_task(conn, request_code, "mailu_app_password", "ok", None)
|
||||
else:
|
||||
raise RuntimeError("missing user id")
|
||||
except Exception:
|
||||
_upsert_task(conn, request_code, "mailu_app_password", "error", "failed to set mail password")
|
||||
|
||||
@ -233,18 +270,16 @@ def provision_access_request(request_code: str) -> ProvisionResult:
|
||||
|
||||
# Task: ensure Vaultwarden account exists (invite flow)
|
||||
try:
|
||||
if user_id:
|
||||
if not user_id:
|
||||
raise RuntimeError("missing user id")
|
||||
result = invite_user(mailu_email or f"{username}@{settings.MAILU_DOMAIN}")
|
||||
if result.ok:
|
||||
_upsert_task(conn, request_code, "vaultwarden_invite", "ok", result.status)
|
||||
else:
|
||||
_upsert_task(conn, request_code, "vaultwarden_invite", "error", result.detail or result.status)
|
||||
else:
|
||||
raise RuntimeError("missing user id")
|
||||
except Exception:
|
||||
_upsert_task(conn, request_code, "vaultwarden_invite", "error", "failed to provision vaultwarden")
|
||||
|
||||
# If everything is OK, advance to awaiting_onboarding.
|
||||
if _all_tasks_ok(conn, request_code, required_tasks):
|
||||
conn.execute(
|
||||
"""
|
||||
@ -257,3 +292,5 @@ def provision_access_request(request_code: str) -> ProvisionResult:
|
||||
return ProvisionResult(ok=True, status="awaiting_onboarding")
|
||||
|
||||
return ProvisionResult(ok=False, status="accounts_building")
|
||||
finally:
|
||||
conn.execute("SELECT pg_advisory_unlock(%s)", (lock_id,))
|
||||
|
||||
@ -17,7 +17,7 @@ from ..db import connect, configured
|
||||
from ..keycloak import admin_client, require_auth
|
||||
from ..mailer import MailerError, access_request_verification_body, send_text_email
|
||||
from ..rate_limit import rate_limit_allow
|
||||
from ..provisioning import provision_tasks_complete
|
||||
from ..provisioning import provision_access_request, provision_tasks_complete
|
||||
from .. import settings
|
||||
|
||||
|
||||
@ -443,6 +443,19 @@ def register(app) -> None:
|
||||
).fetchone()
|
||||
if not row:
|
||||
return jsonify({"error": "not found"}), 404
|
||||
current_status = _normalize_status(row.get("status") or "")
|
||||
if current_status == "accounts_building":
|
||||
try:
|
||||
provision_access_request(code)
|
||||
except Exception:
|
||||
pass
|
||||
row = conn.execute(
|
||||
"SELECT status, username, initial_password, initial_password_revealed_at FROM access_requests WHERE request_code = %s",
|
||||
(code,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return jsonify({"error": "not found"}), 404
|
||||
|
||||
status = _advance_status(conn, code, row.get("username") or "", row.get("status") or "")
|
||||
response: dict[str, Any] = {
|
||||
"ok": True,
|
||||
|
||||
@ -72,6 +72,9 @@ ACCESS_REQUEST_SUBMIT_RATE_WINDOW_SEC = int(
|
||||
ACCESS_REQUEST_STATUS_RATE_LIMIT = int(os.getenv("ACCESS_REQUEST_STATUS_RATE_LIMIT", "60"))
|
||||
ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC = int(os.getenv("ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC", "60"))
|
||||
ACCESS_REQUEST_EMAIL_VERIFY_TTL_SEC = int(os.getenv("ACCESS_REQUEST_EMAIL_VERIFY_TTL_SEC", str(24 * 60 * 60)))
|
||||
ACCESS_REQUEST_PROVISION_RETRY_COOLDOWN_SEC = float(
|
||||
os.getenv("ACCESS_REQUEST_PROVISION_RETRY_COOLDOWN_SEC", "30")
|
||||
)
|
||||
|
||||
PORTAL_PUBLIC_BASE_URL = os.getenv("PORTAL_PUBLIC_BASE_URL", "https://bstein.dev").rstrip("/")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user