diff --git a/backend/tests/test_access_request_submission.py b/backend/tests/test_access_request_submission.py new file mode 100644 index 0000000..71edc1b --- /dev/null +++ b/backend/tests/test_access_request_submission.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +from contextlib import contextmanager +from types import SimpleNamespace +from typing import Any + +from flask import Flask, request +import psycopg + +from atlas_portal.routes.access_request_submission import register_access_request_submission + + +class DummyResult: + def __init__(self, row: dict[str, Any] | None = None) -> None: + self.row = row + + def fetchone(self) -> dict[str, Any] | None: + return self.row + + +class DummyConn: + def __init__(self, rows_by_query: dict[str, dict[str, Any] | None] | None = None) -> None: + self.rows_by_query = rows_by_query or {} + self.executed: list[tuple[str, object | None]] = [] + self.raise_unique_on_insert = False + self.raise_on_any = False + self.rolled_back = False + + def execute(self, query: str, params: object | None = None) -> DummyResult: + self.executed.append((query, params)) + if self.raise_on_any: + raise RuntimeError("database failed") + if self.raise_unique_on_insert and "INSERT INTO access_requests" in query: + raise psycopg.errors.UniqueViolation("duplicate") + for key, row in self.rows_by_query.items(): + if key in query: + return DummyResult(row) + return DummyResult() + + def rollback(self) -> None: + self.rolled_back = True + + +class UniqueRaceConn(DummyConn): + def __init__(self, row_after_rollback: dict[str, Any] | None) -> None: + super().__init__() + self.row_after_rollback = row_after_rollback + + def execute(self, query: str, params: object | None = None) -> DummyResult: + self.executed.append((query, params)) + if "INSERT INTO access_requests" in query: + raise psycopg.errors.UniqueViolation("duplicate") + if "SELECT request_code, status" in query: + return DummyResult(self.row_after_rollback if self.rolled_back else None) + return DummyResult() + + +class DummyAdmin: + def __init__( + self, + *, + ready: bool = False, + user: dict[str, Any] | None = None, + email_user: dict[str, Any] | None = None, + ) -> None: + self._ready = ready + self.user = user + self.email_user = email_user + + def ready(self) -> bool: + return self._ready + + def find_user(self, username: str) -> dict[str, Any] | None: + return self.user + + def find_user_by_email(self, email: str) -> dict[str, Any] | None: + return self.email_user + + +class MailerError(Exception): + pass + + +class VerificationError(Exception): + def __init__(self, status_code: int, message: str) -> None: + super().__init__(message) + self.status_code = status_code + self.message = message + + +class DummyDeps: + EMAIL_VERIFY_PENDING_STATUS = "pending_email_verification" + MailerError = MailerError + VerificationError = VerificationError + + def __init__(self, conn: DummyConn | None = None) -> None: + self.settings = SimpleNamespace( + ACCESS_REQUEST_ENABLED=True, + ACCESS_REQUEST_SUBMIT_RATE_LIMIT=5, + ACCESS_REQUEST_SUBMIT_RATE_WINDOW_SEC=60, + MAILU_DOMAIN="bstein.dev", + ACCESS_REQUEST_INTERNAL_EMAIL_ALLOWLIST={"allowed@bstein.dev"}, + ) + self.conn = conn or DummyConn() + self.configured_value = True + self.admin = DummyAdmin() + self.rate_limit_results: list[bool] = [] + self.sent: list[tuple[str, str, str]] = [] + self.fail_connect = False + self.fail_send = False + self.verify_status = "pending" + self.verify_error: VerificationError | None = None + self.verify_runtime_error = False + + def configured(self) -> bool: + return self.configured_value + + def admin_client(self) -> DummyAdmin: + return self.admin + + def _client_ip(self) -> str: + return "203.0.113.30" + + def rate_limit_allow(self, *args, **kwargs) -> bool: + if self.rate_limit_results: + return self.rate_limit_results.pop(0) + return True + + @contextmanager + def connect(self): + if self.fail_connect: + raise RuntimeError("database offline") + yield self.conn + + def _extract_request_payload(self) -> tuple[str, str, str, str, str]: + payload = request.get_json(silent=True) or {} + return ( + (payload.get("username") or "").strip(), + (payload.get("email") or "").strip(), + (payload.get("note") or "").strip(), + (payload.get("first_name") or "").strip(), + (payload.get("last_name") or "").strip(), + ) + + def _normalize_name(self, value: str) -> str: + return " ".join(value.strip().split()) + + def _validate_username(self, username: str) -> str | None: + return None if username and username != "bad" else "username is required" + + def _validate_name(self, value: str, *, label: str, required: bool) -> str | None: + if value == "bad": + return f"{label} is invalid" + if required and not value: + return f"{label} is required" + return None + + def _normalize_status(self, status: str) -> str: + return "accounts_building" if status == "approved" else (status or "unknown") + + def _random_request_code(self, username: str) -> str: + return f"{username}~CODE" + + def _hash_verification_token(self, token: str) -> str: + return f"hash:{token}" + + def _send_verification_email(self, *, request_code: str, email: str, token: str) -> None: + if self.fail_send: + raise self.MailerError("send failed") + self.sent.append((request_code, email, token)) + + def _verify_request(self, conn: DummyConn, code: str, token: str) -> str: + if self.verify_runtime_error: + raise RuntimeError("verify failed") + if self.verify_error: + raise self.verify_error + return self.verify_status + + +def make_client(deps: DummyDeps): + app = Flask(__name__) + register_access_request_submission(app, deps) + return app.test_client() + + +def request_payload(**overrides: str) -> dict[str, str]: + payload = { + "username": "alice", + "email": "alice@example.dev", + "first_name": "Alice", + "last_name": "Atlas", + "note": "please", + } + payload.update(overrides) + return payload + + +def test_availability_preflight_existing_and_available_paths() -> None: + deps = DummyDeps() + client = make_client(deps) + + deps.settings.ACCESS_REQUEST_ENABLED = False + assert client.get("/api/access/request/availability?username=alice").status_code == 503 + deps.settings.ACCESS_REQUEST_ENABLED = True + + deps.configured_value = False + assert client.get("/api/access/request/availability?username=alice").status_code == 503 + deps.configured_value = True + + assert client.get("/api/access/request/availability?username=bad").get_json()["reason"] == "invalid" + + deps.admin = DummyAdmin(ready=True, user={"id": "user-1"}) + assert client.get("/api/access/request/availability?username=alice").get_json()["reason"] == "exists" + deps.admin = DummyAdmin() + + deps.conn = DummyConn({"SELECT status": {"status": "approved"}}) + data = client.get("/api/access/request/availability?username=alice").get_json() + assert data == {"available": False, "reason": "requested", "status": "accounts_building"} + + deps.conn = DummyConn() + assert client.get("/api/access/request/availability?username=alice").get_json() == {"available": True} + + deps.fail_connect = True + assert client.get("/api/access/request/availability?username=alice").status_code == 502 + + +def test_submit_preflight_validation_and_admin_conflicts() -> None: + deps = DummyDeps() + client = make_client(deps) + + deps.settings.ACCESS_REQUEST_ENABLED = False + assert client.post("/api/access/request", json=request_payload()).status_code == 503 + deps.settings.ACCESS_REQUEST_ENABLED = True + + deps.configured_value = False + assert client.post("/api/access/request", json=request_payload()).status_code == 503 + deps.configured_value = True + + deps.rate_limit_results = [False] + assert client.post("/api/access/request", json=request_payload()).status_code == 429 + assert client.post("/api/access/request", json=request_payload(username="bad")).status_code == 400 + assert client.post("/api/access/request", json=request_payload(first_name="bad")).status_code == 400 + assert client.post("/api/access/request", json=request_payload(last_name="")).status_code == 400 + assert client.post("/api/access/request", json=request_payload(email="")).status_code == 400 + assert client.post("/api/access/request", json=request_payload(email="not-email")).status_code == 400 + assert client.post("/api/access/request", json=request_payload(email="new@bstein.dev")).status_code == 400 + + deps.admin = DummyAdmin(ready=True, user={"id": "user-1"}) + assert client.post("/api/access/request", json=request_payload()).status_code == 409 + deps.admin = DummyAdmin(ready=True, email_user={"id": "user-2"}) + assert client.post("/api/access/request", json=request_payload()).status_code == 409 + + +def test_submit_existing_pending_new_unique_and_failure_paths() -> None: + pending = DummyConn({"SELECT request_code, status": {"request_code": "alice~OLD", "status": "pending"}}) + deps = DummyDeps(pending) + client = make_client(deps) + assert client.post("/api/access/request", json=request_payload()).get_json() == { + "ok": True, + "request_code": "alice~OLD", + "status": "pending", + } + + existing = DummyConn( + {"SELECT request_code, status": {"request_code": "alice~VERIFY", "status": deps.EMAIL_VERIFY_PENDING_STATUS}} + ) + deps.conn = existing + data = client.post("/api/access/request", json=request_payload()).get_json() + assert data["request_code"] == "alice~VERIFY" + assert deps.sent[-1][0] == "alice~VERIFY" + assert any("UPDATE access_requests" in query for query, _ in existing.executed) + + deps.fail_send = True + assert client.post("/api/access/request", json=request_payload()).status_code == 502 + deps.fail_send = False + + new_conn = DummyConn() + deps.conn = new_conn + data = client.post("/api/access/request", json=request_payload(username="brad")).get_json() + assert data["request_code"] == "brad~CODE" + assert any("INSERT INTO access_requests" in query for query, _ in new_conn.executed) + + deps.fail_send = True + assert client.post("/api/access/request", json=request_payload(username="casey")).status_code == 502 + deps.fail_send = False + + unique = UniqueRaceConn({"request_code": "alice~RACE", "status": "pending"}) + deps.conn = unique + assert client.post("/api/access/request", json=request_payload(username="dana")).get_json()["request_code"] == "alice~RACE" + assert unique.rolled_back is True + + unique_missing = UniqueRaceConn(None) + deps.conn = unique_missing + assert client.post("/api/access/request", json=request_payload(username="erin")).status_code == 502 + + deps.fail_connect = True + assert client.post("/api/access/request", json=request_payload(username="fran")).status_code == 502 + + +def test_verify_and_verify_link_paths() -> None: + deps = DummyDeps() + client = make_client(deps) + + deps.settings.ACCESS_REQUEST_ENABLED = False + assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 503 + deps.settings.ACCESS_REQUEST_ENABLED = True + + deps.configured_value = False + assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 503 + deps.configured_value = True + + deps.rate_limit_results = [False] + assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 429 + assert client.post("/api/access/request/verify", json={"token": "tok"}).status_code == 400 + assert client.post("/api/access/request/verify", json={"request_code": "code"}).status_code == 400 + deps.rate_limit_results = [True, False] + assert client.post("/api/access/request/verify", json={"request_code": "code", "verify": "tok"}).status_code == 429 + + assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).get_json() == { + "ok": True, + "status": "pending", + } + deps.verify_error = VerificationError(410, "expired") + assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).status_code == 410 + deps.verify_error = None + deps.verify_runtime_error = True + assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).status_code == 502 + deps.verify_runtime_error = False + + assert client.get("/api/access/request/verify-link").headers["Location"].endswith("verify_error=missing+token") + assert "verified=1" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"] + deps.verify_error = VerificationError(401, "bad token") + assert "bad%20token" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"] + deps.verify_error = None + deps.verify_runtime_error = True + assert "failed+to+verify" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"] + deps.verify_runtime_error = False + + deps.settings.ACCESS_REQUEST_ENABLED = False + assert client.get("/api/access/request/verify-link?code=code&token=tok").status_code == 503 + deps.settings.ACCESS_REQUEST_ENABLED = True + + deps.configured_value = False + assert client.get("/api/access/request/verify-link?code=code&token=tok").status_code == 503 + + +def test_resend_preflight_success_and_failure_paths() -> None: + deps = DummyDeps() + client = make_client(deps) + + deps.settings.ACCESS_REQUEST_ENABLED = False + assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 503 + deps.settings.ACCESS_REQUEST_ENABLED = True + + deps.configured_value = False + assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 503 + deps.configured_value = True + + deps.rate_limit_results = [False] + assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 429 + assert client.post("/api/access/request/resend", json={}).status_code == 400 + deps.rate_limit_results = [True, False] + assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 429 + + deps.conn = DummyConn({"SELECT status, contact_email": None}) + assert client.post("/api/access/request/resend", json={"request_code": "missing"}).status_code == 404 + + deps.conn = DummyConn({"SELECT status, contact_email": {"status": "approved", "contact_email": "alice@example.dev"}}) + assert client.post("/api/access/request/resend", json={"request_code": "code"}).get_json()["status"] == "accounts_building" + + deps.conn = DummyConn({"SELECT status, contact_email": {"status": deps.EMAIL_VERIFY_PENDING_STATUS, "contact_email": ""}}) + assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 409 + + success_conn = DummyConn( + {"SELECT status, contact_email": {"status": deps.EMAIL_VERIFY_PENDING_STATUS, "contact_email": "alice@example.dev"}} + ) + deps.conn = success_conn + assert client.post("/api/access/request/resend", json={"code": "code"}).get_json()["ok"] is True + assert any("UPDATE access_requests" in query for query, _ in success_conn.executed) + + deps.fail_send = True + assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 502 + deps.fail_send = False + + deps.fail_connect = True + assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 502