atlasbot: add conversation state
This commit is contained in:
parent
0e471ecc37
commit
0476edae6a
@ -20,6 +20,7 @@ class AnswerRequest(BaseModel):
|
|||||||
content: str | None = None
|
content: str | None = None
|
||||||
mode: str | None = None
|
mode: str | None = None
|
||||||
history: list[dict[str, str]] | None = None
|
history: list[dict[str, str]] | None = None
|
||||||
|
conversation_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class AnswerResponse(BaseModel):
|
class AnswerResponse(BaseModel):
|
||||||
@ -30,7 +31,7 @@ class Api:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
answer_handler: Callable[[str, str, list[dict[str, str]] | None], Awaitable[AnswerResult]],
|
answer_handler: Callable[[str, str, list[dict[str, str]] | None, str | None], Awaitable[AnswerResult]],
|
||||||
) -> None:
|
) -> None:
|
||||||
self._settings = settings
|
self._settings = settings
|
||||||
self._answer_handler = answer_handler
|
self._answer_handler = answer_handler
|
||||||
@ -53,7 +54,8 @@ class Api:
|
|||||||
if not question:
|
if not question:
|
||||||
raise HTTPException(status_code=400, detail="missing question")
|
raise HTTPException(status_code=400, detail="missing question")
|
||||||
mode = (payload.mode or "quick").strip().lower()
|
mode = (payload.mode or "quick").strip().lower()
|
||||||
result = await self._answer_handler(question, mode, payload.history)
|
conversation_id = payload.conversation_id
|
||||||
|
result = await self._answer_handler(question, mode, payload.history, conversation_id)
|
||||||
log.info(
|
log.info(
|
||||||
"answer",
|
"answer",
|
||||||
extra={
|
extra={
|
||||||
|
|||||||
@ -53,6 +53,8 @@ class Settings:
|
|||||||
|
|
||||||
snapshot_ttl_sec: int
|
snapshot_ttl_sec: int
|
||||||
thinking_interval_sec: int
|
thinking_interval_sec: int
|
||||||
|
conversation_ttl_sec: int
|
||||||
|
snapshot_pin_enabled: bool
|
||||||
|
|
||||||
queue_enabled: bool
|
queue_enabled: bool
|
||||||
nats_url: str
|
nats_url: str
|
||||||
@ -141,6 +143,8 @@ def load_settings() -> Settings:
|
|||||||
ariadne_state_token=os.getenv("ARIADNE_STATE_TOKEN", ""),
|
ariadne_state_token=os.getenv("ARIADNE_STATE_TOKEN", ""),
|
||||||
snapshot_ttl_sec=_env_int("ATLASBOT_SNAPSHOT_TTL_SEC", "30"),
|
snapshot_ttl_sec=_env_int("ATLASBOT_SNAPSHOT_TTL_SEC", "30"),
|
||||||
thinking_interval_sec=_env_int("ATLASBOT_THINKING_INTERVAL_SEC", "30"),
|
thinking_interval_sec=_env_int("ATLASBOT_THINKING_INTERVAL_SEC", "30"),
|
||||||
|
conversation_ttl_sec=_env_int("ATLASBOT_CONVERSATION_TTL_SEC", "900"),
|
||||||
|
snapshot_pin_enabled=_env_bool("ATLASBOT_SNAPSHOT_PIN_ENABLED", "false"),
|
||||||
queue_enabled=_env_bool("ATLASBOT_QUEUE_ENABLED", "false"),
|
queue_enabled=_env_bool("ATLASBOT_QUEUE_ENABLED", "false"),
|
||||||
nats_url=os.getenv("ATLASBOT_NATS_URL", "nats://nats.nats.svc.cluster.local:4222"),
|
nats_url=os.getenv("ATLASBOT_NATS_URL", "nats://nats.nats.svc.cluster.local:4222"),
|
||||||
nats_stream=os.getenv("ATLASBOT_NATS_STREAM", "atlasbot"),
|
nats_stream=os.getenv("ATLASBOT_NATS_STREAM", "atlasbot"),
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -9,7 +10,7 @@ from atlasbot.config import Settings
|
|||||||
from atlasbot.knowledge.loader import KnowledgeBase
|
from atlasbot.knowledge.loader import KnowledgeBase
|
||||||
from atlasbot.llm.client import LLMClient, build_messages, parse_json
|
from atlasbot.llm.client import LLMClient, build_messages, parse_json
|
||||||
from atlasbot.llm import prompts
|
from atlasbot.llm import prompts
|
||||||
from atlasbot.snapshot.builder import SnapshotProvider, summary_text
|
from atlasbot.snapshot.builder import SnapshotProvider, build_summary, summary_text
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -29,6 +30,29 @@ class AnswerResult:
|
|||||||
meta: dict[str, Any]
|
meta: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvidenceItem:
|
||||||
|
path: str
|
||||||
|
reason: str
|
||||||
|
value: Any | None = None
|
||||||
|
value_at_claim: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClaimItem:
|
||||||
|
id: str
|
||||||
|
claim: str
|
||||||
|
evidence: list[EvidenceItem]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConversationState:
|
||||||
|
updated_at: float
|
||||||
|
claims: list[ClaimItem]
|
||||||
|
snapshot_id: str | None = None
|
||||||
|
snapshot: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class AnswerEngine:
|
class AnswerEngine:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -41,6 +65,7 @@ class AnswerEngine:
|
|||||||
self._llm = llm
|
self._llm = llm
|
||||||
self._kb = kb
|
self._kb = kb
|
||||||
self._snapshot = snapshot
|
self._snapshot = snapshot
|
||||||
|
self._state: dict[str, ConversationState] = {}
|
||||||
|
|
||||||
async def answer(
|
async def answer(
|
||||||
self,
|
self,
|
||||||
@ -49,6 +74,7 @@ class AnswerEngine:
|
|||||||
mode: str,
|
mode: str,
|
||||||
history: list[dict[str, str]] | None = None,
|
history: list[dict[str, str]] | None = None,
|
||||||
observer: Callable[[str, str], None] | None = None,
|
observer: Callable[[str, str], None] | None = None,
|
||||||
|
conversation_id: str | None = None,
|
||||||
) -> AnswerResult:
|
) -> AnswerResult:
|
||||||
question = (question or "").strip()
|
question = (question or "").strip()
|
||||||
if not question:
|
if not question:
|
||||||
@ -56,10 +82,15 @@ class AnswerEngine:
|
|||||||
if mode == "stock":
|
if mode == "stock":
|
||||||
return await self._answer_stock(question)
|
return await self._answer_stock(question)
|
||||||
|
|
||||||
|
state = self._get_state(conversation_id)
|
||||||
snapshot = self._snapshot.get()
|
snapshot = self._snapshot.get()
|
||||||
|
snapshot_used = snapshot
|
||||||
|
if self._settings.snapshot_pin_enabled and state and state.snapshot:
|
||||||
|
snapshot_used = state.snapshot
|
||||||
|
summary = build_summary(snapshot_used)
|
||||||
kb_summary = self._kb.summary()
|
kb_summary = self._kb.summary()
|
||||||
runbooks = self._kb.runbook_titles(limit=4)
|
runbooks = self._kb.runbook_titles(limit=4)
|
||||||
snapshot_ctx = summary_text(snapshot)
|
snapshot_ctx = summary_text(snapshot_used)
|
||||||
history_ctx = _format_history(history)
|
history_ctx = _format_history(history)
|
||||||
base_context = _join_context([
|
base_context = _join_context([
|
||||||
kb_summary,
|
kb_summary,
|
||||||
@ -97,15 +128,30 @@ class AnswerEngine:
|
|||||||
"atlasbot_selection",
|
"atlasbot_selection",
|
||||||
extra={"extra": {"mode": mode, "selected": len(best), "scores": scores.__dict__}},
|
extra={"extra": {"mode": mode, "selected": len(best), "scores": scores.__dict__}},
|
||||||
)
|
)
|
||||||
|
if classify.get("follow_up") and state and state.claims:
|
||||||
|
if observer:
|
||||||
|
observer("followup", "answering follow-up")
|
||||||
|
reply = await self._answer_followup(question, state, summary, classify, mode)
|
||||||
|
meta = {
|
||||||
|
"mode": mode,
|
||||||
|
"follow_up": True,
|
||||||
|
"classify": classify,
|
||||||
|
}
|
||||||
|
return AnswerResult(reply, scores, meta)
|
||||||
|
|
||||||
if observer:
|
if observer:
|
||||||
observer("synthesize", "synthesizing reply")
|
observer("synthesize", "synthesizing reply")
|
||||||
reply = await self._synthesize(question, best, base_context, classify, mode)
|
reply = await self._synthesize(question, best, base_context, classify, mode)
|
||||||
|
claims = await self._extract_claims(question, reply, summary)
|
||||||
|
if conversation_id and claims:
|
||||||
|
self._store_state(conversation_id, claims, summary, snapshot_used)
|
||||||
meta = {
|
meta = {
|
||||||
"mode": mode,
|
"mode": mode,
|
||||||
"angles": angles,
|
"angles": angles,
|
||||||
"scores": scores.__dict__,
|
"scores": scores.__dict__,
|
||||||
"classify": classify,
|
"classify": classify,
|
||||||
"candidates": len(candidates),
|
"candidates": len(candidates),
|
||||||
|
"claims": len(claims),
|
||||||
}
|
}
|
||||||
return AnswerResult(reply, scores, meta)
|
return AnswerResult(reply, scores, meta)
|
||||||
|
|
||||||
@ -121,6 +167,8 @@ class AnswerEngine:
|
|||||||
data = _parse_json_block(raw, fallback={"needs_snapshot": True})
|
data = _parse_json_block(raw, fallback={"needs_snapshot": True})
|
||||||
if "answer_style" not in data:
|
if "answer_style" not in data:
|
||||||
data["answer_style"] = "direct"
|
data["answer_style"] = "direct"
|
||||||
|
if "follow_up_kind" not in data:
|
||||||
|
data["follow_up_kind"] = "other"
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def _angles(self, question: str, classify: dict[str, Any], mode: str) -> list[dict[str, Any]]:
|
async def _angles(self, question: str, classify: dict[str, Any], mode: str) -> list[dict[str, Any]]:
|
||||||
@ -216,6 +264,114 @@ class AnswerEngine:
|
|||||||
refine_messages = build_messages(prompts.CLUSTER_SYSTEM, refine_prompt, context=context)
|
refine_messages = build_messages(prompts.CLUSTER_SYSTEM, refine_prompt, context=context)
|
||||||
return await self._llm.chat(refine_messages, model=model)
|
return await self._llm.chat(refine_messages, model=model)
|
||||||
|
|
||||||
|
async def _extract_claims(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
reply: str,
|
||||||
|
summary: dict[str, Any],
|
||||||
|
) -> list[ClaimItem]:
|
||||||
|
if not reply or not summary:
|
||||||
|
return []
|
||||||
|
summary_json = _json_excerpt(summary)
|
||||||
|
prompt = prompts.CLAIM_MAP_PROMPT + "\nQuestion: " + question + "\nAnswer: " + reply
|
||||||
|
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt, context=f"SnapshotSummaryJson:{summary_json}")
|
||||||
|
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
|
||||||
|
data = _parse_json_block(raw, fallback={})
|
||||||
|
claims_raw = data.get("claims") if isinstance(data, dict) else None
|
||||||
|
claims: list[ClaimItem] = []
|
||||||
|
if isinstance(claims_raw, list):
|
||||||
|
for entry in claims_raw:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
claim_text = str(entry.get("claim") or "").strip()
|
||||||
|
claim_id = str(entry.get("id") or "").strip() or f"c{len(claims)+1}"
|
||||||
|
evidence_items: list[EvidenceItem] = []
|
||||||
|
for ev in entry.get("evidence") or []:
|
||||||
|
if not isinstance(ev, dict):
|
||||||
|
continue
|
||||||
|
path = str(ev.get("path") or "").strip()
|
||||||
|
if not path:
|
||||||
|
continue
|
||||||
|
reason = str(ev.get("reason") or "").strip()
|
||||||
|
value = _resolve_path(summary, path)
|
||||||
|
evidence_items.append(EvidenceItem(path=path, reason=reason, value=value, value_at_claim=value))
|
||||||
|
if claim_text and evidence_items:
|
||||||
|
claims.append(ClaimItem(id=claim_id, claim=claim_text, evidence=evidence_items))
|
||||||
|
return claims
|
||||||
|
|
||||||
|
async def _answer_followup(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
state: ConversationState,
|
||||||
|
summary: dict[str, Any],
|
||||||
|
classify: dict[str, Any],
|
||||||
|
mode: str,
|
||||||
|
) -> str:
|
||||||
|
follow_kind = classify.get("follow_up_kind") if isinstance(classify, dict) else "other"
|
||||||
|
claim_ids = await self._select_claims(question, state.claims)
|
||||||
|
selected = [claim for claim in state.claims if claim.id in claim_ids] if claim_ids else state.claims[:2]
|
||||||
|
evidence_lines = []
|
||||||
|
for claim in selected:
|
||||||
|
evidence_lines.append(f"Claim: {claim.claim}")
|
||||||
|
for ev in claim.evidence:
|
||||||
|
current = _resolve_path(summary, ev.path)
|
||||||
|
ev.value = current
|
||||||
|
delta_note = ""
|
||||||
|
if ev.value_at_claim is not None and current is not None and current != ev.value_at_claim:
|
||||||
|
delta_note = f" (now {current})"
|
||||||
|
evidence_lines.append(f"- {ev.path}: {ev.value_at_claim}{delta_note} {('- ' + ev.reason) if ev.reason else ''}")
|
||||||
|
evidence_ctx = "\n".join(evidence_lines)
|
||||||
|
prompt = prompts.FOLLOWUP_EVIDENCE_PROMPT
|
||||||
|
if follow_kind in {"next_steps", "change"}:
|
||||||
|
prompt = prompts.FOLLOWUP_ACTION_PROMPT
|
||||||
|
prompt = prompt + "\nFollow-up: " + question + "\nEvidence:\n" + evidence_ctx
|
||||||
|
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
|
||||||
|
model = _synthesis_model(self._settings, mode)
|
||||||
|
return await self._llm.chat(messages, model=model)
|
||||||
|
|
||||||
|
async def _select_claims(self, question: str, claims: list[ClaimItem]) -> list[str]:
|
||||||
|
if not claims:
|
||||||
|
return []
|
||||||
|
claims_brief = [{"id": claim.id, "claim": claim.claim} for claim in claims]
|
||||||
|
prompt = prompts.SELECT_CLAIMS_PROMPT + "\nFollow-up: " + question + "\nClaims: " + json.dumps(claims_brief)
|
||||||
|
messages = build_messages(prompts.CLUSTER_SYSTEM, prompt)
|
||||||
|
raw = await self._llm.chat(messages, model=self._settings.ollama_model_fast)
|
||||||
|
data = _parse_json_block(raw, fallback={})
|
||||||
|
ids = data.get("claim_ids") if isinstance(data, dict) else []
|
||||||
|
if isinstance(ids, list):
|
||||||
|
return [str(item) for item in ids if item]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_state(self, conversation_id: str | None) -> ConversationState | None:
|
||||||
|
if not conversation_id:
|
||||||
|
return None
|
||||||
|
self._cleanup_state()
|
||||||
|
return self._state.get(conversation_id)
|
||||||
|
|
||||||
|
def _store_state(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
claims: list[ClaimItem],
|
||||||
|
summary: dict[str, Any],
|
||||||
|
snapshot: dict[str, Any] | None,
|
||||||
|
) -> None:
|
||||||
|
snapshot_id = _snapshot_id(summary)
|
||||||
|
pinned_snapshot = snapshot if self._settings.snapshot_pin_enabled else None
|
||||||
|
self._state[conversation_id] = ConversationState(
|
||||||
|
updated_at=time.monotonic(),
|
||||||
|
claims=claims,
|
||||||
|
snapshot_id=snapshot_id,
|
||||||
|
snapshot=pinned_snapshot,
|
||||||
|
)
|
||||||
|
self._cleanup_state()
|
||||||
|
|
||||||
|
def _cleanup_state(self) -> None:
|
||||||
|
ttl = max(60, self._settings.conversation_ttl_sec)
|
||||||
|
now = time.monotonic()
|
||||||
|
expired = [key for key, state in self._state.items() if now - state.updated_at > ttl]
|
||||||
|
for key in expired:
|
||||||
|
self._state.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
def _join_context(parts: list[str]) -> str:
|
def _join_context(parts: list[str]) -> str:
|
||||||
text = "\n".join([p for p in parts if p])
|
text = "\n".join([p for p in parts if p])
|
||||||
@ -227,12 +383,19 @@ def _format_history(history: list[dict[str, str]] | None) -> str:
|
|||||||
return ""
|
return ""
|
||||||
lines = ["Recent conversation:"]
|
lines = ["Recent conversation:"]
|
||||||
for entry in history[-4:]:
|
for entry in history[-4:]:
|
||||||
question = entry.get("q") if isinstance(entry, dict) else None
|
if not isinstance(entry, dict):
|
||||||
answer = entry.get("a") if isinstance(entry, dict) else None
|
continue
|
||||||
|
question = entry.get("q")
|
||||||
|
answer = entry.get("a")
|
||||||
|
role = entry.get("role")
|
||||||
|
content = entry.get("content")
|
||||||
if question:
|
if question:
|
||||||
lines.append(f"Q: {question}")
|
lines.append(f"Q: {question}")
|
||||||
if answer:
|
if answer:
|
||||||
lines.append(f"A: {answer}")
|
lines.append(f"A: {answer}")
|
||||||
|
if role and content:
|
||||||
|
prefix = "Q" if role == "user" else "A"
|
||||||
|
lines.append(f"{prefix}: {content}")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@ -312,3 +475,44 @@ def _coerce_int(value: Any, default: int) -> int:
|
|||||||
|
|
||||||
def _default_scores() -> AnswerScores:
|
def _default_scores() -> AnswerScores:
|
||||||
return AnswerScores(confidence=60, relevance=60, satisfaction=60, hallucination_risk="medium")
|
return AnswerScores(confidence=60, relevance=60, satisfaction=60, hallucination_risk="medium")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path(data: Any, path: str) -> Any | None:
|
||||||
|
cursor = data
|
||||||
|
for part in re.split(r"\\.(?![^\\[]*\\])", path):
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
match = re.match(r"^(\\w+)(?:\\[(\\d+)\\])?$", part)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
key = match.group(1)
|
||||||
|
index = match.group(2)
|
||||||
|
if isinstance(cursor, dict):
|
||||||
|
cursor = cursor.get(key)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
if index is not None:
|
||||||
|
try:
|
||||||
|
idx = int(index)
|
||||||
|
if isinstance(cursor, list) and 0 <= idx < len(cursor):
|
||||||
|
cursor = cursor[idx]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
return cursor
|
||||||
|
|
||||||
|
|
||||||
|
def _snapshot_id(summary: dict[str, Any]) -> str | None:
|
||||||
|
if not summary:
|
||||||
|
return None
|
||||||
|
for key in ("generated_at", "snapshot_ts", "snapshot_id"):
|
||||||
|
value = summary.get(key)
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _json_excerpt(summary: dict[str, Any], max_chars: int = 12000) -> str:
|
||||||
|
raw = json.dumps(summary, ensure_ascii=False)
|
||||||
|
return raw[:max_chars]
|
||||||
|
|||||||
@ -16,7 +16,8 @@ CLASSIFY_PROMPT = (
|
|||||||
"needs_snapshot (bool), needs_kb (bool), needs_metrics (bool), "
|
"needs_snapshot (bool), needs_kb (bool), needs_metrics (bool), "
|
||||||
"needs_general (bool), intent (short string), ambiguity (0-1), "
|
"needs_general (bool), intent (short string), ambiguity (0-1), "
|
||||||
"answer_style (direct|insightful), topic_summary (short string), "
|
"answer_style (direct|insightful), topic_summary (short string), "
|
||||||
"follow_up (bool), question_type (metric|diagnostic|planning|open_ended)."
|
"follow_up (bool), follow_up_kind (evidence|why|clarify|next_steps|change|other), "
|
||||||
|
"question_type (metric|diagnostic|planning|open_ended)."
|
||||||
)
|
)
|
||||||
|
|
||||||
ANGLE_PROMPT = (
|
ANGLE_PROMPT = (
|
||||||
@ -54,6 +55,37 @@ REFINE_PROMPT = (
|
|||||||
"Keep it grounded in the context. If you cannot add insight, say so explicitly."
|
"Keep it grounded in the context. If you cannot add insight, say so explicitly."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CLAIM_MAP_PROMPT = (
|
||||||
|
"Extract a claim map from the answer. "
|
||||||
|
"Return JSON with fields: claims (list). "
|
||||||
|
"Each claim object: id (short string), claim (short sentence), "
|
||||||
|
"evidence (list of objects with path and reason). "
|
||||||
|
"Paths must point into the provided SnapshotSummary JSON using dot notation, "
|
||||||
|
"with list indexes in brackets, e.g. metrics.node_load[0].node. "
|
||||||
|
"Do not invent evidence; if no evidence exists, omit the claim."
|
||||||
|
)
|
||||||
|
|
||||||
|
SELECT_CLAIMS_PROMPT = (
|
||||||
|
"Pick which prior claim(s) the follow-up refers to. "
|
||||||
|
"Return JSON with fields: claim_ids (list of ids), follow_up_kind "
|
||||||
|
"(evidence|why|clarify|next_steps|change|other). "
|
||||||
|
"If none apply, return an empty list."
|
||||||
|
)
|
||||||
|
|
||||||
|
FOLLOWUP_EVIDENCE_PROMPT = (
|
||||||
|
"Answer the follow-up using only the provided claims and evidence. "
|
||||||
|
"Be conversational, not bullet-heavy. "
|
||||||
|
"If evidence does not support a claim, say so plainly. "
|
||||||
|
"Do not add new claims."
|
||||||
|
)
|
||||||
|
|
||||||
|
FOLLOWUP_ACTION_PROMPT = (
|
||||||
|
"Answer the follow-up using the provided claims and evidence. "
|
||||||
|
"You may suggest next steps or changes, but keep them tightly tied "
|
||||||
|
"to the evidence list. "
|
||||||
|
"Be conversational and concise."
|
||||||
|
)
|
||||||
|
|
||||||
STOCK_SYSTEM = (
|
STOCK_SYSTEM = (
|
||||||
"You are Atlas, a helpful assistant. "
|
"You are Atlas, a helpful assistant. "
|
||||||
"Be concise and truthful. "
|
"Be concise and truthful. "
|
||||||
|
|||||||
@ -31,22 +31,24 @@ async def main() -> None:
|
|||||||
|
|
||||||
async def handler(payload: dict[str, object]) -> dict[str, object]:
|
async def handler(payload: dict[str, object]) -> dict[str, object]:
|
||||||
history = payload.get("history") if isinstance(payload, dict) else None
|
history = payload.get("history") if isinstance(payload, dict) else None
|
||||||
|
conversation_id = payload.get("conversation_id") if isinstance(payload, dict) else None
|
||||||
result = await engine.answer(
|
result = await engine.answer(
|
||||||
str(payload.get("question", "") or ""),
|
str(payload.get("question", "") or ""),
|
||||||
mode=str(payload.get("mode", "quick") or "quick"),
|
mode=str(payload.get("mode", "quick") or "quick"),
|
||||||
history=history if isinstance(history, list) else None,
|
history=history if isinstance(history, list) else None,
|
||||||
|
conversation_id=str(conversation_id) if isinstance(conversation_id, str) else None,
|
||||||
)
|
)
|
||||||
return {"reply": result.reply, "scores": result.scores.__dict__}
|
return {"reply": result.reply, "scores": result.scores.__dict__}
|
||||||
|
|
||||||
queue = QueueManager(settings, handler)
|
queue = QueueManager(settings, handler)
|
||||||
await queue.start()
|
await queue.start()
|
||||||
|
|
||||||
async def answer_handler(question: str, mode: str, history=None, observer=None) -> AnswerResult:
|
async def answer_handler(question: str, mode: str, history=None, conversation_id=None, observer=None) -> AnswerResult:
|
||||||
if settings.queue_enabled:
|
if settings.queue_enabled:
|
||||||
payload = await queue.submit({"question": question, "mode": mode, "history": history or []})
|
payload = await queue.submit({"question": question, "mode": mode, "history": history or [], "conversation_id": conversation_id})
|
||||||
reply = payload.get("reply", "") if isinstance(payload, dict) else ""
|
reply = payload.get("reply", "") if isinstance(payload, dict) else ""
|
||||||
return AnswerResult(reply=reply or "", scores=result_scores(payload), meta={"mode": mode})
|
return AnswerResult(reply=reply or "", scores=result_scores(payload), meta={"mode": mode})
|
||||||
return await engine.answer(question, mode=mode, history=history, observer=observer)
|
return await engine.answer(question, mode=mode, history=history, observer=observer, conversation_id=conversation_id)
|
||||||
|
|
||||||
api = Api(settings, answer_handler)
|
api = Api(settings, answer_handler)
|
||||||
server = uvicorn.Server(uvicorn.Config(api.app, host="0.0.0.0", port=settings.http_port, log_level="info"))
|
server = uvicorn.Server(uvicorn.Config(api.app, host="0.0.0.0", port=settings.http_port, log_level="info"))
|
||||||
|
|||||||
@ -80,7 +80,10 @@ class MatrixBot:
|
|||||||
settings: Settings,
|
settings: Settings,
|
||||||
bot: MatrixBotConfig,
|
bot: MatrixBotConfig,
|
||||||
engine: AnswerEngine,
|
engine: AnswerEngine,
|
||||||
answer_handler: Callable[[str, str, list[dict[str, str]] | None, Callable[[str, str], None] | None], Awaitable[AnswerResult]]
|
answer_handler: Callable[
|
||||||
|
[str, str, list[dict[str, str]] | None, str | None, Callable[[str, str], None] | None],
|
||||||
|
Awaitable[AnswerResult],
|
||||||
|
]
|
||||||
| None = None,
|
| None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._settings = settings
|
self._settings = settings
|
||||||
@ -155,9 +158,11 @@ class MatrixBot:
|
|||||||
task = asyncio.create_task(heartbeat())
|
task = asyncio.create_task(heartbeat())
|
||||||
started = time.monotonic()
|
started = time.monotonic()
|
||||||
try:
|
try:
|
||||||
handler = self._answer_handler or (lambda q, m, h, obs: self._engine.answer(q, mode=m, history=h, observer=obs))
|
handler = self._answer_handler or (
|
||||||
|
lambda q, m, h, cid, obs: self._engine.answer(q, mode=m, history=h, observer=obs, conversation_id=cid)
|
||||||
|
)
|
||||||
history = self._history.get(room_id, [])
|
history = self._history.get(room_id, [])
|
||||||
result = await handler(question, mode, history, observer)
|
result = await handler(question, mode, history, room_id, observer)
|
||||||
elapsed = time.monotonic() - started
|
elapsed = time.monotonic() - started
|
||||||
await self._client.send_message(token, room_id, result.reply)
|
await self._client.send_message(token, room_id, result.reply)
|
||||||
log.info(
|
log.info(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user