182 lines
6.1 KiB
Python
182 lines
6.1 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
from contextlib import contextmanager
|
||
|
|
from unittest import TestCase, mock
|
||
|
|
|
||
|
|
from atlas_portal.app_factory import create_app
|
||
|
|
from atlas_portal.routes import access_requests as ar
|
||
|
|
|
||
|
|
|
||
|
|
class DummyResult:
|
||
|
|
def __init__(self, row=None):
|
||
|
|
self._row = row
|
||
|
|
|
||
|
|
def fetchone(self):
|
||
|
|
return self._row
|
||
|
|
|
||
|
|
def fetchall(self):
|
||
|
|
return []
|
||
|
|
|
||
|
|
|
||
|
|
class DummyConn:
|
||
|
|
def __init__(self, rows_by_query=None):
|
||
|
|
self._rows_by_query = rows_by_query or {}
|
||
|
|
self.executed = []
|
||
|
|
|
||
|
|
def execute(self, query, params=None):
|
||
|
|
self.executed.append((query, params))
|
||
|
|
for key, row in self._rows_by_query.items():
|
||
|
|
if key in query:
|
||
|
|
return DummyResult(row)
|
||
|
|
return DummyResult()
|
||
|
|
|
||
|
|
|
||
|
|
class DummyAdmin:
|
||
|
|
def ready(self):
|
||
|
|
return False
|
||
|
|
|
||
|
|
def find_user(self, username):
|
||
|
|
return None
|
||
|
|
|
||
|
|
def find_user_by_email(self, email):
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
@contextmanager
|
||
|
|
def dummy_connect(rows_by_query=None):
|
||
|
|
yield DummyConn(rows_by_query=rows_by_query)
|
||
|
|
|
||
|
|
|
||
|
|
class AccessRequestTests(TestCase):
|
||
|
|
@classmethod
|
||
|
|
def setUpClass(cls):
|
||
|
|
cls.schema_patch = mock.patch("atlas_portal.app_factory.ensure_schema", lambda: None)
|
||
|
|
cls.schema_patch.start()
|
||
|
|
cls.app = create_app()
|
||
|
|
cls.client = cls.app.test_client()
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def tearDownClass(cls):
|
||
|
|
cls.schema_patch.stop()
|
||
|
|
|
||
|
|
def setUp(self):
|
||
|
|
self.configured_patch = mock.patch.object(ar, "configured", lambda: True)
|
||
|
|
self.rate_patch = mock.patch.object(ar, "rate_limit_allow", lambda *args, **kwargs: True)
|
||
|
|
self.admin_patch = mock.patch.object(ar, "admin_client", lambda: DummyAdmin())
|
||
|
|
self.configured_patch.start()
|
||
|
|
self.rate_patch.start()
|
||
|
|
self.admin_patch.start()
|
||
|
|
|
||
|
|
def tearDown(self):
|
||
|
|
self.configured_patch.stop()
|
||
|
|
self.rate_patch.stop()
|
||
|
|
self.admin_patch.stop()
|
||
|
|
|
||
|
|
def test_request_access_requires_last_name(self):
|
||
|
|
with mock.patch.object(ar, "connect", lambda: dummy_connect()):
|
||
|
|
resp = self.client.post(
|
||
|
|
"/api/access/request",
|
||
|
|
data=json.dumps(
|
||
|
|
{
|
||
|
|
"username": "alice",
|
||
|
|
"email": "alice@example.com",
|
||
|
|
"first_name": "Alice",
|
||
|
|
"last_name": "",
|
||
|
|
"note": "",
|
||
|
|
}
|
||
|
|
),
|
||
|
|
content_type="application/json",
|
||
|
|
)
|
||
|
|
data = resp.get_json()
|
||
|
|
self.assertEqual(resp.status_code, 400)
|
||
|
|
self.assertIn("last name is required", data.get("error", ""))
|
||
|
|
|
||
|
|
def test_request_access_sends_verification_email(self):
|
||
|
|
sent = {}
|
||
|
|
|
||
|
|
def fake_send_email(*, request_code, email, token):
|
||
|
|
sent["request_code"] = request_code
|
||
|
|
sent["email"] = email
|
||
|
|
|
||
|
|
with (
|
||
|
|
mock.patch.object(ar, "_random_request_code", lambda username: f"{username}~CODE123"),
|
||
|
|
mock.patch.object(ar, "_send_verification_email", fake_send_email),
|
||
|
|
mock.patch.object(ar, "connect", lambda: dummy_connect()),
|
||
|
|
):
|
||
|
|
resp = self.client.post(
|
||
|
|
"/api/access/request",
|
||
|
|
data=json.dumps(
|
||
|
|
{
|
||
|
|
"username": "alice",
|
||
|
|
"email": "alice@example.com",
|
||
|
|
"first_name": "Alice",
|
||
|
|
"last_name": "Atlas",
|
||
|
|
"note": "",
|
||
|
|
}
|
||
|
|
),
|
||
|
|
content_type="application/json",
|
||
|
|
)
|
||
|
|
data = resp.get_json()
|
||
|
|
self.assertEqual(resp.status_code, 200)
|
||
|
|
self.assertEqual(data.get("request_code"), "alice~CODE123")
|
||
|
|
self.assertEqual(data.get("status"), "pending_email_verification")
|
||
|
|
self.assertEqual(sent.get("request_code"), "alice~CODE123")
|
||
|
|
self.assertEqual(sent.get("email"), "alice@example.com")
|
||
|
|
|
||
|
|
def test_request_access_email_failure_returns_request_code(self):
|
||
|
|
def fake_send_email(*, request_code, email, token):
|
||
|
|
raise ar.MailerError("failed")
|
||
|
|
|
||
|
|
with (
|
||
|
|
mock.patch.object(ar, "_random_request_code", lambda username: f"{username}~CODE123"),
|
||
|
|
mock.patch.object(ar, "_send_verification_email", fake_send_email),
|
||
|
|
mock.patch.object(ar, "connect", lambda: dummy_connect()),
|
||
|
|
):
|
||
|
|
resp = self.client.post(
|
||
|
|
"/api/access/request",
|
||
|
|
data=json.dumps(
|
||
|
|
{
|
||
|
|
"username": "alice",
|
||
|
|
"email": "alice@example.com",
|
||
|
|
"first_name": "Alice",
|
||
|
|
"last_name": "Atlas",
|
||
|
|
"note": "",
|
||
|
|
}
|
||
|
|
),
|
||
|
|
content_type="application/json",
|
||
|
|
)
|
||
|
|
data = resp.get_json()
|
||
|
|
self.assertEqual(resp.status_code, 502)
|
||
|
|
self.assertEqual(data.get("request_code"), "alice~CODE123")
|
||
|
|
self.assertIn("failed to send verification email", data.get("error", ""))
|
||
|
|
|
||
|
|
def test_request_access_resend_sends_email(self):
|
||
|
|
sent = {}
|
||
|
|
|
||
|
|
def fake_send_email(*, request_code, email, token):
|
||
|
|
sent["request_code"] = request_code
|
||
|
|
sent["email"] = email
|
||
|
|
|
||
|
|
rows = {
|
||
|
|
"SELECT status, contact_email": {
|
||
|
|
"status": "pending_email_verification",
|
||
|
|
"contact_email": "alice@example.com",
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
with (
|
||
|
|
mock.patch.object(ar, "_send_verification_email", fake_send_email),
|
||
|
|
mock.patch.object(ar, "connect", lambda: dummy_connect(rows)),
|
||
|
|
):
|
||
|
|
resp = self.client.post(
|
||
|
|
"/api/access/request/resend",
|
||
|
|
data=json.dumps({"request_code": "alice~CODE123"}),
|
||
|
|
content_type="application/json",
|
||
|
|
)
|
||
|
|
data = resp.get_json()
|
||
|
|
self.assertEqual(resp.status_code, 200)
|
||
|
|
self.assertEqual(data.get("status"), "pending_email_verification")
|
||
|
|
self.assertEqual(sent.get("request_code"), "alice~CODE123")
|
||
|
|
self.assertEqual(sent.get("email"), "alice@example.com")
|