atlasbot: add genius mode and history context

This commit is contained in:
Brad Stein 2026-01-29 16:48:53 -03:00
parent 4dd7516778
commit 010cfdb07c
7 changed files with 215 additions and 19 deletions

View File

@ -19,6 +19,7 @@ class AnswerRequest(BaseModel):
text: str | None = None
content: str | None = None
mode: str | None = None
history: list[dict[str, str]] | None = None
class AnswerResponse(BaseModel):
@ -26,7 +27,11 @@ class AnswerResponse(BaseModel):
class Api:
def __init__(self, settings: Settings, answer_handler: Callable[[str, str], Awaitable[AnswerResult]]) -> None:
def __init__(
self,
settings: Settings,
answer_handler: Callable[[str, str, list[dict[str, str]] | None], Awaitable[AnswerResult]],
) -> None:
self._settings = settings
self._answer_handler = answer_handler
self.app = FastAPI()
@ -48,7 +53,7 @@ class Api:
if not question:
raise HTTPException(status_code=400, detail="missing question")
mode = (payload.mode or "quick").strip().lower()
result = await self._answer_handler(question, mode)
result = await self._answer_handler(question, mode, payload.history)
log.info(
"answer",
extra={

View File

@ -37,6 +37,7 @@ class Settings:
ollama_model: str
ollama_model_fast: str
ollama_model_smart: str
ollama_model_genius: str
ollama_fallback_model: str
ollama_timeout_sec: float
ollama_retries: int
@ -61,8 +62,10 @@ class Settings:
fast_max_angles: int
smart_max_angles: int
genius_max_angles: int
fast_max_candidates: int
smart_max_candidates: int
genius_max_candidates: int
@dataclass(frozen=True)
@ -85,6 +88,11 @@ def _load_matrix_bots(bot_mentions: tuple[str, ...]) -> tuple[MatrixBotConfig, .
if smart_user and smart_pass:
bots.append(MatrixBotConfig(smart_user, smart_pass, (smart_user,), "smart"))
genius_user = os.getenv("BOT_USER_GENIUS", "").strip()
genius_pass = os.getenv("BOT_PASS_GENIUS", "").strip()
if genius_user and genius_pass:
bots.append(MatrixBotConfig(genius_user, genius_pass, (genius_user,), "genius"))
if bots:
return tuple(bots)
@ -101,7 +109,7 @@ def load_settings() -> Settings:
bot_mentions = tuple(
[
item.strip()
for item in os.getenv("BOT_MENTIONS", "atlasbot,atlas-quick,atlas-smart").split(",")
for item in os.getenv("BOT_MENTIONS", "atlasbot,atlas-quick,atlas-smart,atlas-genius").split(",")
if item.strip()
]
)
@ -119,6 +127,7 @@ def load_settings() -> Settings:
ollama_model=os.getenv("OLLAMA_MODEL", "qwen2.5:14b-instruct"),
ollama_model_fast=os.getenv("ATLASBOT_MODEL_FAST", "qwen2.5:14b-instruct"),
ollama_model_smart=os.getenv("ATLASBOT_MODEL_SMART", "qwen2.5:14b-instruct"),
ollama_model_genius=os.getenv("ATLASBOT_MODEL_GENIUS", "qwen2.5:14b-instruct"),
ollama_fallback_model=os.getenv("OLLAMA_FALLBACK_MODEL", ""),
ollama_timeout_sec=_env_float("OLLAMA_TIMEOUT_SEC", "480"),
ollama_retries=_env_int("OLLAMA_RETRIES", "1"),
@ -139,6 +148,8 @@ def load_settings() -> Settings:
nats_result_bucket=os.getenv("ATLASBOT_NATS_RESULTS", "atlasbot_results"),
fast_max_angles=_env_int("ATLASBOT_FAST_MAX_ANGLES", "2"),
smart_max_angles=_env_int("ATLASBOT_SMART_MAX_ANGLES", "5"),
genius_max_angles=_env_int("ATLASBOT_GENIUS_MAX_ANGLES", "9"),
fast_max_candidates=_env_int("ATLASBOT_FAST_MAX_CANDIDATES", "2"),
smart_max_candidates=_env_int("ATLASBOT_SMART_MAX_CANDIDATES", "6"),
genius_max_candidates=_env_int("ATLASBOT_GENIUS_MAX_CANDIDATES", "10"),
)

View File

@ -47,6 +47,7 @@ class AnswerEngine:
question: str,
*,
mode: str,
history: list[dict[str, str]] | None = None,
observer: Callable[[str, str], None] | None = None,
) -> AnswerResult:
question = (question or "").strip()
@ -59,10 +60,12 @@ class AnswerEngine:
kb_summary = self._kb.summary()
runbooks = self._kb.runbook_titles(limit=4)
snapshot_ctx = summary_text(snapshot)
history_ctx = _format_history(history)
base_context = _join_context([
kb_summary,
runbooks,
f"ClusterSnapshot:{snapshot_ctx}" if snapshot_ctx else "",
history_ctx,
])
started = time.monotonic()
@ -96,7 +99,7 @@ class AnswerEngine:
)
if observer:
observer("synthesize", "synthesizing reply")
reply = await self._synthesize(question, best, base_context)
reply = await self._synthesize(question, best, base_context, classify, mode)
meta = {
"mode": mode,
"angles": angles,
@ -115,10 +118,13 @@ class AnswerEngine:
prompt = prompts.CLASSIFY_PROMPT + "\nQuestion: " + question
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
return _parse_json_block(raw, fallback={"needs_snapshot": True})
data = _parse_json_block(raw, fallback={"needs_snapshot": True})
if "answer_style" not in data:
data["answer_style"] = "direct"
return data
async def _angles(self, question: str, classify: dict[str, Any], mode: str) -> list[dict[str, Any]]:
max_angles = self._settings.fast_max_angles if mode == "quick" else self._settings.smart_max_angles
max_angles = _angles_limit(self._settings, mode)
prompt = prompts.ANGLE_PROMPT.format(max_angles=max_angles) + "\nQuestion: " + question
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
@ -134,14 +140,15 @@ class AnswerEngine:
context: str,
mode: str,
) -> list[dict[str, Any]]:
limit = self._settings.fast_max_candidates if mode == "quick" else self._settings.smart_max_candidates
limit = _candidates_limit(self._settings, mode)
selected = angles[:limit]
tasks = []
model = _candidate_model(self._settings, mode)
for angle in selected:
angle_q = angle.get("question") or question
prompt = prompts.CANDIDATE_PROMPT + "\nQuestion: " + angle_q
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
tasks.append(self._llm.chat(messages, model=self._settings.ollama_model_smart))
tasks.append(self._llm.chat(messages, model=model))
replies = await asyncio.gather(*tasks)
candidates = []
for angle, reply in zip(selected, replies, strict=False):
@ -163,21 +170,39 @@ class AnswerEngine:
best = [entry for entry, _scores in scored[:3]]
return best, scored[0][1]
async def _synthesize(self, question: str, best: list[dict[str, Any]], context: str) -> str:
async def _synthesize(
self,
question: str,
best: list[dict[str, Any]],
context: str,
classify: dict[str, Any],
mode: str,
) -> str:
if not best:
return "I do not have enough information to answer that yet."
parts = []
for item in best:
parts.append(f"- {item['reply']}")
style = classify.get("answer_style") if isinstance(classify, dict) else None
intent = classify.get("intent") if isinstance(classify, dict) else None
ambiguity = classify.get("ambiguity") if isinstance(classify, dict) else None
style_line = f"AnswerStyle: {style}" if style else "AnswerStyle: default"
if intent:
style_line = f"{style_line}; Intent: {intent}"
if ambiguity is not None:
style_line = f"{style_line}; Ambiguity: {ambiguity}"
prompt = (
prompts.SYNTHESIZE_PROMPT
+ "\n"
+ style_line
+ "\nQuestion: "
+ question
+ "\nCandidate answers:\n"
+ "\n".join(parts)
)
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=context)
reply = await self._llm.chat(messages, model=self._settings.ollama_model_smart)
model = _synthesis_model(self._settings, mode)
reply = await self._llm.chat(messages, model=model)
return reply
@ -186,6 +211,48 @@ def _join_context(parts: list[str]) -> str:
return text.strip()
def _format_history(history: list[dict[str, str]] | None) -> str:
if not history:
return ""
lines = ["Recent conversation:"]
for entry in history[-4:]:
question = entry.get("q") if isinstance(entry, dict) else None
answer = entry.get("a") if isinstance(entry, dict) else None
if question:
lines.append(f"Q: {question}")
if answer:
lines.append(f"A: {answer}")
return "\n".join(lines)
def _angles_limit(settings: Settings, mode: str) -> int:
if mode == "genius":
return settings.genius_max_angles
if mode == "quick":
return settings.fast_max_angles
return settings.smart_max_angles
def _candidates_limit(settings: Settings, mode: str) -> int:
if mode == "genius":
return settings.genius_max_candidates
if mode == "quick":
return settings.fast_max_candidates
return settings.smart_max_candidates
def _candidate_model(settings: Settings, mode: str) -> str:
if mode == "genius":
return settings.ollama_model_genius
return settings.ollama_model_smart
def _synthesis_model(settings: Settings, mode: str) -> str:
if mode == "genius":
return settings.ollama_model_genius
return settings.ollama_model_smart
def _parse_json_block(text: str, *, fallback: dict[str, Any]) -> dict[str, Any]:
raw = text.strip()
match = re.search(r"\{.*\}", raw, flags=re.S)

View File

@ -12,7 +12,8 @@ CLUSTER_SYSTEM = (
CLASSIFY_PROMPT = (
"Classify the user question. Return JSON with fields: "
"needs_snapshot (bool), needs_kb (bool), needs_metrics (bool), "
"needs_general (bool), intent (short string), ambiguity (0-1)."
"needs_general (bool), intent (short string), ambiguity (0-1), "
"answer_style (direct|insightful)."
)
ANGLE_PROMPT = (
@ -36,6 +37,9 @@ SYNTHESIZE_PROMPT = (
"Synthesize a final response from the best candidates. "
"Use a natural, helpful tone with light reasoning. "
"Avoid lists unless the user asked for lists. "
"If AnswerStyle is insightful, add one grounded insight or mild hypothesis, "
"but mark uncertainty briefly. "
"If AnswerStyle is direct, keep it short and factual. "
"Do not include confidence scores or evaluation metadata."
)

View File

@ -29,19 +29,24 @@ async def main() -> None:
engine = _build_engine(settings)
async def handler(payload: dict[str, str]) -> dict[str, object]:
result = await engine.answer(payload.get("question", ""), mode=payload.get("mode", "quick"))
async def handler(payload: dict[str, object]) -> dict[str, object]:
history = payload.get("history") if isinstance(payload, dict) else None
result = await engine.answer(
str(payload.get("question", "") or ""),
mode=str(payload.get("mode", "quick") or "quick"),
history=history if isinstance(history, list) else None,
)
return {"reply": result.reply, "scores": result.scores.__dict__}
queue = QueueManager(settings, handler)
await queue.start()
async def answer_handler(question: str, mode: str, observer=None) -> AnswerResult:
async def answer_handler(question: str, mode: str, history=None, observer=None) -> AnswerResult:
if settings.queue_enabled:
payload = await queue.submit({"question": question, "mode": mode})
payload = await queue.submit({"question": question, "mode": mode, "history": history or []})
reply = payload.get("reply", "") if isinstance(payload, dict) else ""
return AnswerResult(reply=reply or "", scores=result_scores(payload), meta={"mode": mode})
return await engine.answer(question, mode=mode, observer=observer)
return await engine.answer(question, mode=mode, history=history, observer=observer)
api = Api(settings, answer_handler)
server = uvicorn.Server(uvicorn.Config(api.app, host="0.0.0.0", port=settings.http_port, log_level="info"))

View File

@ -80,13 +80,15 @@ class MatrixBot:
settings: Settings,
bot: MatrixBotConfig,
engine: AnswerEngine,
answer_handler: Callable[[str, str, Callable[[str, str], None] | None], Awaitable[AnswerResult]] | None = None,
answer_handler: Callable[[str, str, list[dict[str, str]] | None, Callable[[str, str], None] | None], Awaitable[AnswerResult]]
| None = None,
) -> None:
self._settings = settings
self._bot = bot
self._engine = engine
self._client = MatrixClient(settings, bot)
self._answer_handler = answer_handler
self._history: dict[str, list[dict[str, str]]] = {}
async def run(self) -> None:
while True:
@ -153,8 +155,9 @@ class MatrixBot:
task = asyncio.create_task(heartbeat())
started = time.monotonic()
try:
handler = self._answer_handler or (lambda q, m, obs: self._engine.answer(q, mode=m, observer=obs))
result = await handler(question, mode, observer)
handler = self._answer_handler or (lambda q, m, h, obs: self._engine.answer(q, mode=m, history=h, observer=obs))
history = self._history.get(room_id, [])
result = await handler(question, mode, history, observer)
elapsed = time.monotonic() - started
await self._client.send_message(token, room_id, result.reply)
log.info(
@ -167,6 +170,8 @@ class MatrixBot:
}
},
)
history.append({"q": question, "a": result.reply})
self._history[room_id] = history[-4:]
finally:
stop.set()
task.cancel()
@ -180,6 +185,8 @@ def _extract_mode(body: str, mentions: tuple[str, ...], default_mode: str) -> tu
if not default_mode:
if "atlas-smart" in lower or "smart" in lower:
mode = "smart"
if "atlas-genius" in lower or "genius" in lower:
mode = "genius"
if "atlas-quick" in lower or "quick" in lower:
mode = "quick"
cleaned = body

View File

@ -77,6 +77,7 @@ def build_summary(snapshot: dict[str, Any] | None) -> dict[str, Any]:
summary.update(_build_pressure(snapshot))
summary.update(_build_hardware(nodes_detail))
summary.update(_build_hardware_by_node(nodes_detail))
summary.update(_build_hardware_usage(metrics, summary.get("hardware_by_node")))
summary.update(_build_node_facts(nodes_detail))
summary.update(_build_node_ages(nodes_detail))
summary.update(_build_node_taints(nodes_detail))
@ -99,6 +100,7 @@ def build_summary(snapshot: dict[str, Any] | None) -> dict[str, Any]:
summary.update(_build_root_disk_headroom(metrics))
summary.update(_build_node_load(metrics))
summary.update(_build_node_load_summary(metrics))
summary.update(_build_cluster_watchlist(summary))
summary.update(_build_workloads(snapshot))
summary.update(_build_flux(snapshot))
return summary
@ -161,6 +163,36 @@ def _build_hardware_by_node(nodes_detail: list[dict[str, Any]]) -> dict[str, Any
return {"hardware_by_node": mapping} if mapping else {}
def _build_hardware_usage(metrics: dict[str, Any], hardware_by_node: dict[str, Any] | None) -> dict[str, Any]:
if not isinstance(hardware_by_node, dict) or not hardware_by_node:
return {}
node_load = metrics.get("node_load") if isinstance(metrics.get("node_load"), list) else []
if not node_load:
return {}
buckets: dict[str, dict[str, list[float]]] = {}
for entry in node_load:
if not isinstance(entry, dict):
continue
node = entry.get("node")
if not isinstance(node, str) or not node:
continue
hardware = hardware_by_node.get(node, "unknown")
bucket = buckets.setdefault(str(hardware), {"load_index": [], "cpu": [], "ram": [], "net": [], "io": []})
for key in ("load_index", "cpu", "ram", "net", "io"):
value = entry.get(key)
if isinstance(value, (int, float)):
bucket[key].append(float(value))
output: list[dict[str, Any]] = []
for hardware, metrics_bucket in buckets.items():
row: dict[str, Any] = {"hardware": hardware}
for key, values in metrics_bucket.items():
if values:
row[key] = sum(values) / len(values)
output.append(row)
output.sort(key=lambda item: (-(item.get("load_index") or 0), item.get("hardware") or ""))
return {"hardware_usage_avg": output}
def _build_node_ages(nodes_detail: list[dict[str, Any]]) -> dict[str, Any]:
ages: list[dict[str, Any]] = []
for node in nodes_detail or []:
@ -1354,6 +1386,69 @@ def _append_node_load_summary(lines: list[str], summary: dict[str, Any]) -> None
lines.append("node_load_outliers: " + _format_names(names))
def _append_hardware_usage(lines: list[str], summary: dict[str, Any]) -> None:
usage = summary.get("hardware_usage_avg")
if not isinstance(usage, list) or not usage:
return
parts = []
for entry in usage[:5]:
if not isinstance(entry, dict):
continue
hardware = entry.get("hardware")
load = entry.get("load_index")
cpu = entry.get("cpu")
ram = entry.get("ram")
io = entry.get("io")
net = entry.get("net")
if not hardware:
continue
label = f"{hardware} idx={_format_float(load)}"
label += f" cpu={_format_float(cpu)} ram={_format_float(ram)}"
label += f" io={_format_rate_bytes(io)} net={_format_rate_bytes(net)}"
parts.append(label)
if parts:
lines.append("hardware_usage_avg: " + "; ".join(parts))
def _append_cluster_watchlist(lines: list[str], summary: dict[str, Any]) -> None:
watchlist = summary.get("cluster_watchlist")
if not isinstance(watchlist, list) or not watchlist:
return
lines.append("cluster_watchlist: " + "; ".join(watchlist))
def _build_cluster_watchlist(summary: dict[str, Any]) -> dict[str, Any]:
items: list[str] = []
nodes_summary = summary.get("nodes_summary") if isinstance(summary.get("nodes_summary"), dict) else {}
not_ready = int(nodes_summary.get("not_ready") or 0)
if not_ready > 0:
items.append(f"not_ready_nodes={not_ready}")
pressure = summary.get("pressure_nodes") if isinstance(summary.get("pressure_nodes"), dict) else {}
pressure_nodes = pressure.get("names") if isinstance(pressure.get("names"), list) else []
if pressure_nodes:
items.append(f"pressure_nodes={len(pressure_nodes)}")
pod_issues = summary.get("pod_issues") if isinstance(summary.get("pod_issues"), dict) else {}
pending_over = int(pod_issues.get("pending_over_15m") or 0)
if pending_over > 0:
items.append(f"pods_pending_over_15m={pending_over}")
workloads = summary.get("workloads_health") if isinstance(summary.get("workloads_health"), dict) else {}
deployments = workloads.get("deployments") if isinstance(workloads.get("deployments"), dict) else {}
statefulsets = workloads.get("statefulsets") if isinstance(workloads.get("statefulsets"), dict) else {}
daemonsets = workloads.get("daemonsets") if isinstance(workloads.get("daemonsets"), dict) else {}
total_not_ready = int(deployments.get("not_ready") or 0) + int(statefulsets.get("not_ready") or 0) + int(daemonsets.get("not_ready") or 0)
if total_not_ready > 0:
items.append(f"workloads_not_ready={total_not_ready}")
flux = summary.get("flux") if isinstance(summary.get("flux"), dict) else {}
flux_not_ready = int(flux.get("not_ready") or 0)
if flux_not_ready > 0:
items.append(f"flux_not_ready={flux_not_ready}")
pvc_usage = summary.get("pvc_usage_top") if isinstance(summary.get("pvc_usage_top"), list) else []
high_pvc = [entry for entry in pvc_usage if isinstance(entry, dict) and (entry.get("value") or 0) >= 90]
if high_pvc:
items.append("pvc_usage>=90%")
return {"cluster_watchlist": items} if items else {}
def _capacity_ratio_parts(entries: list[dict[str, Any]], ratio_key: str, usage_key: str, req_key: str) -> list[str]:
parts: list[str] = []
for entry in entries[:5]:
@ -1489,6 +1584,8 @@ def summary_text(snapshot: dict[str, Any] | None) -> str:
_append_workloads(lines, summary)
_append_workloads_by_namespace(lines, summary)
_append_node_load_summary(lines, summary)
_append_cluster_watchlist(lines, summary)
_append_hardware_usage(lines, summary)
_append_flux(lines, summary)
_append_units_windows(lines, summary)
return "\n".join(lines)