From d85ac70c833eefafd97604799a46e8433739d8ce Mon Sep 17 00:00:00 2001 From: Brad Stein Date: Thu, 5 Feb 2026 12:40:40 -0300 Subject: [PATCH] atlasbot: broaden intent routing and retriever selection --- atlasbot/engine/answerer.py | 72 +++++++++++++++++++++++++++++--- atlasbot/engine/intent_router.py | 38 ++++++++--------- atlasbot/llm/prompts.py | 12 ++++++ 3 files changed, 98 insertions(+), 24 deletions(-) diff --git a/atlasbot/engine/answerer.py b/atlasbot/engine/answerer.py index b05d60f..48e78db 100644 --- a/atlasbot/engine/answerer.py +++ b/atlasbot/engine/answerer.py @@ -110,6 +110,8 @@ class ScoreContext: sub_questions: list[str] retries: int parallelism: int + select_best: bool + fast_model: str class AnswerEngine: @@ -155,6 +157,7 @@ class AnswerEngine: "route", "decompose", "chunk_score", + "chunk_select", "fact_select", "synth", "subanswer", @@ -265,7 +268,7 @@ class AnswerEngine: if spine_line: key_facts = _merge_fact_lines([spine_line], key_facts) metric_facts = _merge_fact_lines([spine_line], metric_facts) - if spine_answer and mode == "fast": + if spine_answer and mode in {"fast", "quick"}: scores = _default_scores() meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started) return AnswerResult(spine_answer, scores, meta) @@ -1343,6 +1346,8 @@ async def _score_chunks( sub_questions=sub_questions, retries=max(1, plan.score_retries), parallelism=plan.parallelism, + select_best=plan.score_retries > 1, + fast_model=plan.fast_model, ) if ctx.parallelism <= 1 or len(groups) * ctx.retries <= 1: return await _score_groups_serial(call_llm, groups, ctx) @@ -1357,7 +1362,11 @@ async def _score_groups_serial( scores: dict[str, float] = {} for grp in groups: runs = [await _score_chunk_group(call_llm, grp, ctx.question, ctx.sub_questions) for _ in range(ctx.retries)] - scores.update(_merge_score_runs(runs)) + if ctx.select_best and len(runs) > 1: + best = await _select_best_score_run(call_llm, grp, runs, ctx) + scores.update(best) + else: + scores.update(_merge_score_runs(runs)) return scores @@ -1375,8 +1384,13 @@ async def _score_groups_parallel( for idx, result in results: grouped.setdefault(idx, []).append(result) scores: dict[str, float] = {} - for runs in grouped.values(): - scores.update(_merge_score_runs(runs)) + for idx, runs in grouped.items(): + if ctx.select_best and len(runs) > 1: + group = groups[idx] + best = await _select_best_score_run(call_llm, group, runs, ctx) + scores.update(best) + else: + scores.update(_merge_score_runs(runs)) return scores @@ -1434,6 +1448,38 @@ def _merge_score_runs(runs: list[dict[str, float]]) -> dict[str, float]: return {key: totals[key] / counts[key] for key in totals} +async def _select_best_score_run( + call_llm: Callable[..., Any], + group: list[dict[str, Any]], + runs: list[dict[str, float]], + ctx: ScoreContext, +) -> dict[str, float]: + if not runs: + return {} + prompt = ( + prompts.RETRIEVER_SELECT_PROMPT + + "\nQuestion: " + + ctx.question + + "\nSubQuestions: " + + json.dumps(ctx.sub_questions) + + "\nChunks: " + + json.dumps(group) + + "\nRuns: " + + json.dumps(runs) + ) + raw = await call_llm(prompts.RETRIEVER_SELECT_SYSTEM, prompt, model=ctx.fast_model, tag="chunk_select") + data = parse_json(raw) + idx = 0 + if isinstance(data, dict): + try: + idx = int(data.get("selected_index") or 0) + except (TypeError, ValueError): + idx = 0 + if idx < 0 or idx >= len(runs): + idx = 0 + return runs[idx] + + def _keyword_hits( ranked: list[dict[str, Any]], head: dict[str, Any], @@ -1574,7 +1620,16 @@ def _spine_lines(lines: list[str]) -> dict[str, str]: if nodes_line: spine["nodes_count"] = nodes_line spine["nodes_ready"] = nodes_line + else: + nodes_total = _line_starting_with(lines, "nodes_total:") + nodes_ready = _line_starting_with(lines, "nodes_ready:") + if nodes_total: + spine["nodes_count"] = nodes_total + if nodes_ready: + spine["nodes_ready"] = nodes_ready hardware_line = _line_starting_with(lines, "hardware_nodes:") + if not hardware_line: + hardware_line = _line_starting_with(lines, "hardware:") if hardware_line: spine["nodes_non_rpi"] = hardware_line hottest_line = _line_starting_with(lines, "hottest:") @@ -1613,7 +1668,14 @@ def _parse_group_line(line: str) -> dict[str, list[str]]: if not part or "=" not in part: continue key, value = part.split("=", 1) - nodes = [item.strip() for item in value.split(",") if item.strip()] + value = value.strip() + nodes: list[str] = [] + if "(" in value and ")" in value: + inner = value[value.find("(") + 1 : value.rfind(")")] + nodes = [item.strip() for item in inner.split(",") if item.strip()] + if not nodes: + cleaned = re.sub(r"^[0-9]+", "", value).strip() + nodes = [item.strip() for item in cleaned.split(",") if item.strip()] groups[key.strip()] = nodes return groups diff --git a/atlasbot/engine/intent_router.py b/atlasbot/engine/intent_router.py index c343b08..50175e2 100644 --- a/atlasbot/engine/intent_router.py +++ b/atlasbot/engine/intent_router.py @@ -10,22 +10,22 @@ class IntentMatch: score: int -_COUNT_TERMS = r"(how many|count|number of|total|totals|tally|amount of|quantity|sum of|overall|in total|all up)" -_NODE_TERMS = r"(nodes?|workers?|worker nodes?|cluster nodes?|machines?|hosts?|members?|instances?)" -_READY_TERMS = r"(ready|unready|not ready|down|offline|not responding|missing)" -_HOTTEST_TERMS = r"(hottest|hot|highest|max(?:imum)?|peak|top|most|worst|spikiest|heaviest|largest)" -_CPU_TERMS = r"(cpu|processor|compute|core|cores|load|load avg|load average)" -_RAM_TERMS = r"(ram|memory|mem|heap)" -_NET_TERMS = r"(net|network|bandwidth|throughput|traffic|rx|tx|ingress|egress|bits|bytes)" -_IO_TERMS = r"(io|i/o|disk io|disk activity|read/write|read write|storage io|iops)" -_DISK_TERMS = r"(disk|storage|volume|pvc|filesystem|fs|capacity|space)" -_PG_TERMS = r"(postgres|postgresql|pg\\b|database|db|sql)" -_CONN_TERMS = r"(connections?|conn|pool|sessions?|clients?)" -_DB_HOT_TERMS = r"(hottest|busiest|most|largest|top|heaviest)" -_NAMESPACE_TERMS = r"(namespace|namespaces|ns\\b)" -_PODS_TERMS = r"(pods?|workloads?|tasks?|containers?)" -_NON_RPI_TERMS = r"(non[-\\s]?raspberry|not\\s+raspberry|non[-\\s]?rpi|not\\s+rpi|amd64|x86|x86_64|jetson)" -_PRESSURE_TERMS = r"(pressure|overload|hotspot|bottleneck|saturation|headroom|strain|stress)" +_COUNT_TERMS = r"(how\\s+many|count|number\\s+of|total|totals|tally|amount\\s+of|quantity|sum\\s+of|overall|in\\s+total|all\\s+up)" +_NODE_TERMS = r"(nodes?|workers?|worker\\s+nodes?|cluster\\s+nodes?|machines?|hosts?|members?|instances?|servers?|agents?|control[-\\s]?plane|control\\s+plane)" +_READY_TERMS = r"(ready|unready|not\\s+ready|down|offline|not\\s+responding|missing|lost|gone|drain(?:ed|ing)?|cordon(?:ed|ing)?)" +_HOTTEST_TERMS = r"(hottest|hot|highest|max(?:imum)?|peak|top|most|worst|spikiest|heaviest|largest|biggest|noisiest|loudest)" +_CPU_TERMS = r"(cpu|processor|processors|compute|core|cores|load|load\\s+avg|load\\s+average|util(?:ization)?|usage)" +_RAM_TERMS = r"(ram|memory|mem|heap|rss|resident|swap)" +_NET_TERMS = r"(net|network|bandwidth|throughput|traffic|rx|tx|ingress|egress|bits|bytes|packets|pps|bps)" +_IO_TERMS = r"(io|i/o|disk\\s+io|disk\\s+activity|read/?write|storage\\s+io|iops|latency)" +_DISK_TERMS = r"(disk|storage|volume|pvc|filesystem|fs|capacity|space|full|usage)" +_PG_TERMS = r"(postgres|postgresql|pg\\b|database|db|sql|psql)" +_CONN_TERMS = r"(connections?|conn|pool|sessions?|clients?|active\\s+connections?|open\\s+connections?)" +_DB_HOT_TERMS = r"(hottest|busiest|most|largest|top|heaviest|noisiest|highest\\s+load)" +_NAMESPACE_TERMS = r"(namespace|namespaces|ns\\b|tenant|workload\\s+namespace)" +_PODS_TERMS = r"(pods?|workloads?|tasks?|containers?|deployments?|jobs?|cronjobs?|daemonsets?|statefulsets?)" +_NON_RPI_TERMS = r"(non[-\\s]?raspberry|not\\s+raspberry|non[-\\s]?rpi|not\\s+rpi|amd64|x86|x86_64|intel|ryzen|jetson|arm64\\b(?!.*rpi))" +_PRESSURE_TERMS = r"(pressure|overload|hotspot|bottleneck|saturation|headroom|strain|stress|critical|warning|at\\s+capacity|near\\s+limit)" def route_intent(question: str) -> IntentMatch | None: @@ -33,11 +33,11 @@ def route_intent(question: str) -> IntentMatch | None: if not text: return None - if re.search(_COUNT_TERMS, text) and re.search(_NODE_TERMS, text): + if re.search(_COUNT_TERMS, text) and (re.search(_NODE_TERMS, text) or "cluster" in text): return IntentMatch("nodes_count", 90) - if re.search(_READY_TERMS, text) and re.search(_NODE_TERMS, text): + if re.search(_READY_TERMS, text) and (re.search(_NODE_TERMS, text) or "cluster" in text or "workers" in text): return IntentMatch("nodes_ready", 85) - if re.search(_NON_RPI_TERMS, text) and re.search(_NODE_TERMS, text): + if re.search(_NON_RPI_TERMS, text) and (re.search(_NODE_TERMS, text) or "cluster" in text): return IntentMatch("nodes_non_rpi", 80) if re.search(_HOTTEST_TERMS, text) and re.search(_CPU_TERMS, text): diff --git a/atlasbot/llm/prompts.py b/atlasbot/llm/prompts.py index d827719..1b987a4 100644 --- a/atlasbot/llm/prompts.py +++ b/atlasbot/llm/prompts.py @@ -59,6 +59,18 @@ CHUNK_SCORE_PROMPT = ( "Return JSON list of objects with: id, score, reason (<=12 words)." ) +RETRIEVER_SELECT_SYSTEM = ( + CLUSTER_SYSTEM + + " Select the best relevance scoring run for the chunk group. " + + "Return JSON only." +) + +RETRIEVER_SELECT_PROMPT = ( + "You are given multiple scoring runs for the same chunk group. " + "Choose the run that best aligns with the question and sub-questions. " + "Return JSON with fields: selected_index (int), rationale (<=16 words)." +) + METRIC_PREFIX_SYSTEM = ( CLUSTER_SYSTEM + " Select relevant metric prefixes from the available list. "