327 lines
12 KiB
Python
327 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from contextlib import contextmanager
|
|
from unittest import TestCase, mock
|
|
|
|
from atlas_portal.app_factory import create_app
|
|
from atlas_portal.routes import access_requests as ar
|
|
|
|
|
|
class DummyResult:
|
|
def __init__(self, row=None):
|
|
self._row = row
|
|
|
|
def fetchone(self):
|
|
return self._row
|
|
|
|
def fetchall(self):
|
|
return []
|
|
|
|
|
|
class DummyConn:
|
|
def __init__(self, rows_by_query=None):
|
|
self._rows_by_query = rows_by_query or {}
|
|
self.executed = []
|
|
|
|
def execute(self, query, params=None):
|
|
self.executed.append((query, params))
|
|
for key, row in self._rows_by_query.items():
|
|
if key in query:
|
|
return DummyResult(row)
|
|
return DummyResult()
|
|
|
|
|
|
class DummyAdmin:
|
|
def ready(self):
|
|
return False
|
|
|
|
def find_user(self, username):
|
|
return None
|
|
|
|
def find_user_by_email(self, email):
|
|
return None
|
|
|
|
|
|
@contextmanager
|
|
def dummy_connect(rows_by_query=None):
|
|
yield DummyConn(rows_by_query=rows_by_query)
|
|
|
|
|
|
class AccessRequestTests(TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.app = create_app()
|
|
cls.client = cls.app.test_client()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
return None
|
|
|
|
def setUp(self):
|
|
self.configured_patch = mock.patch.object(ar, "configured", lambda: True)
|
|
self.rate_patch = mock.patch.object(ar, "rate_limit_allow", lambda *args, **kwargs: True)
|
|
self.admin_patch = mock.patch.object(ar, "admin_client", lambda: DummyAdmin())
|
|
self.configured_patch.start()
|
|
self.rate_patch.start()
|
|
self.admin_patch.start()
|
|
|
|
def tearDown(self):
|
|
self.configured_patch.stop()
|
|
self.rate_patch.stop()
|
|
self.admin_patch.stop()
|
|
|
|
def test_request_access_requires_last_name(self):
|
|
with mock.patch.object(ar, "connect", lambda: dummy_connect()):
|
|
resp = self.client.post(
|
|
"/api/access/request",
|
|
data=json.dumps(
|
|
{
|
|
"username": "alice",
|
|
"email": "alice@example.com",
|
|
"first_name": "Alice",
|
|
"last_name": "",
|
|
"note": "",
|
|
}
|
|
),
|
|
content_type="application/json",
|
|
)
|
|
data = resp.get_json()
|
|
self.assertEqual(resp.status_code, 400)
|
|
self.assertIn("last name is required", data.get("error", ""))
|
|
|
|
def test_request_access_sends_verification_email(self):
|
|
sent = {}
|
|
|
|
def fake_send_email(*, request_code, email, token):
|
|
sent["request_code"] = request_code
|
|
sent["email"] = email
|
|
|
|
with (
|
|
mock.patch.object(ar, "_random_request_code", lambda username: f"{username}~CODE123"),
|
|
mock.patch.object(ar, "_send_verification_email", fake_send_email),
|
|
mock.patch.object(ar, "connect", lambda: dummy_connect()),
|
|
):
|
|
resp = self.client.post(
|
|
"/api/access/request",
|
|
data=json.dumps(
|
|
{
|
|
"username": "alice",
|
|
"email": "alice@example.com",
|
|
"first_name": "Alice",
|
|
"last_name": "Atlas",
|
|
"note": "",
|
|
}
|
|
),
|
|
content_type="application/json",
|
|
)
|
|
data = resp.get_json()
|
|
self.assertEqual(resp.status_code, 200)
|
|
self.assertEqual(data.get("request_code"), "alice~CODE123")
|
|
self.assertEqual(data.get("status"), "pending_email_verification")
|
|
self.assertEqual(sent.get("request_code"), "alice~CODE123")
|
|
self.assertEqual(sent.get("email"), "alice@example.com")
|
|
|
|
def test_request_access_email_failure_returns_request_code(self):
|
|
def fake_send_email(*, request_code, email, token):
|
|
raise ar.MailerError("failed")
|
|
|
|
with (
|
|
mock.patch.object(ar, "_random_request_code", lambda username: f"{username}~CODE123"),
|
|
mock.patch.object(ar, "_send_verification_email", fake_send_email),
|
|
mock.patch.object(ar, "connect", lambda: dummy_connect()),
|
|
):
|
|
resp = self.client.post(
|
|
"/api/access/request",
|
|
data=json.dumps(
|
|
{
|
|
"username": "alice",
|
|
"email": "alice@example.com",
|
|
"first_name": "Alice",
|
|
"last_name": "Atlas",
|
|
"note": "",
|
|
}
|
|
),
|
|
content_type="application/json",
|
|
)
|
|
data = resp.get_json()
|
|
self.assertEqual(resp.status_code, 502)
|
|
self.assertEqual(data.get("request_code"), "alice~CODE123")
|
|
self.assertIn("failed to send verification email", data.get("error", ""))
|
|
|
|
def test_request_access_resend_sends_email(self):
|
|
sent = {}
|
|
|
|
def fake_send_email(*, request_code, email, token):
|
|
sent["request_code"] = request_code
|
|
sent["email"] = email
|
|
|
|
rows = {
|
|
"SELECT status, contact_email": {
|
|
"status": "pending_email_verification",
|
|
"contact_email": "alice@example.com",
|
|
}
|
|
}
|
|
|
|
with (
|
|
mock.patch.object(ar, "_send_verification_email", fake_send_email),
|
|
mock.patch.object(ar, "connect", lambda: dummy_connect(rows)),
|
|
):
|
|
resp = self.client.post(
|
|
"/api/access/request/resend",
|
|
data=json.dumps({"request_code": "alice~CODE123"}),
|
|
content_type="application/json",
|
|
)
|
|
data = resp.get_json()
|
|
self.assertEqual(resp.status_code, 200)
|
|
self.assertEqual(data.get("status"), "pending_email_verification")
|
|
self.assertEqual(sent.get("request_code"), "alice~CODE123")
|
|
self.assertEqual(sent.get("email"), "alice@example.com")
|
|
|
|
def test_verify_request_updates_status(self):
|
|
token = "tok-123"
|
|
rows = {
|
|
"SELECT status, email_verification_token_hash": {
|
|
"status": "pending_email_verification",
|
|
"email_verification_token_hash": ar._hash_verification_token(token),
|
|
"email_verification_sent_at": ar.datetime.now(ar.timezone.utc),
|
|
}
|
|
}
|
|
with dummy_connect(rows) as conn:
|
|
status = ar._verify_request(conn, "alice~CODE123", token)
|
|
self.assertEqual(status, "pending")
|
|
|
|
def test_verify_link_redirects(self):
|
|
token = "tok-123"
|
|
rows = {
|
|
"SELECT status, email_verification_token_hash": {
|
|
"status": "pending_email_verification",
|
|
"email_verification_token_hash": ar._hash_verification_token(token),
|
|
"email_verification_sent_at": ar.datetime.now(ar.timezone.utc),
|
|
}
|
|
}
|
|
with mock.patch.object(ar, "connect", lambda: dummy_connect(rows)):
|
|
resp = self.client.get(f"/api/access/request/verify-link?code=alice~CODE123&token={token}")
|
|
self.assertEqual(resp.status_code, 302)
|
|
self.assertIn("verified=1", resp.headers.get("Location", ""))
|
|
|
|
def test_status_includes_email_verified(self):
|
|
rows = {
|
|
"SELECT status": {
|
|
"status": "pending",
|
|
"username": "alice",
|
|
"initial_password": None,
|
|
"initial_password_revealed_at": None,
|
|
"email_verified_at": ar.datetime.now(ar.timezone.utc),
|
|
}
|
|
}
|
|
with mock.patch.object(ar, "connect", lambda: dummy_connect(rows)):
|
|
resp = self.client.post(
|
|
"/api/access/request/status",
|
|
data=json.dumps({"request_code": "alice~CODE123"}),
|
|
content_type="application/json",
|
|
)
|
|
data = resp.get_json()
|
|
self.assertEqual(resp.status_code, 200)
|
|
self.assertTrue(data.get("email_verified"))
|
|
|
|
def test_status_hides_initial_password_without_reveal_flag(self):
|
|
rows = {
|
|
"SELECT status": {
|
|
"status": "awaiting_onboarding",
|
|
"username": "alice",
|
|
"initial_password": "temp-pass",
|
|
"initial_password_revealed_at": None,
|
|
"email_verified_at": None,
|
|
}
|
|
}
|
|
with (
|
|
mock.patch.object(ar, "connect", lambda: dummy_connect(rows)),
|
|
mock.patch.object(ar, "_advance_status", lambda *args, **kwargs: "awaiting_onboarding"),
|
|
):
|
|
resp = self.client.post(
|
|
"/api/access/request/status",
|
|
data=json.dumps({"request_code": "alice~CODE123"}),
|
|
content_type="application/json",
|
|
)
|
|
data = resp.get_json()
|
|
self.assertEqual(resp.status_code, 200)
|
|
self.assertIsNone(data.get("initial_password"))
|
|
|
|
def test_status_reveals_initial_password_with_flag(self):
|
|
rows = {
|
|
"SELECT status": {
|
|
"status": "awaiting_onboarding",
|
|
"username": "alice",
|
|
"initial_password": "temp-pass",
|
|
"initial_password_revealed_at": None,
|
|
"email_verified_at": None,
|
|
}
|
|
}
|
|
with (
|
|
mock.patch.object(ar, "connect", lambda: dummy_connect(rows)),
|
|
mock.patch.object(ar, "_advance_status", lambda *args, **kwargs: "awaiting_onboarding"),
|
|
):
|
|
resp = self.client.post(
|
|
"/api/access/request/status",
|
|
data=json.dumps({"request_code": "alice~CODE123", "reveal_initial_password": True}),
|
|
content_type="application/json",
|
|
)
|
|
data = resp.get_json()
|
|
self.assertEqual(resp.status_code, 200)
|
|
self.assertEqual(data.get("initial_password"), "temp-pass")
|
|
|
|
def test_onboarding_payload_includes_vaultwarden_grandfathered(self):
|
|
rows = {
|
|
"SELECT approval_flags": {
|
|
"approval_flags": ["vaultwarden_grandfathered"],
|
|
"contact_email": "alice@example.com",
|
|
}
|
|
}
|
|
conn = DummyConn(rows_by_query=rows)
|
|
with (
|
|
mock.patch.object(ar, "_completed_onboarding_steps", lambda *args, **kwargs: set()),
|
|
mock.patch.object(ar, "_password_rotation_requested", lambda *args, **kwargs: False),
|
|
):
|
|
payload = ar._onboarding_payload(conn, "alice~CODE123", "alice")
|
|
vault = payload.get("vaultwarden") or {}
|
|
self.assertTrue(vault.get("grandfathered"))
|
|
self.assertEqual(vault.get("recovery_email"), "alice@example.com")
|
|
|
|
def test_retry_request_fallback_updates_tasks(self):
|
|
rows = {"SELECT status": {"status": "accounts_building"}}
|
|
conn = DummyConn(rows_by_query=rows)
|
|
|
|
@contextmanager
|
|
def connect_override():
|
|
yield conn
|
|
|
|
with (
|
|
mock.patch.object(ar.ariadne_client, "enabled", lambda: False),
|
|
mock.patch.object(ar, "connect", lambda: connect_override()),
|
|
mock.patch.object(ar, "provision_access_request", lambda *_args, **_kwargs: None),
|
|
):
|
|
resp = self.client.post(
|
|
"/api/access/request/retry",
|
|
data=json.dumps({"request_code": "alice~CODE123"}),
|
|
content_type="application/json",
|
|
)
|
|
data = resp.get_json()
|
|
self.assertEqual(resp.status_code, 200)
|
|
self.assertTrue(data.get("ok"))
|
|
self.assertTrue(any("provision_attempted_at" in query for query, _params in conn.executed))
|
|
|
|
def test_retry_request_rejects_non_retryable(self):
|
|
rows = {"SELECT status": {"status": "ready"}}
|
|
|
|
with (
|
|
mock.patch.object(ar.ariadne_client, "enabled", lambda: False),
|
|
mock.patch.object(ar, "connect", lambda: dummy_connect(rows)),
|
|
):
|
|
resp = self.client.post(
|
|
"/api/access/request/retry",
|
|
data=json.dumps({"request_code": "alice~CODE123"}),
|
|
content_type="application/json",
|
|
)
|
|
self.assertEqual(resp.status_code, 409)
|