bstein-dev-home/backend/tests/test_access_request_submission.py

387 lines
16 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
from contextlib import contextmanager
from types import SimpleNamespace
from typing import Any
from flask import Flask, request
import psycopg
from atlas_portal.routes.access_request_submission import register_access_request_submission
class DummyResult:
def __init__(self, row: dict[str, Any] | None = None) -> None:
self.row = row
def fetchone(self) -> dict[str, Any] | None:
return self.row
class DummyConn:
def __init__(self, rows_by_query: dict[str, dict[str, Any] | None] | None = None) -> None:
self.rows_by_query = rows_by_query or {}
self.executed: list[tuple[str, object | None]] = []
self.raise_unique_on_insert = False
self.raise_on_any = False
self.rolled_back = False
def execute(self, query: str, params: object | None = None) -> DummyResult:
self.executed.append((query, params))
if self.raise_on_any:
raise RuntimeError("database failed")
if self.raise_unique_on_insert and "INSERT INTO access_requests" in query:
raise psycopg.errors.UniqueViolation("duplicate")
for key, row in self.rows_by_query.items():
if key in query:
return DummyResult(row)
return DummyResult()
def rollback(self) -> None:
self.rolled_back = True
class UniqueRaceConn(DummyConn):
def __init__(self, row_after_rollback: dict[str, Any] | None) -> None:
super().__init__()
self.row_after_rollback = row_after_rollback
def execute(self, query: str, params: object | None = None) -> DummyResult:
self.executed.append((query, params))
if "INSERT INTO access_requests" in query:
raise psycopg.errors.UniqueViolation("duplicate")
if "SELECT request_code, status" in query:
return DummyResult(self.row_after_rollback if self.rolled_back else None)
return DummyResult()
class DummyAdmin:
def __init__(
self,
*,
ready: bool = False,
user: dict[str, Any] | None = None,
email_user: dict[str, Any] | None = None,
) -> None:
self._ready = ready
self.user = user
self.email_user = email_user
def ready(self) -> bool:
return self._ready
def find_user(self, username: str) -> dict[str, Any] | None:
return self.user
def find_user_by_email(self, email: str) -> dict[str, Any] | None:
return self.email_user
class MailerError(Exception):
pass
class VerificationError(Exception):
def __init__(self, status_code: int, message: str) -> None:
super().__init__(message)
self.status_code = status_code
self.message = message
class DummyDeps:
EMAIL_VERIFY_PENDING_STATUS = "pending_email_verification"
MailerError = MailerError
VerificationError = VerificationError
def __init__(self, conn: DummyConn | None = None) -> None:
self.settings = SimpleNamespace(
ACCESS_REQUEST_ENABLED=True,
ACCESS_REQUEST_SUBMIT_RATE_LIMIT=5,
ACCESS_REQUEST_SUBMIT_RATE_WINDOW_SEC=60,
MAILU_DOMAIN="bstein.dev",
ACCESS_REQUEST_INTERNAL_EMAIL_ALLOWLIST={"allowed@bstein.dev"},
)
self.conn = conn or DummyConn()
self.configured_value = True
self.admin = DummyAdmin()
self.rate_limit_results: list[bool] = []
self.sent: list[tuple[str, str, str]] = []
self.fail_connect = False
self.fail_send = False
self.verify_status = "pending"
self.verify_error: VerificationError | None = None
self.verify_runtime_error = False
def configured(self) -> bool:
return self.configured_value
def admin_client(self) -> DummyAdmin:
return self.admin
def _client_ip(self) -> str:
return "203.0.113.30"
def rate_limit_allow(self, *args, **kwargs) -> bool:
if self.rate_limit_results:
return self.rate_limit_results.pop(0)
return True
@contextmanager
def connect(self):
if self.fail_connect:
raise RuntimeError("database offline")
yield self.conn
def _extract_request_payload(self) -> tuple[str, str, str, str, str]:
payload = request.get_json(silent=True) or {}
return (
(payload.get("username") or "").strip(),
(payload.get("email") or "").strip(),
(payload.get("note") or "").strip(),
(payload.get("first_name") or "").strip(),
(payload.get("last_name") or "").strip(),
)
def _normalize_name(self, value: str) -> str:
return " ".join(value.strip().split())
def _validate_username(self, username: str) -> str | None:
return None if username and username != "bad" else "username is required"
def _validate_name(self, value: str, *, label: str, required: bool) -> str | None:
if value == "bad":
return f"{label} is invalid"
if required and not value:
return f"{label} is required"
return None
def _normalize_status(self, status: str) -> str:
return "accounts_building" if status == "approved" else (status or "unknown")
def _random_request_code(self, username: str) -> str:
return f"{username}~CODE"
def _hash_verification_token(self, token: str) -> str:
return f"hash:{token}"
def _send_verification_email(self, *, request_code: str, email: str, token: str) -> None:
if self.fail_send:
raise self.MailerError("send failed")
self.sent.append((request_code, email, token))
def _verify_request(self, conn: DummyConn, code: str, token: str) -> str:
if self.verify_runtime_error:
raise RuntimeError("verify failed")
if self.verify_error:
raise self.verify_error
return self.verify_status
def make_client(deps: DummyDeps):
app = Flask(__name__)
register_access_request_submission(app, deps)
return app.test_client()
def request_payload(**overrides: str) -> dict[str, str]:
payload = {
"username": "alice",
"email": "alice@example.dev",
"first_name": "Alice",
"last_name": "Atlas",
"note": "please",
}
payload.update(overrides)
return payload
def test_availability_preflight_existing_and_available_paths() -> None:
deps = DummyDeps()
client = make_client(deps)
deps.settings.ACCESS_REQUEST_ENABLED = False
assert client.get("/api/access/request/availability?username=alice").status_code == 503
deps.settings.ACCESS_REQUEST_ENABLED = True
deps.configured_value = False
assert client.get("/api/access/request/availability?username=alice").status_code == 503
deps.configured_value = True
assert client.get("/api/access/request/availability?username=bad").get_json()["reason"] == "invalid"
deps.admin = DummyAdmin(ready=True, user={"id": "user-1"})
assert client.get("/api/access/request/availability?username=alice").get_json()["reason"] == "exists"
deps.admin = DummyAdmin()
deps.conn = DummyConn({"SELECT status": {"status": "approved"}})
data = client.get("/api/access/request/availability?username=alice").get_json()
assert data == {"available": False, "reason": "requested", "status": "accounts_building"}
deps.conn = DummyConn()
assert client.get("/api/access/request/availability?username=alice").get_json() == {"available": True}
deps.fail_connect = True
assert client.get("/api/access/request/availability?username=alice").status_code == 502
def test_submit_preflight_validation_and_admin_conflicts() -> None:
deps = DummyDeps()
client = make_client(deps)
deps.settings.ACCESS_REQUEST_ENABLED = False
assert client.post("/api/access/request", json=request_payload()).status_code == 503
deps.settings.ACCESS_REQUEST_ENABLED = True
deps.configured_value = False
assert client.post("/api/access/request", json=request_payload()).status_code == 503
deps.configured_value = True
deps.rate_limit_results = [False]
assert client.post("/api/access/request", json=request_payload()).status_code == 429
assert client.post("/api/access/request", json=request_payload(username="bad")).status_code == 400
assert client.post("/api/access/request", json=request_payload(first_name="bad")).status_code == 400
assert client.post("/api/access/request", json=request_payload(last_name="")).status_code == 400
assert client.post("/api/access/request", json=request_payload(email="")).status_code == 400
assert client.post("/api/access/request", json=request_payload(email="not-email")).status_code == 400
assert client.post("/api/access/request", json=request_payload(email="new@bstein.dev")).status_code == 400
deps.admin = DummyAdmin(ready=True, user={"id": "user-1"})
assert client.post("/api/access/request", json=request_payload()).status_code == 409
deps.admin = DummyAdmin(ready=True, email_user={"id": "user-2"})
assert client.post("/api/access/request", json=request_payload()).status_code == 409
def test_submit_existing_pending_new_unique_and_failure_paths() -> None:
pending = DummyConn({"SELECT request_code, status": {"request_code": "alice~OLD", "status": "pending"}})
deps = DummyDeps(pending)
client = make_client(deps)
assert client.post("/api/access/request", json=request_payload()).get_json() == {
"ok": True,
"request_code": "alice~OLD",
"status": "pending",
}
existing = DummyConn(
{"SELECT request_code, status": {"request_code": "alice~VERIFY", "status": deps.EMAIL_VERIFY_PENDING_STATUS}}
)
deps.conn = existing
data = client.post("/api/access/request", json=request_payload()).get_json()
assert data["request_code"] == "alice~VERIFY"
assert deps.sent[-1][0] == "alice~VERIFY"
assert any("UPDATE access_requests" in query for query, _ in existing.executed)
deps.fail_send = True
assert client.post("/api/access/request", json=request_payload()).status_code == 502
deps.fail_send = False
new_conn = DummyConn()
deps.conn = new_conn
data = client.post("/api/access/request", json=request_payload(username="brad")).get_json()
assert data["request_code"] == "brad~CODE"
assert any("INSERT INTO access_requests" in query for query, _ in new_conn.executed)
deps.fail_send = True
assert client.post("/api/access/request", json=request_payload(username="casey")).status_code == 502
deps.fail_send = False
unique = UniqueRaceConn({"request_code": "alice~RACE", "status": "pending"})
deps.conn = unique
assert client.post("/api/access/request", json=request_payload(username="dana")).get_json()["request_code"] == "alice~RACE"
assert unique.rolled_back is True
unique_missing = UniqueRaceConn(None)
deps.conn = unique_missing
assert client.post("/api/access/request", json=request_payload(username="erin")).status_code == 502
deps.fail_connect = True
assert client.post("/api/access/request", json=request_payload(username="fran")).status_code == 502
def test_verify_and_verify_link_paths() -> None:
deps = DummyDeps()
client = make_client(deps)
deps.settings.ACCESS_REQUEST_ENABLED = False
assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 503
deps.settings.ACCESS_REQUEST_ENABLED = True
deps.configured_value = False
assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 503
deps.configured_value = True
deps.rate_limit_results = [False]
assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 429
assert client.post("/api/access/request/verify", json={"token": "tok"}).status_code == 400
assert client.post("/api/access/request/verify", json={"request_code": "code"}).status_code == 400
deps.rate_limit_results = [True, False]
assert client.post("/api/access/request/verify", json={"request_code": "code", "verify": "tok"}).status_code == 429
assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).get_json() == {
"ok": True,
"status": "pending",
}
deps.verify_error = VerificationError(410, "expired")
assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).status_code == 410
deps.verify_error = None
deps.verify_runtime_error = True
assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).status_code == 502
deps.verify_runtime_error = False
assert client.get("/api/access/request/verify-link").headers["Location"].endswith("verify_error=missing+token")
assert "verified=1" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"]
deps.verify_error = VerificationError(401, "bad token")
assert "bad%20token" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"]
deps.verify_error = None
deps.verify_runtime_error = True
assert "failed+to+verify" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"]
deps.verify_runtime_error = False
deps.settings.ACCESS_REQUEST_ENABLED = False
assert client.get("/api/access/request/verify-link?code=code&token=tok").status_code == 503
deps.settings.ACCESS_REQUEST_ENABLED = True
deps.configured_value = False
assert client.get("/api/access/request/verify-link?code=code&token=tok").status_code == 503
def test_resend_preflight_success_and_failure_paths() -> None:
deps = DummyDeps()
client = make_client(deps)
deps.settings.ACCESS_REQUEST_ENABLED = False
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 503
deps.settings.ACCESS_REQUEST_ENABLED = True
deps.configured_value = False
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 503
deps.configured_value = True
deps.rate_limit_results = [False]
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 429
assert client.post("/api/access/request/resend", json={}).status_code == 400
deps.rate_limit_results = [True, False]
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 429
deps.conn = DummyConn({"SELECT status, contact_email": None})
assert client.post("/api/access/request/resend", json={"request_code": "missing"}).status_code == 404
deps.conn = DummyConn({"SELECT status, contact_email": {"status": "approved", "contact_email": "alice@example.dev"}})
assert client.post("/api/access/request/resend", json={"request_code": "code"}).get_json()["status"] == "accounts_building"
deps.conn = DummyConn({"SELECT status, contact_email": {"status": deps.EMAIL_VERIFY_PENDING_STATUS, "contact_email": ""}})
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 409
success_conn = DummyConn(
{"SELECT status, contact_email": {"status": deps.EMAIL_VERIFY_PENDING_STATUS, "contact_email": "alice@example.dev"}}
)
deps.conn = success_conn
assert client.post("/api/access/request/resend", json={"code": "code"}).get_json()["ok"] is True
assert any("UPDATE access_requests" in query for query, _ in success_conn.executed)
deps.fail_send = True
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 502
deps.fail_send = False
deps.fail_connect = True
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 502