diff --git a/backend/app.py b/backend/app.py index da1e323..e2b1aec 100644 --- a/backend/app.py +++ b/backend/app.py @@ -41,6 +41,8 @@ AI_K8S_LABEL = os.getenv("AI_K8S_LABEL", "app=ollama") AI_K8S_NAMESPACE = os.getenv("AI_K8S_NAMESPACE", "ai") AI_MODEL_ANNOTATION = os.getenv("AI_MODEL_ANNOTATION", "ai.bstein.dev/model") AI_GPU_ANNOTATION = os.getenv("AI_GPU_ANNOTATION", "ai.bstein.dev/gpu") +AI_WARM_INTERVAL_SEC = float(os.getenv("AI_WARM_INTERVAL_SEC", "300")) +AI_WARM_ENABLED = os.getenv("AI_WARM_ENABLED", "true").lower() in ("1", "true", "yes") _LAB_STATUS_CACHE: dict[str, Any] = {"ts": 0.0, "value": None} @@ -256,6 +258,35 @@ def _discover_ai_meta() -> dict[str, str]: return meta +def _keep_warm() -> None: + """Periodically ping the model to keep it warm.""" + if not AI_WARM_ENABLED or AI_WARM_INTERVAL_SEC <= 0: + return + + def loop() -> None: + while True: + time.sleep(AI_WARM_INTERVAL_SEC) + try: + body = { + "model": AI_CHAT_MODEL, + "messages": [{"role": "user", "content": "ping"}], + "stream": False, + } + with httpx.Client(timeout=min(AI_CHAT_TIMEOUT_SEC, 15)) as client: + client.post(f"{AI_CHAT_API}/api/chat", json=body) + except Exception: + # best-effort; ignore failures + continue + + import threading + + threading.Thread(target=loop, daemon=True, name="ai-keep-warm").start() + + +# Start keep-warm loop on import. +_keep_warm() + + @app.route("/", defaults={"path": ""}) @app.route("/") def serve_frontend(path: str) -> Any: