623 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import collections
import json
import os
import re
import ssl
import time
from typing import Any
from urllib import error, parse, request
BASE = os.environ.get("MATRIX_BASE", "http://othrys-synapse-matrix-synapse:8008")
AUTH_BASE = os.environ.get("AUTH_BASE", "http://matrix-authentication-service:8080")
USER = os.environ["BOT_USER"]
PASSWORD = os.environ["BOT_PASS"]
ROOM_ALIAS = "#othrys:live.bstein.dev"
OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/")
MODEL = os.environ.get("OLLAMA_MODEL", "qwen2.5-coder:7b-instruct-q4_0")
API_KEY = os.environ.get("CHAT_API_KEY", "")
KB_DIR = os.environ.get("KB_DIR", "")
VM_URL = os.environ.get("VM_URL", "http://victoria-metrics-single-server.monitoring.svc.cluster.local:8428")
BOT_MENTIONS = os.environ.get("BOT_MENTIONS", f"{USER},atlas")
SERVER_NAME = os.environ.get("MATRIX_SERVER_NAME", "live.bstein.dev")
MAX_KB_CHARS = int(os.environ.get("ATLASBOT_MAX_KB_CHARS", "2500"))
MAX_TOOL_CHARS = int(os.environ.get("ATLASBOT_MAX_TOOL_CHARS", "2500"))
TOKEN_RE = re.compile(r"[a-z0-9][a-z0-9_.-]{1,}", re.IGNORECASE)
HOST_RE = re.compile(r"(?i)([a-z0-9-]+(?:\\.[a-z0-9-]+)+)")
STOPWORDS = {
"the",
"and",
"for",
"with",
"this",
"that",
"from",
"into",
"what",
"how",
"why",
"when",
"where",
"which",
"who",
"can",
"could",
"should",
"would",
"please",
"help",
"atlas",
"othrys",
}
METRIC_HINT_WORDS = {
"health",
"status",
"down",
"slow",
"error",
"unknown_error",
"timeout",
"crash",
"crashloop",
"restart",
"restarts",
"pending",
"unreachable",
"latency",
}
def _tokens(text: str) -> list[str]:
toks = [t.lower() for t in TOKEN_RE.findall(text or "")]
return [t for t in toks if t not in STOPWORDS and len(t) >= 2]
# Mention detection (Matrix rich mentions + plain @atlas).
MENTION_TOKENS = [m.strip() for m in BOT_MENTIONS.split(",") if m.strip()]
MENTION_LOCALPARTS = [m.lstrip("@").split(":", 1)[0] for m in MENTION_TOKENS]
MENTION_RE = re.compile(
r"(?<!\\w)@(?:" + "|".join(re.escape(m) for m in MENTION_LOCALPARTS) + r")(?:\\:[^\\s]+)?(?!\\w)",
re.IGNORECASE,
)
def normalize_user_id(token: str) -> str:
t = token.strip()
if not t:
return ""
if t.startswith("@") and ":" in t:
return t
t = t.lstrip("@")
if ":" in t:
return f"@{t}"
return f"@{t}:{SERVER_NAME}"
MENTION_USER_IDS = {normalize_user_id(t).lower() for t in MENTION_TOKENS if normalize_user_id(t)}
def is_mentioned(content: dict, body: str) -> bool:
if MENTION_RE.search(body or "") is not None:
return True
mentions = content.get("m.mentions", {})
user_ids = mentions.get("user_ids", [])
if not isinstance(user_ids, list):
return False
return any(isinstance(uid, str) and uid.lower() in MENTION_USER_IDS for uid in user_ids)
# Matrix HTTP helper.
def req(method: str, path: str, token: str | None = None, body=None, timeout=60, base: str | None = None):
url = (base or BASE) + path
data = None
headers = {}
if body is not None:
data = json.dumps(body).encode()
headers["Content-Type"] = "application/json"
if token:
headers["Authorization"] = f"Bearer {token}"
r = request.Request(url, data=data, headers=headers, method=method)
with request.urlopen(r, timeout=timeout) as resp:
raw = resp.read()
return json.loads(raw.decode()) if raw else {}
def login() -> str:
login_user = normalize_user_id(USER)
payload = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": login_user},
"password": PASSWORD,
}
res = req("POST", "/_matrix/client/v3/login", body=payload, base=AUTH_BASE)
return res["access_token"]
def resolve_alias(token: str, alias: str) -> str:
enc = parse.quote(alias)
res = req("GET", f"/_matrix/client/v3/directory/room/{enc}", token)
return res["room_id"]
def join_room(token: str, room: str):
req("POST", f"/_matrix/client/v3/rooms/{parse.quote(room)}/join", token, body={})
def send_msg(token: str, room: str, text: str):
path = f"/_matrix/client/v3/rooms/{parse.quote(room)}/send/m.room.message"
req("POST", path, token, body={"msgtype": "m.text", "body": text})
# Atlas KB loader (no external deps; files are pre-rendered JSON via scripts/knowledge_render_atlas.py).
KB = {"catalog": {}, "runbooks": []}
_HOST_INDEX: dict[str, list[dict]] = {}
_NAME_INDEX: set[str] = set()
def _load_json_file(path: str) -> Any | None:
try:
with open(path, "rb") as f:
return json.loads(f.read().decode("utf-8"))
except Exception:
return None
def load_kb():
global KB, _HOST_INDEX, _NAME_INDEX
if not KB_DIR:
return
catalog = _load_json_file(os.path.join(KB_DIR, "catalog", "atlas.json")) or {}
runbooks = _load_json_file(os.path.join(KB_DIR, "catalog", "runbooks.json")) or []
KB = {"catalog": catalog, "runbooks": runbooks}
host_index: dict[str, list[dict]] = collections.defaultdict(list)
for ep in catalog.get("http_endpoints", []) if isinstance(catalog, dict) else []:
host = (ep.get("host") or "").lower()
if host:
host_index[host].append(ep)
_HOST_INDEX = {k: host_index[k] for k in sorted(host_index.keys())}
names: set[str] = set()
for s in catalog.get("services", []) if isinstance(catalog, dict) else []:
if isinstance(s, dict) and s.get("name"):
names.add(str(s["name"]).lower())
for w in catalog.get("workloads", []) if isinstance(catalog, dict) else []:
if isinstance(w, dict) and w.get("name"):
names.add(str(w["name"]).lower())
_NAME_INDEX = names
def kb_retrieve(query: str, *, limit: int = 3) -> str:
q = (query or "").strip()
if not q or not KB.get("runbooks"):
return ""
ql = q.lower()
q_tokens = _tokens(q)
if not q_tokens:
return ""
scored: list[tuple[int, dict]] = []
for doc in KB.get("runbooks", []):
if not isinstance(doc, dict):
continue
title = str(doc.get("title") or "")
body = str(doc.get("body") or "")
tags = doc.get("tags") or []
entrypoints = doc.get("entrypoints") or []
hay = (title + "\n" + " ".join(tags) + "\n" + " ".join(entrypoints) + "\n" + body).lower()
score = 0
for t in set(q_tokens):
if t in hay:
score += 3 if t in title.lower() else 1
for h in entrypoints:
if isinstance(h, str) and h.lower() in ql:
score += 4
if score:
scored.append((score, doc))
scored.sort(key=lambda x: x[0], reverse=True)
picked = [d for _, d in scored[:limit]]
if not picked:
return ""
parts: list[str] = ["Atlas KB (retrieved):"]
used = 0
for d in picked:
path = d.get("path") or ""
title = d.get("title") or path
body = (d.get("body") or "").strip()
snippet = body[:900].strip()
chunk = f"- {title} ({path})\n{snippet}"
if used + len(chunk) > MAX_KB_CHARS:
break
parts.append(chunk)
used += len(chunk)
return "\n".join(parts).strip()
def catalog_hints(query: str) -> tuple[str, list[tuple[str, str]]]:
q = (query or "").strip()
if not q or not KB.get("catalog"):
return "", []
ql = q.lower()
hosts = {m.group(1).lower() for m in HOST_RE.finditer(ql) if m.group(1).lower().endswith("bstein.dev")}
# Also match by known workload/service names.
for t in _tokens(ql):
if t in _NAME_INDEX:
hosts |= {ep["host"].lower() for ep in KB["catalog"].get("http_endpoints", []) if isinstance(ep, dict) and ep.get("backend", {}).get("service") == t}
edges: list[tuple[str, str]] = []
lines: list[str] = []
for host in sorted(hosts):
for ep in _HOST_INDEX.get(host, []):
backend = ep.get("backend") or {}
ns = backend.get("namespace") or ""
svc = backend.get("service") or ""
path = ep.get("path") or "/"
if not svc:
continue
wk = backend.get("workloads") or []
wk_str = ", ".join(f"{w.get('kind')}:{w.get('name')}" for w in wk if isinstance(w, dict) and w.get("name")) or "unknown"
lines.append(f"- {host}{path}{ns}/{svc}{wk_str}")
for w in wk:
if isinstance(w, dict) and w.get("name"):
edges.append((ns, str(w["name"])))
if not lines:
return "", []
return "Atlas endpoints (from GitOps):\n" + "\n".join(lines[:20]), edges
# Kubernetes API (read-only). RBAC is provided via ServiceAccount atlasbot.
_K8S_TOKEN: str | None = None
_K8S_CTX: ssl.SSLContext | None = None
def _k8s_context() -> ssl.SSLContext:
global _K8S_CTX
if _K8S_CTX is not None:
return _K8S_CTX
ca_path = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt"
ctx = ssl.create_default_context(cafile=ca_path)
_K8S_CTX = ctx
return ctx
def _k8s_token() -> str:
global _K8S_TOKEN
if _K8S_TOKEN:
return _K8S_TOKEN
token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token"
with open(token_path, "r", encoding="utf-8") as f:
_K8S_TOKEN = f.read().strip()
return _K8S_TOKEN
def k8s_get(path: str, timeout: int = 8) -> dict:
host = os.environ.get("KUBERNETES_SERVICE_HOST")
port = os.environ.get("KUBERNETES_SERVICE_PORT_HTTPS") or os.environ.get("KUBERNETES_SERVICE_PORT") or "443"
if not host:
raise RuntimeError("k8s host missing")
url = f"https://{host}:{port}{path}"
headers = {"Authorization": f"Bearer {_k8s_token()}"}
r = request.Request(url, headers=headers, method="GET")
with request.urlopen(r, timeout=timeout, context=_k8s_context()) as resp:
raw = resp.read()
return json.loads(raw.decode()) if raw else {}
def k8s_pods(namespace: str) -> list[dict]:
data = k8s_get(f"/api/v1/namespaces/{parse.quote(namespace)}/pods?limit=500")
items = data.get("items") or []
return items if isinstance(items, list) else []
def summarize_pods(namespace: str, prefixes: set[str] | None = None) -> str:
try:
pods = k8s_pods(namespace)
except Exception:
return ""
out: list[str] = []
for p in pods:
md = p.get("metadata") or {}
st = p.get("status") or {}
name = md.get("name") or ""
if prefixes and not any(name.startswith(pref + "-") or name == pref or name.startswith(pref) for pref in prefixes):
continue
phase = st.get("phase") or "?"
cs = st.get("containerStatuses") or []
restarts = 0
ready = 0
total = 0
reason = st.get("reason") or ""
for c in cs if isinstance(cs, list) else []:
if not isinstance(c, dict):
continue
total += 1
restarts += int(c.get("restartCount") or 0)
if c.get("ready"):
ready += 1
state = c.get("state") or {}
if not reason and isinstance(state, dict):
waiting = state.get("waiting") or {}
if isinstance(waiting, dict) and waiting.get("reason"):
reason = waiting.get("reason")
extra = f" ({reason})" if reason else ""
out.append(f"- {namespace}/{name}: {phase} {ready}/{total} restarts={restarts}{extra}")
return "\n".join(out[:20])
def flux_not_ready() -> str:
try:
data = k8s_get(
"/apis/kustomize.toolkit.fluxcd.io/v1/namespaces/flux-system/kustomizations?limit=200"
)
except Exception:
return ""
items = data.get("items") or []
bad: list[str] = []
for it in items if isinstance(items, list) else []:
md = it.get("metadata") or {}
st = it.get("status") or {}
name = md.get("name") or ""
conds = st.get("conditions") or []
ready = None
msg = ""
for c in conds if isinstance(conds, list) else []:
if isinstance(c, dict) and c.get("type") == "Ready":
ready = c.get("status")
msg = c.get("message") or ""
if ready not in ("True", True):
bad.append(f"- flux kustomization/{name}: Ready={ready} {msg}".strip())
return "\n".join(bad[:10])
# VictoriaMetrics (PromQL) helpers.
def vm_query(query: str, timeout: int = 8) -> dict | None:
try:
url = VM_URL.rstrip("/") + "/api/v1/query?" + parse.urlencode({"query": query})
with request.urlopen(url, timeout=timeout) as resp:
return json.loads(resp.read().decode())
except Exception:
return None
def _vm_value_series(res: dict) -> list[dict]:
if not res or (res.get("status") != "success"):
return []
data = res.get("data") or {}
result = data.get("result") or []
return result if isinstance(result, list) else []
def vm_render_result(res: dict | None, limit: int = 12) -> str:
if not res:
return ""
series = _vm_value_series(res)
if not series:
return ""
out: list[str] = []
for r in series[:limit]:
if not isinstance(r, dict):
continue
metric = r.get("metric") or {}
value = r.get("value") or []
val = value[1] if isinstance(value, list) and len(value) > 1 else ""
# Prefer common labels if present.
label_parts = []
for k in ("namespace", "pod", "container", "node", "instance", "job", "phase"):
if isinstance(metric, dict) and metric.get(k):
label_parts.append(f"{k}={metric.get(k)}")
if not label_parts and isinstance(metric, dict):
for k in sorted(metric.keys()):
if k.startswith("__"):
continue
label_parts.append(f"{k}={metric.get(k)}")
if len(label_parts) >= 4:
break
labels = ", ".join(label_parts) if label_parts else "series"
out.append(f"- {labels}: {val}")
return "\n".join(out)
def vm_top_restarts(hours: int = 1) -> str:
q = f"topk(5, sum by (namespace,pod) (increase(kube_pod_container_status_restarts_total[{hours}h])))"
res = vm_query(q)
if not res or (res.get("status") != "success"):
return ""
out: list[str] = []
for r in (res.get("data") or {}).get("result") or []:
if not isinstance(r, dict):
continue
m = r.get("metric") or {}
v = r.get("value") or []
ns = (m.get("namespace") or "").strip()
pod = (m.get("pod") or "").strip()
val = v[1] if isinstance(v, list) and len(v) > 1 else ""
if pod:
out.append(f"- restarts({hours}h): {ns}/{pod} = {val}")
return "\n".join(out)
def vm_cluster_snapshot() -> str:
parts: list[str] = []
# Node readiness (kube-state-metrics).
ready = vm_query('sum(kube_node_status_condition{condition="Ready",status="true"})')
not_ready = vm_query('sum(kube_node_status_condition{condition="Ready",status="false"})')
if ready and not_ready:
try:
r = _vm_value_series(ready)[0]["value"][1]
nr = _vm_value_series(not_ready)[0]["value"][1]
parts.append(f"- nodes ready: {r} (not ready: {nr})")
except Exception:
pass
phases = vm_query("sum by (phase) (kube_pod_status_phase)")
pr = vm_render_result(phases, limit=8)
if pr:
parts.append("Pod phases:")
parts.append(pr)
return "\n".join(parts).strip()
# Conversation state.
history = collections.defaultdict(list) # (room_id, sender|None) -> list[str] (short transcript)
def key_for(room_id: str, sender: str, is_dm: bool):
return (room_id, None) if is_dm else (room_id, sender)
def build_context(prompt: str, *, allow_tools: bool, targets: list[tuple[str, str]]) -> str:
parts: list[str] = []
kb = kb_retrieve(prompt)
if kb:
parts.append(kb)
endpoints, edges = catalog_hints(prompt)
if endpoints:
parts.append(endpoints)
if allow_tools:
# Scope pod summaries to relevant namespaces/workloads when possible.
prefixes_by_ns: dict[str, set[str]] = collections.defaultdict(set)
for ns, name in (targets or []) + (edges or []):
if ns and name:
prefixes_by_ns[ns].add(name)
pod_lines: list[str] = []
for ns in sorted(prefixes_by_ns.keys()):
summary = summarize_pods(ns, prefixes_by_ns[ns])
if summary:
pod_lines.append(f"Pods (live):\n{summary}")
if pod_lines:
parts.append("\n".join(pod_lines)[:MAX_TOOL_CHARS])
flux_bad = flux_not_ready()
if flux_bad:
parts.append("Flux (not ready):\n" + flux_bad)
p_l = (prompt or "").lower()
if any(w in p_l for w in METRIC_HINT_WORDS):
restarts = vm_top_restarts(1)
if restarts:
parts.append("VictoriaMetrics (top restarts 1h):\n" + restarts)
snap = vm_cluster_snapshot()
if snap:
parts.append("VictoriaMetrics (cluster snapshot):\n" + snap)
return "\n\n".join([p for p in parts if p]).strip()
def ollama_reply(hist_key, prompt: str, *, context: str) -> str:
try:
system = (
"System: You are Atlas, the Titan lab assistant for Atlas/Othrys. "
"Be helpful, direct, and concise. "
"Prefer answering with exact repo paths and Kubernetes resource names. "
"Never include or request secret values."
)
transcript_parts = [system]
if context:
transcript_parts.append("Context (grounded):\n" + context[:MAX_KB_CHARS])
transcript_parts.extend(history[hist_key][-24:])
transcript_parts.append(f"User: {prompt}")
transcript = "\n".join(transcript_parts)
payload = {"model": MODEL, "message": transcript}
headers = {"Content-Type": "application/json"}
if API_KEY:
headers["x-api-key"] = API_KEY
r = request.Request(OLLAMA_URL, data=json.dumps(payload).encode(), headers=headers)
with request.urlopen(r, timeout=20) as resp:
data = json.loads(resp.read().decode())
reply = data.get("message") or data.get("response") or data.get("reply") or "I'm here to help."
history[hist_key].append(f"Atlas: {reply}")
return reply
except Exception:
return "Im here — but I couldnt reach the model backend."
def sync_loop(token: str, room_id: str):
since = None
try:
res = req("GET", "/_matrix/client/v3/sync?timeout=0", token, timeout=10)
since = res.get("next_batch")
except Exception:
pass
while True:
params = {"timeout": 30000}
if since:
params["since"] = since
query = parse.urlencode(params)
try:
res = req("GET", f"/_matrix/client/v3/sync?{query}", token, timeout=35)
except Exception:
time.sleep(5)
continue
since = res.get("next_batch", since)
# invites
for rid, data in res.get("rooms", {}).get("invite", {}).items():
try:
join_room(token, rid)
except Exception:
pass
# messages
for rid, data in res.get("rooms", {}).get("join", {}).items():
timeline = data.get("timeline", {}).get("events", [])
joined_count = data.get("summary", {}).get("m.joined_member_count")
is_dm = joined_count is not None and joined_count <= 2
for ev in timeline:
if ev.get("type") != "m.room.message":
continue
content = ev.get("content", {})
body = (content.get("body", "") or "").strip()
if not body:
continue
sender = ev.get("sender", "")
if sender == f"@{USER}:live.bstein.dev":
continue
mentioned = is_mentioned(content, body)
hist_key = key_for(rid, sender, is_dm)
history[hist_key].append(f"{sender}: {body}")
history[hist_key] = history[hist_key][-80:]
if not (is_dm or mentioned):
continue
# Only do live cluster/metrics introspection in DMs.
allow_tools = is_dm
promql = ""
if allow_tools:
m = re.match(r"(?is)^\\s*promql\\s*(?:\\:|\\s)\\s*(.+?)\\s*$", body)
if m:
promql = m.group(1).strip()
# Attempt to scope tools to the most likely workloads when hostnames are mentioned.
targets: list[tuple[str, str]] = []
for m in HOST_RE.finditer(body.lower()):
host = m.group(1).lower()
for ep in _HOST_INDEX.get(host, []):
backend = ep.get("backend") or {}
ns = backend.get("namespace") or ""
for w in backend.get("workloads") or []:
if isinstance(w, dict) and w.get("name"):
targets.append((ns, str(w["name"])))
context = build_context(body, allow_tools=allow_tools, targets=targets)
if allow_tools and promql:
res = vm_query(promql, timeout=20)
rendered = vm_render_result(res, limit=15) or "(no results)"
extra = "VictoriaMetrics (PromQL result):\n" + rendered
context = (context + "\n\n" + extra).strip() if context else extra
reply = ollama_reply(hist_key, body, context=context)
send_msg(token, rid, reply)
def login_with_retry():
last_err = None
for attempt in range(10):
try:
return login()
except Exception as exc: # noqa: BLE001
last_err = exc
time.sleep(min(30, 2 ** attempt))
raise last_err
def main():
load_kb()
token = login_with_retry()
try:
room_id = resolve_alias(token, ROOM_ALIAS)
join_room(token, room_id)
except Exception:
room_id = None
sync_loop(token, room_id)
if __name__ == "__main__":
main()