diff --git a/services/comms/scripts/atlasbot/bot.py b/services/comms/scripts/atlasbot/bot.py index 96765b1..43f578b 100644 --- a/services/comms/scripts/atlasbot/bot.py +++ b/services/comms/scripts/atlasbot/bot.py @@ -3169,6 +3169,23 @@ def _preferred_tags_for_prompt(prompt: str) -> set[str]: return tags & _ALLOWED_INSIGHT_TAGS +def _primary_tags_for_prompt(prompt: str) -> set[str]: + q = normalize_query(prompt) + if any(word in q for word in ("cpu", "ram", "memory", "net", "network", "io", "disk", "hottest", "busy", "usage", "utilization", "load")): + return {"utilization"} + if any(word in q for word in ("postgres", "database", "db", "connections")): + return {"database"} + if any(word in q for word in ("pod", "pods", "deployment", "job", "cronjob")): + return {"pods"} + if any(word in q for word in ("workload", "service", "namespace")): + return {"workloads"} + if any(word in q for word in ("ready", "not ready", "down", "unreachable", "availability")): + return {"availability"} + if any(word in q for word in ("node", "nodes", "hardware", "arch", "architecture", "rpi", "jetson", "amd64", "arm64", "worker", "control-plane")): + return {"hardware", "inventory", "architecture"} + return set() + + _TAG_KEYWORDS: dict[str, tuple[str, ...]] = { "utilization": ("cpu", "ram", "memory", "net", "network", "io", "disk", "usage", "utilization", "hottest", "busy"), "database": ("postgres", "db", "database", "connections"), @@ -3745,25 +3762,43 @@ def _fast_fact_lines( *, focus_tags: set[str], avoid_tags: set[str], + primary_tags: set[str] | None = None, limit: int = 10, ) -> list[str]: if not fact_lines: return [] - selected: list[str] = [] + primary_tags = primary_tags or set() + scored: list[tuple[int, int, str]] = [] for idx, line in enumerate(fact_lines): fid = f"F{idx + 1}" tags = set(fact_meta.get(fid, {}).get("tags") or []) - if focus_tags and not (focus_tags & tags): - continue if avoid_tags and (avoid_tags & tags): continue - selected.append(line) + score = 0 + if primary_tags: + score += 4 * len(tags & primary_tags) + if focus_tags: + score += 2 * len(tags & focus_tags) + scored.append((score, idx, line)) + scored.sort(key=lambda item: (-item[0], item[1])) + selected: list[str] = [] + for score, _, line in scored: + if score <= 0 and selected: + break + if score > 0: + selected.append(line) if len(selected) >= limit: break - if selected: - return selected - trimmed = fact_lines[:limit] - return trimmed or fact_lines + if not selected: + selected = [line for _, _, line in scored[:limit]] + elif len(selected) < limit: + for _, _, line in scored: + if line in selected: + continue + selected.append(line) + if len(selected) >= limit: + break + return selected def _open_ended_fast_single( @@ -3799,6 +3834,7 @@ def _open_ended_fast( ) -> str: model = _model_for_mode("fast") subjective = _is_subjective_query(prompt) + primary_tags = _primary_tags_for_prompt(prompt) focus_tags = _preferred_tags_for_prompt(prompt) if not focus_tags and subjective: focus_tags = set(_ALLOWED_INSIGHT_TAGS) @@ -3808,15 +3844,15 @@ def _open_ended_fast( fact_meta, focus_tags=focus_tags, avoid_tags=avoid_tags, + primary_tags=primary_tags, ) selected_meta = _fact_pack_meta(selected_lines) selected_pack = _fact_pack_text(selected_lines, selected_meta) - context = _append_history_context(selected_pack, history_lines) if state: state.total_steps = _open_ended_total_steps("fast") return _open_ended_fast_single( prompt, - context=context, + context=selected_pack, state=state, model=model, )