460 lines
15 KiB
Python

from __future__ import annotations
import re
from typing import Any
from atlasbot.llm import prompts
from atlasbot.llm.client import parse_json
from ._base import *
from .retrieval_ext import _dedupe_lines
def _merge_fact_lines(primary: list[str], fallback: list[str]) -> list[str]:
merged: list[str] = []
for line in primary + fallback:
value = (line or "").strip()
if value and value not in merged:
merged.append(value)
return merged
def _strip_unknown_entities(reply: str, unknown_nodes: list[str], unknown_namespaces: list[str]) -> str:
if not reply:
return reply
if not unknown_nodes and not unknown_namespaces:
return reply
sentences = [s.strip() for s in re.split(r"(?<=[.!?])\s+", reply) if s.strip()]
if not sentences:
return reply
lowered_nodes = [node.lower() for node in unknown_nodes]
lowered_namespaces = [ns.lower() for ns in unknown_namespaces]
kept: list[str] = []
for sent in sentences:
lower = sent.lower()
if lowered_nodes and any(node in lower for node in lowered_nodes):
continue
if lowered_namespaces and any(f"namespace {ns}" in lower for ns in lowered_namespaces):
continue
kept.append(sent)
cleaned = " ".join(kept).strip()
return cleaned or reply
def _needs_evidence_guard(reply: str, facts: list[str]) -> bool:
if not reply or not facts:
return False
lower_reply = reply.lower()
fact_text = " ".join(facts).lower()
node_pattern = re.compile(r"\b(titan-[0-9a-z]+|node-?\d+)\b", re.IGNORECASE)
nodes = {m.group(1).lower() for m in node_pattern.finditer(reply)}
if nodes:
missing = [node for node in nodes if node not in fact_text]
if missing:
return True
pressure_terms = ("pressure", "diskpressure", "memorypressure", "pidpressure", "headroom")
if any(term in lower_reply for term in pressure_terms) and not any(term in fact_text for term in pressure_terms):
return True
arch_terms = ("amd64", "arm64", "rpi", "rpi4", "rpi5", "jetson")
return any(term in lower_reply for term in arch_terms) and not any(term in fact_text for term in arch_terms)
async def _contradiction_decision(ctx: ContradictionContext, attempts: int = 1) -> dict[str, Any]:
best = {"use_facts": True, "confidence": 50}
facts_block = "\n".join(ctx.facts[:12])
for idx in range(max(1, attempts)):
variant = f"Variant: {idx + 1}" if attempts > 1 else ""
prompt = (
prompts.CONTRADICTION_PROMPT.format(question=ctx.question, draft=ctx.reply, facts=facts_block)
+ ("\n" + variant if variant else "")
)
raw = await ctx.call_llm(
prompts.CONTRADICTION_SYSTEM,
prompt,
model=ctx.plan.fast_model,
tag="contradiction",
)
data = _parse_json_block(raw, fallback={})
try:
confidence = int(data.get("confidence", 50))
except Exception:
confidence = 50
use_facts = bool(data.get("use_facts", True))
if confidence >= best.get("confidence", 0):
best = {"use_facts": use_facts, "confidence": confidence}
return best
def _filter_lines_by_keywords(lines: list[str], keywords: list[str], max_lines: int) -> list[str]:
if not lines:
return []
tokens = _expand_tokens(keywords)
if not tokens:
return lines[:max_lines]
filtered = [line for line in lines if any(tok in line.lower() for tok in tokens)]
return (filtered or lines)[:max_lines]
def _rank_metric_lines(lines: list[str], tokens: set[str], max_lines: int) -> list[str]:
if not lines or not tokens:
return []
ranked: list[tuple[int, int, str]] = []
for line in lines:
lower = line.lower()
hits = sum(1 for tok in tokens if tok in lower)
if not hits:
continue
has_number = 1 if re.search(r"\d", line) else 0
ranked.append((has_number, hits, line))
ranked.sort(key=lambda item: (-item[0], -item[1], item[2]))
return [item[2] for item in ranked[:max_lines]]
def _select_metric_line(lines: list[str], question: str, tokens: list[str] | set[str]) -> str | None:
if not lines or not tokens:
return None
token_set = {str(tok).lower() for tok in tokens if tok}
ranked = _rank_metric_lines(lines, token_set, max_lines=6)
if not ranked:
return None
question_lower = (question or "").lower()
if any(term in question_lower for term in ("how many", "count", "total")):
for line in ranked:
lower = line.lower()
if "total" in lower or "count" in lower:
return line
return ranked[0]
def _format_direct_metric_line(line: str) -> str:
if not line:
return ""
if ":" in line:
formatted = _format_colon_metric(line)
if formatted:
return formatted
if "=" in line:
formatted = _format_equals_metric(line)
if formatted:
return formatted
return line
def _format_colon_metric(line: str) -> str | None:
key, value = line.split(":", 1)
key = key.strip().replace("_", " ")
value = value.strip()
if not value:
return None
if key == "nodes":
formatted = _format_nodes_value(value)
if formatted:
return formatted
if key in {"nodes total", "nodes_total"}:
return f"Atlas has {value} total nodes."
return f"{key} is {value}."
def _format_equals_metric(line: str) -> str | None:
pairs: list[str] = []
for part in line.split(","):
if "=" not in part:
continue
key, value = part.split("=", 1)
key = key.strip().replace("_", " ")
value = value.strip()
if not value:
continue
if key in {"nodes total", "nodes_total"}:
return f"Atlas has {value} total nodes."
pairs.append(f"{key} is {value}")
if not pairs:
return None
if len(pairs) == 1:
return f"{pairs[0]}."
return "; ".join(pairs) + "."
def _format_nodes_value(value: str) -> str | None:
parts = [p.strip() for p in value.split(",") if p.strip()]
total = None
rest: list[str] = []
for part in parts:
if part.startswith("total="):
total = part.split("=", 1)[1]
else:
rest.append(part.replace("_", " "))
if not total:
return None
if rest:
return f"Atlas has {total} total nodes ({'; '.join(rest)})."
return f"Atlas has {total} total nodes."
def _global_facts(lines: list[str]) -> list[str]:
if not lines:
return []
wanted = ("nodes_total", "nodes_ready", "cluster_name", "cluster", "nodes_not_ready")
facts: list[str] = []
for line in lines:
lower = line.lower()
if any(key in lower for key in wanted):
facts.append(line)
return _dedupe_lines(facts, limit=6)
def _has_keyword_overlap(lines: list[str], keywords: list[str]) -> bool:
if not lines or not keywords:
return False
tokens = _expand_tokens(keywords)
if not tokens:
return False
for line in lines:
lower = line.lower()
if any(tok in lower for tok in tokens):
return True
return False
def _merge_tokens(primary: list[str], secondary: list[str], third: list[str] | None = None) -> list[str]:
merged: list[str] = []
for token in primary + secondary + (third or []):
if not token:
continue
if token not in merged:
merged.append(token)
return merged
def _extract_question_tokens(question: str) -> list[str]:
if not question:
return []
tokens: list[str] = []
for part in re.split(r"[^a-zA-Z0-9_-]+", question.lower()):
if len(part) < TOKEN_MIN_LEN:
continue
if part not in tokens:
tokens.append(part)
return tokens
def _expand_tokens(tokens: list[str]) -> list[str]:
if not tokens:
return []
expanded: list[str] = []
for token in tokens:
if not isinstance(token, str):
continue
for part in re.split(r"[^a-zA-Z0-9_-]+", token.lower()):
if len(part) < TOKEN_MIN_LEN:
continue
if part not in expanded:
expanded.append(part)
return expanded
def _ensure_token_coverage(lines: list[str], tokens: list[str], summary_lines: list[str], max_add: int = 4) -> list[str]:
if not lines or not tokens or not summary_lines:
return lines
hay = " ".join(lines).lower()
missing = [tok for tok in tokens if tok and tok.lower() not in hay]
if not missing:
return lines
added: list[str] = []
for token in missing:
token_lower = token.lower()
for line in summary_lines:
if token_lower in line.lower() and line not in lines and line not in added:
added.append(line)
break
if len(added) >= max_add:
break
if not added:
return lines
return _merge_fact_lines(added, lines)
def _best_keyword_line(lines: list[str], keywords: list[str]) -> str | None:
if not lines or not keywords:
return None
tokens = _expand_tokens(keywords)
if not tokens:
return None
best = None
best_score = 0
for line in lines:
lower = line.lower()
score = sum(1 for tok in tokens if tok in lower)
if score > best_score:
best_score = score
best = line
return best if best_score > 0 else None
def _line_starting_with(lines: list[str], prefix: str) -> str | None:
if not lines or not prefix:
return None
lower_prefix = prefix.lower()
for line in lines:
if str(line).lower().startswith(lower_prefix):
return line
return None
def _non_rpi_nodes(summary: dict[str, Any]) -> dict[str, list[str]]:
hardware = summary.get("hardware_by_node") if isinstance(summary, dict) else None
if not isinstance(hardware, dict):
return {}
grouped: dict[str, list[str]] = {}
for node, hw in hardware.items():
if not isinstance(node, str) or not isinstance(hw, str):
continue
if hw.startswith("rpi"):
continue
grouped.setdefault(hw, []).append(node)
for nodes in grouped.values():
nodes.sort()
return grouped
def _format_hardware_groups(groups: dict[str, list[str]], label: str) -> str:
if not groups:
return ""
parts = []
for hw, nodes in sorted(groups.items()):
parts.append(f"{hw} ({', '.join(nodes)})")
return f"{label}: " + "; ".join(parts) + "."
def _lexicon_context(summary: dict[str, Any]) -> str: # noqa: C901
if not isinstance(summary, dict):
return ""
lexicon = summary.get("lexicon")
if not isinstance(lexicon, dict):
return ""
terms = lexicon.get("terms")
aliases = lexicon.get("aliases")
lines: list[str] = []
if isinstance(terms, list):
for entry in terms[:8]:
if not isinstance(entry, dict):
continue
term = entry.get("term")
meaning = entry.get("meaning")
if term and meaning:
lines.append(f"{term}: {meaning}")
if isinstance(aliases, dict):
for key, value in list(aliases.items())[:6]:
if key and value:
lines.append(f"alias {key} -> {value}")
if not lines:
return ""
return "Lexicon:\n" + "\n".join(lines)
def _parse_json_block(text: str, *, fallback: dict[str, Any]) -> dict[str, Any]:
raw = text.strip()
match = re.search(r"\{.*\}", raw, flags=re.S)
if match:
return parse_json(match.group(0), fallback=fallback)
return parse_json(raw, fallback=fallback)
def _parse_json_list(text: str) -> list[dict[str, Any]]:
raw = text.strip()
match = re.search(r"\[.*\]", raw, flags=re.S)
data = parse_json(match.group(0), fallback={}) if match else parse_json(raw, fallback={})
if isinstance(data, list):
return [entry for entry in data if isinstance(entry, dict)]
return []
def _scores_from_json(data: dict[str, Any]) -> AnswerScores:
return AnswerScores(
confidence=_coerce_int(data.get("confidence"), 60),
relevance=_coerce_int(data.get("relevance"), 60),
satisfaction=_coerce_int(data.get("satisfaction"), 60),
hallucination_risk=str(data.get("hallucination_risk") or "medium"),
)
def _coerce_int(value: Any, default: int) -> int:
try:
return int(float(value))
except (TypeError, ValueError):
return default
def _default_scores() -> AnswerScores:
return AnswerScores(confidence=60, relevance=60, satisfaction=60, hallucination_risk="medium")
def _style_hint(classify: dict[str, Any]) -> str:
style = (classify.get("answer_style") or "").strip().lower()
qtype = (classify.get("question_type") or "").strip().lower()
if style == "insightful" or qtype in {"open_ended", "planning"}:
return "insightful"
return "direct"
def _needs_evidence_fix(reply: str, classify: dict[str, Any]) -> bool:
if not reply:
return False
lowered = reply.lower()
missing_markers = (
"don't have",
"do not have",
"don't know",
"cannot",
"can't",
"need to",
"would need",
"does not provide",
"does not mention",
"not mention",
"not provided",
"not in context",
"not referenced",
"missing",
"no specific",
"no information",
)
if classify.get("needs_snapshot") and any(marker in lowered for marker in missing_markers):
return True
return classify.get("question_type") in {"metric", "diagnostic"} and not re.search(r"\d", reply)
def _should_use_insight_guard(classify: dict[str, Any]) -> bool:
style = (classify.get("answer_style") or "").strip().lower()
qtype = (classify.get("question_type") or "").strip().lower()
return style == "insightful" or qtype in {"open_ended", "planning"}
async def _apply_insight_guard(inputs: InsightGuardInput) -> str:
if not inputs.reply or not _should_use_insight_guard(inputs.classify):
return inputs.reply
guard_prompt = prompts.INSIGHT_GUARD_PROMPT.format(question=inputs.question, answer=inputs.reply)
guard_raw = await inputs.call_llm(
prompts.INSIGHT_GUARD_SYSTEM,
guard_prompt,
context=inputs.context,
model=inputs.plan.fast_model,
tag="insight_guard",
)
guard = _parse_json_block(guard_raw, fallback={})
if guard.get("ok") is True:
return inputs.reply
fix_prompt = prompts.INSIGHT_FIX_PROMPT.format(question=inputs.question, answer=inputs.reply)
if inputs.facts:
fix_prompt = fix_prompt + "\nFacts:\n" + "\n".join(inputs.facts[:6])
return await inputs.call_llm(
prompts.INSIGHT_FIX_SYSTEM,
fix_prompt,
context=inputs.context,
model=inputs.plan.model,
tag="insight_fix",
)
__all__ = [name for name in globals() if name.startswith("_") and not name.startswith("__")]