from __future__ import annotations import re import secrets import string from typing import Any from flask import jsonify, request, g import psycopg from ..db import connect, configured from ..keycloak import admin_client, require_auth from ..rate_limit import rate_limit_allow from ..provisioning import provision_tasks_complete from .. import settings def _extract_request_payload() -> tuple[str, str, str]: payload = request.get_json(silent=True) or {} username = (payload.get("username") or "").strip() email = (payload.get("email") or "").strip() note = (payload.get("note") or "").strip() return username, email, note def _random_request_code(username: str) -> str: suffix = "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(10)) return f"{username}~{suffix}" def _client_ip() -> str: xff = (request.headers.get("X-Forwarded-For") or "").strip() if xff: return xff.split(",", 1)[0].strip() or "unknown" x_real_ip = (request.headers.get("X-Real-IP") or "").strip() if x_real_ip: return x_real_ip return request.remote_addr or "unknown" ONBOARDING_STEPS: tuple[str, ...] = ( "keycloak_password_changed", "keycloak_mfa_configured", "vaultwarden_master_password", "element_recovery_key", "element_recovery_key_stored", ) KEYCLOAK_MANAGED_STEPS: set[str] = {"keycloak_password_changed", "keycloak_mfa_configured"} KEYCLOAK_OTP_CRED_TYPES: set[str] = {"otp", "totp"} def _normalize_status(status: str) -> str: cleaned = (status or "").strip().lower() if cleaned == "approved": return "accounts_building" return cleaned or "unknown" def _fetch_completed_onboarding_steps(conn, request_code: str) -> set[str]: rows = conn.execute( "SELECT step FROM access_request_onboarding_steps WHERE request_code = %s", (request_code,), ).fetchall() completed: set[str] = set() for row in rows: step = row.get("step") if isinstance(row, dict) else None if isinstance(step, str) and step: completed.add(step) return completed def _auto_completed_keycloak_steps(username: str) -> set[str]: if not username: return set() if not admin_client().ready(): return set() completed: set[str] = set() try: user = admin_client().find_user(username) or {} user_id = user.get("id") if isinstance(user, dict) else None if not isinstance(user_id, str) or not user_id: return set() full = {} try: full = admin_client().get_user(user_id) except Exception: full = user if isinstance(user, dict) else {} actions = full.get("requiredActions") required_actions: set[str] = set() if isinstance(actions, list): required_actions = {a for a in actions if isinstance(a, str)} if "UPDATE_PASSWORD" not in required_actions: completed.add("keycloak_password_changed") otp_present = False try: creds = admin_client().get_user_credentials(user_id) for cred in creds: ctype = cred.get("type") if isinstance(cred, dict) else None if isinstance(ctype, str) and ctype.lower() in KEYCLOAK_OTP_CRED_TYPES: otp_present = True break except Exception: otp_present = False if otp_present or "CONFIGURE_TOTP" not in required_actions: completed.add("keycloak_mfa_configured") except Exception: return set() return completed def _completed_onboarding_steps(conn, request_code: str, username: str) -> set[str]: completed = _fetch_completed_onboarding_steps(conn, request_code) return completed | _auto_completed_keycloak_steps(username) def _automation_ready(conn, request_code: str, username: str) -> bool: if not username: return False if not admin_client().ready(): return False # Prefer task-based readiness when we have task rows for the request. task_row = conn.execute( "SELECT 1 FROM access_request_tasks WHERE request_code = %s LIMIT 1", (request_code,), ).fetchone() if task_row: return provision_tasks_complete(conn, request_code) # Fallback for legacy requests: confirm user exists and has a mail app password. try: user = admin_client().find_user(username) if not user: return False user_id = user.get("id") if isinstance(user, dict) else None if not user_id: return False full = admin_client().get_user(str(user_id)) attrs = full.get("attributes") or {} if not isinstance(attrs, dict): return False raw_pw = attrs.get("mailu_app_password") if isinstance(raw_pw, list): return bool(raw_pw and isinstance(raw_pw[0], str) and raw_pw[0]) return bool(isinstance(raw_pw, str) and raw_pw) except Exception: return False def _advance_status(conn, request_code: str, username: str, status: str) -> str: status = _normalize_status(status) if status == "accounts_building" and _automation_ready(conn, request_code, username): conn.execute( "UPDATE access_requests SET status = 'awaiting_onboarding' WHERE request_code = %s AND status = 'accounts_building'", (request_code,), ) return "awaiting_onboarding" if status == "awaiting_onboarding": completed = _completed_onboarding_steps(conn, request_code, username) if set(ONBOARDING_STEPS).issubset(completed): conn.execute( "UPDATE access_requests SET status = 'ready' WHERE request_code = %s AND status = 'awaiting_onboarding'", (request_code,), ) return "ready" return status def register(app) -> None: @app.route("/api/access/request", methods=["POST"]) def request_access() -> Any: if not settings.ACCESS_REQUEST_ENABLED: return jsonify({"error": "request access disabled"}), 503 if not configured(): return jsonify({"error": "server not configured"}), 503 ip = _client_ip() username, email, note = _extract_request_payload() rate_key = ip if username: rate_key = f"{ip}:{username}" if not rate_limit_allow( rate_key, key="access_request_submit", limit=settings.ACCESS_REQUEST_SUBMIT_RATE_LIMIT, window_sec=settings.ACCESS_REQUEST_SUBMIT_RATE_WINDOW_SEC, ): return jsonify({"error": "rate limited"}), 429 if not username: return jsonify({"error": "username is required"}), 400 if len(username) < 3 or len(username) > 32: return jsonify({"error": "username must be 3-32 characters"}), 400 if not re.fullmatch(r"[a-zA-Z0-9._-]+", username): return jsonify({"error": "username contains invalid characters"}), 400 if email and "@" not in email: return jsonify({"error": "invalid email"}), 400 if admin_client().ready() and admin_client().find_user(username): return jsonify({"error": "username already exists"}), 409 try: with connect() as conn: existing = conn.execute( """ SELECT request_code, status FROM access_requests WHERE username = %s AND status = 'pending' ORDER BY created_at DESC LIMIT 1 """, (username,), ).fetchone() if existing: return jsonify({"ok": True, "request_code": existing["request_code"], "status": existing["status"]}) request_code = _random_request_code(username) try: conn.execute( """ INSERT INTO access_requests (request_code, username, contact_email, note, status) VALUES (%s, %s, %s, %s, 'pending') """, (request_code, username, email or None, note or None), ) except psycopg.errors.UniqueViolation: conn.rollback() existing = conn.execute( """ SELECT request_code, status FROM access_requests WHERE username = %s AND status = 'pending' ORDER BY created_at DESC LIMIT 1 """, (username,), ).fetchone() if not existing: raise return jsonify({"ok": True, "request_code": existing["request_code"], "status": existing["status"]}) except Exception: return jsonify({"error": "failed to submit request"}), 502 return jsonify({"ok": True, "request_code": request_code}) @app.route("/api/access/request/status", methods=["POST"]) def request_access_status() -> Any: if not settings.ACCESS_REQUEST_ENABLED: return jsonify({"error": "request access disabled"}), 503 if not configured(): return jsonify({"error": "server not configured"}), 503 ip = _client_ip() if not rate_limit_allow( ip, key="access_request_status", limit=settings.ACCESS_REQUEST_STATUS_RATE_LIMIT, window_sec=settings.ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC, ): return jsonify({"error": "rate limited"}), 429 payload = request.get_json(silent=True) or {} code = (payload.get("request_code") or payload.get("code") or "").strip() if not code: return jsonify({"error": "request_code is required"}), 400 # Additional per-code limiter to avoid global NAT rate-limit blowups. if not rate_limit_allow( f"{ip}:{code}", key="access_request_status_code", limit=max(20, settings.ACCESS_REQUEST_STATUS_RATE_LIMIT), window_sec=settings.ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC, ): return jsonify({"error": "rate limited"}), 429 try: with connect() as conn: row = conn.execute( "SELECT status, username, initial_password, initial_password_revealed_at FROM access_requests WHERE request_code = %s", (code,), ).fetchone() if not row: return jsonify({"error": "not found"}), 404 status = _advance_status(conn, code, row.get("username") or "", row.get("status") or "") response: dict[str, Any] = { "ok": True, "status": status, "username": row.get("username") or "", } if status in {"awaiting_onboarding", "ready"}: password = row.get("initial_password") revealed_at = row.get("initial_password_revealed_at") if isinstance(password, str) and password and revealed_at is None: response["initial_password"] = password conn.execute( "UPDATE access_requests SET initial_password_revealed_at = NOW() WHERE request_code = %s AND initial_password_revealed_at IS NULL", (code,), ) if status in {"awaiting_onboarding", "ready"}: response["onboarding_url"] = f"/onboarding?code={code}" if status in {"awaiting_onboarding", "ready"}: completed = sorted(_completed_onboarding_steps(conn, code, row.get("username") or "")) response["onboarding"] = { "required_steps": list(ONBOARDING_STEPS), "completed_steps": completed, } return jsonify(response) except Exception: return jsonify({"error": "failed to load status"}), 502 @app.route("/api/access/request/onboarding/attest", methods=["POST"]) @require_auth def request_access_onboarding_attest() -> Any: if not configured(): return jsonify({"error": "server not configured"}), 503 payload = request.get_json(silent=True) or {} code = (payload.get("request_code") or payload.get("code") or "").strip() step = (payload.get("step") or "").strip() completed = payload.get("completed") if not code: return jsonify({"error": "request_code is required"}), 400 if step not in ONBOARDING_STEPS: return jsonify({"error": "invalid step"}), 400 if step in KEYCLOAK_MANAGED_STEPS: return jsonify({"error": "step is managed by keycloak"}), 400 username = getattr(g, "keycloak_username", "") or "" if not username: return jsonify({"error": "invalid token"}), 401 try: with connect() as conn: row = conn.execute( "SELECT username, status FROM access_requests WHERE request_code = %s", (code,), ).fetchone() if not row: return jsonify({"error": "not found"}), 404 if (row.get("username") or "") != username: return jsonify({"error": "forbidden"}), 403 status = _normalize_status(row.get("status") or "") if status not in {"awaiting_onboarding", "ready"}: return jsonify({"error": "onboarding not available"}), 409 mark_done = True if isinstance(completed, bool): mark_done = completed if mark_done: conn.execute( """ INSERT INTO access_request_onboarding_steps (request_code, step) VALUES (%s, %s) ON CONFLICT (request_code, step) DO NOTHING """, (code, step), ) else: conn.execute( "DELETE FROM access_request_onboarding_steps WHERE request_code = %s AND step = %s", (code, step), ) # Re-evaluate completion to update request status to ready if applicable. status = _advance_status(conn, code, username, status) completed_steps = sorted(_completed_onboarding_steps(conn, code, username)) except Exception: return jsonify({"error": "failed to update onboarding"}), 502 return jsonify( { "ok": True, "status": status, "onboarding": {"required_steps": list(ONBOARDING_STEPS), "completed_steps": completed_steps}, } )