bstein-dev-home/backend/atlas_portal/routes/access_request_state.py

496 lines
18 KiB
Python

from __future__ import annotations
from datetime import datetime, timezone
import hashlib
import hmac
import re
import secrets
import string
from typing import Any
from urllib.parse import quote
from flask import request
from .. import ariadne_client
from ..db import connect, configured
from ..keycloak import admin_client, oidc_client
from ..mailer import MailerError, access_request_verification_body, send_text_email
from ..rate_limit import rate_limit_allow
from ..provisioning import provision_access_request, provision_tasks_complete
from .. import settings
from .access_request_onboarding_policy import (
KEYCLOAK_MANAGED_STEPS,
ONBOARDING_OPTIONAL_STEPS,
ONBOARDING_REQUIRED_STEPS,
ONBOARDING_STEP_PREREQUISITES,
ONBOARDING_STEPS,
VAULTWARDEN_GRANDFATHERED_FLAG,
_KEYCLOAK_PASSWORD_ROTATION_REQUESTED_ARTIFACT,
_VAULTWARDEN_READY_STATUSES,
)
def _extract_request_payload() -> tuple[str, str, str, str, str]:
payload = request.get_json(silent=True) or {}
username = (payload.get("username") or "").strip()
email = (payload.get("email") or "").strip()
note = (payload.get("note") or "").strip()
first_name = (payload.get("first_name") or "").strip()
last_name = (payload.get("last_name") or "").strip()
return username, email, note, first_name, last_name
def _normalize_name(value: str) -> str:
return " ".join(value.strip().split())
def _validate_name(value: str, *, label: str, required: bool) -> str | None:
cleaned = _normalize_name(value)
if not cleaned:
return f"{label} is required" if required else None
if len(cleaned) > 80:
return f"{label} must be 1-80 characters"
if any(ch in "\r\n\t" for ch in cleaned):
return f"{label} contains invalid whitespace"
return None
def _validate_username(username: str) -> str | None:
if not username:
return "username is required"
if len(username) < 3 or len(username) > 32:
return "username must be 3-32 characters"
if not re.fullmatch(r"[a-zA-Z0-9._-]+", username):
return "username contains invalid characters"
return None
def _random_request_code(username: str) -> str:
suffix = "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(10))
return f"{username}~{suffix}"
def _client_ip() -> str:
xff = (request.headers.get("X-Forwarded-For") or "").strip()
if xff:
return xff.split(",", 1)[0].strip() or "unknown"
x_real_ip = (request.headers.get("X-Real-IP") or "").strip()
if x_real_ip:
return x_real_ip
return request.remote_addr or "unknown"
EMAIL_VERIFY_PENDING_STATUS = "pending_email_verification"
def _hash_verification_token(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def _verify_url(request_code: str, token: str) -> str:
base = settings.PORTAL_PUBLIC_BASE_URL.rstrip("/")
return f"{base}/api/access/request/verify-link?code={quote(request_code)}&token={quote(token)}"
def _send_verification_email(*, request_code: str, email: str, token: str) -> None:
verify_url = _verify_url(request_code, token)
send_text_email(
to_addr=email,
subject="Atlas: confirm your email",
body=access_request_verification_body(request_code=request_code, verify_url=verify_url),
)
class VerificationError(Exception):
"""Describe an email verification failure with an HTTP status."""
def __init__(self, status_code: int, message: str) -> None:
super().__init__(message)
self.status_code = status_code
self.message = message
def _verify_request(conn, code: str, token: str) -> str:
"""Validate email proof and atomically advance a pending request."""
row = conn.execute(
"""
SELECT status, email_verification_token_hash, email_verification_sent_at, email_verified_at
FROM access_requests
WHERE request_code = %s
""",
(code,),
).fetchone()
if not row:
raise VerificationError(404, "not found")
status = _normalize_status(row.get("status") or "")
if status != EMAIL_VERIFY_PENDING_STATUS:
return status
stored_hash = str(row.get("email_verification_token_hash") or "")
if not stored_hash:
raise VerificationError(409, "verification token missing")
provided_hash = _hash_verification_token(token)
if not hmac.compare_digest(stored_hash, provided_hash):
raise VerificationError(401, "invalid token")
sent_at = row.get("email_verification_sent_at")
if isinstance(sent_at, datetime):
now = datetime.now(timezone.utc)
if sent_at.tzinfo is None:
sent_at = sent_at.replace(tzinfo=timezone.utc)
age_sec = (now - sent_at).total_seconds()
if age_sec > settings.ACCESS_REQUEST_EMAIL_VERIFY_TTL_SEC:
raise VerificationError(410, "verification token expired")
conn.execute(
"""
UPDATE access_requests
SET status = 'pending',
email_verified_at = NOW(),
email_verification_token_hash = NULL
WHERE request_code = %s AND status = %s
""",
(code, EMAIL_VERIFY_PENDING_STATUS),
)
return "pending"
def _normalize_status(status: str) -> str:
cleaned = (status or "").strip().lower()
if cleaned == "approved":
return "accounts_building"
return cleaned or "unknown"
def _fetch_completed_onboarding_steps(conn, request_code: str) -> set[str]:
"""Return manually attested onboarding steps for one request."""
rows = conn.execute(
"SELECT step FROM access_request_onboarding_steps WHERE request_code = %s",
(request_code,),
).fetchall()
completed: set[str] = set()
for row in rows:
step = row.get("step") if isinstance(row, dict) else None
if isinstance(step, str) and step:
completed.add(step)
return completed
def _normalize_flag_list(raw: Any) -> set[str]:
if isinstance(raw, list):
return {item for item in raw if isinstance(item, str) and item}
if isinstance(raw, str) and raw:
return {raw}
return set()
def _fetch_request_flags_and_email(conn, request_code: str) -> tuple[set[str], str]:
"""Return approval flags and contact email used by onboarding decisions."""
row = conn.execute(
"SELECT approval_flags, contact_email FROM access_requests WHERE request_code = %s",
(request_code,),
).fetchone()
if not row:
return set(), ""
flags = _normalize_flag_list(row.get("approval_flags"))
email = row.get("contact_email") if isinstance(row, dict) else ""
return flags, (email or "").strip()
def _user_in_group(username: str, group_name: str) -> bool:
"""Return whether a Keycloak user belongs to a named group."""
if not username or not group_name:
return False
if not admin_client().ready():
return False
try:
user = admin_client().find_user(username) or {}
user_id = user.get("id") if isinstance(user, dict) else None
if not isinstance(user_id, str) or not user_id:
return False
groups = admin_client().list_user_groups(user_id)
except Exception:
return False
return group_name in groups
def _vaultwarden_grandfathered(conn, request_code: str, username: str) -> tuple[bool, str]:
flags, contact_email = _fetch_request_flags_and_email(conn, request_code)
if VAULTWARDEN_GRANDFATHERED_FLAG in flags:
return True, contact_email
if _user_in_group(username, VAULTWARDEN_GRANDFATHERED_FLAG):
return True, contact_email
return False, contact_email
def _resolve_recovery_email(username: str, fallback: str) -> str:
"""Find the best recovery email for Vaultwarden onboarding."""
if username and admin_client().ready():
try:
user = admin_client().find_user(username) or {}
user_id = user.get("id") if isinstance(user, dict) else None
if isinstance(user_id, str) and user_id:
full = admin_client().get_user(user_id)
email = full.get("email")
if isinstance(email, str) and email.strip():
return email.strip()
except Exception:
pass
return (fallback or "").strip()
def _password_rotation_requested(conn, request_code: str) -> bool:
"""Return whether Keycloak password rotation was requested for this request."""
row = conn.execute(
"""
SELECT 1
FROM access_request_onboarding_artifacts
WHERE request_code = %s AND artifact = %s
LIMIT 1
""",
(request_code, _KEYCLOAK_PASSWORD_ROTATION_REQUESTED_ARTIFACT),
).fetchone()
return bool(row)
def _request_keycloak_password_rotation(conn, request_code: str, username: str) -> None:
"""Require Keycloak password rotation and persist the request marker."""
if not username:
raise ValueError("username missing")
if not admin_client().ready():
raise RuntimeError("keycloak admin unavailable")
user = admin_client().find_user(username) or {}
user_id = user.get("id") if isinstance(user, dict) else None
if not isinstance(user_id, str) or not user_id:
raise RuntimeError("keycloak user not found")
full = admin_client().get_user(user_id)
actions = full.get("requiredActions")
actions_list: list[str] = []
if isinstance(actions, list):
actions_list = [a for a in actions if isinstance(a, str)]
if "UPDATE_PASSWORD" not in actions_list:
actions_list.append("UPDATE_PASSWORD")
admin_client().update_user_safe(user_id, {"requiredActions": actions_list})
conn.execute(
"""
INSERT INTO access_request_onboarding_artifacts (request_code, artifact, value_hash)
VALUES (%s, %s, NOW()::text)
ON CONFLICT (request_code, artifact) DO NOTHING
""",
(request_code, _KEYCLOAK_PASSWORD_ROTATION_REQUESTED_ARTIFACT),
)
def _extract_attr(attrs: Any, key: str) -> str:
"""Return the first string value for a Keycloak attribute."""
if not isinstance(attrs, dict):
return ""
raw = attrs.get(key)
if isinstance(raw, list):
for item in raw:
if isinstance(item, str) and item.strip():
return item.strip()
return ""
if isinstance(raw, str) and raw.strip():
return raw.strip()
return ""
def _vaultwarden_status_for_user(username: str) -> str:
"""Read the Vaultwarden lifecycle status mirrored on a Keycloak user."""
if not username:
return ""
if not admin_client().ready():
return ""
try:
user = admin_client().find_user(username) or {}
user_id = user.get("id") if isinstance(user, dict) else None
if not isinstance(user_id, str) or not user_id:
return ""
full = admin_client().get_user(user_id)
attrs = full.get("attributes") if isinstance(full, dict) else {}
return _extract_attr(attrs, "vaultwarden_status")
except Exception:
return ""
def _auto_completed_service_steps(attrs: Any) -> set[str]:
"""Infer onboarding steps completed by backend service automation."""
completed: set[str] = set()
if not isinstance(attrs, dict):
return completed
vaultwarden_status = _extract_attr(attrs, "vaultwarden_status")
vaultwarden_master = _extract_attr(attrs, "vaultwarden_master_password_set_at")
if vaultwarden_master or vaultwarden_status in _VAULTWARDEN_READY_STATUSES:
completed.add("vaultwarden_master_password")
nextcloud_synced_at = _extract_attr(attrs, "nextcloud_mail_synced_at")
if nextcloud_synced_at:
completed.add("nextcloud_mail_integration")
firefly_rotated_at = _extract_attr(attrs, "firefly_password_rotated_at")
if firefly_rotated_at:
completed.add("firefly_password_rotated")
wger_rotated_at = _extract_attr(attrs, "wger_password_rotated_at")
if wger_rotated_at:
completed.add("wger_password_rotated")
return completed
def _auto_completed_keycloak_steps(conn, request_code: str, username: str) -> set[str]:
"""Infer onboarding steps from Keycloak profile state."""
if not username:
return set()
if not admin_client().ready():
return set()
if not request_code:
return set()
completed: set[str] = set()
try:
user = admin_client().find_user(username) or {}
user_id = user.get("id") if isinstance(user, dict) else None
if not isinstance(user_id, str) or not user_id:
return set()
full = {}
try:
full = admin_client().get_user(user_id)
except Exception:
full = user if isinstance(user, dict) else {}
attrs = full.get("attributes") if isinstance(full, dict) else {}
completed |= _auto_completed_service_steps(attrs)
actions = full.get("requiredActions")
required_actions: set[str] = set()
actions_list: list[str] = []
if isinstance(actions, list):
actions_list = [a for a in actions if isinstance(a, str)]
required_actions = set(actions_list)
if _password_rotation_requested(conn, request_code) and "UPDATE_PASSWORD" not in required_actions:
completed.add("keycloak_password_rotated")
# Backfill: earlier accounts were created with CONFIGURE_TOTP as a required action,
# which forces users to enroll MFA at first login. We no longer require that, so
# remove it if present.
if "CONFIGURE_TOTP" in required_actions:
try:
admin_client().update_user_safe(
user_id,
{"requiredActions": [a for a in actions_list if a != "CONFIGURE_TOTP"]},
)
except Exception:
pass
except Exception:
return set()
return completed
def _completed_onboarding_steps(conn, request_code: str, username: str) -> set[str]:
completed = _fetch_completed_onboarding_steps(conn, request_code)
return completed | _auto_completed_keycloak_steps(conn, request_code, username)
def _automation_ready(conn, request_code: str, username: str) -> bool:
"""Return whether account provisioning has finished enough for onboarding."""
if not username:
return False
if not admin_client().ready():
return False
# Prefer task-based readiness when we have task rows for the request.
task_row = conn.execute(
"SELECT 1 FROM access_request_tasks WHERE request_code = %s LIMIT 1",
(request_code,),
).fetchone()
if task_row:
return provision_tasks_complete(conn, request_code)
# Fallback for legacy requests: confirm user exists and has a mail app password.
try:
user = admin_client().find_user(username)
if not user:
return False
user_id = user.get("id") if isinstance(user, dict) else None
if not user_id:
return False
full = admin_client().get_user(str(user_id))
attrs = full.get("attributes") or {}
if not isinstance(attrs, dict):
return False
raw_pw = attrs.get("mailu_app_password")
if isinstance(raw_pw, list):
return bool(raw_pw and isinstance(raw_pw[0], str) and raw_pw[0])
return bool(isinstance(raw_pw, str) and raw_pw)
except Exception:
return False
def _advance_status(conn, request_code: str, username: str, status: str) -> str:
"""Advance an access request through automatic status transitions."""
status = _normalize_status(status)
if status == "accounts_building" and _automation_ready(conn, request_code, username):
conn.execute(
"UPDATE access_requests SET status = 'awaiting_onboarding' WHERE request_code = %s AND status = 'accounts_building'",
(request_code,),
)
return "awaiting_onboarding"
if status == "awaiting_onboarding":
completed = _completed_onboarding_steps(conn, request_code, username)
required_steps = set(ONBOARDING_REQUIRED_STEPS)
grandfathered, _ = _vaultwarden_grandfathered(conn, request_code, username)
vaultwarden_status = _vaultwarden_status_for_user(username)
if grandfathered and vaultwarden_status == "grandfathered":
required_steps.add("vaultwarden_store_temp_password")
if required_steps.issubset(completed):
conn.execute(
"UPDATE access_requests SET status = 'ready' WHERE request_code = %s AND status = 'awaiting_onboarding'",
(request_code,),
)
return "ready"
return status
def _onboarding_payload(conn, request_code: str, username: str) -> dict[str, Any]:
"""Build the onboarding progress payload returned to the frontend."""
completed_steps = sorted(_completed_onboarding_steps(conn, request_code, username))
password_rotation_requested = _password_rotation_requested(conn, request_code)
grandfathered, contact_email = _vaultwarden_grandfathered(conn, request_code, username)
recovery_email = _resolve_recovery_email(username, contact_email) if grandfathered else ""
vaultwarden_status = _vaultwarden_status_for_user(username)
vaultwarden_matched = grandfathered and vaultwarden_status == "grandfathered"
required_steps = list(ONBOARDING_REQUIRED_STEPS)
if vaultwarden_matched:
required_steps.append("vaultwarden_store_temp_password")
return {
"required_steps": required_steps,
"optional_steps": sorted(ONBOARDING_OPTIONAL_STEPS),
"completed_steps": completed_steps,
"keycloak": {
"password_rotation_requested": password_rotation_requested,
},
"vaultwarden": {
"grandfathered": grandfathered,
"recovery_email": recovery_email,
"matched": vaultwarden_matched,
},
}
# Keep the historical access_requests module patch surface intact for tests and
# callers while the route handlers live in smaller focused modules.
__all__ = [name for name in globals() if not name.startswith("__")]