from __future__ import annotations import json from contextlib import contextmanager from unittest import TestCase, mock from atlas_portal.app_factory import create_app from atlas_portal.routes import access_requests as ar class DummyResult: def __init__(self, row=None): self._row = row def fetchone(self): return self._row def fetchall(self): return [] class DummyConn: def __init__(self, rows_by_query=None): self._rows_by_query = rows_by_query or {} self.executed = [] def execute(self, query, params=None): self.executed.append((query, params)) for key, row in self._rows_by_query.items(): if key in query: return DummyResult(row) return DummyResult() class DummyAdmin: def ready(self): return False def find_user(self, username): return None def find_user_by_email(self, email): return None @contextmanager def dummy_connect(rows_by_query=None): yield DummyConn(rows_by_query=rows_by_query) class AccessRequestTests(TestCase): @classmethod def setUpClass(cls): cls.schema_patch = mock.patch("atlas_portal.app_factory.ensure_schema", lambda: None) cls.schema_patch.start() cls.app = create_app() cls.client = cls.app.test_client() @classmethod def tearDownClass(cls): cls.schema_patch.stop() def setUp(self): self.configured_patch = mock.patch.object(ar, "configured", lambda: True) self.rate_patch = mock.patch.object(ar, "rate_limit_allow", lambda *args, **kwargs: True) self.admin_patch = mock.patch.object(ar, "admin_client", lambda: DummyAdmin()) self.configured_patch.start() self.rate_patch.start() self.admin_patch.start() def tearDown(self): self.configured_patch.stop() self.rate_patch.stop() self.admin_patch.stop() def test_request_access_requires_last_name(self): with mock.patch.object(ar, "connect", lambda: dummy_connect()): resp = self.client.post( "/api/access/request", data=json.dumps( { "username": "alice", "email": "alice@example.com", "first_name": "Alice", "last_name": "", "note": "", } ), content_type="application/json", ) data = resp.get_json() self.assertEqual(resp.status_code, 400) self.assertIn("last name is required", data.get("error", "")) def test_request_access_sends_verification_email(self): sent = {} def fake_send_email(*, request_code, email, token): sent["request_code"] = request_code sent["email"] = email with ( mock.patch.object(ar, "_random_request_code", lambda username: f"{username}~CODE123"), mock.patch.object(ar, "_send_verification_email", fake_send_email), mock.patch.object(ar, "connect", lambda: dummy_connect()), ): resp = self.client.post( "/api/access/request", data=json.dumps( { "username": "alice", "email": "alice@example.com", "first_name": "Alice", "last_name": "Atlas", "note": "", } ), content_type="application/json", ) data = resp.get_json() self.assertEqual(resp.status_code, 200) self.assertEqual(data.get("request_code"), "alice~CODE123") self.assertEqual(data.get("status"), "pending_email_verification") self.assertEqual(sent.get("request_code"), "alice~CODE123") self.assertEqual(sent.get("email"), "alice@example.com") def test_request_access_email_failure_returns_request_code(self): def fake_send_email(*, request_code, email, token): raise ar.MailerError("failed") with ( mock.patch.object(ar, "_random_request_code", lambda username: f"{username}~CODE123"), mock.patch.object(ar, "_send_verification_email", fake_send_email), mock.patch.object(ar, "connect", lambda: dummy_connect()), ): resp = self.client.post( "/api/access/request", data=json.dumps( { "username": "alice", "email": "alice@example.com", "first_name": "Alice", "last_name": "Atlas", "note": "", } ), content_type="application/json", ) data = resp.get_json() self.assertEqual(resp.status_code, 502) self.assertEqual(data.get("request_code"), "alice~CODE123") self.assertIn("failed to send verification email", data.get("error", "")) def test_request_access_resend_sends_email(self): sent = {} def fake_send_email(*, request_code, email, token): sent["request_code"] = request_code sent["email"] = email rows = { "SELECT status, contact_email": { "status": "pending_email_verification", "contact_email": "alice@example.com", } } with ( mock.patch.object(ar, "_send_verification_email", fake_send_email), mock.patch.object(ar, "connect", lambda: dummy_connect(rows)), ): resp = self.client.post( "/api/access/request/resend", data=json.dumps({"request_code": "alice~CODE123"}), content_type="application/json", ) data = resp.get_json() self.assertEqual(resp.status_code, 200) self.assertEqual(data.get("status"), "pending_email_verification") self.assertEqual(sent.get("request_code"), "alice~CODE123") self.assertEqual(sent.get("email"), "alice@example.com") def test_verify_request_updates_status(self): token = "tok-123" rows = { "SELECT status, email_verification_token_hash": { "status": "pending_email_verification", "email_verification_token_hash": ar._hash_verification_token(token), "email_verification_sent_at": ar.datetime.now(ar.timezone.utc), } } with dummy_connect(rows) as conn: status = ar._verify_request(conn, "alice~CODE123", token) self.assertEqual(status, "pending") def test_verify_link_redirects(self): token = "tok-123" rows = { "SELECT status, email_verification_token_hash": { "status": "pending_email_verification", "email_verification_token_hash": ar._hash_verification_token(token), "email_verification_sent_at": ar.datetime.now(ar.timezone.utc), } } with mock.patch.object(ar, "connect", lambda: dummy_connect(rows)): resp = self.client.get(f"/api/access/request/verify-link?code=alice~CODE123&token={token}") self.assertEqual(resp.status_code, 302) self.assertIn("verified=1", resp.headers.get("Location", "")) def test_status_includes_email_verified(self): rows = { "SELECT status": { "status": "pending", "username": "alice", "initial_password": None, "initial_password_revealed_at": None, "email_verified_at": ar.datetime.now(ar.timezone.utc), } } with mock.patch.object(ar, "connect", lambda: dummy_connect(rows)): resp = self.client.post( "/api/access/request/status", data=json.dumps({"request_code": "alice~CODE123"}), content_type="application/json", ) data = resp.get_json() self.assertEqual(resp.status_code, 200) self.assertTrue(data.get("email_verified")) def test_status_hides_initial_password_without_reveal_flag(self): rows = { "SELECT status": { "status": "awaiting_onboarding", "username": "alice", "initial_password": "temp-pass", "initial_password_revealed_at": None, "email_verified_at": None, } } with ( mock.patch.object(ar, "connect", lambda: dummy_connect(rows)), mock.patch.object(ar, "_advance_status", lambda *args, **kwargs: "awaiting_onboarding"), ): resp = self.client.post( "/api/access/request/status", data=json.dumps({"request_code": "alice~CODE123"}), content_type="application/json", ) data = resp.get_json() self.assertEqual(resp.status_code, 200) self.assertIsNone(data.get("initial_password")) def test_status_reveals_initial_password_with_flag(self): rows = { "SELECT status": { "status": "awaiting_onboarding", "username": "alice", "initial_password": "temp-pass", "initial_password_revealed_at": None, "email_verified_at": None, } } with ( mock.patch.object(ar, "connect", lambda: dummy_connect(rows)), mock.patch.object(ar, "_advance_status", lambda *args, **kwargs: "awaiting_onboarding"), ): resp = self.client.post( "/api/access/request/status", data=json.dumps({"request_code": "alice~CODE123", "reveal_initial_password": True}), content_type="application/json", ) data = resp.get_json() self.assertEqual(resp.status_code, 200) self.assertEqual(data.get("initial_password"), "temp-pass") def test_onboarding_payload_includes_vaultwarden_grandfathered(self): rows = { "SELECT approval_flags": { "approval_flags": ["vaultwarden_grandfathered"], "contact_email": "alice@example.com", } } conn = DummyConn(rows_by_query=rows) with ( mock.patch.object(ar, "_completed_onboarding_steps", lambda *args, **kwargs: set()), mock.patch.object(ar, "_password_rotation_requested", lambda *args, **kwargs: False), ): payload = ar._onboarding_payload(conn, "alice~CODE123", "alice") vault = payload.get("vaultwarden") or {} self.assertTrue(vault.get("grandfathered")) self.assertEqual(vault.get("recovery_email"), "alice@example.com")