396 lines
15 KiB
Python

from __future__ import annotations
import re
import secrets
import string
from typing import Any
from flask import jsonify, request, g
import psycopg
from ..db import connect, configured
from ..keycloak import admin_client, require_auth
from ..rate_limit import rate_limit_allow
from ..provisioning import provision_tasks_complete
from .. import settings
def _extract_request_payload() -> tuple[str, str, str]:
payload = request.get_json(silent=True) or {}
username = (payload.get("username") or "").strip()
email = (payload.get("email") or "").strip()
note = (payload.get("note") or "").strip()
return username, email, note
def _random_request_code(username: str) -> str:
suffix = "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(10))
return f"{username}~{suffix}"
def _client_ip() -> str:
xff = (request.headers.get("X-Forwarded-For") or "").strip()
if xff:
return xff.split(",", 1)[0].strip() or "unknown"
x_real_ip = (request.headers.get("X-Real-IP") or "").strip()
if x_real_ip:
return x_real_ip
return request.remote_addr or "unknown"
ONBOARDING_STEPS: tuple[str, ...] = (
"keycloak_password_changed",
"keycloak_mfa_configured",
"vaultwarden_master_password",
"element_recovery_key",
"element_recovery_key_stored",
)
KEYCLOAK_MANAGED_STEPS: set[str] = {"keycloak_password_changed", "keycloak_mfa_configured"}
KEYCLOAK_OTP_CRED_TYPES: set[str] = {"otp", "totp"}
def _normalize_status(status: str) -> str:
cleaned = (status or "").strip().lower()
if cleaned == "approved":
return "accounts_building"
return cleaned or "unknown"
def _fetch_completed_onboarding_steps(conn, request_code: str) -> set[str]:
rows = conn.execute(
"SELECT step FROM access_request_onboarding_steps WHERE request_code = %s",
(request_code,),
).fetchall()
completed: set[str] = set()
for row in rows:
step = row.get("step") if isinstance(row, dict) else None
if isinstance(step, str) and step:
completed.add(step)
return completed
def _auto_completed_keycloak_steps(username: str) -> set[str]:
if not username:
return set()
if not admin_client().ready():
return set()
completed: set[str] = set()
try:
user = admin_client().find_user(username) or {}
user_id = user.get("id") if isinstance(user, dict) else None
if not isinstance(user_id, str) or not user_id:
return set()
full = {}
try:
full = admin_client().get_user(user_id)
except Exception:
full = user if isinstance(user, dict) else {}
actions = full.get("requiredActions")
required_actions: set[str] = set()
if isinstance(actions, list):
required_actions = {a for a in actions if isinstance(a, str)}
if "UPDATE_PASSWORD" not in required_actions:
completed.add("keycloak_password_changed")
otp_present = False
try:
creds = admin_client().get_user_credentials(user_id)
for cred in creds:
ctype = cred.get("type") if isinstance(cred, dict) else None
if isinstance(ctype, str) and ctype.lower() in KEYCLOAK_OTP_CRED_TYPES:
otp_present = True
break
except Exception:
otp_present = False
if otp_present or "CONFIGURE_TOTP" not in required_actions:
completed.add("keycloak_mfa_configured")
except Exception:
return set()
return completed
def _completed_onboarding_steps(conn, request_code: str, username: str) -> set[str]:
completed = _fetch_completed_onboarding_steps(conn, request_code)
return completed | _auto_completed_keycloak_steps(username)
def _automation_ready(conn, request_code: str, username: str) -> bool:
if not username:
return False
if not admin_client().ready():
return False
# Prefer task-based readiness when we have task rows for the request.
task_row = conn.execute(
"SELECT 1 FROM access_request_tasks WHERE request_code = %s LIMIT 1",
(request_code,),
).fetchone()
if task_row:
return provision_tasks_complete(conn, request_code)
# Fallback for legacy requests: confirm user exists and has a mail app password.
try:
user = admin_client().find_user(username)
if not user:
return False
user_id = user.get("id") if isinstance(user, dict) else None
if not user_id:
return False
full = admin_client().get_user(str(user_id))
attrs = full.get("attributes") or {}
if not isinstance(attrs, dict):
return False
raw_pw = attrs.get("mailu_app_password")
if isinstance(raw_pw, list):
return bool(raw_pw and isinstance(raw_pw[0], str) and raw_pw[0])
return bool(isinstance(raw_pw, str) and raw_pw)
except Exception:
return False
def _advance_status(conn, request_code: str, username: str, status: str) -> str:
status = _normalize_status(status)
if status == "accounts_building" and _automation_ready(conn, request_code, username):
conn.execute(
"UPDATE access_requests SET status = 'awaiting_onboarding' WHERE request_code = %s AND status = 'accounts_building'",
(request_code,),
)
return "awaiting_onboarding"
if status == "awaiting_onboarding":
completed = _completed_onboarding_steps(conn, request_code, username)
if set(ONBOARDING_STEPS).issubset(completed):
conn.execute(
"UPDATE access_requests SET status = 'ready' WHERE request_code = %s AND status = 'awaiting_onboarding'",
(request_code,),
)
return "ready"
return status
def register(app) -> None:
@app.route("/api/access/request", methods=["POST"])
def request_access() -> Any:
if not settings.ACCESS_REQUEST_ENABLED:
return jsonify({"error": "request access disabled"}), 503
if not configured():
return jsonify({"error": "server not configured"}), 503
ip = _client_ip()
username, email, note = _extract_request_payload()
rate_key = ip
if username:
rate_key = f"{ip}:{username}"
if not rate_limit_allow(
rate_key,
key="access_request_submit",
limit=settings.ACCESS_REQUEST_SUBMIT_RATE_LIMIT,
window_sec=settings.ACCESS_REQUEST_SUBMIT_RATE_WINDOW_SEC,
):
return jsonify({"error": "rate limited"}), 429
if not username:
return jsonify({"error": "username is required"}), 400
if len(username) < 3 or len(username) > 32:
return jsonify({"error": "username must be 3-32 characters"}), 400
if not re.fullmatch(r"[a-zA-Z0-9._-]+", username):
return jsonify({"error": "username contains invalid characters"}), 400
if email and "@" not in email:
return jsonify({"error": "invalid email"}), 400
if admin_client().ready() and admin_client().find_user(username):
return jsonify({"error": "username already exists"}), 409
try:
with connect() as conn:
existing = conn.execute(
"""
SELECT request_code, status
FROM access_requests
WHERE username = %s AND status = 'pending'
ORDER BY created_at DESC
LIMIT 1
""",
(username,),
).fetchone()
if existing:
return jsonify({"ok": True, "request_code": existing["request_code"], "status": existing["status"]})
request_code = _random_request_code(username)
try:
conn.execute(
"""
INSERT INTO access_requests
(request_code, username, contact_email, note, status)
VALUES
(%s, %s, %s, %s, 'pending')
""",
(request_code, username, email or None, note or None),
)
except psycopg.errors.UniqueViolation:
conn.rollback()
existing = conn.execute(
"""
SELECT request_code, status
FROM access_requests
WHERE username = %s AND status = 'pending'
ORDER BY created_at DESC
LIMIT 1
""",
(username,),
).fetchone()
if not existing:
raise
return jsonify({"ok": True, "request_code": existing["request_code"], "status": existing["status"]})
except Exception:
return jsonify({"error": "failed to submit request"}), 502
return jsonify({"ok": True, "request_code": request_code})
@app.route("/api/access/request/status", methods=["POST"])
def request_access_status() -> Any:
if not settings.ACCESS_REQUEST_ENABLED:
return jsonify({"error": "request access disabled"}), 503
if not configured():
return jsonify({"error": "server not configured"}), 503
ip = _client_ip()
if not rate_limit_allow(
ip,
key="access_request_status",
limit=settings.ACCESS_REQUEST_STATUS_RATE_LIMIT,
window_sec=settings.ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC,
):
return jsonify({"error": "rate limited"}), 429
payload = request.get_json(silent=True) or {}
code = (payload.get("request_code") or payload.get("code") or "").strip()
if not code:
return jsonify({"error": "request_code is required"}), 400
# Additional per-code limiter to avoid global NAT rate-limit blowups.
if not rate_limit_allow(
f"{ip}:{code}",
key="access_request_status_code",
limit=max(20, settings.ACCESS_REQUEST_STATUS_RATE_LIMIT),
window_sec=settings.ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC,
):
return jsonify({"error": "rate limited"}), 429
try:
with connect() as conn:
row = conn.execute(
"SELECT status, username, initial_password, initial_password_revealed_at FROM access_requests WHERE request_code = %s",
(code,),
).fetchone()
if not row:
return jsonify({"error": "not found"}), 404
status = _advance_status(conn, code, row.get("username") or "", row.get("status") or "")
response: dict[str, Any] = {
"ok": True,
"status": status,
"username": row.get("username") or "",
}
if status in {"awaiting_onboarding", "ready"}:
password = row.get("initial_password")
revealed_at = row.get("initial_password_revealed_at")
if isinstance(password, str) and password and revealed_at is None:
response["initial_password"] = password
conn.execute(
"UPDATE access_requests SET initial_password_revealed_at = NOW() WHERE request_code = %s AND initial_password_revealed_at IS NULL",
(code,),
)
if status in {"awaiting_onboarding", "ready"}:
response["onboarding_url"] = f"/onboarding?code={code}"
if status in {"awaiting_onboarding", "ready"}:
completed = sorted(_completed_onboarding_steps(conn, code, row.get("username") or ""))
response["onboarding"] = {
"required_steps": list(ONBOARDING_STEPS),
"completed_steps": completed,
}
return jsonify(response)
except Exception:
return jsonify({"error": "failed to load status"}), 502
@app.route("/api/access/request/onboarding/attest", methods=["POST"])
@require_auth
def request_access_onboarding_attest() -> Any:
if not configured():
return jsonify({"error": "server not configured"}), 503
payload = request.get_json(silent=True) or {}
code = (payload.get("request_code") or payload.get("code") or "").strip()
step = (payload.get("step") or "").strip()
completed = payload.get("completed")
if not code:
return jsonify({"error": "request_code is required"}), 400
if step not in ONBOARDING_STEPS:
return jsonify({"error": "invalid step"}), 400
if step in KEYCLOAK_MANAGED_STEPS:
return jsonify({"error": "step is managed by keycloak"}), 400
username = getattr(g, "keycloak_username", "") or ""
if not username:
return jsonify({"error": "invalid token"}), 401
try:
with connect() as conn:
row = conn.execute(
"SELECT username, status FROM access_requests WHERE request_code = %s",
(code,),
).fetchone()
if not row:
return jsonify({"error": "not found"}), 404
if (row.get("username") or "") != username:
return jsonify({"error": "forbidden"}), 403
status = _normalize_status(row.get("status") or "")
if status not in {"awaiting_onboarding", "ready"}:
return jsonify({"error": "onboarding not available"}), 409
mark_done = True
if isinstance(completed, bool):
mark_done = completed
if mark_done:
conn.execute(
"""
INSERT INTO access_request_onboarding_steps (request_code, step)
VALUES (%s, %s)
ON CONFLICT (request_code, step) DO NOTHING
""",
(code, step),
)
else:
conn.execute(
"DELETE FROM access_request_onboarding_steps WHERE request_code = %s AND step = %s",
(code, step),
)
# Re-evaluate completion to update request status to ready if applicable.
status = _advance_status(conn, code, username, status)
completed_steps = sorted(_completed_onboarding_steps(conn, code, username))
except Exception:
return jsonify({"error": "failed to update onboarding"}), 502
return jsonify(
{
"ok": True,
"status": status,
"onboarding": {"required_steps": list(ONBOARDING_STEPS), "completed_steps": completed_steps},
}
)