243 lines
9.4 KiB
Python
243 lines
9.4 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from contextlib import contextmanager
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from types import SimpleNamespace
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from flask import Flask, jsonify
|
||
|
|
|
||
|
|
from atlas_portal.routes.access_request_status import register_access_request_status
|
||
|
|
|
||
|
|
|
||
|
|
class DummyResult:
|
||
|
|
def __init__(self, row: dict[str, Any] | None = None, rows: list[dict[str, Any]] | None = None) -> None:
|
||
|
|
self.row = row
|
||
|
|
self.rows = rows or []
|
||
|
|
|
||
|
|
def fetchone(self) -> dict[str, Any] | None:
|
||
|
|
return self.row
|
||
|
|
|
||
|
|
def fetchall(self) -> list[dict[str, Any]]:
|
||
|
|
return self.rows
|
||
|
|
|
||
|
|
|
||
|
|
class DummyConn:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
rows_by_query: dict[str, dict[str, Any] | None] | None = None,
|
||
|
|
many_by_query: dict[str, list[dict[str, Any]]] | None = None,
|
||
|
|
) -> None:
|
||
|
|
self.rows_by_query = rows_by_query or {}
|
||
|
|
self.many_by_query = many_by_query or {}
|
||
|
|
self.executed: list[tuple[str, object | None]] = []
|
||
|
|
|
||
|
|
def execute(self, query: str, params: object | None = None) -> DummyResult:
|
||
|
|
self.executed.append((query, params))
|
||
|
|
for key, rows in self.many_by_query.items():
|
||
|
|
if key in query:
|
||
|
|
return DummyResult(rows=rows)
|
||
|
|
for key, row in self.rows_by_query.items():
|
||
|
|
if key in query:
|
||
|
|
return DummyResult(row=row)
|
||
|
|
return DummyResult()
|
||
|
|
|
||
|
|
|
||
|
|
class DummyAriadne:
|
||
|
|
def __init__(self, *, enabled: bool = False) -> None:
|
||
|
|
self._enabled = enabled
|
||
|
|
self.proxy_calls: list[tuple[str, str, object | None]] = []
|
||
|
|
|
||
|
|
def enabled(self) -> bool:
|
||
|
|
return self._enabled
|
||
|
|
|
||
|
|
def proxy(self, method: str, path: str, payload: object | None = None):
|
||
|
|
self.proxy_calls.append((method, path, payload))
|
||
|
|
return jsonify({"proxied": True, "method": method, "path": path, "payload": payload})
|
||
|
|
|
||
|
|
|
||
|
|
class DummyDeps:
|
||
|
|
def __init__(self, conn: DummyConn | None = None) -> None:
|
||
|
|
self.settings = SimpleNamespace(
|
||
|
|
ACCESS_REQUEST_ENABLED=True,
|
||
|
|
ACCESS_REQUEST_STATUS_RATE_LIMIT=5,
|
||
|
|
ACCESS_REQUEST_STATUS_RATE_WINDOW_SEC=60,
|
||
|
|
)
|
||
|
|
self.conn = conn or DummyConn()
|
||
|
|
self.ariadne_client = DummyAriadne()
|
||
|
|
self.rate_limit_results: list[bool] = []
|
||
|
|
self.provisioned: list[str] = []
|
||
|
|
self.fail_connect = False
|
||
|
|
self.fail_provision = False
|
||
|
|
self.configured_value = True
|
||
|
|
|
||
|
|
def configured(self) -> bool:
|
||
|
|
return self.configured_value
|
||
|
|
|
||
|
|
def _client_ip(self) -> str:
|
||
|
|
return "203.0.113.20"
|
||
|
|
|
||
|
|
def rate_limit_allow(self, *args, **kwargs) -> bool:
|
||
|
|
if self.rate_limit_results:
|
||
|
|
return self.rate_limit_results.pop(0)
|
||
|
|
return True
|
||
|
|
|
||
|
|
@contextmanager
|
||
|
|
def connect(self):
|
||
|
|
if self.fail_connect:
|
||
|
|
raise RuntimeError("database offline")
|
||
|
|
yield self.conn
|
||
|
|
|
||
|
|
def _normalize_status(self, status: str) -> str:
|
||
|
|
return "accounts_building" if status == "approved" else (status or "unknown")
|
||
|
|
|
||
|
|
def _advance_status(self, conn: DummyConn, code: str, username: str, status: str) -> str:
|
||
|
|
return self._normalize_status(status)
|
||
|
|
|
||
|
|
def provision_access_request(self, code: str) -> None:
|
||
|
|
self.provisioned.append(code)
|
||
|
|
if self.fail_provision:
|
||
|
|
raise RuntimeError("provision failed")
|
||
|
|
|
||
|
|
def provision_tasks_complete(self, conn: DummyConn, code: str) -> bool:
|
||
|
|
return True
|
||
|
|
|
||
|
|
def _onboarding_payload(self, conn: DummyConn, code: str, username: str) -> dict[str, str]:
|
||
|
|
return {"code": code, "username": username}
|
||
|
|
|
||
|
|
|
||
|
|
def make_client(deps: DummyDeps):
|
||
|
|
app = Flask(__name__)
|
||
|
|
register_access_request_status(app, deps)
|
||
|
|
return app.test_client()
|
||
|
|
|
||
|
|
|
||
|
|
def test_status_preflight_and_rate_limit_paths() -> None:
|
||
|
|
deps = DummyDeps()
|
||
|
|
client = make_client(deps)
|
||
|
|
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = False
|
||
|
|
assert client.post("/api/access/request/status", json={"request_code": "code"}).status_code == 503
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = True
|
||
|
|
|
||
|
|
deps.configured_value = False
|
||
|
|
assert client.post("/api/access/request/status", json={"request_code": "code"}).status_code == 503
|
||
|
|
deps.configured_value = True
|
||
|
|
|
||
|
|
deps.rate_limit_results = [False]
|
||
|
|
assert client.post("/api/access/request/status", json={"request_code": "code"}).status_code == 429
|
||
|
|
assert client.post("/api/access/request/status", json={}).status_code == 400
|
||
|
|
|
||
|
|
deps.rate_limit_results = [True, False]
|
||
|
|
assert client.post("/api/access/request/status", json={"request_code": "code"}).status_code == 429
|
||
|
|
|
||
|
|
|
||
|
|
def test_status_returns_tasks_onboarding_and_reveals_password() -> None:
|
||
|
|
now = datetime(2026, 4, 20, tzinfo=timezone.utc)
|
||
|
|
conn = DummyConn(
|
||
|
|
rows_by_query={
|
||
|
|
"SELECT status,": {
|
||
|
|
"status": "awaiting_onboarding",
|
||
|
|
"username": "alice",
|
||
|
|
"initial_password": "temp-pass",
|
||
|
|
"initial_password_revealed_at": now,
|
||
|
|
"email_verified_at": now,
|
||
|
|
}
|
||
|
|
},
|
||
|
|
many_by_query={
|
||
|
|
"SELECT task, status, detail, updated_at": [
|
||
|
|
{"task": "mail", "status": "error", "detail": "smtp failed", "updated_at": now},
|
||
|
|
{"task": "apps", "status": "ok", "detail": "", "updated_at": "not-a-date"},
|
||
|
|
]
|
||
|
|
},
|
||
|
|
)
|
||
|
|
deps = DummyDeps(conn)
|
||
|
|
client = make_client(deps)
|
||
|
|
|
||
|
|
response = client.post("/api/access/request/status", json={"request_code": "code", "reveal_initial_password": True})
|
||
|
|
data = response.get_json()
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
assert data["status"] == "awaiting_onboarding"
|
||
|
|
assert data["email_verified"] is True
|
||
|
|
assert data["blocked"] is True
|
||
|
|
assert data["automation_complete"] is True
|
||
|
|
assert data["tasks"][0]["detail"] == "smtp failed"
|
||
|
|
assert data["tasks"][0]["updated_at"].startswith("2026-04-20T00:00:00")
|
||
|
|
assert data["initial_password_revealed_at"].startswith("2026-04-20T00:00:00")
|
||
|
|
assert "initial_password" not in data
|
||
|
|
assert data["onboarding_url"] == "/onboarding?code=code"
|
||
|
|
assert data["onboarding"] == {"code": "code", "username": "alice"}
|
||
|
|
|
||
|
|
|
||
|
|
def test_status_autoprovisions_and_handles_failure_paths() -> None:
|
||
|
|
conn = DummyConn(rows_by_query={"SELECT status,": {"status": "approved", "username": "alice"}})
|
||
|
|
deps = DummyDeps(conn)
|
||
|
|
deps.fail_provision = True
|
||
|
|
client = make_client(deps)
|
||
|
|
|
||
|
|
response = client.post("/api/access/request/status", json={"request_code": "code", "reveal_password": True})
|
||
|
|
assert response.status_code == 200
|
||
|
|
assert deps.provisioned == ["code"]
|
||
|
|
|
||
|
|
not_found = DummyDeps(DummyConn(rows_by_query={"SELECT status,": None}))
|
||
|
|
assert make_client(not_found).post("/api/access/request/status", json={"request_code": "missing"}).status_code == 404
|
||
|
|
|
||
|
|
broken = DummyDeps()
|
||
|
|
broken.fail_connect = True
|
||
|
|
assert make_client(broken).post("/api/access/request/status", json={"request_code": "code"}).status_code == 502
|
||
|
|
|
||
|
|
|
||
|
|
def test_retry_preflight_proxy_and_validation_paths() -> None:
|
||
|
|
deps = DummyDeps()
|
||
|
|
client = make_client(deps)
|
||
|
|
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = False
|
||
|
|
assert client.post("/api/access/request/retry", json={"request_code": "code"}).status_code == 503
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = True
|
||
|
|
|
||
|
|
deps.configured_value = False
|
||
|
|
assert client.post("/api/access/request/retry", json={"request_code": "code"}).status_code == 503
|
||
|
|
deps.configured_value = True
|
||
|
|
|
||
|
|
deps.rate_limit_results = [False]
|
||
|
|
assert client.post("/api/access/request/retry", json={"request_code": "code"}).status_code == 429
|
||
|
|
assert client.post("/api/access/request/retry", json={}).status_code == 400
|
||
|
|
|
||
|
|
deps.ariadne_client = DummyAriadne(enabled=True)
|
||
|
|
response = client.post("/api/access/request/retry", json={"request_code": "code", "tasks": ["mail", "", 5]})
|
||
|
|
assert response.status_code == 200
|
||
|
|
assert deps.ariadne_client.proxy_calls == [
|
||
|
|
("POST", "/api/access/requests/code/retry", {"tasks": ["mail"]})
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def test_retry_updates_failed_tasks_and_swallows_provision_errors() -> None:
|
||
|
|
conn = DummyConn(rows_by_query={"SELECT status FROM access_requests": {"status": "accounts_building"}})
|
||
|
|
deps = DummyDeps(conn)
|
||
|
|
deps.fail_provision = True
|
||
|
|
client = make_client(deps)
|
||
|
|
|
||
|
|
response = client.post("/api/access/request/retry", json={"request_code": "code", "tasks": ["mail"]})
|
||
|
|
data = response.get_json()
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
assert data == {"ok": True, "status": "accounts_building"}
|
||
|
|
assert any("task = ANY" in query for query, _ in conn.executed)
|
||
|
|
|
||
|
|
no_tasks_conn = DummyConn(rows_by_query={"SELECT status FROM access_requests": {"status": "approved"}})
|
||
|
|
no_tasks = DummyDeps(no_tasks_conn)
|
||
|
|
assert make_client(no_tasks).post("/api/access/request/retry", json={"request_code": "code"}).status_code == 200
|
||
|
|
assert any("WHERE request_code = %s AND status = 'error'" in query for query, _ in no_tasks_conn.executed)
|
||
|
|
|
||
|
|
missing = DummyDeps(DummyConn(rows_by_query={"SELECT status FROM access_requests": None}))
|
||
|
|
assert make_client(missing).post("/api/access/request/retry", json={"request_code": "missing"}).status_code == 404
|
||
|
|
|
||
|
|
rejected = DummyDeps(DummyConn(rows_by_query={"SELECT status FROM access_requests": {"status": "ready"}}))
|
||
|
|
assert make_client(rejected).post("/api/access/request/retry", json={"request_code": "code"}).status_code == 409
|
||
|
|
|
||
|
|
broken = DummyDeps()
|
||
|
|
broken.fail_connect = True
|
||
|
|
assert make_client(broken).post("/api/access/request/retry", json={"request_code": "code"}).status_code == 502
|