460 lines
15 KiB
Python
460 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import Any
|
|
|
|
from atlasbot.llm import prompts
|
|
from atlasbot.llm.client import parse_json
|
|
|
|
from ._base import *
|
|
from .retrieval_ext import _dedupe_lines
|
|
|
|
|
|
def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]:
|
|
merged: list[str] = []
|
|
for line in primary + fallback:
|
|
value = (line or "").strip()
|
|
if value and value not in merged:
|
|
merged.append(value)
|
|
return merged
|
|
|
|
|
|
def _strip_unknown_entities(reply: str, unknown_nodes: list[str], unknown_namespaces: list[str]) -> str:
|
|
if not reply:
|
|
return reply
|
|
if not unknown_nodes and not unknown_namespaces:
|
|
return reply
|
|
sentences = [s.strip() for s in re.split(r"(?<=[.!?])\s+", reply) if s.strip()]
|
|
if not sentences:
|
|
return reply
|
|
lowered_nodes = [node.lower() for node in unknown_nodes]
|
|
lowered_namespaces = [ns.lower() for ns in unknown_namespaces]
|
|
kept: list[str] = []
|
|
for sent in sentences:
|
|
lower = sent.lower()
|
|
if lowered_nodes and any(node in lower for node in lowered_nodes):
|
|
continue
|
|
if lowered_namespaces and any(f"namespace {ns}" in lower for ns in lowered_namespaces):
|
|
continue
|
|
kept.append(sent)
|
|
cleaned = " ".join(kept).strip()
|
|
return cleaned or reply
|
|
|
|
|
|
def _needs_evidence_guard(reply: str, facts: list[str]) -> bool:
|
|
if not reply or not facts:
|
|
return False
|
|
lower_reply = reply.lower()
|
|
fact_text = " ".join(facts).lower()
|
|
node_pattern = re.compile(r"\b(titan-[0-9a-z]+|node-?\d+)\b", re.IGNORECASE)
|
|
nodes = {m.group(1).lower() for m in node_pattern.finditer(reply)}
|
|
if nodes:
|
|
missing = [node for node in nodes if node not in fact_text]
|
|
if missing:
|
|
return True
|
|
pressure_terms = ("pressure", "diskpressure", "memorypressure", "pidpressure", "headroom")
|
|
if any(term in lower_reply for term in pressure_terms) and not any(term in fact_text for term in pressure_terms):
|
|
return True
|
|
arch_terms = ("amd64", "arm64", "rpi", "rpi4", "rpi5", "jetson")
|
|
return any(term in lower_reply for term in arch_terms) and not any(term in fact_text for term in arch_terms)
|
|
|
|
|
|
async def _contradiction_decision(ctx: ContradictionContext, attempts: int = 1) -> dict[str, Any]:
|
|
best = {"use_facts": True, "confidence": 50}
|
|
facts_block = "\n".join(ctx.facts[:12])
|
|
for idx in range(max(1, attempts)):
|
|
variant = f"Variant: {idx + 1}" if attempts > 1 else ""
|
|
prompt = (
|
|
prompts.CONTRADICTION_PROMPT.format(question=ctx.question, draft=ctx.reply, facts=facts_block)
|
|
+ ("\n" + variant if variant else "")
|
|
)
|
|
raw = await ctx.call_llm(
|
|
prompts.CONTRADICTION_SYSTEM,
|
|
prompt,
|
|
model=ctx.plan.fast_model,
|
|
tag="contradiction",
|
|
)
|
|
data = _parse_json_block(raw, fallback={})
|
|
try:
|
|
confidence = int(data.get("confidence", 50))
|
|
except Exception:
|
|
confidence = 50
|
|
use_facts = bool(data.get("use_facts", True))
|
|
if confidence >= best.get("confidence", 0):
|
|
best = {"use_facts": use_facts, "confidence": confidence}
|
|
return best
|
|
|
|
|
|
def _filter_lines_by_keywords(lines: list[str], keywords: list[str], max_lines: int) -> list[str]:
|
|
if not lines:
|
|
return []
|
|
tokens = _expand_tokens(keywords)
|
|
if not tokens:
|
|
return lines[:max_lines]
|
|
filtered = [line for line in lines if any(tok in line.lower() for tok in tokens)]
|
|
return (filtered or lines)[:max_lines]
|
|
|
|
|
|
def _rank_metric_lines(lines: list[str], tokens: set[str], max_lines: int) -> list[str]:
|
|
if not lines or not tokens:
|
|
return []
|
|
ranked: list[tuple[int, int, str]] = []
|
|
for line in lines:
|
|
lower = line.lower()
|
|
hits = sum(1 for tok in tokens if tok in lower)
|
|
if not hits:
|
|
continue
|
|
has_number = 1 if re.search(r"\d", line) else 0
|
|
ranked.append((has_number, hits, line))
|
|
ranked.sort(key=lambda item: (-item[0], -item[1], item[2]))
|
|
return [item[2] for item in ranked[:max_lines]]
|
|
|
|
|
|
def _select_metric_line(lines: list[str], question: str, tokens: list[str] | set[str]) -> str | None:
|
|
if not lines or not tokens:
|
|
return None
|
|
token_set = {str(tok).lower() for tok in tokens if tok}
|
|
ranked = _rank_metric_lines(lines, token_set, max_lines=6)
|
|
if not ranked:
|
|
return None
|
|
question_lower = (question or "").lower()
|
|
if any(term in question_lower for term in ("how many", "count", "total")):
|
|
for line in ranked:
|
|
lower = line.lower()
|
|
if "total" in lower or "count" in lower:
|
|
return line
|
|
return ranked[0]
|
|
|
|
|
|
def _format_direct_metric_line(line: str) -> str:
|
|
if not line:
|
|
return ""
|
|
if ":" in line:
|
|
formatted = _format_colon_metric(line)
|
|
if formatted:
|
|
return formatted
|
|
if "=" in line:
|
|
formatted = _format_equals_metric(line)
|
|
if formatted:
|
|
return formatted
|
|
return line
|
|
|
|
|
|
def _format_colon_metric(line: str) -> str | None:
|
|
key, value = line.split(":", 1)
|
|
key = key.strip().replace("_", " ")
|
|
value = value.strip()
|
|
if not value:
|
|
return None
|
|
if key == "nodes":
|
|
formatted = _format_nodes_value(value)
|
|
if formatted:
|
|
return formatted
|
|
if key in {"nodes total", "nodes_total"}:
|
|
return f"Atlas has {value} total nodes."
|
|
return f"{key} is {value}."
|
|
|
|
|
|
def _format_equals_metric(line: str) -> str | None:
|
|
pairs: list[str] = []
|
|
for part in line.split(","):
|
|
if "=" not in part:
|
|
continue
|
|
key, value = part.split("=", 1)
|
|
key = key.strip().replace("_", " ")
|
|
value = value.strip()
|
|
if not value:
|
|
continue
|
|
if key in {"nodes total", "nodes_total"}:
|
|
return f"Atlas has {value} total nodes."
|
|
pairs.append(f"{key} is {value}")
|
|
if not pairs:
|
|
return None
|
|
if len(pairs) == 1:
|
|
return f"{pairs[0]}."
|
|
return "; ".join(pairs) + "."
|
|
|
|
|
|
def _format_nodes_value(value: str) -> str | None:
|
|
parts = [p.strip() for p in value.split(",") if p.strip()]
|
|
total = None
|
|
rest: list[str] = []
|
|
for part in parts:
|
|
if part.startswith("total="):
|
|
total = part.split("=", 1)[1]
|
|
else:
|
|
rest.append(part.replace("_", " "))
|
|
if not total:
|
|
return None
|
|
if rest:
|
|
return f"Atlas has {total} total nodes ({'; '.join(rest)})."
|
|
return f"Atlas has {total} total nodes."
|
|
|
|
|
|
def _global_facts(lines: list[str]) -> list[str]:
|
|
if not lines:
|
|
return []
|
|
wanted = ("nodes_total", "nodes_ready", "cluster_name", "cluster", "nodes_not_ready")
|
|
facts: list[str] = []
|
|
for line in lines:
|
|
lower = line.lower()
|
|
if any(key in lower for key in wanted):
|
|
facts.append(line)
|
|
return _dedupe_lines(facts, limit=6)
|
|
|
|
|
|
def _has_keyword_overlap(lines: list[str], keywords: list[str]) -> bool:
|
|
if not lines or not keywords:
|
|
return False
|
|
tokens = _expand_tokens(keywords)
|
|
if not tokens:
|
|
return False
|
|
for line in lines:
|
|
lower = line.lower()
|
|
if any(tok in lower for tok in tokens):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _merge_tokens(primary: list[str], secondary: list[str], third: list[str] | None = None) -> list[str]:
|
|
merged: list[str] = []
|
|
for token in primary + secondary + (third or []):
|
|
if not token:
|
|
continue
|
|
if token not in merged:
|
|
merged.append(token)
|
|
return merged
|
|
|
|
|
|
def _extract_question_tokens(question: str) -> list[str]:
|
|
if not question:
|
|
return []
|
|
tokens: list[str] = []
|
|
for part in re.split(r"[^a-zA-Z0-9_-]+", question.lower()):
|
|
if len(part) < TOKEN_MIN_LEN:
|
|
continue
|
|
if part not in tokens:
|
|
tokens.append(part)
|
|
return tokens
|
|
|
|
|
|
def _expand_tokens(tokens: list[str]) -> list[str]:
|
|
if not tokens:
|
|
return []
|
|
expanded: list[str] = []
|
|
for token in tokens:
|
|
if not isinstance(token, str):
|
|
continue
|
|
for part in re.split(r"[^a-zA-Z0-9_-]+", token.lower()):
|
|
if len(part) < TOKEN_MIN_LEN:
|
|
continue
|
|
if part not in expanded:
|
|
expanded.append(part)
|
|
return expanded
|
|
|
|
|
|
def _ensure_token_coverage(lines: list[str], tokens: list[str], summary_lines: list[str], max_add: int = 4) -> list[str]:
|
|
if not lines or not tokens or not summary_lines:
|
|
return lines
|
|
hay = " ".join(lines).lower()
|
|
missing = [tok for tok in tokens if tok and tok.lower() not in hay]
|
|
if not missing:
|
|
return lines
|
|
added: list[str] = []
|
|
for token in missing:
|
|
token_lower = token.lower()
|
|
for line in summary_lines:
|
|
if token_lower in line.lower() and line not in lines and line not in added:
|
|
added.append(line)
|
|
break
|
|
if len(added) >= max_add:
|
|
break
|
|
if not added:
|
|
return lines
|
|
return _merge_fact_lines(added, lines)
|
|
|
|
|
|
def _best_keyword_line(lines: list[str], keywords: list[str]) -> str | None:
|
|
if not lines or not keywords:
|
|
return None
|
|
tokens = _expand_tokens(keywords)
|
|
if not tokens:
|
|
return None
|
|
best = None
|
|
best_score = 0
|
|
for line in lines:
|
|
lower = line.lower()
|
|
score = sum(1 for tok in tokens if tok in lower)
|
|
if score > best_score:
|
|
best_score = score
|
|
best = line
|
|
return best if best_score > 0 else None
|
|
|
|
|
|
def _line_starting_with(lines: list[str], prefix: str) -> str | None:
|
|
if not lines or not prefix:
|
|
return None
|
|
lower_prefix = prefix.lower()
|
|
for line in lines:
|
|
if str(line).lower().startswith(lower_prefix):
|
|
return line
|
|
return None
|
|
|
|
|
|
def _non_rpi_nodes(summary: dict[str, Any]) -> dict[str, list[str]]:
|
|
hardware = summary.get("hardware_by_node") if isinstance(summary, dict) else None
|
|
if not isinstance(hardware, dict):
|
|
return {}
|
|
grouped: dict[str, list[str]] = {}
|
|
for node, hw in hardware.items():
|
|
if not isinstance(node, str) or not isinstance(hw, str):
|
|
continue
|
|
if hw.startswith("rpi"):
|
|
continue
|
|
grouped.setdefault(hw, []).append(node)
|
|
for nodes in grouped.values():
|
|
nodes.sort()
|
|
return grouped
|
|
|
|
|
|
def _format_hardware_groups(groups: dict[str, list[str]], label: str) -> str:
|
|
if not groups:
|
|
return ""
|
|
parts = []
|
|
for hw, nodes in sorted(groups.items()):
|
|
parts.append(f"{hw} ({', '.join(nodes)})")
|
|
return f"{label}: " + "; ".join(parts) + "."
|
|
|
|
|
|
def _lexicon_context(summary: dict[str, Any]) -> str: # noqa: C901
|
|
if not isinstance(summary, dict):
|
|
return ""
|
|
lexicon = summary.get("lexicon")
|
|
if not isinstance(lexicon, dict):
|
|
return ""
|
|
terms = lexicon.get("terms")
|
|
aliases = lexicon.get("aliases")
|
|
lines: list[str] = []
|
|
if isinstance(terms, list):
|
|
for entry in terms[:8]:
|
|
if not isinstance(entry, dict):
|
|
continue
|
|
term = entry.get("term")
|
|
meaning = entry.get("meaning")
|
|
if term and meaning:
|
|
lines.append(f"{term}: {meaning}")
|
|
if isinstance(aliases, dict):
|
|
for key, value in list(aliases.items())[:6]:
|
|
if key and value:
|
|
lines.append(f"alias {key} -> {value}")
|
|
if not lines:
|
|
return ""
|
|
return "Lexicon:\n" + "\n".join(lines)
|
|
|
|
|
|
def _parse_json_block(text: str, *, fallback: dict[str, Any]) -> dict[str, Any]:
|
|
raw = text.strip()
|
|
match = re.search(r"\{.*\}", raw, flags=re.S)
|
|
if match:
|
|
return parse_json(match.group(0), fallback=fallback)
|
|
return parse_json(raw, fallback=fallback)
|
|
|
|
|
|
def _parse_json_list(text: str) -> list[dict[str, Any]]:
|
|
raw = text.strip()
|
|
match = re.search(r"\[.*\]", raw, flags=re.S)
|
|
data = parse_json(match.group(0), fallback={}) if match else parse_json(raw, fallback={})
|
|
if isinstance(data, list):
|
|
return [entry for entry in data if isinstance(entry, dict)]
|
|
return []
|
|
|
|
|
|
def _scores_from_json(data: dict[str, Any]) -> AnswerScores:
|
|
return AnswerScores(
|
|
confidence=_coerce_int(data.get("confidence"), 60),
|
|
relevance=_coerce_int(data.get("relevance"), 60),
|
|
satisfaction=_coerce_int(data.get("satisfaction"), 60),
|
|
hallucination_risk=str(data.get("hallucination_risk") or "medium"),
|
|
)
|
|
|
|
|
|
def _coerce_int(value: Any, default: int) -> int:
|
|
try:
|
|
return int(float(value))
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
|
|
def _default_scores() -> AnswerScores:
|
|
return AnswerScores(confidence=60, relevance=60, satisfaction=60, hallucination_risk="medium")
|
|
|
|
|
|
def _style_hint(classify: dict[str, Any]) -> str:
|
|
style = (classify.get("answer_style") or "").strip().lower()
|
|
qtype = (classify.get("question_type") or "").strip().lower()
|
|
if style == "insightful" or qtype in {"open_ended", "planning"}:
|
|
return "insightful"
|
|
return "direct"
|
|
|
|
|
|
def _needs_evidence_fix(reply: str, classify: dict[str, Any]) -> bool:
|
|
if not reply:
|
|
return False
|
|
lowered = reply.lower()
|
|
missing_markers = (
|
|
"don't have",
|
|
"do not have",
|
|
"don't know",
|
|
"cannot",
|
|
"can't",
|
|
"need to",
|
|
"would need",
|
|
"does not provide",
|
|
"does not mention",
|
|
"not mention",
|
|
"not provided",
|
|
"not in context",
|
|
"not referenced",
|
|
"missing",
|
|
"no specific",
|
|
"no information",
|
|
)
|
|
if classify.get("needs_snapshot") and any(marker in lowered for marker in missing_markers):
|
|
return True
|
|
return classify.get("question_type") in {"metric", "diagnostic"} and not re.search(r"\d", reply)
|
|
|
|
|
|
def _should_use_insight_guard(classify: dict[str, Any]) -> bool:
|
|
style = (classify.get("answer_style") or "").strip().lower()
|
|
qtype = (classify.get("question_type") or "").strip().lower()
|
|
return style == "insightful" or qtype in {"open_ended", "planning"}
|
|
|
|
|
|
async def _apply_insight_guard(inputs: InsightGuardInput) -> str:
|
|
if not inputs.reply or not _should_use_insight_guard(inputs.classify):
|
|
return inputs.reply
|
|
guard_prompt = prompts.INSIGHT_GUARD_PROMPT.format(question=inputs.question, answer=inputs.reply)
|
|
guard_raw = await inputs.call_llm(
|
|
prompts.INSIGHT_GUARD_SYSTEM,
|
|
guard_prompt,
|
|
context=inputs.context,
|
|
model=inputs.plan.fast_model,
|
|
tag="insight_guard",
|
|
)
|
|
guard = _parse_json_block(guard_raw, fallback={})
|
|
if guard.get("ok") is True:
|
|
return inputs.reply
|
|
fix_prompt = prompts.INSIGHT_FIX_PROMPT.format(question=inputs.question, answer=inputs.reply)
|
|
if inputs.facts:
|
|
fix_prompt = fix_prompt + "\nFacts:\n" + "\n".join(inputs.facts[:6])
|
|
return await inputs.call_llm(
|
|
prompts.INSIGHT_FIX_SYSTEM,
|
|
fix_prompt,
|
|
context=inputs.context,
|
|
model=inputs.plan.model,
|
|
tag="insight_fix",
|
|
)
|
|
|
|
|
|
__all__ = [name for name in globals() if name.startswith("_") and not name.startswith("__")]
|