from __future__ import annotations from datetime import datetime, timezone import hashlib import hmac import re import secrets import string from typing import Any from urllib.parse import quote from flask import request from .. import ariadne_client from ..db import connect, configured from ..keycloak import admin_client, oidc_client from ..mailer import MailerError, access_request_verification_body, send_text_email from ..rate_limit import rate_limit_allow from ..provisioning import provision_access_request, provision_tasks_complete from .. import settings from .access_request_onboarding_policy import ( KEYCLOAK_MANAGED_STEPS, ONBOARDING_OPTIONAL_STEPS, ONBOARDING_REQUIRED_STEPS, ONBOARDING_STEP_PREREQUISITES, ONBOARDING_STEPS, VAULTWARDEN_GRANDFATHERED_FLAG, _KEYCLOAK_PASSWORD_ROTATION_REQUESTED_ARTIFACT, _VAULTWARDEN_READY_STATUSES, ) def _extract_request_payload() -> tuple[str, str, str, str, str]: payload = request.get_json(silent=True) or {} username = (payload.get("username") or "").strip() email = (payload.get("email") or "").strip() note = (payload.get("note") or "").strip() first_name = (payload.get("first_name") or "").strip() last_name = (payload.get("last_name") or "").strip() return username, email, note, first_name, last_name def _normalize_name(value: str) -> str: return " ".join(value.strip().split()) def _validate_name(value: str, *, label: str, required: bool) -> str | None: cleaned = _normalize_name(value) if not cleaned: return f"{label} is required" if required else None if len(cleaned) > 80: return f"{label} must be 1-80 characters" if any(ch in "\r\n\t" for ch in cleaned): return f"{label} contains invalid whitespace" return None def _validate_username(username: str) -> str | None: if not username: return "username is required" if len(username) < 3 or len(username) > 32: return "username must be 3-32 characters" if not re.fullmatch(r"[a-zA-Z0-9._-]+", username): return "username contains invalid characters" return None def _random_request_code(username: str) -> str: suffix = "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(10)) return f"{username}~{suffix}" def _client_ip() -> str: xff = (request.headers.get("X-Forwarded-For") or "").strip() if xff: return xff.split(",", 1)[0].strip() or "unknown" x_real_ip = (request.headers.get("X-Real-IP") or "").strip() if x_real_ip: return x_real_ip return request.remote_addr or "unknown" EMAIL_VERIFY_PENDING_STATUS = "pending_email_verification" def _hash_verification_token(token: str) -> str: return hashlib.sha256(token.encode("utf-8")).hexdigest() def _verify_url(request_code: str, token: str) -> str: base = settings.PORTAL_PUBLIC_BASE_URL.rstrip("/") return f"{base}/api/access/request/verify-link?code={quote(request_code)}&token={quote(token)}" def _send_verification_email(*, request_code: str, email: str, token: str) -> None: verify_url = _verify_url(request_code, token) send_text_email( to_addr=email, subject="Atlas: confirm your email", body=access_request_verification_body(request_code=request_code, verify_url=verify_url), ) class VerificationError(Exception): """Describe an email verification failure with an HTTP status.""" def __init__(self, status_code: int, message: str) -> None: super().__init__(message) self.status_code = status_code self.message = message def _verify_request(conn, code: str, token: str) -> str: """Validate email proof and atomically advance a pending request.""" row = conn.execute( """ SELECT status, email_verification_token_hash, email_verification_sent_at, email_verified_at FROM access_requests WHERE request_code = %s """, (code,), ).fetchone() if not row: raise VerificationError(404, "not found") status = _normalize_status(row.get("status") or "") if status != EMAIL_VERIFY_PENDING_STATUS: return status stored_hash = str(row.get("email_verification_token_hash") or "") if not stored_hash: raise VerificationError(409, "verification token missing") provided_hash = _hash_verification_token(token) if not hmac.compare_digest(stored_hash, provided_hash): raise VerificationError(401, "invalid token") sent_at = row.get("email_verification_sent_at") if isinstance(sent_at, datetime): now = datetime.now(timezone.utc) if sent_at.tzinfo is None: sent_at = sent_at.replace(tzinfo=timezone.utc) age_sec = (now - sent_at).total_seconds() if age_sec > settings.ACCESS_REQUEST_EMAIL_VERIFY_TTL_SEC: raise VerificationError(410, "verification token expired") conn.execute( """ UPDATE access_requests SET status = 'pending', email_verified_at = NOW(), email_verification_token_hash = NULL WHERE request_code = %s AND status = %s """, (code, EMAIL_VERIFY_PENDING_STATUS), ) return "pending" def _normalize_status(status: str) -> str: cleaned = (status or "").strip().lower() if cleaned == "approved": return "accounts_building" return cleaned or "unknown" def _fetch_completed_onboarding_steps(conn, request_code: str) -> set[str]: """Return manually attested onboarding steps for one request.""" rows = conn.execute( "SELECT step FROM access_request_onboarding_steps WHERE request_code = %s", (request_code,), ).fetchall() completed: set[str] = set() for row in rows: step = row.get("step") if isinstance(row, dict) else None if isinstance(step, str) and step: completed.add(step) return completed def _normalize_flag_list(raw: Any) -> set[str]: if isinstance(raw, list): return {item for item in raw if isinstance(item, str) and item} if isinstance(raw, str) and raw: return {raw} return set() def _fetch_request_flags_and_email(conn, request_code: str) -> tuple[set[str], str]: """Return approval flags and contact email used by onboarding decisions.""" row = conn.execute( "SELECT approval_flags, contact_email FROM access_requests WHERE request_code = %s", (request_code,), ).fetchone() if not row: return set(), "" flags = _normalize_flag_list(row.get("approval_flags")) email = row.get("contact_email") if isinstance(row, dict) else "" return flags, (email or "").strip() def _user_in_group(username: str, group_name: str) -> bool: """Return whether a Keycloak user belongs to a named group.""" if not username or not group_name: return False if not admin_client().ready(): return False try: user = admin_client().find_user(username) or {} user_id = user.get("id") if isinstance(user, dict) else None if not isinstance(user_id, str) or not user_id: return False groups = admin_client().list_user_groups(user_id) except Exception: return False return group_name in groups def _vaultwarden_grandfathered(conn, request_code: str, username: str) -> tuple[bool, str]: flags, contact_email = _fetch_request_flags_and_email(conn, request_code) if VAULTWARDEN_GRANDFATHERED_FLAG in flags: return True, contact_email if _user_in_group(username, VAULTWARDEN_GRANDFATHERED_FLAG): return True, contact_email return False, contact_email def _resolve_recovery_email(username: str, fallback: str) -> str: """Find the best recovery email for Vaultwarden onboarding.""" if username and admin_client().ready(): try: user = admin_client().find_user(username) or {} user_id = user.get("id") if isinstance(user, dict) else None if isinstance(user_id, str) and user_id: full = admin_client().get_user(user_id) email = full.get("email") if isinstance(email, str) and email.strip(): return email.strip() except Exception: pass return (fallback or "").strip() def _password_rotation_requested(conn, request_code: str) -> bool: """Return whether Keycloak password rotation was requested for this request.""" row = conn.execute( """ SELECT 1 FROM access_request_onboarding_artifacts WHERE request_code = %s AND artifact = %s LIMIT 1 """, (request_code, _KEYCLOAK_PASSWORD_ROTATION_REQUESTED_ARTIFACT), ).fetchone() return bool(row) def _request_keycloak_password_rotation(conn, request_code: str, username: str) -> None: """Require Keycloak password rotation and persist the request marker.""" if not username: raise ValueError("username missing") if not admin_client().ready(): raise RuntimeError("keycloak admin unavailable") user = admin_client().find_user(username) or {} user_id = user.get("id") if isinstance(user, dict) else None if not isinstance(user_id, str) or not user_id: raise RuntimeError("keycloak user not found") full = admin_client().get_user(user_id) actions = full.get("requiredActions") actions_list: list[str] = [] if isinstance(actions, list): actions_list = [a for a in actions if isinstance(a, str)] if "UPDATE_PASSWORD" not in actions_list: actions_list.append("UPDATE_PASSWORD") admin_client().update_user_safe(user_id, {"requiredActions": actions_list}) conn.execute( """ INSERT INTO access_request_onboarding_artifacts (request_code, artifact, value_hash) VALUES (%s, %s, NOW()::text) ON CONFLICT (request_code, artifact) DO NOTHING """, (request_code, _KEYCLOAK_PASSWORD_ROTATION_REQUESTED_ARTIFACT), ) def _extract_attr(attrs: Any, key: str) -> str: """Return the first string value for a Keycloak attribute.""" if not isinstance(attrs, dict): return "" raw = attrs.get(key) if isinstance(raw, list): for item in raw: if isinstance(item, str) and item.strip(): return item.strip() return "" if isinstance(raw, str) and raw.strip(): return raw.strip() return "" def _vaultwarden_status_for_user(username: str) -> str: """Read the Vaultwarden lifecycle status mirrored on a Keycloak user.""" if not username: return "" if not admin_client().ready(): return "" try: user = admin_client().find_user(username) or {} user_id = user.get("id") if isinstance(user, dict) else None if not isinstance(user_id, str) or not user_id: return "" full = admin_client().get_user(user_id) attrs = full.get("attributes") if isinstance(full, dict) else {} return _extract_attr(attrs, "vaultwarden_status") except Exception: return "" def _auto_completed_service_steps(attrs: Any) -> set[str]: """Infer onboarding steps completed by backend service automation.""" completed: set[str] = set() if not isinstance(attrs, dict): return completed vaultwarden_status = _extract_attr(attrs, "vaultwarden_status") vaultwarden_master = _extract_attr(attrs, "vaultwarden_master_password_set_at") if vaultwarden_master or vaultwarden_status in _VAULTWARDEN_READY_STATUSES: completed.add("vaultwarden_master_password") nextcloud_synced_at = _extract_attr(attrs, "nextcloud_mail_synced_at") if nextcloud_synced_at: completed.add("nextcloud_mail_integration") firefly_rotated_at = _extract_attr(attrs, "firefly_password_rotated_at") if firefly_rotated_at: completed.add("firefly_password_rotated") wger_rotated_at = _extract_attr(attrs, "wger_password_rotated_at") if wger_rotated_at: completed.add("wger_password_rotated") return completed def _auto_completed_keycloak_steps(conn, request_code: str, username: str) -> set[str]: """Infer onboarding steps from Keycloak profile state.""" if not username: return set() if not admin_client().ready(): return set() if not request_code: return set() completed: set[str] = set() try: user = admin_client().find_user(username) or {} user_id = user.get("id") if isinstance(user, dict) else None if not isinstance(user_id, str) or not user_id: return set() full = {} try: full = admin_client().get_user(user_id) except Exception: full = user if isinstance(user, dict) else {} attrs = full.get("attributes") if isinstance(full, dict) else {} completed |= _auto_completed_service_steps(attrs) actions = full.get("requiredActions") required_actions: set[str] = set() actions_list: list[str] = [] if isinstance(actions, list): actions_list = [a for a in actions if isinstance(a, str)] required_actions = set(actions_list) if _password_rotation_requested(conn, request_code) and "UPDATE_PASSWORD" not in required_actions: completed.add("keycloak_password_rotated") # Backfill: earlier accounts were created with CONFIGURE_TOTP as a required action, # which forces users to enroll MFA at first login. We no longer require that, so # remove it if present. if "CONFIGURE_TOTP" in required_actions: try: admin_client().update_user_safe( user_id, {"requiredActions": [a for a in actions_list if a != "CONFIGURE_TOTP"]}, ) except Exception: pass except Exception: return set() return completed def _completed_onboarding_steps(conn, request_code: str, username: str) -> set[str]: completed = _fetch_completed_onboarding_steps(conn, request_code) return completed | _auto_completed_keycloak_steps(conn, request_code, username) def _automation_ready(conn, request_code: str, username: str) -> bool: """Return whether account provisioning has finished enough for onboarding.""" if not username: return False if not admin_client().ready(): return False # Prefer task-based readiness when we have task rows for the request. task_row = conn.execute( "SELECT 1 FROM access_request_tasks WHERE request_code = %s LIMIT 1", (request_code,), ).fetchone() if task_row: return provision_tasks_complete(conn, request_code) # Fallback for legacy requests: confirm user exists and has a mail app password. try: user = admin_client().find_user(username) if not user: return False user_id = user.get("id") if isinstance(user, dict) else None if not user_id: return False full = admin_client().get_user(str(user_id)) attrs = full.get("attributes") or {} if not isinstance(attrs, dict): return False raw_pw = attrs.get("mailu_app_password") if isinstance(raw_pw, list): return bool(raw_pw and isinstance(raw_pw[0], str) and raw_pw[0]) return bool(isinstance(raw_pw, str) and raw_pw) except Exception: return False def _advance_status(conn, request_code: str, username: str, status: str) -> str: """Advance an access request through automatic status transitions.""" status = _normalize_status(status) if status == "accounts_building" and _automation_ready(conn, request_code, username): conn.execute( "UPDATE access_requests SET status = 'awaiting_onboarding' WHERE request_code = %s AND status = 'accounts_building'", (request_code,), ) return "awaiting_onboarding" if status == "awaiting_onboarding": completed = _completed_onboarding_steps(conn, request_code, username) required_steps = set(ONBOARDING_REQUIRED_STEPS) grandfathered, _ = _vaultwarden_grandfathered(conn, request_code, username) vaultwarden_status = _vaultwarden_status_for_user(username) if grandfathered and vaultwarden_status == "grandfathered": required_steps.add("vaultwarden_store_temp_password") if required_steps.issubset(completed): conn.execute( "UPDATE access_requests SET status = 'ready' WHERE request_code = %s AND status = 'awaiting_onboarding'", (request_code,), ) return "ready" return status def _onboarding_payload(conn, request_code: str, username: str) -> dict[str, Any]: """Build the onboarding progress payload returned to the frontend.""" completed_steps = sorted(_completed_onboarding_steps(conn, request_code, username)) password_rotation_requested = _password_rotation_requested(conn, request_code) grandfathered, contact_email = _vaultwarden_grandfathered(conn, request_code, username) recovery_email = _resolve_recovery_email(username, contact_email) if grandfathered else "" vaultwarden_status = _vaultwarden_status_for_user(username) vaultwarden_matched = grandfathered and vaultwarden_status == "grandfathered" required_steps = list(ONBOARDING_REQUIRED_STEPS) if vaultwarden_matched: required_steps.append("vaultwarden_store_temp_password") return { "required_steps": required_steps, "optional_steps": sorted(ONBOARDING_OPTIONAL_STEPS), "completed_steps": completed_steps, "keycloak": { "password_rotation_requested": password_rotation_requested, }, "vaultwarden": { "grandfathered": grandfathered, "recovery_email": recovery_email, "matched": vaultwarden_matched, }, } # Keep the historical access_requests module patch surface intact for tests and # callers while the route handlers live in smaller focused modules. __all__ = [name for name in globals() if not name.startswith("__")]