542 lines
22 KiB
Python

from __future__ import annotations
from datetime import datetime, timezone
import hashlib
import hmac
import re
import secrets
import string
from typing import Any
from urllib.parse import quote
from flask import jsonify, request, g
import psycopg
from ..db import connect, configured
from ..keycloak import admin_client, require_auth
from ..mailer import MailerError, access_request_verification_body, send_text_email
from ..rate_limit import rate_limit_allow
from ..provisioning import provision_tasks_complete
from .. import settings
def _extract_request_payload() -> tuple[str, str, str]:
payload = request.get_json(silent=True) or {}
username = (payload.get("username") or "").strip()
email = (payload.get("email") or "").strip()
note = (payload.get("note") or "").strip()
return username, email, note
def _random_request_code(username: str) -> str:
suffix = "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(10))
return f"{username}~{suffix}"
def _client_ip() -> str:
xff = (request.headers.get("X-Forwarded-For") or "").strip()
if xff:
return xff.split(",", 1)[0].strip() or "unknown"
x_real_ip = (request.headers.get("X-Real-IP") or "").strip()
if x_real_ip:
return x_real_ip
return request.remote_addr or "unknown"
EMAIL_VERIFY_PENDING_STATUS = "pending_email_verification"
def _hash_verification_token(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _verify_url(request_code: str, token: str) -> str:
base = settings.PORTAL_PUBLIC_BASE_URL.rstrip("/")
return f"{base}/request-access?code={quote(request_code)}&verify={quote(token)}"
ONBOARDING_STEPS: tuple[str, ...] = (
"keycloak_password_changed",
"keycloak_mfa_configured",
"vaultwarden_master_password",
"element_recovery_key",
"element_recovery_key_stored",
)
KEYCLOAK_MANAGED_STEPS: set[str] = {"keycloak_password_changed", "keycloak_mfa_configured"}
KEYCLOAK_OTP_CRED_TYPES: set[str] = {"otp", "totp"}
def _normalize_status(status: str) -> str:
cleaned = (status or "").strip().lower()
if cleaned == "approved":
return "accounts_building"
return cleaned or "unknown"
def _fetch_completed_onboarding_steps(conn, request_code: str) -> set[str]:
rows = conn.execute(
"SELECT step FROM access_request_onboarding_steps WHERE request_code = %s",
(request_code,),
).fetchall()
completed: set[str] = set()
for row in rows:
step = row.get("step") if isinstance(row, dict) else None
if isinstance(step, str) and step:
completed.add(step)
return completed
def _auto_completed_keycloak_steps(username: str) -> set[str]:
if not username:
return set()
if not admin_client().ready():
return set()
completed: set[str] = set()
try:
user = admin_client().find_user(username) or {}
user_id = user.get("id") if isinstance(user, dict) else None
if not isinstance(user_id, str) or not user_id:
return set()
full = {}
try:
full = admin_client().get_user(user_id)
except Exception:
full = user if isinstance(user, dict) else {}
actions = full.get("requiredActions")
required_actions: set[str] = set()
if isinstance(actions, list):
required_actions = {a for a in actions if isinstance(a, str)}
if "UPDATE_PASSWORD" not in required_actions:
completed.add("keycloak_password_changed")
otp_present = False
try:
creds = admin_client().get_user_credentials(user_id)
for cred in creds:
ctype = cred.get("type") if isinstance(cred, dict) else None
if isinstance(ctype, str) and ctype.lower() in KEYCLOAK_OTP_CRED_TYPES:
otp_present = True
break
except Exception:
otp_present = False
if otp_present or "CONFIGURE_TOTP" not in required_actions:
completed.add("keycloak_mfa_configured")
except Exception:
return set()
return completed
def _completed_onboarding_steps(conn, request_code: str, username: str) -> set[str]:
completed = _fetch_completed_onboarding_steps(conn, request_code)
return completed | _auto_completed_keycloak_steps(username)
def _automation_ready(conn, request_code: str, username: str) -> bool:
if not username:
return False
if not admin_client().ready():
return False
# Prefer task-based readiness when we have task rows for the request.
task_row = conn.execute(
"SELECT 1 FROM access_request_tasks WHERE request_code = %s LIMIT 1",
(request_code,),
).fetchone()
if task_row:
return provision_tasks_complete(conn, request_code)
# Fallback for legacy requests: confirm user exists and has a mail app password.
try:
user = admin_client().find_user(username)
if not user:
return False
user_id = user.get("id") if isinstance(user, dict) else None
if not user_id:
return False
full = admin_client().get_user(str(user_id))
attrs = full.get("attributes") or {}
if not isinstance(attrs, dict):
return False
raw_pw = attrs.get("mailu_app_password")
if isinstance(raw_pw, list):
return bool(raw_pw and isinstance(raw_pw[0], str) and raw_pw[0])
return bool(isinstance(raw_pw, str) and raw_pw)
except Exception:
return False
def _advance_status(conn, request_code: str, username: str, status: str) -> str:
status = _normalize_status(status)
if status == "accounts_building" and _automation_ready(conn, request_code, username):
conn.execute(
"UPDATE access_requests SET status = 'awaiting_onboarding' WHERE request_code = %s AND status = 'accounts_building'",
(request_code,),
)
return "awaiting_onboarding"
if status == "awaiting_onboarding":
completed = _completed_onboarding_steps(conn, request_code, username)
if set(ONBOARDING_STEPS).issubset(completed):
conn.execute(
"UPDATE access_requests SET status = 'ready' WHERE request_code = %s AND status = 'awaiting_onboarding'",
(request_code,),
)
return "ready"
return status
def register(app) -> None:
@app.route("/api/access/request", methods=["POST"])
def request_access() -> Any:
if not settings.ACCESS_REQUEST_ENABLED:
return jsonify({"error": "request access disabled"}), 503
if not configured():
return jsonify({"error": "server not configured"}), 503
ip = _client_ip()
username, email, note = _extract_request_payload()
rate_key = ip
if username:
rate_key = f"{ip}:{username}"
if not rate_limit_allow(
rate_key,
key="access_request_submit",
limit=settings.ACCESS_REQUEST_SUBMIT_RATE_LIMIT,
window_sec=settings.ACCESS_REQUEST_SUBMIT_RATE_WINDOW_SEC,
):
return jsonify({"error": "rate limited"}), 429
if not username:
return jsonify({"error": "username is required"}), 400
if len(username) < 3 or len(username) > 32:
return jsonify({"error": "username must be 3-32 characters"}), 400
if not re.fullmatch(r"[a-zA-Z0-9._-]+", username):
return jsonify({"error": "username contains invalid characters"}), 400
if not email:
return jsonify({"error": "email is required"}), 400
if "@" not in email:
return jsonify({"error": "invalid email"}), 400
if email.lower().endswith(f"@{settings.MAILU_DOMAIN.lower()}"):
return jsonify({"error": "email must be an external address"}), 400
if admin_client().ready() and admin_client().find_user(username):
return jsonify({"error": "username already exists"}), 409
try:
with connect() as conn:
existing = conn.execute(
"""
SELECT request_code, status
FROM access_requests
WHERE username = %s AND status IN (%s, 'pending')
ORDER BY created_at DESC
LIMIT 1
""",
(username, EMAIL_VERIFY_PENDING_STATUS),
).fetchone()
if existing:
existing_status = str(existing.get("status") or "")
request_code = str(existing.get("request_code") or "")
if existing_status != EMAIL_VERIFY_PENDING_STATUS:
return jsonify({"ok": True, "request_code": request_code, "status": existing_status})
token = secrets.token_urlsafe(24)
token_hash = _hash_verification_token(token)
conn.execute(
"""
UPDATE access_requests
SET contact_email = %s,
note = %s,
email_verification_token_hash = %s,
email_verification_sent_at = NOW(),
email_verified_at = NULL
WHERE request_code = %s AND status = %s
""",
(email, note or None, token_hash, request_code, EMAIL_VERIFY_PENDING_STATUS),
)
verify_url = _verify_url(request_code, token)
try:
send_text_email(
to_addr=email,
subject="Atlas: confirm your email",
body=access_request_verification_body(request_code=request_code, verify_url=verify_url),
)
except MailerError:
return (
jsonify({"error": "failed to send verification email", "request_code": request_code}),
502,
)
return jsonify({"ok": True, "request_code": request_code, "status": EMAIL_VERIFY_PENDING_STATUS})
request_code = _random_request_code(username)
token = secrets.token_urlsafe(24)
token_hash = _hash_verification_token(token)
try:
conn.execute(
"""
INSERT INTO access_requests
(request_code, username, contact_email, note, status,
email_verification_token_hash, email_verification_sent_at)
VALUES
(%s, %s, %s, %s, %s, %s, NOW())
""",
(request_code, username, email, note or None, EMAIL_VERIFY_PENDING_STATUS, token_hash),
)
except psycopg.errors.UniqueViolation:
conn.rollback()
existing = conn.execute(
"""
SELECT request_code, status
FROM access_requests
WHERE username = %s AND status IN (%s, 'pending')
ORDER BY created_at DESC
LIMIT 1
""",
(username, EMAIL_VERIFY_PENDING_STATUS),
).fetchone()
if not existing:
raise
return jsonify({"ok": True, "request_code": existing["request_code"], "status": existing["status"]})
verify_url = _verify_url(request_code, token)
try:
send_text_email(
to_addr=email,
subject="Atlas: confirm your email",
body=access_request_verification_body(request_code=request_code, verify_url=verify_url),
)
except MailerError:
return jsonify({"error": "failed to send verification email", "request_code": request_code}), 502
except Exception:
return jsonify({"error": "failed to submit request"}), 502
return jsonify({"ok": True, "request_code": request_code, "status": EMAIL_VERIFY_PENDING_STATUS})
@app.route("/api/access/request/verify", methods=["POST"])
def request_access_verify() -> Any:
if not settings.ACCESS_REQUEST_ENABLED:
return jsonify({"error": "request access disabled"}), 503
if not configured():
return jsonify({"error": "server not configured"}), 503
ip = _client_ip()
if not rate_limit_allow(
ip,
key="access_request_verify",
limit=60,
window_sec=60,
):
return jsonify({"error": "rate limited"}), 429
payload = request.get_json(silent=True) or {}
code = (payload.get("request_code") or payload.get("code") or "").strip()
token = (payload.get("token") or payload.get("verify") or "").strip()
if not code:
return jsonify({"error": "request_code is required"}), 400
if not token:
return jsonify({"error": "token is required"}), 400
if not rate_limit_allow(
f"{ip}:{code}",
key="access_request_verify_code",
limit=30,
window_sec=60,
):
return jsonify({"error": "rate limited"}), 429
try:
with connect() as conn:
row = conn.execute(
"""
SELECT status, email_verification_token_hash, email_verification_sent_at, email_verified_at
FROM access_requests
WHERE request_code = %s
""",
(code,),
).fetchone()
if not row:
return jsonify({"error": "not found"}), 404
status = _normalize_status(row.get("status") or "")
if status != EMAIL_VERIFY_PENDING_STATUS:
return jsonify({"ok": True, "status": status})
stored_hash = str(row.get("email_verification_token_hash") or "")
if not stored_hash:
return jsonify({"error": "verification token missing"}), 409
provided_hash = _hash_verification_token(token)
if not hmac.compare_digest(stored_hash, provided_hash):
return jsonify({"error": "invalid token"}), 401
sent_at = row.get("email_verification_sent_at")
if isinstance(sent_at, datetime):
now = datetime.now(timezone.utc)
if sent_at.tzinfo is None:
sent_at = sent_at.replace(tzinfo=timezone.utc)
age_sec = (now - sent_at).total_seconds()
if age_sec > settings.ACCESS_REQUEST_EMAIL_VERIFY_TTL_SEC:
return jsonify({"error": "verification token expired"}), 410
conn.execute(
"""
UPDATE access_requests
SET status = 'pending',
email_verified_at = NOW(),
email_verification_token_hash = NULL
WHERE request_code = %s AND status = %s
""",
(code, EMAIL_VERIFY_PENDING_STATUS),
)
return jsonify({"ok": True, "status": "pending"})
except Exception:
return jsonify({"error": "failed to verify"}), 502
@app.route("/api/access/request/status", methods=["POST"])
def request_access_status() -> Any:
if not settings.ACCESS_REQUEST_ENABLED:
return jsonify({"error": "request access disabled"}), 503
if not configured():
return jsonify({"error": "server not configured"}), 503
ip = _client_ip()
if not rate_limit_allow(
ip,
key="access_request_status",
limit=settings.ACCESS_REQUEST_STATUS_RATE_LIMIT,
window_sec=settings.ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC,
):
return jsonify({"error": "rate limited"}), 429
payload = request.get_json(silent=True) or {}
code = (payload.get("request_code") or payload.get("code") or "").strip()
if not code:
return jsonify({"error": "request_code is required"}), 400
# Additional per-code limiter to avoid global NAT rate-limit blowups.
if not rate_limit_allow(
f"{ip}:{code}",
key="access_request_status_code",
limit=max(20, settings.ACCESS_REQUEST_STATUS_RATE_LIMIT),
window_sec=settings.ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC,
):
return jsonify({"error": "rate limited"}), 429
try:
with connect() as conn:
row = conn.execute(
"SELECT status, username, initial_password, initial_password_revealed_at FROM access_requests WHERE request_code = %s",
(code,),
).fetchone()
if not row:
return jsonify({"error": "not found"}), 404
status = _advance_status(conn, code, row.get("username") or "", row.get("status") or "")
response: dict[str, Any] = {
"ok": True,
"status": status,
"username": row.get("username") or "",
}
if status in {"awaiting_onboarding", "ready"}:
password = row.get("initial_password")
revealed_at = row.get("initial_password_revealed_at")
if isinstance(password, str) and password and revealed_at is None:
response["initial_password"] = password
conn.execute(
"UPDATE access_requests SET initial_password_revealed_at = NOW() WHERE request_code = %s AND initial_password_revealed_at IS NULL",
(code,),
)
if status in {"awaiting_onboarding", "ready"}:
response["onboarding_url"] = f"/onboarding?code={code}"
if status in {"awaiting_onboarding", "ready"}:
completed = sorted(_completed_onboarding_steps(conn, code, row.get("username") or ""))
response["onboarding"] = {
"required_steps": list(ONBOARDING_STEPS),
"completed_steps": completed,
}
return jsonify(response)
except Exception:
return jsonify({"error": "failed to load status"}), 502
@app.route("/api/access/request/onboarding/attest", methods=["POST"])
@require_auth
def request_access_onboarding_attest() -> Any:
if not configured():
return jsonify({"error": "server not configured"}), 503
payload = request.get_json(silent=True) or {}
code = (payload.get("request_code") or payload.get("code") or "").strip()
step = (payload.get("step") or "").strip()
completed = payload.get("completed")
if not code:
return jsonify({"error": "request_code is required"}), 400
if step not in ONBOARDING_STEPS:
return jsonify({"error": "invalid step"}), 400
if step in KEYCLOAK_MANAGED_STEPS:
return jsonify({"error": "step is managed by keycloak"}), 400
username = getattr(g, "keycloak_username", "") or ""
if not username:
return jsonify({"error": "invalid token"}), 401
try:
with connect() as conn:
row = conn.execute(
"SELECT username, status FROM access_requests WHERE request_code = %s",
(code,),
).fetchone()
if not row:
return jsonify({"error": "not found"}), 404
if (row.get("username") or "") != username:
return jsonify({"error": "forbidden"}), 403
status = _normalize_status(row.get("status") or "")
if status not in {"awaiting_onboarding", "ready"}:
return jsonify({"error": "onboarding not available"}), 409
mark_done = True
if isinstance(completed, bool):
mark_done = completed
if mark_done:
conn.execute(
"""
INSERT INTO access_request_onboarding_steps (request_code, step)
VALUES (%s, %s)
ON CONFLICT (request_code, step) DO NOTHING
""",
(code, step),
)
else:
conn.execute(
"DELETE FROM access_request_onboarding_steps WHERE request_code = %s AND step = %s",
(code, step),
)
# Re-evaluate completion to update request status to ready if applicable.
status = _advance_status(conn, code, username, status)
completed_steps = sorted(_completed_onboarding_steps(conn, code, username))
except Exception:
return jsonify({"error": "failed to update onboarding"}), 502
return jsonify(
{
"ok": True,
"status": status,
"onboarding": {"required_steps": list(ONBOARDING_STEPS), "completed_steps": completed_steps},
}
)