from __future__ import annotations from datetime import datetime, timezone import hashlib import hmac import re import secrets import string from typing import Any from urllib.parse import quote from flask import jsonify, request, g import psycopg from ..db import connect, configured from ..keycloak import admin_client, require_auth from ..mailer import MailerError, access_request_verification_body, send_text_email from ..rate_limit import rate_limit_allow from ..provisioning import provision_access_request, provision_tasks_complete from .. import settings def _extract_request_payload() -> tuple[str, str, str]: payload = request.get_json(silent=True) or {} username = (payload.get("username") or "").strip() email = (payload.get("email") or "").strip() note = (payload.get("note") or "").strip() return username, email, note def _random_request_code(username: str) -> str: suffix = "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(10)) return f"{username}~{suffix}" def _client_ip() -> str: xff = (request.headers.get("X-Forwarded-For") or "").strip() if xff: return xff.split(",", 1)[0].strip() or "unknown" x_real_ip = (request.headers.get("X-Real-IP") or "").strip() if x_real_ip: return x_real_ip return request.remote_addr or "unknown" EMAIL_VERIFY_PENDING_STATUS = "pending_email_verification" def _hash_verification_token(token: str) -> str: return hashlib.sha256(token.encode("utf-8")).hexdigest() def _verify_url(request_code: str, token: str) -> str: base = settings.PORTAL_PUBLIC_BASE_URL.rstrip("/") return f"{base}/request-access?code={quote(request_code)}&verify={quote(token)}" ONBOARDING_STEPS: tuple[str, ...] = ( "keycloak_password_changed", "keycloak_mfa_configured", "vaultwarden_master_password", "element_recovery_key", "element_recovery_key_stored", ) KEYCLOAK_MANAGED_STEPS: set[str] = {"keycloak_password_changed", "keycloak_mfa_configured"} KEYCLOAK_OTP_CRED_TYPES: set[str] = {"otp", "totp"} def _normalize_status(status: str) -> str: cleaned = (status or "").strip().lower() if cleaned == "approved": return "accounts_building" return cleaned or "unknown" def _fetch_completed_onboarding_steps(conn, request_code: str) -> set[str]: rows = conn.execute( "SELECT step FROM access_request_onboarding_steps WHERE request_code = %s", (request_code,), ).fetchall() completed: set[str] = set() for row in rows: step = row.get("step") if isinstance(row, dict) else None if isinstance(step, str) and step: completed.add(step) return completed def _auto_completed_keycloak_steps(username: str) -> set[str]: if not username: return set() if not admin_client().ready(): return set() completed: set[str] = set() try: user = admin_client().find_user(username) or {} user_id = user.get("id") if isinstance(user, dict) else None if not isinstance(user_id, str) or not user_id: return set() full = {} try: full = admin_client().get_user(user_id) except Exception: full = user if isinstance(user, dict) else {} actions = full.get("requiredActions") required_actions: set[str] = set() if isinstance(actions, list): required_actions = {a for a in actions if isinstance(a, str)} if "UPDATE_PASSWORD" not in required_actions: completed.add("keycloak_password_changed") otp_present = False try: creds = admin_client().get_user_credentials(user_id) for cred in creds: ctype = cred.get("type") if isinstance(cred, dict) else None if isinstance(ctype, str) and ctype.lower() in KEYCLOAK_OTP_CRED_TYPES: otp_present = True break except Exception: otp_present = False if otp_present or "CONFIGURE_TOTP" not in required_actions: completed.add("keycloak_mfa_configured") except Exception: return set() return completed def _completed_onboarding_steps(conn, request_code: str, username: str) -> set[str]: completed = _fetch_completed_onboarding_steps(conn, request_code) return completed | _auto_completed_keycloak_steps(username) def _automation_ready(conn, request_code: str, username: str) -> bool: if not username: return False if not admin_client().ready(): return False # Prefer task-based readiness when we have task rows for the request. task_row = conn.execute( "SELECT 1 FROM access_request_tasks WHERE request_code = %s LIMIT 1", (request_code,), ).fetchone() if task_row: return provision_tasks_complete(conn, request_code) # Fallback for legacy requests: confirm user exists and has a mail app password. try: user = admin_client().find_user(username) if not user: return False user_id = user.get("id") if isinstance(user, dict) else None if not user_id: return False full = admin_client().get_user(str(user_id)) attrs = full.get("attributes") or {} if not isinstance(attrs, dict): return False raw_pw = attrs.get("mailu_app_password") if isinstance(raw_pw, list): return bool(raw_pw and isinstance(raw_pw[0], str) and raw_pw[0]) return bool(isinstance(raw_pw, str) and raw_pw) except Exception: return False def _advance_status(conn, request_code: str, username: str, status: str) -> str: status = _normalize_status(status) if status == "accounts_building" and _automation_ready(conn, request_code, username): conn.execute( "UPDATE access_requests SET status = 'awaiting_onboarding' WHERE request_code = %s AND status = 'accounts_building'", (request_code,), ) return "awaiting_onboarding" if status == "awaiting_onboarding": completed = _completed_onboarding_steps(conn, request_code, username) if set(ONBOARDING_STEPS).issubset(completed): conn.execute( "UPDATE access_requests SET status = 'ready' WHERE request_code = %s AND status = 'awaiting_onboarding'", (request_code,), ) return "ready" return status def register(app) -> None: @app.route("/api/access/request", methods=["POST"]) def request_access() -> Any: if not settings.ACCESS_REQUEST_ENABLED: return jsonify({"error": "request access disabled"}), 503 if not configured(): return jsonify({"error": "server not configured"}), 503 ip = _client_ip() username, email, note = _extract_request_payload() rate_key = ip if username: rate_key = f"{ip}:{username}" if not rate_limit_allow( rate_key, key="access_request_submit", limit=settings.ACCESS_REQUEST_SUBMIT_RATE_LIMIT, window_sec=settings.ACCESS_REQUEST_SUBMIT_RATE_WINDOW_SEC, ): return jsonify({"error": "rate limited"}), 429 if not username: return jsonify({"error": "username is required"}), 400 if len(username) < 3 or len(username) > 32: return jsonify({"error": "username must be 3-32 characters"}), 400 if not re.fullmatch(r"[a-zA-Z0-9._-]+", username): return jsonify({"error": "username contains invalid characters"}), 400 if not email: return jsonify({"error": "email is required"}), 400 if "@" not in email: return jsonify({"error": "invalid email"}), 400 email_lower = email.lower() if email_lower.endswith(f"@{settings.MAILU_DOMAIN.lower()}") and ( email_lower not in settings.ACCESS_REQUEST_INTERNAL_EMAIL_ALLOWLIST ): return jsonify({"error": "email must be an external address"}), 400 if admin_client().ready() and admin_client().find_user(username): return jsonify({"error": "username already exists"}), 409 try: with connect() as conn: existing = conn.execute( """ SELECT request_code, status FROM access_requests WHERE username = %s AND status IN (%s, 'pending') ORDER BY created_at DESC LIMIT 1 """, (username, EMAIL_VERIFY_PENDING_STATUS), ).fetchone() if existing: existing_status = str(existing.get("status") or "") request_code = str(existing.get("request_code") or "") if existing_status != EMAIL_VERIFY_PENDING_STATUS: return jsonify({"ok": True, "request_code": request_code, "status": existing_status}) token = secrets.token_urlsafe(24) token_hash = _hash_verification_token(token) conn.execute( """ UPDATE access_requests SET contact_email = %s, note = %s, email_verification_token_hash = %s, email_verification_sent_at = NOW(), email_verified_at = NULL WHERE request_code = %s AND status = %s """, (email, note or None, token_hash, request_code, EMAIL_VERIFY_PENDING_STATUS), ) verify_url = _verify_url(request_code, token) try: send_text_email( to_addr=email, subject="Atlas: confirm your email", body=access_request_verification_body(request_code=request_code, verify_url=verify_url), ) except MailerError: return ( jsonify({"error": "failed to send verification email", "request_code": request_code}), 502, ) return jsonify({"ok": True, "request_code": request_code, "status": EMAIL_VERIFY_PENDING_STATUS}) request_code = _random_request_code(username) token = secrets.token_urlsafe(24) token_hash = _hash_verification_token(token) try: conn.execute( """ INSERT INTO access_requests (request_code, username, contact_email, note, status, email_verification_token_hash, email_verification_sent_at) VALUES (%s, %s, %s, %s, %s, %s, NOW()) """, (request_code, username, email, note or None, EMAIL_VERIFY_PENDING_STATUS, token_hash), ) except psycopg.errors.UniqueViolation: conn.rollback() existing = conn.execute( """ SELECT request_code, status FROM access_requests WHERE username = %s AND status IN (%s, 'pending') ORDER BY created_at DESC LIMIT 1 """, (username, EMAIL_VERIFY_PENDING_STATUS), ).fetchone() if not existing: raise return jsonify({"ok": True, "request_code": existing["request_code"], "status": existing["status"]}) verify_url = _verify_url(request_code, token) try: send_text_email( to_addr=email, subject="Atlas: confirm your email", body=access_request_verification_body(request_code=request_code, verify_url=verify_url), ) except MailerError: return jsonify({"error": "failed to send verification email", "request_code": request_code}), 502 except Exception: return jsonify({"error": "failed to submit request"}), 502 return jsonify({"ok": True, "request_code": request_code, "status": EMAIL_VERIFY_PENDING_STATUS}) @app.route("/api/access/request/verify", methods=["POST"]) def request_access_verify() -> Any: if not settings.ACCESS_REQUEST_ENABLED: return jsonify({"error": "request access disabled"}), 503 if not configured(): return jsonify({"error": "server not configured"}), 503 ip = _client_ip() if not rate_limit_allow( ip, key="access_request_verify", limit=60, window_sec=60, ): return jsonify({"error": "rate limited"}), 429 payload = request.get_json(silent=True) or {} code = (payload.get("request_code") or payload.get("code") or "").strip() token = (payload.get("token") or payload.get("verify") or "").strip() if not code: return jsonify({"error": "request_code is required"}), 400 if not token: return jsonify({"error": "token is required"}), 400 if not rate_limit_allow( f"{ip}:{code}", key="access_request_verify_code", limit=30, window_sec=60, ): return jsonify({"error": "rate limited"}), 429 try: with connect() as conn: row = conn.execute( """ SELECT status, email_verification_token_hash, email_verification_sent_at, email_verified_at FROM access_requests WHERE request_code = %s """, (code,), ).fetchone() if not row: return jsonify({"error": "not found"}), 404 status = _normalize_status(row.get("status") or "") if status != EMAIL_VERIFY_PENDING_STATUS: return jsonify({"ok": True, "status": status}) stored_hash = str(row.get("email_verification_token_hash") or "") if not stored_hash: return jsonify({"error": "verification token missing"}), 409 provided_hash = _hash_verification_token(token) if not hmac.compare_digest(stored_hash, provided_hash): return jsonify({"error": "invalid token"}), 401 sent_at = row.get("email_verification_sent_at") if isinstance(sent_at, datetime): now = datetime.now(timezone.utc) if sent_at.tzinfo is None: sent_at = sent_at.replace(tzinfo=timezone.utc) age_sec = (now - sent_at).total_seconds() if age_sec > settings.ACCESS_REQUEST_EMAIL_VERIFY_TTL_SEC: return jsonify({"error": "verification token expired"}), 410 conn.execute( """ UPDATE access_requests SET status = 'pending', email_verified_at = NOW(), email_verification_token_hash = NULL WHERE request_code = %s AND status = %s """, (code, EMAIL_VERIFY_PENDING_STATUS), ) return jsonify({"ok": True, "status": "pending"}) except Exception: return jsonify({"error": "failed to verify"}), 502 @app.route("/api/access/request/status", methods=["POST"]) def request_access_status() -> Any: if not settings.ACCESS_REQUEST_ENABLED: return jsonify({"error": "request access disabled"}), 503 if not configured(): return jsonify({"error": "server not configured"}), 503 ip = _client_ip() if not rate_limit_allow( ip, key="access_request_status", limit=settings.ACCESS_REQUEST_STATUS_RATE_LIMIT, window_sec=settings.ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC, ): return jsonify({"error": "rate limited"}), 429 payload = request.get_json(silent=True) or {} code = (payload.get("request_code") or payload.get("code") or "").strip() if not code: return jsonify({"error": "request_code is required"}), 400 # Additional per-code limiter to avoid global NAT rate-limit blowups. if not rate_limit_allow( f"{ip}:{code}", key="access_request_status_code", limit=max(20, settings.ACCESS_REQUEST_STATUS_RATE_LIMIT), window_sec=settings.ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC, ): return jsonify({"error": "rate limited"}), 429 try: with connect() as conn: row = conn.execute( "SELECT status, username, initial_password, initial_password_revealed_at FROM access_requests WHERE request_code = %s", (code,), ).fetchone() if not row: return jsonify({"error": "not found"}), 404 current_status = _normalize_status(row.get("status") or "") if current_status == "accounts_building": try: provision_access_request(code) except Exception: pass row = conn.execute( "SELECT status, username, initial_password, initial_password_revealed_at FROM access_requests WHERE request_code = %s", (code,), ).fetchone() if not row: return jsonify({"error": "not found"}), 404 status = _advance_status(conn, code, row.get("username") or "", row.get("status") or "") response: dict[str, Any] = { "ok": True, "status": status, "username": row.get("username") or "", } task_rows = conn.execute( """ SELECT task, status, detail, updated_at FROM access_request_tasks WHERE request_code = %s ORDER BY task """, (code,), ).fetchall() if task_rows: tasks: list[dict[str, Any]] = [] blocked = False for task_row in task_rows: task_name = task_row.get("task") if isinstance(task_row, dict) else None task_status = task_row.get("status") if isinstance(task_row, dict) else None detail = task_row.get("detail") if isinstance(task_row, dict) else None updated_at = task_row.get("updated_at") if isinstance(task_row, dict) else None if isinstance(task_status, str) and task_status == "error": blocked = True task_payload: dict[str, Any] = { "task": task_name or "", "status": task_status or "", } if isinstance(detail, str) and detail: task_payload["detail"] = detail if isinstance(updated_at, datetime): task_payload["updated_at"] = updated_at.astimezone(timezone.utc).isoformat() tasks.append(task_payload) response["tasks"] = tasks response["automation_complete"] = provision_tasks_complete(conn, code) response["blocked"] = blocked if status in {"awaiting_onboarding", "ready"}: password = row.get("initial_password") revealed_at = row.get("initial_password_revealed_at") if isinstance(password, str) and password and revealed_at is None: response["initial_password"] = password conn.execute( "UPDATE access_requests SET initial_password_revealed_at = NOW() WHERE request_code = %s AND initial_password_revealed_at IS NULL", (code,), ) if status in {"awaiting_onboarding", "ready"}: response["onboarding_url"] = f"/onboarding?code={code}" if status in {"awaiting_onboarding", "ready"}: completed = sorted(_completed_onboarding_steps(conn, code, row.get("username") or "")) response["onboarding"] = { "required_steps": list(ONBOARDING_STEPS), "completed_steps": completed, } return jsonify(response) except Exception: return jsonify({"error": "failed to load status"}), 502 @app.route("/api/access/request/onboarding/attest", methods=["POST"]) @require_auth def request_access_onboarding_attest() -> Any: if not configured(): return jsonify({"error": "server not configured"}), 503 payload = request.get_json(silent=True) or {} code = (payload.get("request_code") or payload.get("code") or "").strip() step = (payload.get("step") or "").strip() completed = payload.get("completed") if not code: return jsonify({"error": "request_code is required"}), 400 if step not in ONBOARDING_STEPS: return jsonify({"error": "invalid step"}), 400 if step in KEYCLOAK_MANAGED_STEPS: return jsonify({"error": "step is managed by keycloak"}), 400 username = getattr(g, "keycloak_username", "") or "" if not username: return jsonify({"error": "invalid token"}), 401 try: with connect() as conn: row = conn.execute( "SELECT username, status FROM access_requests WHERE request_code = %s", (code,), ).fetchone() if not row: return jsonify({"error": "not found"}), 404 if (row.get("username") or "") != username: return jsonify({"error": "forbidden"}), 403 status = _normalize_status(row.get("status") or "") if status not in {"awaiting_onboarding", "ready"}: return jsonify({"error": "onboarding not available"}), 409 mark_done = True if isinstance(completed, bool): mark_done = completed if mark_done: conn.execute( """ INSERT INTO access_request_onboarding_steps (request_code, step) VALUES (%s, %s) ON CONFLICT (request_code, step) DO NOTHING """, (code, step), ) else: conn.execute( "DELETE FROM access_request_onboarding_steps WHERE request_code = %s AND step = %s", (code, step), ) # Re-evaluate completion to update request status to ready if applicable. status = _advance_status(conn, code, username, status) completed_steps = sorted(_completed_onboarding_steps(conn, code, username)) except Exception: return jsonify({"error": "failed to update onboarding"}), 502 return jsonify( { "ok": True, "status": status, "onboarding": {"required_steps": list(ONBOARDING_STEPS), "completed_steps": completed_steps}, } )