portal: make provisioning retries safe

This commit is contained in:
Brad Stein 2026-01-03 04:08:13 -03:00
parent 71678a4819
commit b41a2a2b1d
4 changed files with 197 additions and 142 deletions

View File

@ -40,12 +40,14 @@ def ensure_schema() -> None:
decided_at TIMESTAMPTZ, decided_at TIMESTAMPTZ,
decided_by TEXT, decided_by TEXT,
initial_password TEXT, initial_password TEXT,
initial_password_revealed_at TIMESTAMPTZ initial_password_revealed_at TIMESTAMPTZ,
provision_attempted_at TIMESTAMPTZ
) )
""" """
) )
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS initial_password TEXT") conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS initial_password TEXT")
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS initial_password_revealed_at TIMESTAMPTZ") conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS initial_password_revealed_at TIMESTAMPTZ")
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS provision_attempted_at TIMESTAMPTZ")
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS email_verification_token_hash TEXT") conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS email_verification_token_hash TEXT")
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS email_verification_sent_at TIMESTAMPTZ") conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS email_verification_sent_at TIMESTAMPTZ")
conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS email_verified_at TIMESTAMPTZ") conn.execute("ALTER TABLE access_requests ADD COLUMN IF NOT EXISTS email_verified_at TIMESTAMPTZ")

View File

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone
import hashlib
import time import time
import httpx import httpx
@ -30,6 +32,11 @@ class ProvisionResult:
status: str status: str
def _advisory_lock_id(request_code: str) -> int:
digest = hashlib.sha256(request_code.encode("utf-8")).digest()
return int.from_bytes(digest[:8], "big", signed=True)
def _upsert_task(conn, request_code: str, task: str, status: str, detail: str | None = None) -> None: def _upsert_task(conn, request_code: str, task: str, status: str, detail: str | None = None) -> None:
conn.execute( conn.execute(
""" """
@ -77,6 +84,15 @@ def provision_access_request(request_code: str) -> ProvisionResult:
required_tasks = list(REQUIRED_PROVISION_TASKS) required_tasks = list(REQUIRED_PROVISION_TASKS)
with connect() as conn: with connect() as conn:
lock_id = _advisory_lock_id(request_code)
lock_row = conn.execute(
"SELECT pg_try_advisory_lock(%s) AS locked",
(lock_id,),
).fetchone()
if not lock_row or not lock_row.get("locked"):
return ProvisionResult(ok=False, status="accounts_building")
try:
row = conn.execute( row = conn.execute(
""" """
SELECT username, SELECT username,
@ -84,7 +100,8 @@ def provision_access_request(request_code: str) -> ProvisionResult:
email_verified_at, email_verified_at,
status, status,
initial_password, initial_password,
initial_password_revealed_at initial_password_revealed_at,
provision_attempted_at
FROM access_requests FROM access_requests
WHERE request_code = %s WHERE request_code = %s
""", """,
@ -99,10 +116,24 @@ def provision_access_request(request_code: str) -> ProvisionResult:
status = str(row.get("status") or "") status = str(row.get("status") or "")
initial_password = row.get("initial_password") initial_password = row.get("initial_password")
revealed_at = row.get("initial_password_revealed_at") revealed_at = row.get("initial_password_revealed_at")
attempted_at = row.get("provision_attempted_at")
if status not in {"accounts_building", "awaiting_onboarding", "ready"}: if status not in {"accounts_building", "awaiting_onboarding", "ready"}:
return ProvisionResult(ok=False, status=status or "unknown") return ProvisionResult(ok=False, status=status or "unknown")
if status == "accounts_building":
now = datetime.now(timezone.utc)
if isinstance(attempted_at, datetime):
if attempted_at.tzinfo is None:
attempted_at = attempted_at.replace(tzinfo=timezone.utc)
age_sec = (now - attempted_at).total_seconds()
if age_sec < settings.ACCESS_REQUEST_PROVISION_RETRY_COOLDOWN_SEC:
return ProvisionResult(ok=False, status="accounts_building")
conn.execute(
"UPDATE access_requests SET provision_attempted_at = NOW() WHERE request_code = %s",
(request_code,),
)
user_id = "" user_id = ""
mailu_email = f"{username}@{settings.MAILU_DOMAIN}" mailu_email = f"{username}@{settings.MAILU_DOMAIN}"
@ -126,6 +157,7 @@ def provision_access_request(request_code: str) -> ProvisionResult:
user_id = str((user or {}).get("id") or "") user_id = str((user or {}).get("id") or "")
if not user_id: if not user_id:
raise RuntimeError("user id missing") raise RuntimeError("user id missing")
try: try:
full = admin_client().get_user(user_id) full = admin_client().get_user(user_id)
attrs = full.get("attributes") or {} attrs = full.get("attributes") or {}
@ -146,43 +178,51 @@ def provision_access_request(request_code: str) -> ProvisionResult:
mailu_email = f"{username}@{settings.MAILU_DOMAIN}" mailu_email = f"{username}@{settings.MAILU_DOMAIN}"
admin_client().set_user_attribute(username, MAILU_EMAIL_ATTR, mailu_email) admin_client().set_user_attribute(username, MAILU_EMAIL_ATTR, mailu_email)
except Exception: except Exception:
# Non-fatal: Mailu sync will fall back to username@domain.
mailu_email = f"{username}@{settings.MAILU_DOMAIN}" mailu_email = f"{username}@{settings.MAILU_DOMAIN}"
_upsert_task(conn, request_code, "keycloak_user", "ok", None) _upsert_task(conn, request_code, "keycloak_user", "ok", None)
except Exception: except Exception:
_upsert_task(conn, request_code, "keycloak_user", "error", "failed to ensure user") _upsert_task(conn, request_code, "keycloak_user", "error", "failed to ensure user")
# Task: set initial temporary password and store it for "show once" onboarding # Task: set initial temporary password and store it for "show once" onboarding.
try: try:
if user_id: if not user_id:
password_value = "" raise RuntimeError("missing user id")
should_reset = status == "accounts_building" and revealed_at is None
password_value: str | None = None
if should_reset:
if isinstance(initial_password, str) and initial_password: if isinstance(initial_password, str) and initial_password:
password_value = initial_password password_value = initial_password
elif initial_password is None and revealed_at is None: elif initial_password is None:
password_value = random_password(20) password_value = random_password(20)
conn.execute( conn.execute(
""" """
UPDATE access_requests UPDATE access_requests
SET initial_password = %s, initial_password_revealed_at = NULL SET initial_password = %s
WHERE request_code = %s AND initial_password IS NULL WHERE request_code = %s AND initial_password IS NULL
""", """,
(password_value, request_code), (password_value, request_code),
) )
initial_password = password_value initial_password = password_value
elif isinstance(initial_password, str) and initial_password and revealed_at is None:
password_value = initial_password
if password_value: if password_value:
admin_client().reset_password(user_id, password_value, temporary=True) admin_client().reset_password(user_id, password_value, temporary=True)
if isinstance(initial_password, str) and initial_password:
_upsert_task(conn, request_code, "keycloak_password", "ok", None) _upsert_task(conn, request_code, "keycloak_password", "ok", None)
elif revealed_at is not None:
_upsert_task(conn, request_code, "keycloak_password", "ok", "initial password already revealed")
else: else:
raise RuntimeError("missing user id") raise RuntimeError("initial password missing")
except Exception: except Exception:
_upsert_task(conn, request_code, "keycloak_password", "error", "failed to set password") _upsert_task(conn, request_code, "keycloak_password", "error", "failed to set password")
# Task: group membership (default dev) # Task: group membership (default dev)
try: try:
if user_id: if not user_id:
raise RuntimeError("missing user id")
groups = settings.DEFAULT_USER_GROUPS or ["dev"] groups = settings.DEFAULT_USER_GROUPS or ["dev"]
for group_name in groups: for group_name in groups:
gid = admin_client().get_group_id(group_name) gid = admin_client().get_group_id(group_name)
@ -190,14 +230,13 @@ def provision_access_request(request_code: str) -> ProvisionResult:
raise RuntimeError("group missing") raise RuntimeError("group missing")
admin_client().add_user_to_group(user_id, gid) admin_client().add_user_to_group(user_id, gid)
_upsert_task(conn, request_code, "keycloak_groups", "ok", None) _upsert_task(conn, request_code, "keycloak_groups", "ok", None)
else:
raise RuntimeError("missing user id")
except Exception: except Exception:
_upsert_task(conn, request_code, "keycloak_groups", "error", "failed to add groups") _upsert_task(conn, request_code, "keycloak_groups", "error", "failed to add groups")
# Task: ensure mailu_app_password attribute exists # Task: ensure mailu_app_password attribute exists
try: try:
if user_id: if not user_id:
raise RuntimeError("missing user id")
full = admin_client().get_user(user_id) full = admin_client().get_user(user_id)
attrs = full.get("attributes") or {} attrs = full.get("attributes") or {}
existing = None existing = None
@ -210,8 +249,6 @@ def provision_access_request(request_code: str) -> ProvisionResult:
if not existing: if not existing:
admin_client().set_user_attribute(username, MAILU_APP_PASSWORD_ATTR, random_password()) admin_client().set_user_attribute(username, MAILU_APP_PASSWORD_ATTR, random_password())
_upsert_task(conn, request_code, "mailu_app_password", "ok", None) _upsert_task(conn, request_code, "mailu_app_password", "ok", None)
else:
raise RuntimeError("missing user id")
except Exception: except Exception:
_upsert_task(conn, request_code, "mailu_app_password", "error", "failed to set mail password") _upsert_task(conn, request_code, "mailu_app_password", "error", "failed to set mail password")
@ -233,18 +270,16 @@ def provision_access_request(request_code: str) -> ProvisionResult:
# Task: ensure Vaultwarden account exists (invite flow) # Task: ensure Vaultwarden account exists (invite flow)
try: try:
if user_id: if not user_id:
raise RuntimeError("missing user id")
result = invite_user(mailu_email or f"{username}@{settings.MAILU_DOMAIN}") result = invite_user(mailu_email or f"{username}@{settings.MAILU_DOMAIN}")
if result.ok: if result.ok:
_upsert_task(conn, request_code, "vaultwarden_invite", "ok", result.status) _upsert_task(conn, request_code, "vaultwarden_invite", "ok", result.status)
else: else:
_upsert_task(conn, request_code, "vaultwarden_invite", "error", result.detail or result.status) _upsert_task(conn, request_code, "vaultwarden_invite", "error", result.detail or result.status)
else:
raise RuntimeError("missing user id")
except Exception: except Exception:
_upsert_task(conn, request_code, "vaultwarden_invite", "error", "failed to provision vaultwarden") _upsert_task(conn, request_code, "vaultwarden_invite", "error", "failed to provision vaultwarden")
# If everything is OK, advance to awaiting_onboarding.
if _all_tasks_ok(conn, request_code, required_tasks): if _all_tasks_ok(conn, request_code, required_tasks):
conn.execute( conn.execute(
""" """
@ -257,3 +292,5 @@ def provision_access_request(request_code: str) -> ProvisionResult:
return ProvisionResult(ok=True, status="awaiting_onboarding") return ProvisionResult(ok=True, status="awaiting_onboarding")
return ProvisionResult(ok=False, status="accounts_building") return ProvisionResult(ok=False, status="accounts_building")
finally:
conn.execute("SELECT pg_advisory_unlock(%s)", (lock_id,))

View File

@ -17,7 +17,7 @@ from ..db import connect, configured
from ..keycloak import admin_client, require_auth from ..keycloak import admin_client, require_auth
from ..mailer import MailerError, access_request_verification_body, send_text_email from ..mailer import MailerError, access_request_verification_body, send_text_email
from ..rate_limit import rate_limit_allow from ..rate_limit import rate_limit_allow
from ..provisioning import provision_tasks_complete from ..provisioning import provision_access_request, provision_tasks_complete
from .. import settings from .. import settings
@ -443,6 +443,19 @@ def register(app) -> None:
).fetchone() ).fetchone()
if not row: if not row:
return jsonify({"error": "not found"}), 404 return jsonify({"error": "not found"}), 404
current_status = _normalize_status(row.get("status") or "")
if current_status == "accounts_building":
try:
provision_access_request(code)
except Exception:
pass
row = conn.execute(
"SELECT status, username, initial_password, initial_password_revealed_at FROM access_requests WHERE request_code = %s",
(code,),
).fetchone()
if not row:
return jsonify({"error": "not found"}), 404
status = _advance_status(conn, code, row.get("username") or "", row.get("status") or "") status = _advance_status(conn, code, row.get("username") or "", row.get("status") or "")
response: dict[str, Any] = { response: dict[str, Any] = {
"ok": True, "ok": True,

View File

@ -72,6 +72,9 @@ ACCESS_REQUEST_SUBMIT_RATE_WINDOW_SEC = int(
ACCESS_REQUEST_STATUS_RATE_LIMIT = int(os.getenv("ACCESS_REQUEST_STATUS_RATE_LIMIT", "60")) ACCESS_REQUEST_STATUS_RATE_LIMIT = int(os.getenv("ACCESS_REQUEST_STATUS_RATE_LIMIT", "60"))
ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC = int(os.getenv("ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC", "60")) ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC = int(os.getenv("ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC", "60"))
ACCESS_REQUEST_EMAIL_VERIFY_TTL_SEC = int(os.getenv("ACCESS_REQUEST_EMAIL_VERIFY_TTL_SEC", str(24 * 60 * 60))) ACCESS_REQUEST_EMAIL_VERIFY_TTL_SEC = int(os.getenv("ACCESS_REQUEST_EMAIL_VERIFY_TTL_SEC", str(24 * 60 * 60)))
ACCESS_REQUEST_PROVISION_RETRY_COOLDOWN_SEC = float(
os.getenv("ACCESS_REQUEST_PROVISION_RETRY_COOLDOWN_SEC", "30")
)
PORTAL_PUBLIC_BASE_URL = os.getenv("PORTAL_PUBLIC_BASE_URL", "https://bstein.dev").rstrip("/") PORTAL_PUBLIC_BASE_URL = os.getenv("PORTAL_PUBLIC_BASE_URL", "https://bstein.dev").rstrip("/")