496 lines
18 KiB
Python
496 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
import hashlib
|
|
import hmac
|
|
import re
|
|
import secrets
|
|
import string
|
|
from typing import Any
|
|
from urllib.parse import quote
|
|
|
|
from flask import request
|
|
|
|
from .. import ariadne_client
|
|
from ..db import connect, configured
|
|
from ..keycloak import admin_client, oidc_client
|
|
from ..mailer import MailerError, access_request_verification_body, send_text_email
|
|
from ..rate_limit import rate_limit_allow
|
|
from ..provisioning import provision_access_request, provision_tasks_complete
|
|
from .. import settings
|
|
from .access_request_onboarding_policy import (
|
|
KEYCLOAK_MANAGED_STEPS,
|
|
ONBOARDING_OPTIONAL_STEPS,
|
|
ONBOARDING_REQUIRED_STEPS,
|
|
ONBOARDING_STEP_PREREQUISITES,
|
|
ONBOARDING_STEPS,
|
|
VAULTWARDEN_GRANDFATHERED_FLAG,
|
|
_KEYCLOAK_PASSWORD_ROTATION_REQUESTED_ARTIFACT,
|
|
_VAULTWARDEN_READY_STATUSES,
|
|
)
|
|
|
|
def _extract_request_payload() -> tuple[str, str, str, str, str]:
|
|
payload = request.get_json(silent=True) or {}
|
|
username = (payload.get("username") or "").strip()
|
|
email = (payload.get("email") or "").strip()
|
|
note = (payload.get("note") or "").strip()
|
|
first_name = (payload.get("first_name") or "").strip()
|
|
last_name = (payload.get("last_name") or "").strip()
|
|
return username, email, note, first_name, last_name
|
|
|
|
|
|
def _normalize_name(value: str) -> str:
|
|
return " ".join(value.strip().split())
|
|
|
|
|
|
def _validate_name(value: str, *, label: str, required: bool) -> str | None:
|
|
cleaned = _normalize_name(value)
|
|
if not cleaned:
|
|
return f"{label} is required" if required else None
|
|
if len(cleaned) > 80:
|
|
return f"{label} must be 1-80 characters"
|
|
if any(ch in "\r\n\t" for ch in cleaned):
|
|
return f"{label} contains invalid whitespace"
|
|
return None
|
|
|
|
|
|
def _validate_username(username: str) -> str | None:
|
|
if not username:
|
|
return "username is required"
|
|
if len(username) < 3 or len(username) > 32:
|
|
return "username must be 3-32 characters"
|
|
if not re.fullmatch(r"[a-zA-Z0-9._-]+", username):
|
|
return "username contains invalid characters"
|
|
return None
|
|
|
|
|
|
def _random_request_code(username: str) -> str:
|
|
suffix = "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(10))
|
|
return f"{username}~{suffix}"
|
|
|
|
|
|
def _client_ip() -> str:
|
|
xff = (request.headers.get("X-Forwarded-For") or "").strip()
|
|
if xff:
|
|
return xff.split(",", 1)[0].strip() or "unknown"
|
|
x_real_ip = (request.headers.get("X-Real-IP") or "").strip()
|
|
if x_real_ip:
|
|
return x_real_ip
|
|
return request.remote_addr or "unknown"
|
|
|
|
|
|
EMAIL_VERIFY_PENDING_STATUS = "pending_email_verification"
|
|
|
|
|
|
def _hash_verification_token(token: str) -> str:
|
|
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _verify_url(request_code: str, token: str) -> str:
|
|
base = settings.PORTAL_PUBLIC_BASE_URL.rstrip("/")
|
|
return f"{base}/api/access/request/verify-link?code={quote(request_code)}&token={quote(token)}"
|
|
|
|
|
|
def _send_verification_email(*, request_code: str, email: str, token: str) -> None:
|
|
verify_url = _verify_url(request_code, token)
|
|
send_text_email(
|
|
to_addr=email,
|
|
subject="Atlas: confirm your email",
|
|
body=access_request_verification_body(request_code=request_code, verify_url=verify_url),
|
|
)
|
|
|
|
|
|
class VerificationError(Exception):
|
|
"""Describe an email verification failure with an HTTP status."""
|
|
|
|
def __init__(self, status_code: int, message: str) -> None:
|
|
super().__init__(message)
|
|
self.status_code = status_code
|
|
self.message = message
|
|
|
|
|
|
def _verify_request(conn, code: str, token: str) -> str:
|
|
"""Validate email proof and atomically advance a pending request."""
|
|
row = conn.execute(
|
|
"""
|
|
SELECT status, email_verification_token_hash, email_verification_sent_at, email_verified_at
|
|
FROM access_requests
|
|
WHERE request_code = %s
|
|
""",
|
|
(code,),
|
|
).fetchone()
|
|
if not row:
|
|
raise VerificationError(404, "not found")
|
|
|
|
status = _normalize_status(row.get("status") or "")
|
|
if status != EMAIL_VERIFY_PENDING_STATUS:
|
|
return status
|
|
|
|
stored_hash = str(row.get("email_verification_token_hash") or "")
|
|
if not stored_hash:
|
|
raise VerificationError(409, "verification token missing")
|
|
|
|
provided_hash = _hash_verification_token(token)
|
|
if not hmac.compare_digest(stored_hash, provided_hash):
|
|
raise VerificationError(401, "invalid token")
|
|
|
|
sent_at = row.get("email_verification_sent_at")
|
|
if isinstance(sent_at, datetime):
|
|
now = datetime.now(timezone.utc)
|
|
if sent_at.tzinfo is None:
|
|
sent_at = sent_at.replace(tzinfo=timezone.utc)
|
|
age_sec = (now - sent_at).total_seconds()
|
|
if age_sec > settings.ACCESS_REQUEST_EMAIL_VERIFY_TTL_SEC:
|
|
raise VerificationError(410, "verification token expired")
|
|
|
|
conn.execute(
|
|
"""
|
|
UPDATE access_requests
|
|
SET status = 'pending',
|
|
email_verified_at = NOW(),
|
|
email_verification_token_hash = NULL
|
|
WHERE request_code = %s AND status = %s
|
|
""",
|
|
(code, EMAIL_VERIFY_PENDING_STATUS),
|
|
)
|
|
return "pending"
|
|
|
|
|
|
def _normalize_status(status: str) -> str:
|
|
cleaned = (status or "").strip().lower()
|
|
if cleaned == "approved":
|
|
return "accounts_building"
|
|
return cleaned or "unknown"
|
|
|
|
|
|
def _fetch_completed_onboarding_steps(conn, request_code: str) -> set[str]:
|
|
"""Return manually attested onboarding steps for one request."""
|
|
rows = conn.execute(
|
|
"SELECT step FROM access_request_onboarding_steps WHERE request_code = %s",
|
|
(request_code,),
|
|
).fetchall()
|
|
completed: set[str] = set()
|
|
for row in rows:
|
|
step = row.get("step") if isinstance(row, dict) else None
|
|
if isinstance(step, str) and step:
|
|
completed.add(step)
|
|
return completed
|
|
|
|
|
|
def _normalize_flag_list(raw: Any) -> set[str]:
|
|
if isinstance(raw, list):
|
|
return {item for item in raw if isinstance(item, str) and item}
|
|
if isinstance(raw, str) and raw:
|
|
return {raw}
|
|
return set()
|
|
|
|
|
|
def _fetch_request_flags_and_email(conn, request_code: str) -> tuple[set[str], str]:
|
|
"""Return approval flags and contact email used by onboarding decisions."""
|
|
row = conn.execute(
|
|
"SELECT approval_flags, contact_email FROM access_requests WHERE request_code = %s",
|
|
(request_code,),
|
|
).fetchone()
|
|
if not row:
|
|
return set(), ""
|
|
flags = _normalize_flag_list(row.get("approval_flags"))
|
|
email = row.get("contact_email") if isinstance(row, dict) else ""
|
|
return flags, (email or "").strip()
|
|
|
|
|
|
def _user_in_group(username: str, group_name: str) -> bool:
|
|
"""Return whether a Keycloak user belongs to a named group."""
|
|
if not username or not group_name:
|
|
return False
|
|
if not admin_client().ready():
|
|
return False
|
|
try:
|
|
user = admin_client().find_user(username) or {}
|
|
user_id = user.get("id") if isinstance(user, dict) else None
|
|
if not isinstance(user_id, str) or not user_id:
|
|
return False
|
|
groups = admin_client().list_user_groups(user_id)
|
|
except Exception:
|
|
return False
|
|
return group_name in groups
|
|
|
|
|
|
def _vaultwarden_grandfathered(conn, request_code: str, username: str) -> tuple[bool, str]:
|
|
flags, contact_email = _fetch_request_flags_and_email(conn, request_code)
|
|
if VAULTWARDEN_GRANDFATHERED_FLAG in flags:
|
|
return True, contact_email
|
|
if _user_in_group(username, VAULTWARDEN_GRANDFATHERED_FLAG):
|
|
return True, contact_email
|
|
return False, contact_email
|
|
|
|
|
|
def _resolve_recovery_email(username: str, fallback: str) -> str:
|
|
"""Find the best recovery email for Vaultwarden onboarding."""
|
|
if username and admin_client().ready():
|
|
try:
|
|
user = admin_client().find_user(username) or {}
|
|
user_id = user.get("id") if isinstance(user, dict) else None
|
|
if isinstance(user_id, str) and user_id:
|
|
full = admin_client().get_user(user_id)
|
|
email = full.get("email")
|
|
if isinstance(email, str) and email.strip():
|
|
return email.strip()
|
|
except Exception:
|
|
pass
|
|
return (fallback or "").strip()
|
|
|
|
|
|
def _password_rotation_requested(conn, request_code: str) -> bool:
|
|
"""Return whether Keycloak password rotation was requested for this request."""
|
|
row = conn.execute(
|
|
"""
|
|
SELECT 1
|
|
FROM access_request_onboarding_artifacts
|
|
WHERE request_code = %s AND artifact = %s
|
|
LIMIT 1
|
|
""",
|
|
(request_code, _KEYCLOAK_PASSWORD_ROTATION_REQUESTED_ARTIFACT),
|
|
).fetchone()
|
|
return bool(row)
|
|
|
|
|
|
def _request_keycloak_password_rotation(conn, request_code: str, username: str) -> None:
|
|
"""Require Keycloak password rotation and persist the request marker."""
|
|
if not username:
|
|
raise ValueError("username missing")
|
|
if not admin_client().ready():
|
|
raise RuntimeError("keycloak admin unavailable")
|
|
|
|
user = admin_client().find_user(username) or {}
|
|
user_id = user.get("id") if isinstance(user, dict) else None
|
|
if not isinstance(user_id, str) or not user_id:
|
|
raise RuntimeError("keycloak user not found")
|
|
|
|
full = admin_client().get_user(user_id)
|
|
actions = full.get("requiredActions")
|
|
actions_list: list[str] = []
|
|
if isinstance(actions, list):
|
|
actions_list = [a for a in actions if isinstance(a, str)]
|
|
if "UPDATE_PASSWORD" not in actions_list:
|
|
actions_list.append("UPDATE_PASSWORD")
|
|
admin_client().update_user_safe(user_id, {"requiredActions": actions_list})
|
|
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO access_request_onboarding_artifacts (request_code, artifact, value_hash)
|
|
VALUES (%s, %s, NOW()::text)
|
|
ON CONFLICT (request_code, artifact) DO NOTHING
|
|
""",
|
|
(request_code, _KEYCLOAK_PASSWORD_ROTATION_REQUESTED_ARTIFACT),
|
|
)
|
|
|
|
|
|
def _extract_attr(attrs: Any, key: str) -> str:
|
|
"""Return the first string value for a Keycloak attribute."""
|
|
if not isinstance(attrs, dict):
|
|
return ""
|
|
raw = attrs.get(key)
|
|
if isinstance(raw, list):
|
|
for item in raw:
|
|
if isinstance(item, str) and item.strip():
|
|
return item.strip()
|
|
return ""
|
|
if isinstance(raw, str) and raw.strip():
|
|
return raw.strip()
|
|
return ""
|
|
|
|
|
|
def _vaultwarden_status_for_user(username: str) -> str:
|
|
"""Read the Vaultwarden lifecycle status mirrored on a Keycloak user."""
|
|
if not username:
|
|
return ""
|
|
if not admin_client().ready():
|
|
return ""
|
|
try:
|
|
user = admin_client().find_user(username) or {}
|
|
user_id = user.get("id") if isinstance(user, dict) else None
|
|
if not isinstance(user_id, str) or not user_id:
|
|
return ""
|
|
full = admin_client().get_user(user_id)
|
|
attrs = full.get("attributes") if isinstance(full, dict) else {}
|
|
return _extract_attr(attrs, "vaultwarden_status")
|
|
except Exception:
|
|
return ""
|
|
|
|
|
|
def _auto_completed_service_steps(attrs: Any) -> set[str]:
|
|
"""Infer onboarding steps completed by backend service automation."""
|
|
completed: set[str] = set()
|
|
if not isinstance(attrs, dict):
|
|
return completed
|
|
|
|
vaultwarden_status = _extract_attr(attrs, "vaultwarden_status")
|
|
vaultwarden_master = _extract_attr(attrs, "vaultwarden_master_password_set_at")
|
|
if vaultwarden_master or vaultwarden_status in _VAULTWARDEN_READY_STATUSES:
|
|
completed.add("vaultwarden_master_password")
|
|
|
|
nextcloud_synced_at = _extract_attr(attrs, "nextcloud_mail_synced_at")
|
|
if nextcloud_synced_at:
|
|
completed.add("nextcloud_mail_integration")
|
|
|
|
firefly_rotated_at = _extract_attr(attrs, "firefly_password_rotated_at")
|
|
if firefly_rotated_at:
|
|
completed.add("firefly_password_rotated")
|
|
|
|
wger_rotated_at = _extract_attr(attrs, "wger_password_rotated_at")
|
|
if wger_rotated_at:
|
|
completed.add("wger_password_rotated")
|
|
|
|
return completed
|
|
|
|
|
|
def _auto_completed_keycloak_steps(conn, request_code: str, username: str) -> set[str]:
|
|
"""Infer onboarding steps from Keycloak profile state."""
|
|
if not username:
|
|
return set()
|
|
if not admin_client().ready():
|
|
return set()
|
|
if not request_code:
|
|
return set()
|
|
|
|
completed: set[str] = set()
|
|
try:
|
|
user = admin_client().find_user(username) or {}
|
|
user_id = user.get("id") if isinstance(user, dict) else None
|
|
if not isinstance(user_id, str) or not user_id:
|
|
return set()
|
|
|
|
full = {}
|
|
try:
|
|
full = admin_client().get_user(user_id)
|
|
except Exception:
|
|
full = user if isinstance(user, dict) else {}
|
|
|
|
attrs = full.get("attributes") if isinstance(full, dict) else {}
|
|
completed |= _auto_completed_service_steps(attrs)
|
|
|
|
actions = full.get("requiredActions")
|
|
required_actions: set[str] = set()
|
|
actions_list: list[str] = []
|
|
if isinstance(actions, list):
|
|
actions_list = [a for a in actions if isinstance(a, str)]
|
|
required_actions = set(actions_list)
|
|
|
|
if _password_rotation_requested(conn, request_code) and "UPDATE_PASSWORD" not in required_actions:
|
|
completed.add("keycloak_password_rotated")
|
|
|
|
# Backfill: earlier accounts were created with CONFIGURE_TOTP as a required action,
|
|
# which forces users to enroll MFA at first login. We no longer require that, so
|
|
# remove it if present.
|
|
if "CONFIGURE_TOTP" in required_actions:
|
|
try:
|
|
admin_client().update_user_safe(
|
|
user_id,
|
|
{"requiredActions": [a for a in actions_list if a != "CONFIGURE_TOTP"]},
|
|
)
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
return set()
|
|
|
|
return completed
|
|
|
|
|
|
def _completed_onboarding_steps(conn, request_code: str, username: str) -> set[str]:
|
|
completed = _fetch_completed_onboarding_steps(conn, request_code)
|
|
return completed | _auto_completed_keycloak_steps(conn, request_code, username)
|
|
|
|
|
|
def _automation_ready(conn, request_code: str, username: str) -> bool:
|
|
"""Return whether account provisioning has finished enough for onboarding."""
|
|
if not username:
|
|
return False
|
|
if not admin_client().ready():
|
|
return False
|
|
|
|
# Prefer task-based readiness when we have task rows for the request.
|
|
task_row = conn.execute(
|
|
"SELECT 1 FROM access_request_tasks WHERE request_code = %s LIMIT 1",
|
|
(request_code,),
|
|
).fetchone()
|
|
if task_row:
|
|
return provision_tasks_complete(conn, request_code)
|
|
|
|
# Fallback for legacy requests: confirm user exists and has a mail app password.
|
|
try:
|
|
user = admin_client().find_user(username)
|
|
if not user:
|
|
return False
|
|
user_id = user.get("id") if isinstance(user, dict) else None
|
|
if not user_id:
|
|
return False
|
|
full = admin_client().get_user(str(user_id))
|
|
attrs = full.get("attributes") or {}
|
|
if not isinstance(attrs, dict):
|
|
return False
|
|
raw_pw = attrs.get("mailu_app_password")
|
|
if isinstance(raw_pw, list):
|
|
return bool(raw_pw and isinstance(raw_pw[0], str) and raw_pw[0])
|
|
return bool(isinstance(raw_pw, str) and raw_pw)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _advance_status(conn, request_code: str, username: str, status: str) -> str:
|
|
"""Advance an access request through automatic status transitions."""
|
|
status = _normalize_status(status)
|
|
|
|
if status == "accounts_building" and _automation_ready(conn, request_code, username):
|
|
conn.execute(
|
|
"UPDATE access_requests SET status = 'awaiting_onboarding' WHERE request_code = %s AND status = 'accounts_building'",
|
|
(request_code,),
|
|
)
|
|
return "awaiting_onboarding"
|
|
|
|
if status == "awaiting_onboarding":
|
|
completed = _completed_onboarding_steps(conn, request_code, username)
|
|
required_steps = set(ONBOARDING_REQUIRED_STEPS)
|
|
grandfathered, _ = _vaultwarden_grandfathered(conn, request_code, username)
|
|
vaultwarden_status = _vaultwarden_status_for_user(username)
|
|
if grandfathered and vaultwarden_status == "grandfathered":
|
|
required_steps.add("vaultwarden_store_temp_password")
|
|
if required_steps.issubset(completed):
|
|
conn.execute(
|
|
"UPDATE access_requests SET status = 'ready' WHERE request_code = %s AND status = 'awaiting_onboarding'",
|
|
(request_code,),
|
|
)
|
|
return "ready"
|
|
|
|
return status
|
|
|
|
|
|
def _onboarding_payload(conn, request_code: str, username: str) -> dict[str, Any]:
|
|
"""Build the onboarding progress payload returned to the frontend."""
|
|
completed_steps = sorted(_completed_onboarding_steps(conn, request_code, username))
|
|
password_rotation_requested = _password_rotation_requested(conn, request_code)
|
|
grandfathered, contact_email = _vaultwarden_grandfathered(conn, request_code, username)
|
|
recovery_email = _resolve_recovery_email(username, contact_email) if grandfathered else ""
|
|
vaultwarden_status = _vaultwarden_status_for_user(username)
|
|
vaultwarden_matched = grandfathered and vaultwarden_status == "grandfathered"
|
|
required_steps = list(ONBOARDING_REQUIRED_STEPS)
|
|
if vaultwarden_matched:
|
|
required_steps.append("vaultwarden_store_temp_password")
|
|
return {
|
|
"required_steps": required_steps,
|
|
"optional_steps": sorted(ONBOARDING_OPTIONAL_STEPS),
|
|
"completed_steps": completed_steps,
|
|
"keycloak": {
|
|
"password_rotation_requested": password_rotation_requested,
|
|
},
|
|
"vaultwarden": {
|
|
"grandfathered": grandfathered,
|
|
"recovery_email": recovery_email,
|
|
"matched": vaultwarden_matched,
|
|
},
|
|
}
|
|
|
|
|
|
# Keep the historical access_requests module patch surface intact for tests and
|
|
# callers while the route handlers live in smaller focused modules.
|
|
__all__ = [name for name in globals() if not name.startswith("__")]
|