diff --git a/backend/tests/test_access_request_status.py b/backend/tests/test_access_request_status.py new file mode 100644 index 0000000..07f165e --- /dev/null +++ b/backend/tests/test_access_request_status.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime, timezone +from types import SimpleNamespace +from typing import Any + +from flask import Flask, jsonify + +from atlas_portal.routes.access_request_status import register_access_request_status + + +class DummyResult: + def __init__(self, row: dict[str, Any] | None = None, rows: list[dict[str, Any]] | None = None) -> None: + self.row = row + self.rows = rows or [] + + def fetchone(self) -> dict[str, Any] | None: + return self.row + + def fetchall(self) -> list[dict[str, Any]]: + return self.rows + + +class DummyConn: + def __init__( + self, + *, + rows_by_query: dict[str, dict[str, Any] | None] | None = None, + many_by_query: dict[str, list[dict[str, Any]]] | None = None, + ) -> None: + self.rows_by_query = rows_by_query or {} + self.many_by_query = many_by_query or {} + self.executed: list[tuple[str, object | None]] = [] + + def execute(self, query: str, params: object | None = None) -> DummyResult: + self.executed.append((query, params)) + for key, rows in self.many_by_query.items(): + if key in query: + return DummyResult(rows=rows) + for key, row in self.rows_by_query.items(): + if key in query: + return DummyResult(row=row) + return DummyResult() + + +class DummyAriadne: + def __init__(self, *, enabled: bool = False) -> None: + self._enabled = enabled + self.proxy_calls: list[tuple[str, str, object | None]] = [] + + def enabled(self) -> bool: + return self._enabled + + def proxy(self, method: str, path: str, payload: object | None = None): + self.proxy_calls.append((method, path, payload)) + return jsonify({"proxied": True, "method": method, "path": path, "payload": payload}) + + +class DummyDeps: + def __init__(self, conn: DummyConn | None = None) -> None: + self.settings = SimpleNamespace( + ACCESS_REQUEST_ENABLED=True, + ACCESS_REQUEST_STATUS_RATE_LIMIT=5, + ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC=60, + ) + self.conn = conn or DummyConn() + self.ariadne_client = DummyAriadne() + self.rate_limit_results: list[bool] = [] + self.provisioned: list[str] = [] + self.fail_connect = False + self.fail_provision = False + self.configured_value = True + + def configured(self) -> bool: + return self.configured_value + + def _client_ip(self) -> str: + return "203.0.113.20" + + def rate_limit_allow(self, *args, **kwargs) -> bool: + if self.rate_limit_results: + return self.rate_limit_results.pop(0) + return True + + @contextmanager + def connect(self): + if self.fail_connect: + raise RuntimeError("database offline") + yield self.conn + + def _normalize_status(self, status: str) -> str: + return "accounts_building" if status == "approved" else (status or "unknown") + + def _advance_status(self, conn: DummyConn, code: str, username: str, status: str) -> str: + return self._normalize_status(status) + + def provision_access_request(self, code: str) -> None: + self.provisioned.append(code) + if self.fail_provision: + raise RuntimeError("provision failed") + + def provision_tasks_complete(self, conn: DummyConn, code: str) -> bool: + return True + + def _onboarding_payload(self, conn: DummyConn, code: str, username: str) -> dict[str, str]: + return {"code": code, "username": username} + + +def make_client(deps: DummyDeps): + app = Flask(__name__) + register_access_request_status(app, deps) + return app.test_client() + + +def test_status_preflight_and_rate_limit_paths() -> None: + deps = DummyDeps() + client = make_client(deps) + + deps.settings.ACCESS_REQUEST_ENABLED = False + assert client.post("/api/access/request/status", json={"request_code": "code"}).status_code == 503 + deps.settings.ACCESS_REQUEST_ENABLED = True + + deps.configured_value = False + assert client.post("/api/access/request/status", json={"request_code": "code"}).status_code == 503 + deps.configured_value = True + + deps.rate_limit_results = [False] + assert client.post("/api/access/request/status", json={"request_code": "code"}).status_code == 429 + assert client.post("/api/access/request/status", json={}).status_code == 400 + + deps.rate_limit_results = [True, False] + assert client.post("/api/access/request/status", json={"request_code": "code"}).status_code == 429 + + +def test_status_returns_tasks_onboarding_and_reveals_password() -> None: + now = datetime(2026, 4, 20, tzinfo=timezone.utc) + conn = DummyConn( + rows_by_query={ + "SELECT status,": { + "status": "awaiting_onboarding", + "username": "alice", + "initial_password": "temp-pass", + "initial_password_revealed_at": now, + "email_verified_at": now, + } + }, + many_by_query={ + "SELECT task, status, detail, updated_at": [ + {"task": "mail", "status": "error", "detail": "smtp failed", "updated_at": now}, + {"task": "apps", "status": "ok", "detail": "", "updated_at": "not-a-date"}, + ] + }, + ) + deps = DummyDeps(conn) + client = make_client(deps) + + response = client.post("/api/access/request/status", json={"request_code": "code", "reveal_initial_password": True}) + data = response.get_json() + + assert response.status_code == 200 + assert data["status"] == "awaiting_onboarding" + assert data["email_verified"] is True + assert data["blocked"] is True + assert data["automation_complete"] is True + assert data["tasks"][0]["detail"] == "smtp failed" + assert data["tasks"][0]["updated_at"].startswith("2026-04-20T00:00:00") + assert data["initial_password_revealed_at"].startswith("2026-04-20T00:00:00") + assert "initial_password" not in data + assert data["onboarding_url"] == "/onboarding?code=code" + assert data["onboarding"] == {"code": "code", "username": "alice"} + + +def test_status_autoprovisions_and_handles_failure_paths() -> None: + conn = DummyConn(rows_by_query={"SELECT status,": {"status": "approved", "username": "alice"}}) + deps = DummyDeps(conn) + deps.fail_provision = True + client = make_client(deps) + + response = client.post("/api/access/request/status", json={"request_code": "code", "reveal_password": True}) + assert response.status_code == 200 + assert deps.provisioned == ["code"] + + not_found = DummyDeps(DummyConn(rows_by_query={"SELECT status,": None})) + assert make_client(not_found).post("/api/access/request/status", json={"request_code": "missing"}).status_code == 404 + + broken = DummyDeps() + broken.fail_connect = True + assert make_client(broken).post("/api/access/request/status", json={"request_code": "code"}).status_code == 502 + + +def test_retry_preflight_proxy_and_validation_paths() -> None: + deps = DummyDeps() + client = make_client(deps) + + deps.settings.ACCESS_REQUEST_ENABLED = False + assert client.post("/api/access/request/retry", json={"request_code": "code"}).status_code == 503 + deps.settings.ACCESS_REQUEST_ENABLED = True + + deps.configured_value = False + assert client.post("/api/access/request/retry", json={"request_code": "code"}).status_code == 503 + deps.configured_value = True + + deps.rate_limit_results = [False] + assert client.post("/api/access/request/retry", json={"request_code": "code"}).status_code == 429 + assert client.post("/api/access/request/retry", json={}).status_code == 400 + + deps.ariadne_client = DummyAriadne(enabled=True) + response = client.post("/api/access/request/retry", json={"request_code": "code", "tasks": ["mail", "", 5]}) + assert response.status_code == 200 + assert deps.ariadne_client.proxy_calls == [ + ("POST", "/api/access/requests/code/retry", {"tasks": ["mail"]}) + ] + + +def test_retry_updates_failed_tasks_and_swallows_provision_errors() -> None: + conn = DummyConn(rows_by_query={"SELECT status FROM access_requests": {"status": "accounts_building"}}) + deps = DummyDeps(conn) + deps.fail_provision = True + client = make_client(deps) + + response = client.post("/api/access/request/retry", json={"request_code": "code", "tasks": ["mail"]}) + data = response.get_json() + + assert response.status_code == 200 + assert data == {"ok": True, "status": "accounts_building"} + assert any("task = ANY" in query for query, _ in conn.executed) + + no_tasks_conn = DummyConn(rows_by_query={"SELECT status FROM access_requests": {"status": "approved"}}) + no_tasks = DummyDeps(no_tasks_conn) + assert make_client(no_tasks).post("/api/access/request/retry", json={"request_code": "code"}).status_code == 200 + assert any("WHERE request_code = %s AND status = 'error'" in query for query, _ in no_tasks_conn.executed) + + missing = DummyDeps(DummyConn(rows_by_query={"SELECT status FROM access_requests": None})) + assert make_client(missing).post("/api/access/request/retry", json={"request_code": "missing"}).status_code == 404 + + rejected = DummyDeps(DummyConn(rows_by_query={"SELECT status FROM access_requests": {"status": "ready"}})) + assert make_client(rejected).post("/api/access/request/retry", json={"request_code": "code"}).status_code == 409 + + broken = DummyDeps() + broken.fail_connect = True + assert make_client(broken).post("/api/access/request/retry", json={"request_code": "code"}).status_code == 502