atlasbot: broaden intent routing and retriever selection

This commit is contained in:
Brad Stein 2026-02-05 12:40:40 -03:00
parent c356abdec0
commit d85ac70c83
3 changed files with 98 additions and 24 deletions

View File

@ -110,6 +110,8 @@ class ScoreContext:
sub_questions: list[str]
retries: int
parallelism: int
select_best: bool
fast_model: str
class AnswerEngine:
@ -155,6 +157,7 @@ class AnswerEngine:
"route",
"decompose",
"chunk_score",
"chunk_select",
"fact_select",
"synth",
"subanswer",
@ -265,7 +268,7 @@ class AnswerEngine:
if spine_line:
key_facts = _merge_fact_lines([spine_line], key_facts)
metric_facts = _merge_fact_lines([spine_line], metric_facts)
if spine_answer and mode == "fast":
if spine_answer and mode in {"fast", "quick"}:
scores = _default_scores()
meta = _build_meta(mode, call_count, call_cap, limit_hit, classify, tool_hint, started)
return AnswerResult(spine_answer, scores, meta)
@ -1343,6 +1346,8 @@ async def _score_chunks(
sub_questions=sub_questions,
retries=max(1, plan.score_retries),
parallelism=plan.parallelism,
select_best=plan.score_retries > 1,
fast_model=plan.fast_model,
)
if ctx.parallelism <= 1 or len(groups) * ctx.retries <= 1:
return await _score_groups_serial(call_llm, groups, ctx)
@ -1357,7 +1362,11 @@ async def _score_groups_serial(
scores: dict[str, float] = {}
for grp in groups:
runs = [await _score_chunk_group(call_llm, grp, ctx.question, ctx.sub_questions) for _ in range(ctx.retries)]
scores.update(_merge_score_runs(runs))
if ctx.select_best and len(runs) > 1:
best = await _select_best_score_run(call_llm, grp, runs, ctx)
scores.update(best)
else:
scores.update(_merge_score_runs(runs))
return scores
@ -1375,8 +1384,13 @@ async def _score_groups_parallel(
for idx, result in results:
grouped.setdefault(idx, []).append(result)
scores: dict[str, float] = {}
for runs in grouped.values():
scores.update(_merge_score_runs(runs))
for idx, runs in grouped.items():
if ctx.select_best and len(runs) > 1:
group = groups[idx]
best = await _select_best_score_run(call_llm, group, runs, ctx)
scores.update(best)
else:
scores.update(_merge_score_runs(runs))
return scores
@ -1434,6 +1448,38 @@ def _merge_score_runs(runs: list[dict[str, float]]) -> dict[str, float]:
return {key: totals[key] / counts[key] for key in totals}
async def _select_best_score_run(
call_llm: Callable[..., Any],
group: list[dict[str, Any]],
runs: list[dict[str, float]],
ctx: ScoreContext,
) -> dict[str, float]:
if not runs:
return {}
prompt = (
prompts.RETRIEVER_SELECT_PROMPT
+ "\nQuestion: "
+ ctx.question
+ "\nSubQuestions: "
+ json.dumps(ctx.sub_questions)
+ "\nChunks: "
+ json.dumps(group)
+ "\nRuns: "
+ json.dumps(runs)
)
raw = await call_llm(prompts.RETRIEVER_SELECT_SYSTEM, prompt, model=ctx.fast_model, tag="chunk_select")
data = parse_json(raw)
idx = 0
if isinstance(data, dict):
try:
idx = int(data.get("selected_index") or 0)
except (TypeError, ValueError):
idx = 0
if idx < 0 or idx >= len(runs):
idx = 0
return runs[idx]
def _keyword_hits(
ranked: list[dict[str, Any]],
head: dict[str, Any],
@ -1574,7 +1620,16 @@ def _spine_lines(lines: list[str]) -> dict[str, str]:
if nodes_line:
spine["nodes_count"] = nodes_line
spine["nodes_ready"] = nodes_line
else:
nodes_total = _line_starting_with(lines, "nodes_total:")
nodes_ready = _line_starting_with(lines, "nodes_ready:")
if nodes_total:
spine["nodes_count"] = nodes_total
if nodes_ready:
spine["nodes_ready"] = nodes_ready
hardware_line = _line_starting_with(lines, "hardware_nodes:")
if not hardware_line:
hardware_line = _line_starting_with(lines, "hardware:")
if hardware_line:
spine["nodes_non_rpi"] = hardware_line
hottest_line = _line_starting_with(lines, "hottest:")
@ -1613,7 +1668,14 @@ def _parse_group_line(line: str) -> dict[str, list[str]]:
if not part or "=" not in part:
continue
key, value = part.split("=", 1)
nodes = [item.strip() for item in value.split(",") if item.strip()]
value = value.strip()
nodes: list[str] = []
if "(" in value and ")" in value:
inner = value[value.find("(") + 1 : value.rfind(")")]
nodes = [item.strip() for item in inner.split(",") if item.strip()]
if not nodes:
cleaned = re.sub(r"^[0-9]+", "", value).strip()
nodes = [item.strip() for item in cleaned.split(",") if item.strip()]
groups[key.strip()] = nodes
return groups

View File

@ -10,22 +10,22 @@ class IntentMatch:
score: int
_COUNT_TERMS = r"(how many|count|number of|total|totals|tally|amount of|quantity|sum of|overall|in total|all up)"
_NODE_TERMS = r"(nodes?|workers?|worker nodes?|cluster nodes?|machines?|hosts?|members?|instances?)"
_READY_TERMS = r"(ready|unready|not ready|down|offline|not responding|missing)"
_HOTTEST_TERMS = r"(hottest|hot|highest|max(?:imum)?|peak|top|most|worst|spikiest|heaviest|largest)"
_CPU_TERMS = r"(cpu|processor|compute|core|cores|load|load avg|load average)"
_RAM_TERMS = r"(ram|memory|mem|heap)"
_NET_TERMS = r"(net|network|bandwidth|throughput|traffic|rx|tx|ingress|egress|bits|bytes)"
_IO_TERMS = r"(io|i/o|disk io|disk activity|read/write|read write|storage io|iops)"
_DISK_TERMS = r"(disk|storage|volume|pvc|filesystem|fs|capacity|space)"
_PG_TERMS = r"(postgres|postgresql|pg\\b|database|db|sql)"
_CONN_TERMS = r"(connections?|conn|pool|sessions?|clients?)"
_DB_HOT_TERMS = r"(hottest|busiest|most|largest|top|heaviest)"
_NAMESPACE_TERMS = r"(namespace|namespaces|ns\\b)"
_PODS_TERMS = r"(pods?|workloads?|tasks?|containers?)"
_NON_RPI_TERMS = r"(non[-\\s]?raspberry|not\\s+raspberry|non[-\\s]?rpi|not\\s+rpi|amd64|x86|x86_64|jetson)"
_PRESSURE_TERMS = r"(pressure|overload|hotspot|bottleneck|saturation|headroom|strain|stress)"
_COUNT_TERMS = r"(how\\s+many|count|number\\s+of|total|totals|tally|amount\\s+of|quantity|sum\\s+of|overall|in\\s+total|all\\s+up)"
_NODE_TERMS = r"(nodes?|workers?|worker\\s+nodes?|cluster\\s+nodes?|machines?|hosts?|members?|instances?|servers?|agents?|control[-\\s]?plane|control\\s+plane)"
_READY_TERMS = r"(ready|unready|not\\s+ready|down|offline|not\\s+responding|missing|lost|gone|drain(?:ed|ing)?|cordon(?:ed|ing)?)"
_HOTTEST_TERMS = r"(hottest|hot|highest|max(?:imum)?|peak|top|most|worst|spikiest|heaviest|largest|biggest|noisiest|loudest)"
_CPU_TERMS = r"(cpu|processor|processors|compute|core|cores|load|load\\s+avg|load\\s+average|util(?:ization)?|usage)"
_RAM_TERMS = r"(ram|memory|mem|heap|rss|resident|swap)"
_NET_TERMS = r"(net|network|bandwidth|throughput|traffic|rx|tx|ingress|egress|bits|bytes|packets|pps|bps)"
_IO_TERMS = r"(io|i/o|disk\\s+io|disk\\s+activity|read/?write|storage\\s+io|iops|latency)"
_DISK_TERMS = r"(disk|storage|volume|pvc|filesystem|fs|capacity|space|full|usage)"
_PG_TERMS = r"(postgres|postgresql|pg\\b|database|db|sql|psql)"
_CONN_TERMS = r"(connections?|conn|pool|sessions?|clients?|active\\s+connections?|open\\s+connections?)"
_DB_HOT_TERMS = r"(hottest|busiest|most|largest|top|heaviest|noisiest|highest\\s+load)"
_NAMESPACE_TERMS = r"(namespace|namespaces|ns\\b|tenant|workload\\s+namespace)"
_PODS_TERMS = r"(pods?|workloads?|tasks?|containers?|deployments?|jobs?|cronjobs?|daemonsets?|statefulsets?)"
_NON_RPI_TERMS = r"(non[-\\s]?raspberry|not\\s+raspberry|non[-\\s]?rpi|not\\s+rpi|amd64|x86|x86_64|intel|ryzen|jetson|arm64\\b(?!.*rpi))"
_PRESSURE_TERMS = r"(pressure|overload|hotspot|bottleneck|saturation|headroom|strain|stress|critical|warning|at\\s+capacity|near\\s+limit)"
def route_intent(question: str) -> IntentMatch | None:
@ -33,11 +33,11 @@ def route_intent(question: str) -> IntentMatch | None:
if not text:
return None
if re.search(_COUNT_TERMS, text) and re.search(_NODE_TERMS, text):
if re.search(_COUNT_TERMS, text) and (re.search(_NODE_TERMS, text) or "cluster" in text):
return IntentMatch("nodes_count", 90)
if re.search(_READY_TERMS, text) and re.search(_NODE_TERMS, text):
if re.search(_READY_TERMS, text) and (re.search(_NODE_TERMS, text) or "cluster" in text or "workers" in text):
return IntentMatch("nodes_ready", 85)
if re.search(_NON_RPI_TERMS, text) and re.search(_NODE_TERMS, text):
if re.search(_NON_RPI_TERMS, text) and (re.search(_NODE_TERMS, text) or "cluster" in text):
return IntentMatch("nodes_non_rpi", 80)
if re.search(_HOTTEST_TERMS, text) and re.search(_CPU_TERMS, text):

View File

@ -59,6 +59,18 @@ CHUNK_SCORE_PROMPT = (
"Return JSON list of objects with: id, score, reason (<=12 words)."
)
RETRIEVER_SELECT_SYSTEM = (
CLUSTER_SYSTEM
+ " Select the best relevance scoring run for the chunk group. "
+ "Return JSON only."
)
RETRIEVER_SELECT_PROMPT = (
"You are given multiple scoring runs for the same chunk group. "
"Choose the run that best aligns with the question and sub-questions. "
"Return JSON with fields: selected_index (int), rationale (<=16 words)."
)
METRIC_PREFIX_SYSTEM = (
CLUSTER_SYSTEM
+ " Select relevant metric prefixes from the available list. "