from __future__ import annotations from contextlib import contextmanager from types import SimpleNamespace from typing import Any from flask import Flask, request import psycopg from atlas_portal.routes.access_request_submission import register_access_request_submission class DummyResult: def __init__(self, row: dict[str, Any] | None = None) -> None: self.row = row def fetchone(self) -> dict[str, Any] | None: return self.row class DummyConn: def __init__(self, rows_by_query: dict[str, dict[str, Any] | None] | None = None) -> None: self.rows_by_query = rows_by_query or {} self.executed: list[tuple[str, object | None]] = [] self.raise_unique_on_insert = False self.raise_on_any = False self.rolled_back = False def execute(self, query: str, params: object | None = None) -> DummyResult: self.executed.append((query, params)) if self.raise_on_any: raise RuntimeError("database failed") if self.raise_unique_on_insert and "INSERT INTO access_requests" in query: raise psycopg.errors.UniqueViolation("duplicate") for key, row in self.rows_by_query.items(): if key in query: return DummyResult(row) return DummyResult() def rollback(self) -> None: self.rolled_back = True class UniqueRaceConn(DummyConn): def __init__(self, row_after_rollback: dict[str, Any] | None) -> None: super().__init__() self.row_after_rollback = row_after_rollback def execute(self, query: str, params: object | None = None) -> DummyResult: self.executed.append((query, params)) if "INSERT INTO access_requests" in query: raise psycopg.errors.UniqueViolation("duplicate") if "SELECT request_code, status" in query: return DummyResult(self.row_after_rollback if self.rolled_back else None) return DummyResult() class DummyAdmin: def __init__( self, *, ready: bool = False, user: dict[str, Any] | None = None, email_user: dict[str, Any] | None = None, ) -> None: self._ready = ready self.user = user self.email_user = email_user def ready(self) -> bool: return self._ready def find_user(self, username: str) -> dict[str, Any] | None: return self.user def find_user_by_email(self, email: str) -> dict[str, Any] | None: return self.email_user class MailerError(Exception): pass class VerificationError(Exception): def __init__(self, status_code: int, message: str) -> None: super().__init__(message) self.status_code = status_code self.message = message class DummyDeps: EMAIL_VERIFY_PENDING_STATUS = "pending_email_verification" MailerError = MailerError VerificationError = VerificationError def __init__(self, conn: DummyConn | None = None) -> None: self.settings = SimpleNamespace( ACCESS_REQUEST_ENABLED=True, ACCESS_REQUEST_SUBMIT_RATE_LIMIT=5, ACCESS_REQUEST_SUBMIT_RATE_WINDOW_SEC=60, MAILU_DOMAIN="bstein.dev", ACCESS_REQUEST_INTERNAL_EMAIL_ALLOWLIST={"allowed@bstein.dev"}, ) self.conn = conn or DummyConn() self.configured_value = True self.admin = DummyAdmin() self.rate_limit_results: list[bool] = [] self.sent: list[tuple[str, str, str]] = [] self.fail_connect = False self.fail_send = False self.verify_status = "pending" self.verify_error: VerificationError | None = None self.verify_runtime_error = False def configured(self) -> bool: return self.configured_value def admin_client(self) -> DummyAdmin: return self.admin def _client_ip(self) -> str: return "203.0.113.30" def rate_limit_allow(self, *args, **kwargs) -> bool: if self.rate_limit_results: return self.rate_limit_results.pop(0) return True @contextmanager def connect(self): if self.fail_connect: raise RuntimeError("database offline") yield self.conn def _extract_request_payload(self) -> tuple[str, str, str, str, str]: payload = request.get_json(silent=True) or {} return ( (payload.get("username") or "").strip(), (payload.get("email") or "").strip(), (payload.get("note") or "").strip(), (payload.get("first_name") or "").strip(), (payload.get("last_name") or "").strip(), ) def _normalize_name(self, value: str) -> str: return " ".join(value.strip().split()) def _validate_username(self, username: str) -> str | None: return None if username and username != "bad" else "username is required" def _validate_name(self, value: str, *, label: str, required: bool) -> str | None: if value == "bad": return f"{label} is invalid" if required and not value: return f"{label} is required" return None def _normalize_status(self, status: str) -> str: return "accounts_building" if status == "approved" else (status or "unknown") def _random_request_code(self, username: str) -> str: return f"{username}~CODE" def _hash_verification_token(self, token: str) -> str: return f"hash:{token}" def _send_verification_email(self, *, request_code: str, email: str, token: str) -> None: if self.fail_send: raise self.MailerError("send failed") self.sent.append((request_code, email, token)) def _verify_request(self, conn: DummyConn, code: str, token: str) -> str: if self.verify_runtime_error: raise RuntimeError("verify failed") if self.verify_error: raise self.verify_error return self.verify_status def make_client(deps: DummyDeps): app = Flask(__name__) register_access_request_submission(app, deps) return app.test_client() def request_payload(**overrides: str) -> dict[str, str]: payload = { "username": "alice", "email": "alice@example.dev", "first_name": "Alice", "last_name": "Atlas", "note": "please", } payload.update(overrides) return payload def test_availability_preflight_existing_and_available_paths() -> None: deps = DummyDeps() client = make_client(deps) deps.settings.ACCESS_REQUEST_ENABLED = False assert client.get("/api/access/request/availability?username=alice").status_code == 503 deps.settings.ACCESS_REQUEST_ENABLED = True deps.configured_value = False assert client.get("/api/access/request/availability?username=alice").status_code == 503 deps.configured_value = True assert client.get("/api/access/request/availability?username=bad").get_json()["reason"] == "invalid" deps.admin = DummyAdmin(ready=True, user={"id": "user-1"}) assert client.get("/api/access/request/availability?username=alice").get_json()["reason"] == "exists" deps.admin = DummyAdmin() deps.conn = DummyConn({"SELECT status": {"status": "approved"}}) data = client.get("/api/access/request/availability?username=alice").get_json() assert data == {"available": False, "reason": "requested", "status": "accounts_building"} deps.conn = DummyConn() assert client.get("/api/access/request/availability?username=alice").get_json() == {"available": True} deps.fail_connect = True assert client.get("/api/access/request/availability?username=alice").status_code == 502 def test_submit_preflight_validation_and_admin_conflicts() -> None: deps = DummyDeps() client = make_client(deps) deps.settings.ACCESS_REQUEST_ENABLED = False assert client.post("/api/access/request", json=request_payload()).status_code == 503 deps.settings.ACCESS_REQUEST_ENABLED = True deps.configured_value = False assert client.post("/api/access/request", json=request_payload()).status_code == 503 deps.configured_value = True deps.rate_limit_results = [False] assert client.post("/api/access/request", json=request_payload()).status_code == 429 assert client.post("/api/access/request", json=request_payload(username="bad")).status_code == 400 assert client.post("/api/access/request", json=request_payload(first_name="bad")).status_code == 400 assert client.post("/api/access/request", json=request_payload(last_name="")).status_code == 400 assert client.post("/api/access/request", json=request_payload(email="")).status_code == 400 assert client.post("/api/access/request", json=request_payload(email="not-email")).status_code == 400 assert client.post("/api/access/request", json=request_payload(email="new@bstein.dev")).status_code == 400 deps.admin = DummyAdmin(ready=True, user={"id": "user-1"}) assert client.post("/api/access/request", json=request_payload()).status_code == 409 deps.admin = DummyAdmin(ready=True, email_user={"id": "user-2"}) assert client.post("/api/access/request", json=request_payload()).status_code == 409 def test_submit_existing_pending_new_unique_and_failure_paths() -> None: pending = DummyConn({"SELECT request_code, status": {"request_code": "alice~OLD", "status": "pending"}}) deps = DummyDeps(pending) client = make_client(deps) assert client.post("/api/access/request", json=request_payload()).get_json() == { "ok": True, "request_code": "alice~OLD", "status": "pending", } existing = DummyConn( {"SELECT request_code, status": {"request_code": "alice~VERIFY", "status": deps.EMAIL_VERIFY_PENDING_STATUS}} ) deps.conn = existing data = client.post("/api/access/request", json=request_payload()).get_json() assert data["request_code"] == "alice~VERIFY" assert deps.sent[-1][0] == "alice~VERIFY" assert any("UPDATE access_requests" in query for query, _ in existing.executed) deps.fail_send = True assert client.post("/api/access/request", json=request_payload()).status_code == 502 deps.fail_send = False new_conn = DummyConn() deps.conn = new_conn data = client.post("/api/access/request", json=request_payload(username="brad")).get_json() assert data["request_code"] == "brad~CODE" assert any("INSERT INTO access_requests" in query for query, _ in new_conn.executed) deps.fail_send = True assert client.post("/api/access/request", json=request_payload(username="casey")).status_code == 502 deps.fail_send = False unique = UniqueRaceConn({"request_code": "alice~RACE", "status": "pending"}) deps.conn = unique assert client.post("/api/access/request", json=request_payload(username="dana")).get_json()["request_code"] == "alice~RACE" assert unique.rolled_back is True unique_missing = UniqueRaceConn(None) deps.conn = unique_missing assert client.post("/api/access/request", json=request_payload(username="erin")).status_code == 502 deps.fail_connect = True assert client.post("/api/access/request", json=request_payload(username="fran")).status_code == 502 def test_verify_and_verify_link_paths() -> None: deps = DummyDeps() client = make_client(deps) deps.settings.ACCESS_REQUEST_ENABLED = False assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 503 deps.settings.ACCESS_REQUEST_ENABLED = True deps.configured_value = False assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 503 deps.configured_value = True deps.rate_limit_results = [False] assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 429 assert client.post("/api/access/request/verify", json={"token": "tok"}).status_code == 400 assert client.post("/api/access/request/verify", json={"request_code": "code"}).status_code == 400 deps.rate_limit_results = [True, False] assert client.post("/api/access/request/verify", json={"request_code": "code", "verify": "tok"}).status_code == 429 assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).get_json() == { "ok": True, "status": "pending", } deps.verify_error = VerificationError(410, "expired") assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).status_code == 410 deps.verify_error = None deps.verify_runtime_error = True assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).status_code == 502 deps.verify_runtime_error = False assert client.get("/api/access/request/verify-link").headers["Location"].endswith("verify_error=missing+token") assert "verified=1" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"] deps.verify_error = VerificationError(401, "bad token") assert "bad%20token" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"] deps.verify_error = None deps.verify_runtime_error = True assert "failed+to+verify" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"] deps.verify_runtime_error = False deps.settings.ACCESS_REQUEST_ENABLED = False assert client.get("/api/access/request/verify-link?code=code&token=tok").status_code == 503 deps.settings.ACCESS_REQUEST_ENABLED = True deps.configured_value = False assert client.get("/api/access/request/verify-link?code=code&token=tok").status_code == 503 def test_resend_preflight_success_and_failure_paths() -> None: deps = DummyDeps() client = make_client(deps) deps.settings.ACCESS_REQUEST_ENABLED = False assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 503 deps.settings.ACCESS_REQUEST_ENABLED = True deps.configured_value = False assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 503 deps.configured_value = True deps.rate_limit_results = [False] assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 429 assert client.post("/api/access/request/resend", json={}).status_code == 400 deps.rate_limit_results = [True, False] assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 429 deps.conn = DummyConn({"SELECT status, contact_email": None}) assert client.post("/api/access/request/resend", json={"request_code": "missing"}).status_code == 404 deps.conn = DummyConn({"SELECT status, contact_email": {"status": "approved", "contact_email": "alice@example.dev"}}) assert client.post("/api/access/request/resend", json={"request_code": "code"}).get_json()["status"] == "accounts_building" deps.conn = DummyConn({"SELECT status, contact_email": {"status": deps.EMAIL_VERIFY_PENDING_STATUS, "contact_email": ""}}) assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 409 success_conn = DummyConn( {"SELECT status, contact_email": {"status": deps.EMAIL_VERIFY_PENDING_STATUS, "contact_email": "alice@example.dev"}} ) deps.conn = success_conn assert client.post("/api/access/request/resend", json={"code": "code"}).get_json()["ok"] is True assert any("UPDATE access_requests" in query for query, _ in success_conn.executed) deps.fail_send = True assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 502 deps.fail_send = False deps.fail_connect = True assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 502