ariadne/ariadne/k8s/exec.py

124 lines
3.5 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
import shlex
import time
from typing import Any
try:
from kubernetes import client, config
from kubernetes.stream import stream
except Exception as exc: # pragma: no cover - import checked at runtime
client = None
config = None
stream = None
_IMPORT_ERROR = exc
else:
_IMPORT_ERROR = None
from .pods import PodSelectionError, select_pod
from ..utils.logging import get_logger
logger = get_logger(__name__)
_CORE_API = None
@dataclass(frozen=True)
class ExecResult:
stdout: str
stderr: str
exit_code: int | None
@property
def ok(self) -> bool:
return self.exit_code in (0, None)
class ExecError(RuntimeError):
pass
def _ensure_client() -> Any:
global _CORE_API
if _IMPORT_ERROR:
raise RuntimeError(f"kubernetes client missing: {_IMPORT_ERROR}") from _IMPORT_ERROR
if _CORE_API is not None:
return _CORE_API
try:
config.load_incluster_config()
except Exception:
config.load_kube_config()
_CORE_API = client.CoreV1Api()
return _CORE_API
def _build_command(command: list[str] | str, env: dict[str, str] | None) -> list[str]:
if isinstance(command, str):
cmd_str = command
else:
cmd_str = shlex.join(command)
if env:
prefix = " ".join(f"{key}={shlex.quote(value)}" for key, value in env.items())
cmd_str = f"{prefix} {cmd_str}"
return ["/bin/sh", "-c", cmd_str]
class PodExecutor:
def __init__(self, namespace: str, label_selector: str, container: str | None = None) -> None:
self._namespace = namespace
self._label_selector = label_selector
self._container = container
def exec(
self,
command: list[str] | str,
env: dict[str, str] | None = None,
timeout_sec: float | None = None,
check: bool = True,
) -> ExecResult:
pod = select_pod(self._namespace, self._label_selector)
cmd = _build_command(command, env)
api = _ensure_client()
resp = stream(
api.connect_get_namespaced_pod_exec,
pod.name,
pod.namespace,
command=cmd,
container=self._container,
stderr=True,
stdin=False,
stdout=True,
tty=False,
_preload_content=False,
)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
exit_code: int | None = None
started = time.monotonic()
try:
while resp.is_open():
resp.update(timeout=1)
if resp.peek_stdout():
stdout_parts.append(resp.read_stdout())
if resp.peek_stderr():
stderr_parts.append(resp.read_stderr())
if hasattr(resp, "peek_exit_code") and resp.peek_exit_code():
exit_code = resp.read_exit_code()
break
if timeout_sec is not None and (time.monotonic() - started) > timeout_sec:
raise TimeoutError("pod exec timed out")
finally:
resp.close()
if exit_code is None:
exit_code = getattr(resp, "returncode", None)
result = ExecResult("".join(stdout_parts), "".join(stderr_parts), exit_code)
if check and not result.ok:
raise ExecError(f"pod exec failed exit_code={result.exit_code} stderr={result.stderr.strip()}")
return result