atlasbot: wire quick smart genius modes
This commit is contained in:
parent
f04f032721
commit
a1e90f4600
@ -17,12 +17,15 @@ BOT_USER_QUICK = os.environ.get("BOT_USER_QUICK", "").strip()
|
|||||||
BOT_PASS_QUICK = os.environ.get("BOT_PASS_QUICK", "").strip()
|
BOT_PASS_QUICK = os.environ.get("BOT_PASS_QUICK", "").strip()
|
||||||
BOT_USER_SMART = os.environ.get("BOT_USER_SMART", "").strip()
|
BOT_USER_SMART = os.environ.get("BOT_USER_SMART", "").strip()
|
||||||
BOT_PASS_SMART = os.environ.get("BOT_PASS_SMART", "").strip()
|
BOT_PASS_SMART = os.environ.get("BOT_PASS_SMART", "").strip()
|
||||||
|
BOT_USER_GENIUS = os.environ.get("BOT_USER_GENIUS", "").strip()
|
||||||
|
BOT_PASS_GENIUS = os.environ.get("BOT_PASS_GENIUS", "").strip()
|
||||||
ROOM_ALIAS = "#othrys:live.bstein.dev"
|
ROOM_ALIAS = "#othrys:live.bstein.dev"
|
||||||
|
|
||||||
OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/")
|
OLLAMA_URL = os.environ.get("OLLAMA_URL", "https://chat.ai.bstein.dev/")
|
||||||
MODEL = os.environ.get("OLLAMA_MODEL", "qwen2.5:14b-instruct")
|
MODEL = os.environ.get("OLLAMA_MODEL", "qwen2.5:14b-instruct")
|
||||||
MODEL_FAST = os.environ.get("ATLASBOT_MODEL_FAST", "")
|
MODEL_FAST = os.environ.get("ATLASBOT_MODEL_FAST", "")
|
||||||
MODEL_DEEP = os.environ.get("ATLASBOT_MODEL_DEEP", "")
|
MODEL_SMART = os.environ.get("ATLASBOT_MODEL_SMART", os.environ.get("ATLASBOT_MODEL_DEEP", "")).strip()
|
||||||
|
MODEL_GENIUS = os.environ.get("ATLASBOT_MODEL_GENIUS", MODEL_SMART).strip()
|
||||||
FALLBACK_MODEL = os.environ.get("OLLAMA_FALLBACK_MODEL", "")
|
FALLBACK_MODEL = os.environ.get("OLLAMA_FALLBACK_MODEL", "")
|
||||||
API_KEY = os.environ.get("CHAT_API_KEY", "")
|
API_KEY = os.environ.get("CHAT_API_KEY", "")
|
||||||
OLLAMA_TIMEOUT_SEC = float(os.environ.get("OLLAMA_TIMEOUT_SEC", "480"))
|
OLLAMA_TIMEOUT_SEC = float(os.environ.get("OLLAMA_TIMEOUT_SEC", "480"))
|
||||||
@ -43,6 +46,9 @@ MAX_TOOL_CHARS = int(os.environ.get("ATLASBOT_MAX_TOOL_CHARS", "2500"))
|
|||||||
MAX_FACTS_CHARS = int(os.environ.get("ATLASBOT_MAX_FACTS_CHARS", "8000"))
|
MAX_FACTS_CHARS = int(os.environ.get("ATLASBOT_MAX_FACTS_CHARS", "8000"))
|
||||||
MAX_CONTEXT_CHARS = int(os.environ.get("ATLASBOT_MAX_CONTEXT_CHARS", "12000"))
|
MAX_CONTEXT_CHARS = int(os.environ.get("ATLASBOT_MAX_CONTEXT_CHARS", "12000"))
|
||||||
THINKING_INTERVAL_SEC = int(os.environ.get("ATLASBOT_THINKING_INTERVAL_SEC", "120"))
|
THINKING_INTERVAL_SEC = int(os.environ.get("ATLASBOT_THINKING_INTERVAL_SEC", "120"))
|
||||||
|
QUICK_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_QUICK_TIME_BUDGET_SEC", "15"))
|
||||||
|
SMART_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_SMART_TIME_BUDGET_SEC", "45"))
|
||||||
|
GENIUS_TIME_BUDGET_SEC = float(os.environ.get("ATLASBOT_GENIUS_TIME_BUDGET_SEC", "180"))
|
||||||
OLLAMA_RETRIES = int(os.environ.get("ATLASBOT_OLLAMA_RETRIES", "2"))
|
OLLAMA_RETRIES = int(os.environ.get("ATLASBOT_OLLAMA_RETRIES", "2"))
|
||||||
OLLAMA_SERIALIZE = os.environ.get("ATLASBOT_OLLAMA_SERIALIZE", "true").lower() != "false"
|
OLLAMA_SERIALIZE = os.environ.get("ATLASBOT_OLLAMA_SERIALIZE", "true").lower() != "false"
|
||||||
|
|
||||||
@ -384,16 +390,20 @@ def _strip_bot_mention(text: str) -> str:
|
|||||||
return cleaned or text.strip()
|
return cleaned or text.strip()
|
||||||
|
|
||||||
|
|
||||||
def _detect_mode_from_body(body: str, *, default: str = "deep") -> str:
|
def _detect_mode_from_body(body: str, *, default: str = "smart") -> str:
|
||||||
lower = normalize_query(body or "")
|
lower = normalize_query(body or "")
|
||||||
if "atlas_quick" in lower or "atlas-quick" in lower:
|
if "atlas_quick" in lower or "atlas-quick" in lower:
|
||||||
return "fast"
|
return "fast"
|
||||||
if "atlas_smart" in lower or "atlas-smart" in lower:
|
if "atlas_smart" in lower or "atlas-smart" in lower:
|
||||||
return "deep"
|
return "smart"
|
||||||
|
if "atlas_genius" in lower or "atlas-genius" in lower:
|
||||||
|
return "genius"
|
||||||
if lower.startswith("quick ") or lower.startswith("fast "):
|
if lower.startswith("quick ") or lower.startswith("fast "):
|
||||||
return "fast"
|
return "fast"
|
||||||
if lower.startswith("smart ") or lower.startswith("deep "):
|
if lower.startswith("smart "):
|
||||||
return "deep"
|
return "smart"
|
||||||
|
if lower.startswith("genius ") or lower.startswith("deep "):
|
||||||
|
return "genius"
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
@ -401,7 +411,7 @@ def _detect_mode(
|
|||||||
content: dict[str, Any],
|
content: dict[str, Any],
|
||||||
body: str,
|
body: str,
|
||||||
*,
|
*,
|
||||||
default: str = "deep",
|
default: str = "smart",
|
||||||
account_user: str = "",
|
account_user: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
mode = _detect_mode_from_body(body, default=default)
|
mode = _detect_mode_from_body(body, default=default)
|
||||||
@ -412,24 +422,72 @@ def _detect_mode(
|
|||||||
if BOT_USER_QUICK and normalize_user_id(BOT_USER_QUICK).lower() in normalized:
|
if BOT_USER_QUICK and normalize_user_id(BOT_USER_QUICK).lower() in normalized:
|
||||||
return "fast"
|
return "fast"
|
||||||
if BOT_USER_SMART and normalize_user_id(BOT_USER_SMART).lower() in normalized:
|
if BOT_USER_SMART and normalize_user_id(BOT_USER_SMART).lower() in normalized:
|
||||||
return "deep"
|
return "smart"
|
||||||
|
if BOT_USER_GENIUS and normalize_user_id(BOT_USER_GENIUS).lower() in normalized:
|
||||||
|
return "genius"
|
||||||
if BOT_USER and normalize_user_id(BOT_USER).lower() in normalized:
|
if BOT_USER and normalize_user_id(BOT_USER).lower() in normalized:
|
||||||
return "deep"
|
return "smart"
|
||||||
if account_user and BOT_USER_QUICK and normalize_user_id(account_user) == normalize_user_id(BOT_USER_QUICK):
|
if account_user and BOT_USER_QUICK and normalize_user_id(account_user) == normalize_user_id(BOT_USER_QUICK):
|
||||||
return "fast"
|
return "fast"
|
||||||
if account_user and BOT_USER_SMART and normalize_user_id(account_user) == normalize_user_id(BOT_USER_SMART):
|
if account_user and BOT_USER_SMART and normalize_user_id(account_user) == normalize_user_id(BOT_USER_SMART):
|
||||||
return "deep"
|
return "smart"
|
||||||
|
if account_user and BOT_USER_GENIUS and normalize_user_id(account_user) == normalize_user_id(BOT_USER_GENIUS):
|
||||||
|
return "genius"
|
||||||
return mode
|
return mode
|
||||||
|
|
||||||
|
|
||||||
def _model_for_mode(mode: str) -> str:
|
def _model_for_mode(mode: str) -> str:
|
||||||
if mode == "fast" and MODEL_FAST:
|
if mode == "fast" and MODEL_FAST:
|
||||||
return MODEL_FAST
|
return MODEL_FAST
|
||||||
if mode == "deep" and MODEL_DEEP:
|
if mode == "smart" and MODEL_SMART:
|
||||||
return MODEL_DEEP
|
return MODEL_SMART
|
||||||
|
if mode == "genius" and MODEL_GENIUS:
|
||||||
|
return MODEL_GENIUS
|
||||||
|
if mode == "deep" and MODEL_SMART:
|
||||||
|
return MODEL_SMART
|
||||||
return MODEL
|
return MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_mode(mode: str) -> str:
|
||||||
|
normalized = (mode or "").strip().lower()
|
||||||
|
if normalized in {"quick", "fast"}:
|
||||||
|
return "fast"
|
||||||
|
if normalized in {"smart"}:
|
||||||
|
return "smart"
|
||||||
|
if normalized in {"genius", "deep"}:
|
||||||
|
return "genius"
|
||||||
|
return "smart"
|
||||||
|
|
||||||
|
|
||||||
|
def _mode_time_budget_sec(mode: str) -> float:
|
||||||
|
normalized = _normalize_mode(mode)
|
||||||
|
if normalized == "fast":
|
||||||
|
return max(1.0, QUICK_TIME_BUDGET_SEC)
|
||||||
|
if normalized == "smart":
|
||||||
|
return max(1.0, SMART_TIME_BUDGET_SEC)
|
||||||
|
if normalized == "genius":
|
||||||
|
return max(1.0, GENIUS_TIME_BUDGET_SEC)
|
||||||
|
return max(1.0, SMART_TIME_BUDGET_SEC)
|
||||||
|
|
||||||
|
|
||||||
|
def _mode_ollama_timeout_sec(mode: str) -> float:
|
||||||
|
normalized = _normalize_mode(mode)
|
||||||
|
budget = _mode_time_budget_sec(normalized)
|
||||||
|
if normalized == "fast":
|
||||||
|
return max(6.0, min(budget - 2.0, OLLAMA_TIMEOUT_SEC))
|
||||||
|
if normalized == "smart":
|
||||||
|
return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC))
|
||||||
|
if normalized == "genius":
|
||||||
|
return max(20.0, min(budget - 10.0, OLLAMA_TIMEOUT_SEC))
|
||||||
|
return max(12.0, min(budget - 5.0, OLLAMA_TIMEOUT_SEC))
|
||||||
|
|
||||||
|
|
||||||
|
def _mode_heartbeat_sec(mode: str) -> int:
|
||||||
|
normalized = _normalize_mode(mode)
|
||||||
|
budget = _mode_time_budget_sec(normalized)
|
||||||
|
return max(5, min(THINKING_INTERVAL_SEC, int(max(5.0, budget / 3.0))))
|
||||||
|
|
||||||
|
|
||||||
# Matrix HTTP helper.
|
# Matrix HTTP helper.
|
||||||
def req(method: str, path: str, token: str | None = None, body=None, timeout=60, base: str | None = None):
|
def req(method: str, path: str, token: str | None = None, body=None, timeout=60, base: str | None = None):
|
||||||
url = (base or BASE) + path
|
url = (base or BASE) + path
|
||||||
@ -3842,9 +3900,12 @@ def _open_ended_multi(
|
|||||||
|
|
||||||
|
|
||||||
def _open_ended_total_steps(mode: str) -> int:
|
def _open_ended_total_steps(mode: str) -> int:
|
||||||
if mode == "fast":
|
normalized = _normalize_mode(mode)
|
||||||
|
if normalized == "fast":
|
||||||
return 2
|
return 2
|
||||||
return 9
|
if normalized == "smart":
|
||||||
|
return 3
|
||||||
|
return 4
|
||||||
|
|
||||||
|
|
||||||
def _fast_fact_lines(
|
def _fast_fact_lines(
|
||||||
@ -4179,6 +4240,7 @@ def _open_ended_fast_single(
|
|||||||
use_history=False,
|
use_history=False,
|
||||||
system_override=_open_ended_system(),
|
system_override=_open_ended_system(),
|
||||||
model=model,
|
model=model,
|
||||||
|
timeout=_mode_ollama_timeout_sec("fast"),
|
||||||
)
|
)
|
||||||
if not _has_body_lines(reply):
|
if not _has_body_lines(reply):
|
||||||
reply = _ollama_call(
|
reply = _ollama_call(
|
||||||
@ -4188,6 +4250,7 @@ def _open_ended_fast_single(
|
|||||||
use_history=False,
|
use_history=False,
|
||||||
system_override=_open_ended_system(),
|
system_override=_open_ended_system(),
|
||||||
model=model,
|
model=model,
|
||||||
|
timeout=_mode_ollama_timeout_sec("fast"),
|
||||||
)
|
)
|
||||||
fallback = _fallback_fact_answer(prompt, context)
|
fallback = _fallback_fact_answer(prompt, context)
|
||||||
if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)):
|
if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)):
|
||||||
@ -4248,16 +4311,53 @@ def _open_ended_deep(
|
|||||||
fact_lines: list[str],
|
fact_lines: list[str],
|
||||||
fact_meta: dict[str, dict[str, Any]],
|
fact_meta: dict[str, dict[str, Any]],
|
||||||
history_lines: list[str],
|
history_lines: list[str],
|
||||||
|
mode: str,
|
||||||
state: ThoughtState | None = None,
|
state: ThoughtState | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
return _open_ended_multi(
|
normalized = _normalize_mode(mode)
|
||||||
prompt,
|
model = _model_for_mode(normalized)
|
||||||
fact_pack=fact_pack,
|
subjective = _is_subjective_query(prompt)
|
||||||
fact_lines=fact_lines,
|
primary_tags = _primary_tags_for_prompt(prompt)
|
||||||
fact_meta=fact_meta,
|
focus_tags = _preferred_tags_for_prompt(prompt)
|
||||||
history_lines=history_lines,
|
if not focus_tags and subjective:
|
||||||
state=state,
|
focus_tags = set(_ALLOWED_INSIGHT_TAGS)
|
||||||
|
avoid_tags = _history_focus_tags(history_lines) if (subjective or _is_followup_query(prompt)) else set()
|
||||||
|
limit = 12 if normalized == "smart" else 18
|
||||||
|
selected_lines = _fast_fact_lines(
|
||||||
|
fact_lines,
|
||||||
|
fact_meta,
|
||||||
|
focus_tags=focus_tags,
|
||||||
|
avoid_tags=avoid_tags,
|
||||||
|
primary_tags=primary_tags,
|
||||||
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
selected_meta = _fact_pack_meta(selected_lines)
|
||||||
|
selected_pack = _fact_pack_text(selected_lines, selected_meta)
|
||||||
|
if _needs_full_fact_pack(prompt) or not selected_lines or normalized == "genius":
|
||||||
|
selected_pack = fact_pack
|
||||||
|
fallback = _fallback_fact_answer(prompt, selected_pack)
|
||||||
|
if not subjective and fallback:
|
||||||
|
if state:
|
||||||
|
state.update("done", step=_open_ended_total_steps(normalized))
|
||||||
|
return _ensure_scores(fallback)
|
||||||
|
if state:
|
||||||
|
state.update("drafting", step=1, note="synthesizing")
|
||||||
|
reply = _ollama_call(
|
||||||
|
("atlasbot_deep", "atlasbot_deep"),
|
||||||
|
prompt,
|
||||||
|
context=_append_history_context(selected_pack, history_lines),
|
||||||
|
use_history=False,
|
||||||
|
system_override=_open_ended_system(),
|
||||||
|
model=model,
|
||||||
|
timeout=_mode_ollama_timeout_sec(normalized),
|
||||||
|
)
|
||||||
|
if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)):
|
||||||
|
reply = fallback
|
||||||
|
if not _has_body_lines(reply):
|
||||||
|
reply = "I don't have enough data in the current snapshot to answer that."
|
||||||
|
if state:
|
||||||
|
state.update("done", step=_open_ended_total_steps(normalized))
|
||||||
|
return _ensure_scores(reply)
|
||||||
|
|
||||||
|
|
||||||
def open_ended_answer(
|
def open_ended_answer(
|
||||||
@ -4285,7 +4385,8 @@ def open_ended_answer(
|
|||||||
return _ensure_scores("I don't have enough data to answer that.")
|
return _ensure_scores("I don't have enough data to answer that.")
|
||||||
fact_meta = _fact_pack_meta(lines)
|
fact_meta = _fact_pack_meta(lines)
|
||||||
fact_pack = _fact_pack_text(lines, fact_meta)
|
fact_pack = _fact_pack_text(lines, fact_meta)
|
||||||
if mode == "fast":
|
normalized = _normalize_mode(mode)
|
||||||
|
if normalized == "fast":
|
||||||
return _open_ended_fast(
|
return _open_ended_fast(
|
||||||
prompt,
|
prompt,
|
||||||
fact_pack=fact_pack,
|
fact_pack=fact_pack,
|
||||||
@ -4300,6 +4401,7 @@ def open_ended_answer(
|
|||||||
fact_lines=lines,
|
fact_lines=lines,
|
||||||
fact_meta=fact_meta,
|
fact_meta=fact_meta,
|
||||||
history_lines=history_lines,
|
history_lines=history_lines,
|
||||||
|
mode=normalized,
|
||||||
state=state,
|
state=state,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4321,6 +4423,7 @@ def _non_cluster_reply(prompt: str, *, history_lines: list[str], mode: str) -> s
|
|||||||
use_history=False,
|
use_history=False,
|
||||||
system_override=system,
|
system_override=system,
|
||||||
model=model,
|
model=model,
|
||||||
|
timeout=_mode_ollama_timeout_sec(mode),
|
||||||
)
|
)
|
||||||
reply = re.sub(r"\bconfidence\s*:\s*(high|medium|low)\b\.?\s*", "", reply, flags=re.IGNORECASE).strip()
|
reply = re.sub(r"\bconfidence\s*:\s*(high|medium|low)\b\.?\s*", "", reply, flags=re.IGNORECASE).strip()
|
||||||
return _ensure_scores(reply)
|
return _ensure_scores(reply)
|
||||||
@ -4372,13 +4475,7 @@ class _AtlasbotHandler(BaseHTTPRequestHandler):
|
|||||||
self._write_json(400, {"error": "missing_prompt"})
|
self._write_json(400, {"error": "missing_prompt"})
|
||||||
return
|
return
|
||||||
cleaned = _strip_bot_mention(prompt)
|
cleaned = _strip_bot_mention(prompt)
|
||||||
mode = str(payload.get("mode") or "deep").lower()
|
mode = _normalize_mode(str(payload.get("mode") or "smart"))
|
||||||
if mode in ("quick", "fast"):
|
|
||||||
mode = "fast"
|
|
||||||
elif mode in ("smart", "deep"):
|
|
||||||
mode = "deep"
|
|
||||||
else:
|
|
||||||
mode = "deep"
|
|
||||||
snapshot = _snapshot_state()
|
snapshot = _snapshot_state()
|
||||||
inventory = _snapshot_inventory(snapshot) or node_inventory_live()
|
inventory = _snapshot_inventory(snapshot) or node_inventory_live()
|
||||||
workloads = _snapshot_workloads(snapshot)
|
workloads = _snapshot_workloads(snapshot)
|
||||||
@ -4669,6 +4766,7 @@ def _ollama_call(
|
|||||||
use_history: bool = True,
|
use_history: bool = True,
|
||||||
system_override: str | None = None,
|
system_override: str | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
system = system_override or (
|
system = system_override or (
|
||||||
"System: You are Atlas, the Titan lab assistant for Atlas/Othrys. "
|
"System: You are Atlas, the Titan lab assistant for Atlas/Othrys. "
|
||||||
@ -4702,6 +4800,7 @@ def _ollama_call(
|
|||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
model_name = model or MODEL
|
model_name = model or MODEL
|
||||||
|
request_timeout = timeout if timeout is not None else OLLAMA_TIMEOUT_SEC
|
||||||
payload = {"model": model_name, "messages": messages, "stream": False}
|
payload = {"model": model_name, "messages": messages, "stream": False}
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
if API_KEY:
|
if API_KEY:
|
||||||
@ -4712,13 +4811,13 @@ def _ollama_call(
|
|||||||
lock.acquire()
|
lock.acquire()
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp:
|
with request.urlopen(r, timeout=request_timeout) as resp:
|
||||||
data = json.loads(resp.read().decode())
|
data = json.loads(resp.read().decode())
|
||||||
except error.HTTPError as exc:
|
except error.HTTPError as exc:
|
||||||
if exc.code == 404 and FALLBACK_MODEL and FALLBACK_MODEL != payload["model"]:
|
if exc.code == 404 and FALLBACK_MODEL and FALLBACK_MODEL != payload["model"]:
|
||||||
payload["model"] = FALLBACK_MODEL
|
payload["model"] = FALLBACK_MODEL
|
||||||
r = request.Request(endpoint, data=json.dumps(payload).encode(), headers=headers)
|
r = request.Request(endpoint, data=json.dumps(payload).encode(), headers=headers)
|
||||||
with request.urlopen(r, timeout=OLLAMA_TIMEOUT_SEC) as resp:
|
with request.urlopen(r, timeout=request_timeout) as resp:
|
||||||
data = json.loads(resp.read().decode())
|
data = json.loads(resp.read().decode())
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
@ -4743,6 +4842,7 @@ def ollama_reply(
|
|||||||
fallback: str = "",
|
fallback: str = "",
|
||||||
use_history: bool = True,
|
use_history: bool = True,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
last_error = None
|
last_error = None
|
||||||
for attempt in range(max(1, OLLAMA_RETRIES + 1)):
|
for attempt in range(max(1, OLLAMA_RETRIES + 1)):
|
||||||
@ -4753,6 +4853,7 @@ def ollama_reply(
|
|||||||
context=context,
|
context=context,
|
||||||
use_history=use_history,
|
use_history=use_history,
|
||||||
model=model,
|
model=model,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
last_error = exc
|
last_error = exc
|
||||||
@ -4773,6 +4874,7 @@ def ollama_reply_with_thinking(
|
|||||||
fallback: str,
|
fallback: str,
|
||||||
use_history: bool = True,
|
use_history: bool = True,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
result: dict[str, str] = {"reply": ""}
|
result: dict[str, str] = {"reply": ""}
|
||||||
done = threading.Event()
|
done = threading.Event()
|
||||||
@ -4785,6 +4887,7 @@ def ollama_reply_with_thinking(
|
|||||||
fallback=fallback,
|
fallback=fallback,
|
||||||
use_history=use_history,
|
use_history=use_history,
|
||||||
model=model,
|
model=model,
|
||||||
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
@ -4841,7 +4944,7 @@ def open_ended_with_thinking(
|
|||||||
thread.start()
|
thread.start()
|
||||||
if not done.wait(2.0):
|
if not done.wait(2.0):
|
||||||
send_msg(token, room, "Thinking…")
|
send_msg(token, room, "Thinking…")
|
||||||
heartbeat = max(10, THINKING_INTERVAL_SEC)
|
heartbeat = _mode_heartbeat_sec(mode)
|
||||||
next_heartbeat = time.monotonic() + heartbeat
|
next_heartbeat = time.monotonic() + heartbeat
|
||||||
while not done.wait(max(0, next_heartbeat - time.monotonic())):
|
while not done.wait(max(0, next_heartbeat - time.monotonic())):
|
||||||
send_msg(token, room, state.status_line())
|
send_msg(token, room, state.status_line())
|
||||||
@ -4906,7 +5009,7 @@ def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str)
|
|||||||
mode = _detect_mode(
|
mode = _detect_mode(
|
||||||
content,
|
content,
|
||||||
body,
|
body,
|
||||||
default=default_mode if default_mode in ("fast", "deep") else "deep",
|
default=_normalize_mode(default_mode),
|
||||||
account_user=account_user,
|
account_user=account_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4972,14 +5075,14 @@ def sync_loop(token: str, room_id: str, *, account_user: str, default_mode: str)
|
|||||||
snapshot=snapshot,
|
snapshot=snapshot,
|
||||||
workloads=workloads,
|
workloads=workloads,
|
||||||
history_lines=history[hist_key],
|
history_lines=history[hist_key],
|
||||||
mode=mode if mode in ("fast", "deep") else "deep",
|
mode=_normalize_mode(mode),
|
||||||
allow_tools=allow_tools,
|
allow_tools=allow_tools,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
reply = _non_cluster_reply(
|
reply = _non_cluster_reply(
|
||||||
cleaned_body,
|
cleaned_body,
|
||||||
history_lines=history[hist_key],
|
history_lines=history[hist_key],
|
||||||
mode=mode if mode in ("fast", "deep") else "deep",
|
mode=_normalize_mode(mode),
|
||||||
)
|
)
|
||||||
send_msg(token, rid, reply)
|
send_msg(token, rid, reply)
|
||||||
history[hist_key].append(f"Atlas: {reply}")
|
history[hist_key].append(f"Atlas: {reply}")
|
||||||
@ -5003,11 +5106,13 @@ def _bot_accounts() -> list[dict[str, str]]:
|
|||||||
return
|
return
|
||||||
accounts.append({"user": user, "password": password, "mode": mode})
|
accounts.append({"user": user, "password": password, "mode": mode})
|
||||||
|
|
||||||
add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "deep")
|
add(BOT_USER_SMART or BOT_USER, BOT_PASS_SMART or BOT_PASS, "smart")
|
||||||
if BOT_USER_QUICK and BOT_PASS_QUICK:
|
if BOT_USER_QUICK and BOT_PASS_QUICK:
|
||||||
add(BOT_USER_QUICK, BOT_PASS_QUICK, "fast")
|
add(BOT_USER_QUICK, BOT_PASS_QUICK, "fast")
|
||||||
|
if BOT_USER_GENIUS and BOT_PASS_GENIUS:
|
||||||
|
add(BOT_USER_GENIUS, BOT_PASS_GENIUS, "genius")
|
||||||
if BOT_USER and BOT_PASS and all(acc["user"] != BOT_USER for acc in accounts):
|
if BOT_USER and BOT_PASS and all(acc["user"] != BOT_USER for acc in accounts):
|
||||||
add(BOT_USER, BOT_PASS, "deep")
|
add(BOT_USER, BOT_PASS, "smart")
|
||||||
|
|
||||||
seen: set[str] = set()
|
seen: set[str] = set()
|
||||||
unique: list[dict[str, str]] = []
|
unique: list[dict[str, str]] = []
|
||||||
|
|||||||
107
services/comms/scripts/tests/test_atlasbot_modes.py
Normal file
107
services/comms/scripts/tests/test_atlasbot_modes.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest import TestCase, mock
|
||||||
|
|
||||||
|
|
||||||
|
BOT_PATH = Path(__file__).resolve().parents[1] / "atlasbot" / "bot.py"
|
||||||
|
|
||||||
|
|
||||||
|
def load_bot_module():
|
||||||
|
env = {
|
||||||
|
"BOT_USER": "atlas-smart",
|
||||||
|
"BOT_PASS": "smart-pass",
|
||||||
|
"BOT_USER_QUICK": "atlas-quick",
|
||||||
|
"BOT_PASS_QUICK": "quick-pass",
|
||||||
|
"BOT_USER_SMART": "atlas-smart",
|
||||||
|
"BOT_PASS_SMART": "smart-pass",
|
||||||
|
"BOT_USER_GENIUS": "atlas-genius",
|
||||||
|
"BOT_PASS_GENIUS": "genius-pass",
|
||||||
|
"OLLAMA_URL": "http://ollama.invalid",
|
||||||
|
"OLLAMA_MODEL": "base-model",
|
||||||
|
"ATLASBOT_MODEL_FAST": "fast-model",
|
||||||
|
"ATLASBOT_MODEL_SMART": "smart-model",
|
||||||
|
"ATLASBOT_MODEL_GENIUS": "genius-model",
|
||||||
|
"ATLASBOT_QUICK_TIME_BUDGET_SEC": "15",
|
||||||
|
"ATLASBOT_SMART_TIME_BUDGET_SEC": "45",
|
||||||
|
"ATLASBOT_GENIUS_TIME_BUDGET_SEC": "180",
|
||||||
|
"KB_DIR": "",
|
||||||
|
"VM_URL": "http://vm.invalid",
|
||||||
|
"ARIADNE_STATE_URL": "",
|
||||||
|
"ARIADNE_STATE_TOKEN": "",
|
||||||
|
}
|
||||||
|
with mock.patch.dict(os.environ, env, clear=False):
|
||||||
|
spec = importlib.util.spec_from_file_location("atlasbot_bot", BOT_PATH)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec.loader is not None
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class AtlasbotModeTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.bot = load_bot_module()
|
||||||
|
|
||||||
|
def test_bot_accounts_include_genius_mode(self):
|
||||||
|
accounts = self.bot._bot_accounts()
|
||||||
|
by_user = {account["user"]: account["mode"] for account in accounts}
|
||||||
|
|
||||||
|
self.assertEqual(by_user["atlas-quick"], "fast")
|
||||||
|
self.assertEqual(by_user["atlas-smart"], "smart")
|
||||||
|
self.assertEqual(by_user["atlas-genius"], "genius")
|
||||||
|
|
||||||
|
def test_objective_cluster_question_uses_fact_pack_without_llm(self):
|
||||||
|
fact_lines = [
|
||||||
|
"hottest_cpu: longhorn-system (6.69)",
|
||||||
|
"hottest_ram: longhorn-system (36.05 GB)",
|
||||||
|
]
|
||||||
|
|
||||||
|
with (
|
||||||
|
mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines),
|
||||||
|
mock.patch.object(self.bot, "_ollama_call", side_effect=AssertionError("LLM should not be called")),
|
||||||
|
):
|
||||||
|
reply = self.bot.open_ended_answer(
|
||||||
|
"what is the hottest cpu node in titan lab currently?",
|
||||||
|
inventory=[],
|
||||||
|
snapshot=None,
|
||||||
|
workloads=[],
|
||||||
|
history_lines=[],
|
||||||
|
mode="smart",
|
||||||
|
allow_tools=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("longhorn-system", reply)
|
||||||
|
self.assertIn("Confidence:", reply)
|
||||||
|
|
||||||
|
def test_subjective_genius_answer_uses_genius_model(self):
|
||||||
|
fact_lines = [
|
||||||
|
"hottest_cpu: longhorn-system (6.69)",
|
||||||
|
"worker_nodes: titan-01, titan-02, titan-03",
|
||||||
|
]
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
def fake_ollama_call(hist_key, prompt, *, context, use_history=True, system_override=None, model=None, timeout=None):
|
||||||
|
captured["model"] = model
|
||||||
|
captured["timeout"] = timeout
|
||||||
|
captured["context"] = context
|
||||||
|
return "The worker spread stands out because Titan keeps meaningful capacity on the same cluster. Confidence: high"
|
||||||
|
|
||||||
|
with (
|
||||||
|
mock.patch.object(self.bot, "_fact_pack_lines", return_value=fact_lines),
|
||||||
|
mock.patch.object(self.bot, "_ollama_call", side_effect=fake_ollama_call),
|
||||||
|
):
|
||||||
|
reply = self.bot.open_ended_answer(
|
||||||
|
"what stands out about titan lab?",
|
||||||
|
inventory=[],
|
||||||
|
snapshot=None,
|
||||||
|
workloads=[],
|
||||||
|
history_lines=[],
|
||||||
|
mode="genius",
|
||||||
|
allow_tools=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("The worker spread stands out", reply)
|
||||||
|
self.assertEqual(captured["model"], "genius-model")
|
||||||
|
self.assertLessEqual(float(captured["timeout"]), 180.0)
|
||||||
Loading…
x
Reference in New Issue
Block a user