99 lines
2.7 KiB
Python

from __future__ import annotations
import logging
import time
from typing import Any
import httpx
from flask import jsonify, request
from . import settings
logger = logging.getLogger(__name__)
class AriadneError(Exception):
def __init__(self, message: str, status_code: int = 502) -> None:
super().__init__(message)
self.status_code = status_code
def enabled() -> bool:
return bool(settings.ARIADNE_URL)
def _auth_headers() -> dict[str, str]:
header = request.headers.get("Authorization", "").strip()
return {"Authorization": header} if header else {}
def _url(path: str) -> str:
base = settings.ARIADNE_URL.rstrip("/")
suffix = path.lstrip("/")
return f"{base}/{suffix}" if suffix else base
def request_raw(
method: str,
path: str,
*,
payload: Any | None = None,
params: dict[str, Any] | None = None,
) -> httpx.Response:
if not enabled():
raise AriadneError("ariadne not configured", 503)
url = _url(path)
attempts = max(1, settings.ARIADNE_RETRY_COUNT)
for attempt in range(1, attempts + 1):
try:
with httpx.Client(timeout=settings.ARIADNE_TIMEOUT_SEC) as client:
resp = client.request(
method,
url,
headers=_auth_headers(),
json=payload,
params=params,
)
if resp.status_code >= 500:
logger.warning(
"ariadne error response",
extra={"method": method, "path": path, "status": resp.status_code},
)
return resp
except httpx.RequestError as exc:
logger.warning(
"ariadne request failed",
extra={
"method": method,
"path": path,
"attempt": attempt,
"timeout_sec": settings.ARIADNE_TIMEOUT_SEC,
"error": str(exc),
},
)
if attempt >= attempts:
raise AriadneError("ariadne unavailable", 502) from exc
time.sleep(settings.ARIADNE_RETRY_BACKOFF_SEC * attempt)
def proxy(
method: str,
path: str,
*,
payload: Any | None = None,
params: dict[str, Any] | None = None,
) -> tuple[Any, int]:
try:
resp = request_raw(method, path, payload=payload, params=params)
except AriadneError as exc:
return jsonify({"error": str(exc)}), exc.status_code
try:
data = resp.json()
except ValueError:
detail = resp.text.strip()
data = {"error": detail or "upstream error"}
return jsonify(data), resp.status_code