387 lines
16 KiB
Python
387 lines
16 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from contextlib import contextmanager
|
||
|
|
from types import SimpleNamespace
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from flask import Flask, request
|
||
|
|
import psycopg
|
||
|
|
|
||
|
|
from atlas_portal.routes.access_request_submission import register_access_request_submission
|
||
|
|
|
||
|
|
|
||
|
|
class DummyResult:
|
||
|
|
def __init__(self, row: dict[str, Any] | None = None) -> None:
|
||
|
|
self.row = row
|
||
|
|
|
||
|
|
def fetchone(self) -> dict[str, Any] | None:
|
||
|
|
return self.row
|
||
|
|
|
||
|
|
|
||
|
|
class DummyConn:
|
||
|
|
def __init__(self, rows_by_query: dict[str, dict[str, Any] | None] | None = None) -> None:
|
||
|
|
self.rows_by_query = rows_by_query or {}
|
||
|
|
self.executed: list[tuple[str, object | None]] = []
|
||
|
|
self.raise_unique_on_insert = False
|
||
|
|
self.raise_on_any = False
|
||
|
|
self.rolled_back = False
|
||
|
|
|
||
|
|
def execute(self, query: str, params: object | None = None) -> DummyResult:
|
||
|
|
self.executed.append((query, params))
|
||
|
|
if self.raise_on_any:
|
||
|
|
raise RuntimeError("database failed")
|
||
|
|
if self.raise_unique_on_insert and "INSERT INTO access_requests" in query:
|
||
|
|
raise psycopg.errors.UniqueViolation("duplicate")
|
||
|
|
for key, row in self.rows_by_query.items():
|
||
|
|
if key in query:
|
||
|
|
return DummyResult(row)
|
||
|
|
return DummyResult()
|
||
|
|
|
||
|
|
def rollback(self) -> None:
|
||
|
|
self.rolled_back = True
|
||
|
|
|
||
|
|
|
||
|
|
class UniqueRaceConn(DummyConn):
|
||
|
|
def __init__(self, row_after_rollback: dict[str, Any] | None) -> None:
|
||
|
|
super().__init__()
|
||
|
|
self.row_after_rollback = row_after_rollback
|
||
|
|
|
||
|
|
def execute(self, query: str, params: object | None = None) -> DummyResult:
|
||
|
|
self.executed.append((query, params))
|
||
|
|
if "INSERT INTO access_requests" in query:
|
||
|
|
raise psycopg.errors.UniqueViolation("duplicate")
|
||
|
|
if "SELECT request_code, status" in query:
|
||
|
|
return DummyResult(self.row_after_rollback if self.rolled_back else None)
|
||
|
|
return DummyResult()
|
||
|
|
|
||
|
|
|
||
|
|
class DummyAdmin:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
ready: bool = False,
|
||
|
|
user: dict[str, Any] | None = None,
|
||
|
|
email_user: dict[str, Any] | None = None,
|
||
|
|
) -> None:
|
||
|
|
self._ready = ready
|
||
|
|
self.user = user
|
||
|
|
self.email_user = email_user
|
||
|
|
|
||
|
|
def ready(self) -> bool:
|
||
|
|
return self._ready
|
||
|
|
|
||
|
|
def find_user(self, username: str) -> dict[str, Any] | None:
|
||
|
|
return self.user
|
||
|
|
|
||
|
|
def find_user_by_email(self, email: str) -> dict[str, Any] | None:
|
||
|
|
return self.email_user
|
||
|
|
|
||
|
|
|
||
|
|
class MailerError(Exception):
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class VerificationError(Exception):
|
||
|
|
def __init__(self, status_code: int, message: str) -> None:
|
||
|
|
super().__init__(message)
|
||
|
|
self.status_code = status_code
|
||
|
|
self.message = message
|
||
|
|
|
||
|
|
|
||
|
|
class DummyDeps:
|
||
|
|
EMAIL_VERIFY_PENDING_STATUS = "pending_email_verification"
|
||
|
|
MailerError = MailerError
|
||
|
|
VerificationError = VerificationError
|
||
|
|
|
||
|
|
def __init__(self, conn: DummyConn | None = None) -> None:
|
||
|
|
self.settings = SimpleNamespace(
|
||
|
|
ACCESS_REQUEST_ENABLED=True,
|
||
|
|
ACCESS_REQUEST_SUBMIT_RATE_LIMIT=5,
|
||
|
|
ACCESS_REQUEST_SUBMIT_RATE_WINDOW_SEC=60,
|
||
|
|
MAILU_DOMAIN="bstein.dev",
|
||
|
|
ACCESS_REQUEST_INTERNAL_EMAIL_ALLOWLIST={"allowed@bstein.dev"},
|
||
|
|
)
|
||
|
|
self.conn = conn or DummyConn()
|
||
|
|
self.configured_value = True
|
||
|
|
self.admin = DummyAdmin()
|
||
|
|
self.rate_limit_results: list[bool] = []
|
||
|
|
self.sent: list[tuple[str, str, str]] = []
|
||
|
|
self.fail_connect = False
|
||
|
|
self.fail_send = False
|
||
|
|
self.verify_status = "pending"
|
||
|
|
self.verify_error: VerificationError | None = None
|
||
|
|
self.verify_runtime_error = False
|
||
|
|
|
||
|
|
def configured(self) -> bool:
|
||
|
|
return self.configured_value
|
||
|
|
|
||
|
|
def admin_client(self) -> DummyAdmin:
|
||
|
|
return self.admin
|
||
|
|
|
||
|
|
def _client_ip(self) -> str:
|
||
|
|
return "203.0.113.30"
|
||
|
|
|
||
|
|
def rate_limit_allow(self, *args, **kwargs) -> bool:
|
||
|
|
if self.rate_limit_results:
|
||
|
|
return self.rate_limit_results.pop(0)
|
||
|
|
return True
|
||
|
|
|
||
|
|
@contextmanager
|
||
|
|
def connect(self):
|
||
|
|
if self.fail_connect:
|
||
|
|
raise RuntimeError("database offline")
|
||
|
|
yield self.conn
|
||
|
|
|
||
|
|
def _extract_request_payload(self) -> tuple[str, str, str, str, str]:
|
||
|
|
payload = request.get_json(silent=True) or {}
|
||
|
|
return (
|
||
|
|
(payload.get("username") or "").strip(),
|
||
|
|
(payload.get("email") or "").strip(),
|
||
|
|
(payload.get("note") or "").strip(),
|
||
|
|
(payload.get("first_name") or "").strip(),
|
||
|
|
(payload.get("last_name") or "").strip(),
|
||
|
|
)
|
||
|
|
|
||
|
|
def _normalize_name(self, value: str) -> str:
|
||
|
|
return " ".join(value.strip().split())
|
||
|
|
|
||
|
|
def _validate_username(self, username: str) -> str | None:
|
||
|
|
return None if username and username != "bad" else "username is required"
|
||
|
|
|
||
|
|
def _validate_name(self, value: str, *, label: str, required: bool) -> str | None:
|
||
|
|
if value == "bad":
|
||
|
|
return f"{label} is invalid"
|
||
|
|
if required and not value:
|
||
|
|
return f"{label} is required"
|
||
|
|
return None
|
||
|
|
|
||
|
|
def _normalize_status(self, status: str) -> str:
|
||
|
|
return "accounts_building" if status == "approved" else (status or "unknown")
|
||
|
|
|
||
|
|
def _random_request_code(self, username: str) -> str:
|
||
|
|
return f"{username}~CODE"
|
||
|
|
|
||
|
|
def _hash_verification_token(self, token: str) -> str:
|
||
|
|
return f"hash:{token}"
|
||
|
|
|
||
|
|
def _send_verification_email(self, *, request_code: str, email: str, token: str) -> None:
|
||
|
|
if self.fail_send:
|
||
|
|
raise self.MailerError("send failed")
|
||
|
|
self.sent.append((request_code, email, token))
|
||
|
|
|
||
|
|
def _verify_request(self, conn: DummyConn, code: str, token: str) -> str:
|
||
|
|
if self.verify_runtime_error:
|
||
|
|
raise RuntimeError("verify failed")
|
||
|
|
if self.verify_error:
|
||
|
|
raise self.verify_error
|
||
|
|
return self.verify_status
|
||
|
|
|
||
|
|
|
||
|
|
def make_client(deps: DummyDeps):
|
||
|
|
app = Flask(__name__)
|
||
|
|
register_access_request_submission(app, deps)
|
||
|
|
return app.test_client()
|
||
|
|
|
||
|
|
|
||
|
|
def request_payload(**overrides: str) -> dict[str, str]:
|
||
|
|
payload = {
|
||
|
|
"username": "alice",
|
||
|
|
"email": "alice@example.dev",
|
||
|
|
"first_name": "Alice",
|
||
|
|
"last_name": "Atlas",
|
||
|
|
"note": "please",
|
||
|
|
}
|
||
|
|
payload.update(overrides)
|
||
|
|
return payload
|
||
|
|
|
||
|
|
|
||
|
|
def test_availability_preflight_existing_and_available_paths() -> None:
|
||
|
|
deps = DummyDeps()
|
||
|
|
client = make_client(deps)
|
||
|
|
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = False
|
||
|
|
assert client.get("/api/access/request/availability?username=alice").status_code == 503
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = True
|
||
|
|
|
||
|
|
deps.configured_value = False
|
||
|
|
assert client.get("/api/access/request/availability?username=alice").status_code == 503
|
||
|
|
deps.configured_value = True
|
||
|
|
|
||
|
|
assert client.get("/api/access/request/availability?username=bad").get_json()["reason"] == "invalid"
|
||
|
|
|
||
|
|
deps.admin = DummyAdmin(ready=True, user={"id": "user-1"})
|
||
|
|
assert client.get("/api/access/request/availability?username=alice").get_json()["reason"] == "exists"
|
||
|
|
deps.admin = DummyAdmin()
|
||
|
|
|
||
|
|
deps.conn = DummyConn({"SELECT status": {"status": "approved"}})
|
||
|
|
data = client.get("/api/access/request/availability?username=alice").get_json()
|
||
|
|
assert data == {"available": False, "reason": "requested", "status": "accounts_building"}
|
||
|
|
|
||
|
|
deps.conn = DummyConn()
|
||
|
|
assert client.get("/api/access/request/availability?username=alice").get_json() == {"available": True}
|
||
|
|
|
||
|
|
deps.fail_connect = True
|
||
|
|
assert client.get("/api/access/request/availability?username=alice").status_code == 502
|
||
|
|
|
||
|
|
|
||
|
|
def test_submit_preflight_validation_and_admin_conflicts() -> None:
|
||
|
|
deps = DummyDeps()
|
||
|
|
client = make_client(deps)
|
||
|
|
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = False
|
||
|
|
assert client.post("/api/access/request", json=request_payload()).status_code == 503
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = True
|
||
|
|
|
||
|
|
deps.configured_value = False
|
||
|
|
assert client.post("/api/access/request", json=request_payload()).status_code == 503
|
||
|
|
deps.configured_value = True
|
||
|
|
|
||
|
|
deps.rate_limit_results = [False]
|
||
|
|
assert client.post("/api/access/request", json=request_payload()).status_code == 429
|
||
|
|
assert client.post("/api/access/request", json=request_payload(username="bad")).status_code == 400
|
||
|
|
assert client.post("/api/access/request", json=request_payload(first_name="bad")).status_code == 400
|
||
|
|
assert client.post("/api/access/request", json=request_payload(last_name="")).status_code == 400
|
||
|
|
assert client.post("/api/access/request", json=request_payload(email="")).status_code == 400
|
||
|
|
assert client.post("/api/access/request", json=request_payload(email="not-email")).status_code == 400
|
||
|
|
assert client.post("/api/access/request", json=request_payload(email="new@bstein.dev")).status_code == 400
|
||
|
|
|
||
|
|
deps.admin = DummyAdmin(ready=True, user={"id": "user-1"})
|
||
|
|
assert client.post("/api/access/request", json=request_payload()).status_code == 409
|
||
|
|
deps.admin = DummyAdmin(ready=True, email_user={"id": "user-2"})
|
||
|
|
assert client.post("/api/access/request", json=request_payload()).status_code == 409
|
||
|
|
|
||
|
|
|
||
|
|
def test_submit_existing_pending_new_unique_and_failure_paths() -> None:
|
||
|
|
pending = DummyConn({"SELECT request_code, status": {"request_code": "alice~OLD", "status": "pending"}})
|
||
|
|
deps = DummyDeps(pending)
|
||
|
|
client = make_client(deps)
|
||
|
|
assert client.post("/api/access/request", json=request_payload()).get_json() == {
|
||
|
|
"ok": True,
|
||
|
|
"request_code": "alice~OLD",
|
||
|
|
"status": "pending",
|
||
|
|
}
|
||
|
|
|
||
|
|
existing = DummyConn(
|
||
|
|
{"SELECT request_code, status": {"request_code": "alice~VERIFY", "status": deps.EMAIL_VERIFY_PENDING_STATUS}}
|
||
|
|
)
|
||
|
|
deps.conn = existing
|
||
|
|
data = client.post("/api/access/request", json=request_payload()).get_json()
|
||
|
|
assert data["request_code"] == "alice~VERIFY"
|
||
|
|
assert deps.sent[-1][0] == "alice~VERIFY"
|
||
|
|
assert any("UPDATE access_requests" in query for query, _ in existing.executed)
|
||
|
|
|
||
|
|
deps.fail_send = True
|
||
|
|
assert client.post("/api/access/request", json=request_payload()).status_code == 502
|
||
|
|
deps.fail_send = False
|
||
|
|
|
||
|
|
new_conn = DummyConn()
|
||
|
|
deps.conn = new_conn
|
||
|
|
data = client.post("/api/access/request", json=request_payload(username="brad")).get_json()
|
||
|
|
assert data["request_code"] == "brad~CODE"
|
||
|
|
assert any("INSERT INTO access_requests" in query for query, _ in new_conn.executed)
|
||
|
|
|
||
|
|
deps.fail_send = True
|
||
|
|
assert client.post("/api/access/request", json=request_payload(username="casey")).status_code == 502
|
||
|
|
deps.fail_send = False
|
||
|
|
|
||
|
|
unique = UniqueRaceConn({"request_code": "alice~RACE", "status": "pending"})
|
||
|
|
deps.conn = unique
|
||
|
|
assert client.post("/api/access/request", json=request_payload(username="dana")).get_json()["request_code"] == "alice~RACE"
|
||
|
|
assert unique.rolled_back is True
|
||
|
|
|
||
|
|
unique_missing = UniqueRaceConn(None)
|
||
|
|
deps.conn = unique_missing
|
||
|
|
assert client.post("/api/access/request", json=request_payload(username="erin")).status_code == 502
|
||
|
|
|
||
|
|
deps.fail_connect = True
|
||
|
|
assert client.post("/api/access/request", json=request_payload(username="fran")).status_code == 502
|
||
|
|
|
||
|
|
|
||
|
|
def test_verify_and_verify_link_paths() -> None:
|
||
|
|
deps = DummyDeps()
|
||
|
|
client = make_client(deps)
|
||
|
|
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = False
|
||
|
|
assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 503
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = True
|
||
|
|
|
||
|
|
deps.configured_value = False
|
||
|
|
assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 503
|
||
|
|
deps.configured_value = True
|
||
|
|
|
||
|
|
deps.rate_limit_results = [False]
|
||
|
|
assert client.post("/api/access/request/verify", json={"request_code": "code", "token": "tok"}).status_code == 429
|
||
|
|
assert client.post("/api/access/request/verify", json={"token": "tok"}).status_code == 400
|
||
|
|
assert client.post("/api/access/request/verify", json={"request_code": "code"}).status_code == 400
|
||
|
|
deps.rate_limit_results = [True, False]
|
||
|
|
assert client.post("/api/access/request/verify", json={"request_code": "code", "verify": "tok"}).status_code == 429
|
||
|
|
|
||
|
|
assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).get_json() == {
|
||
|
|
"ok": True,
|
||
|
|
"status": "pending",
|
||
|
|
}
|
||
|
|
deps.verify_error = VerificationError(410, "expired")
|
||
|
|
assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).status_code == 410
|
||
|
|
deps.verify_error = None
|
||
|
|
deps.verify_runtime_error = True
|
||
|
|
assert client.post("/api/access/request/verify", json={"code": "code", "token": "tok"}).status_code == 502
|
||
|
|
deps.verify_runtime_error = False
|
||
|
|
|
||
|
|
assert client.get("/api/access/request/verify-link").headers["Location"].endswith("verify_error=missing+token")
|
||
|
|
assert "verified=1" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"]
|
||
|
|
deps.verify_error = VerificationError(401, "bad token")
|
||
|
|
assert "bad%20token" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"]
|
||
|
|
deps.verify_error = None
|
||
|
|
deps.verify_runtime_error = True
|
||
|
|
assert "failed+to+verify" in client.get("/api/access/request/verify-link?code=code&token=tok").headers["Location"]
|
||
|
|
deps.verify_runtime_error = False
|
||
|
|
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = False
|
||
|
|
assert client.get("/api/access/request/verify-link?code=code&token=tok").status_code == 503
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = True
|
||
|
|
|
||
|
|
deps.configured_value = False
|
||
|
|
assert client.get("/api/access/request/verify-link?code=code&token=tok").status_code == 503
|
||
|
|
|
||
|
|
|
||
|
|
def test_resend_preflight_success_and_failure_paths() -> None:
|
||
|
|
deps = DummyDeps()
|
||
|
|
client = make_client(deps)
|
||
|
|
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = False
|
||
|
|
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 503
|
||
|
|
deps.settings.ACCESS_REQUEST_ENABLED = True
|
||
|
|
|
||
|
|
deps.configured_value = False
|
||
|
|
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 503
|
||
|
|
deps.configured_value = True
|
||
|
|
|
||
|
|
deps.rate_limit_results = [False]
|
||
|
|
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 429
|
||
|
|
assert client.post("/api/access/request/resend", json={}).status_code == 400
|
||
|
|
deps.rate_limit_results = [True, False]
|
||
|
|
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 429
|
||
|
|
|
||
|
|
deps.conn = DummyConn({"SELECT status, contact_email": None})
|
||
|
|
assert client.post("/api/access/request/resend", json={"request_code": "missing"}).status_code == 404
|
||
|
|
|
||
|
|
deps.conn = DummyConn({"SELECT status, contact_email": {"status": "approved", "contact_email": "alice@example.dev"}})
|
||
|
|
assert client.post("/api/access/request/resend", json={"request_code": "code"}).get_json()["status"] == "accounts_building"
|
||
|
|
|
||
|
|
deps.conn = DummyConn({"SELECT status, contact_email": {"status": deps.EMAIL_VERIFY_PENDING_STATUS, "contact_email": ""}})
|
||
|
|
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 409
|
||
|
|
|
||
|
|
success_conn = DummyConn(
|
||
|
|
{"SELECT status, contact_email": {"status": deps.EMAIL_VERIFY_PENDING_STATUS, "contact_email": "alice@example.dev"}}
|
||
|
|
)
|
||
|
|
deps.conn = success_conn
|
||
|
|
assert client.post("/api/access/request/resend", json={"code": "code"}).get_json()["ok"] is True
|
||
|
|
assert any("UPDATE access_requests" in query for query, _ in success_conn.executed)
|
||
|
|
|
||
|
|
deps.fail_send = True
|
||
|
|
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 502
|
||
|
|
deps.fail_send = False
|
||
|
|
|
||
|
|
deps.fail_connect = True
|
||
|
|
assert client.post("/api/access/request/resend", json={"request_code": "code"}).status_code == 502
|