diff --git a/services/comms/scripts/atlasbot/bot.py b/services/comms/scripts/atlasbot/bot.py index 4fa67d4..8806d2a 100644 --- a/services/comms/scripts/atlasbot/bot.py +++ b/services/comms/scripts/atlasbot/bot.py @@ -3854,6 +3854,18 @@ def _fallback_fact_answer(prompt: str, context: str) -> str: return sentence +def _is_quantitative_prompt(prompt: str) -> bool: + q = normalize_query(prompt) + if not q: + return False + tokens = set(_tokens(prompt)) + if "how many" in q or "count" in tokens or "total" in tokens: + return True + if tokens & {"highest", "lowest", "hottest", "most", "least"}: + return True + return False + + def _open_ended_fast_single( prompt: str, *, @@ -3880,10 +3892,9 @@ def _open_ended_fast_single( system_override=_open_ended_system(), model=model, ) - if not _has_body_lines(reply): - fallback = _fallback_fact_answer(prompt, context) - if fallback: - reply = fallback + fallback = _fallback_fact_answer(prompt, context) + if fallback and (_is_quantitative_prompt(prompt) or not _has_body_lines(reply)): + reply = fallback if state: state.update("done", step=_open_ended_total_steps("fast")) return _ensure_scores(reply)