diff --git a/frigate/api/chat.py b/frigate/api/chat.py index 4e6bdbd3b4..f523d9991c 100644 --- a/frigate/api/chat.py +++ b/frigate/api/chat.py @@ -7,7 +7,7 @@ import operator import time from datetime import datetime from functools import reduce -from typing import Any, Dict, List, Optional +from typing import Any, Optional import cv2 from fastapi import APIRouter, Body, Depends, HTTPException, Request @@ -59,7 +59,7 @@ class ToolExecuteRequest(BaseModel): """Request model for tool execution.""" tool_name: str - arguments: Dict[str, Any] + arguments: dict[str, Any] class VLMMonitorRequest(BaseModel): @@ -68,8 +68,8 @@ class VLMMonitorRequest(BaseModel): camera: str condition: str max_duration_minutes: int = 60 - labels: List[str] = [] - zones: List[str] = [] + labels: list[str] = [] + zones: list[str] = [] @router.get( @@ -91,10 +91,10 @@ def get_tools(request: Request) -> JSONResponse: def _resolve_zones( - zones: List[str], + zones: list[str], config: FrigateConfig, - target_cameras: List[str], -) -> List[str]: + target_cameras: list[str], +) -> list[str]: """Map zone names to their canonical config keys, case-insensitively. LLMs frequently echo a user's casing ("Front Yard") instead of the @@ -107,7 +107,7 @@ def _resolve_zones( if not zones: return zones - lookup: Dict[str, str] = {} + lookup: dict[str, str] = {} for camera_id in target_cameras: camera_config = config.cameras.get(camera_id) if camera_config is None: @@ -120,8 +120,8 @@ def _resolve_zones( async def _execute_search_objects( request: Request, - arguments: Dict[str, Any], - allowed_cameras: List[str], + arguments: dict[str, Any], + allowed_cameras: list[str], ) -> JSONResponse: """ Execute the search_objects tool. @@ -213,8 +213,8 @@ async def _execute_search_objects( async def _execute_search_objects_semantic( request: Request, - arguments: Dict[str, Any], - allowed_cameras: List[str], + arguments: dict[str, Any], + allowed_cameras: list[str], semantic_query: str, ) -> JSONResponse: """Search objects via fused thumbnail + description embeddings. @@ -263,8 +263,8 @@ async def _execute_search_objects_semantic( limit = int(arguments.get("limit", 25)) limit = max(1, min(limit, 100)) - visual_distances: Dict[str, float] = {} - description_distances: Dict[str, float] = {} + visual_distances: dict[str, float] = {} + description_distances: dict[str, float] = {} try: rows = context.search_thumbnail(semantic_query) visual_distances = {row[0]: row[1] for row in rows} @@ -305,7 +305,7 @@ async def _execute_search_objects_semantic( eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))} - scored: List[tuple[str, float]] = [] + scored: list[tuple[str, float]] = [] for eid in eligible: v_score = ( distance_to_score(visual_distances[eid], context.thumb_stats) @@ -331,9 +331,9 @@ async def _execute_search_objects_semantic( async def _execute_find_similar_objects( request: Request, - arguments: Dict[str, Any], - allowed_cameras: List[str], -) -> Dict[str, Any]: + arguments: dict[str, Any], + allowed_cameras: list[str], +) -> dict[str, Any]: """Execute the find_similar_objects tool. Returns a plain dict (not JSONResponse) so the chat loop can embed it @@ -403,8 +403,8 @@ async def _execute_find_similar_objects( # version (see frigate/embeddings/__init__.py). Mirror the pattern used by # frigate/api/event.py events_search: fetch top-k globally, then intersect # with the structured filters via Peewee. - visual_distances: Dict[str, float] = {} - description_distances: Dict[str, float] = {} + visual_distances: dict[str, float] = {} + description_distances: dict[str, float] = {} try: if similarity_mode in ("visual", "fused"): @@ -462,7 +462,7 @@ async def _execute_find_similar_objects( eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))} # 6. Fuse and rank. - scored: List[tuple[str, float]] = [] + scored: list[tuple[str, float]] = [] for eid in eligible: v_score = ( distance_to_score(visual_distances[eid], context.thumb_stats) @@ -503,7 +503,7 @@ async def _execute_find_similar_objects( async def execute_tool( request: Request, body: ToolExecuteRequest = Body(...), - allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter), + allowed_cameras: list[str] = Depends(get_allowed_cameras_for_filter), ) -> JSONResponse: """ Execute a tool function call. @@ -545,8 +545,8 @@ async def execute_tool( async def _execute_get_live_context( request: Request, camera: str, - allowed_cameras: List[str], -) -> Dict[str, Any]: + allowed_cameras: list[str], +) -> dict[str, Any]: # Reject wildcards explicitly so models retry with a real camera name # instead of silently fanning out across every camera. if camera in ("*", "all"): @@ -593,7 +593,7 @@ async def _execute_get_live_context( "stationary": obj_dict.get("stationary", False), } - result: Dict[str, Any] = { + result: dict[str, Any] = { "camera": camera, "timestamp": frame_time, "detections": list(tracked_objects_dict.values()), @@ -620,7 +620,7 @@ async def _execute_get_live_context( async def _get_live_frame_image_url( request: Request, camera: str, - allowed_cameras: List[str], + allowed_cameras: list[str], ) -> Optional[str]: """ Fetch the current live frame for a camera as a base64 data URL. @@ -659,8 +659,8 @@ async def _get_live_frame_image_url( async def _execute_set_camera_state( request: Request, - arguments: Dict[str, Any], -) -> Dict[str, Any]: + arguments: dict[str, Any], +) -> dict[str, Any]: role = request.headers.get("remote-role", "") if "admin" not in [r.strip() for r in role.split(",")]: return {"error": "Admin privileges required to change camera settings."} @@ -699,10 +699,10 @@ async def _execute_set_camera_state( async def _execute_tool_internal( tool_name: str, - arguments: Dict[str, Any], + arguments: dict[str, Any], request: Request, - allowed_cameras: List[str], -) -> Dict[str, Any]: + allowed_cameras: list[str], +) -> dict[str, Any]: """ Internal helper to execute a tool and return the result as a dict. @@ -763,8 +763,8 @@ async def _execute_tool_internal( async def _execute_start_camera_watch( request: Request, - arguments: Dict[str, Any], -) -> Dict[str, Any]: + arguments: dict[str, Any], +) -> dict[str, Any]: camera = arguments.get("camera", "").strip() condition = arguments.get("condition", "").strip() max_duration_minutes = int(arguments.get("max_duration_minutes", 60)) @@ -814,14 +814,14 @@ async def _execute_start_camera_watch( } -def _execute_stop_camera_watch() -> Dict[str, Any]: +def _execute_stop_camera_watch() -> dict[str, Any]: cancelled = stop_vlm_watch_job() if cancelled: return {"success": True, "message": "Watch job cancelled."} return {"success": False, "message": "No active watch job to cancel."} -def _execute_get_profile_status(request: Request) -> Dict[str, Any]: +def _execute_get_profile_status(request: Request) -> dict[str, Any]: """Return profile status including active profile and activation timestamps.""" profile_manager = getattr(request.app, "profile_manager", None) if profile_manager is None: @@ -846,9 +846,9 @@ def _execute_get_profile_status(request: Request) -> Dict[str, Any]: def _execute_get_recap( - arguments: Dict[str, Any], - allowed_cameras: List[str], -) -> Dict[str, Any]: + arguments: dict[str, Any], + allowed_cameras: list[str], +) -> dict[str, Any]: """Fetch review segments with GenAI metadata for a time period.""" from functools import reduce @@ -909,7 +909,7 @@ def _execute_get_recap( .iterator() ) - events: List[Dict[str, Any]] = [] + events: list[dict[str, Any]] = [] for row in rows: data = row.get("data") or {} @@ -920,7 +920,7 @@ def _execute_get_recap( data = {} camera = row["camera"] - event: Dict[str, Any] = { + event: dict[str, Any] = { "camera": camera.replace("_", " ").title(), "severity": row.get("severity", "detection"), } @@ -984,10 +984,10 @@ def _execute_get_recap( async def _execute_pending_tools( - pending_tool_calls: List[Dict[str, Any]], + pending_tool_calls: list[dict[str, Any]], request: Request, - allowed_cameras: List[str], -) -> tuple[List[ToolCall], List[Dict[str, Any]], List[Dict[str, Any]]]: + allowed_cameras: list[str], +) -> tuple[list[ToolCall], list[dict[str, Any]], list[dict[str, Any]]]: """ Execute a list of tool calls. @@ -996,9 +996,9 @@ async def _execute_pending_tools( tool result dicts for conversation, extra messages to inject after tool results — e.g. user messages with images) """ - tool_calls_out: List[ToolCall] = [] - tool_results: List[Dict[str, Any]] = [] - extra_messages: List[Dict[str, Any]] = [] + tool_calls_out: list[ToolCall] = [] + tool_results: list[dict[str, Any]] = [] + extra_messages: list[dict[str, Any]] = [] for tool_call in pending_tool_calls: tool_name = tool_call["name"] tool_args = tool_call.get("arguments") or {} @@ -1106,7 +1106,7 @@ async def _execute_pending_tools( async def chat_completion( request: Request, body: ChatCompletionRequest = Body(...), - allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter), + allowed_cameras: list[str] = Depends(get_allowed_cameras_for_filter), ): """ Chat completion endpoint with tool calling support. @@ -1138,19 +1138,23 @@ async def chat_completion( ) conversation = [] - system_prompt = build_chat_system_prompt( - config=config, - allowed_cameras=allowed_cameras, - semantic_search_enabled=semantic_search_enabled, - attribute_classifications=attribute_classifications, - ) - - conversation.append( - { - "role": "system", - "content": system_prompt, - } - ) + # Build the system message only when the client hasn't already pinned one. + # The first turn has no system message; we generate it (with the current + # timestamp) and return the whole chain so the client persists it. Later + # turns send it back verbatim, freezing the timestamp so the prompt prefix + # stays byte-identical and the model server's prompt cache keeps hitting. + if not body.messages or body.messages[0].role != "system": + conversation.append( + { + "role": "system", + "content": build_chat_system_prompt( + config=config, + allowed_cameras=allowed_cameras, + semantic_search_enabled=semantic_search_enabled, + attribute_classifications=attribute_classifications, + ), + } + ) for msg in body.messages: msg_dict = { @@ -1161,11 +1165,13 @@ async def chat_completion( msg_dict["tool_call_id"] = msg.tool_call_id if msg.name: msg_dict["name"] = msg.name + if msg.tool_calls is not None: + msg_dict["tool_calls"] = msg.tool_calls conversation.append(msg_dict) tool_iterations = 0 - tool_calls: List[ToolCall] = [] + tool_calls: list[ToolCall] = [] max_iterations = body.max_tool_iterations logger.debug( @@ -1175,11 +1181,20 @@ async def chat_completion( # True LLM streaming when client supports it and stream requested if body.stream and hasattr(genai_client, "chat_with_tools_stream"): - stream_tool_calls: List[ToolCall] = [] stream_iterations = 0 async def stream_body_llm(): - nonlocal conversation, stream_tool_calls, stream_iterations + nonlocal conversation, stream_iterations + + def _emit_chain(extra: Optional[list[dict[str, Any]]] = None): + # Return the full conversation (including the system message) so + # the client persists and replays it verbatim next turn. + chain = conversation + (extra or []) + return ( + json.dumps({"type": "messages", "messages": chain}).encode("utf-8") + + b"\n" + ) + while stream_iterations < max_iterations: if await request.is_disconnected(): logger.debug("Client disconnected, stopping chat stream") @@ -1244,31 +1259,33 @@ async def chat_completion( ) return ( - executed_calls, + _executed_calls, tool_results, extra_msgs, ) = await _execute_pending_tools( pending, request, allowed_cameras ) - stream_tool_calls.extend(executed_calls) conversation.extend(tool_results) conversation.extend(extra_msgs) - yield ( - json.dumps( - { - "type": "tool_calls", - "tool_calls": [ - tc.model_dump() for tc in stream_tool_calls - ], - } - ).encode("utf-8") - + b"\n" - ) + # Emit the running chain so the client can render tool + # calls live and replay them verbatim next turn. + yield _emit_chain() break else: + # Streaming never appends the final assistant message + # to the conversation, so add it to the chain. + yield _emit_chain( + extra=[ + { + "role": "assistant", + "content": msg.get("content"), + } + ] + ) yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n") return else: + yield _emit_chain() yield json.dumps({"type": "done"}).encode("utf-8") + b"\n" return StreamingResponse( @@ -1315,19 +1332,15 @@ async def chat_completion( if body.stream: final_reasoning = response.get("reasoning") + chain = list(conversation) + async def stream_body() -> Any: - if tool_calls: - yield ( - json.dumps( - { - "type": "tool_calls", - "tool_calls": [ - tc.model_dump() for tc in tool_calls - ], - } - ).encode("utf-8") - + b"\n" + yield ( + json.dumps({"type": "messages", "messages": chain}).encode( + "utf-8" ) + + b"\n" + ) # Emit the full reasoning trace up front when the # underlying client did not stream it if final_reasoning: @@ -1363,6 +1376,7 @@ async def chat_completion( finish_reason=response.get("finish_reason", "stop"), tool_iterations=tool_iterations, tool_calls=tool_calls, + messages=list(conversation), ).model_dump(), ) @@ -1395,6 +1409,7 @@ async def chat_completion( finish_reason="length", tool_iterations=tool_iterations, tool_calls=tool_calls, + messages=list(conversation), ).model_dump(), ) diff --git a/frigate/api/defs/request/chat_body.py b/frigate/api/defs/request/chat_body.py index 228781c80b..04b168b9fa 100644 --- a/frigate/api/defs/request/chat_body.py +++ b/frigate/api/defs/request/chat_body.py @@ -1,6 +1,6 @@ """Chat API request models.""" -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel, Field @@ -11,13 +11,29 @@ class ChatMessage(BaseModel): role: str = Field( description="Message role: 'user', 'assistant', 'system', or 'tool'" ) - content: str = Field(description="Message content") + content: Optional[Any] = Field( + default=None, + description=( + "Message content. Usually a string, but may be a multimodal content " + "list (e.g. text + image_url) or null for assistant turns that only " + "request tool calls." + ), + ) tool_call_id: Optional[str] = Field( default=None, description="For tool messages, the ID of the tool call" ) name: Optional[str] = Field( default=None, description="For tool messages, the tool name" ) + tool_calls: Optional[list[dict[str, Any]]] = Field( + default=None, + description=( + "For assistant messages replayed from prior turns, the OpenAI-format " + "tool calls the model previously requested. Replaying these verbatim " + "keeps the conversation prefix byte-for-byte identical so the model " + "server's prompt cache hits on follow-up turns." + ), + ) class ChatCompletionRequest(BaseModel): diff --git a/frigate/api/defs/response/chat_response.py b/frigate/api/defs/response/chat_response.py index c2b3e6b1f2..59c8549e73 100644 --- a/frigate/api/defs/response/chat_response.py +++ b/frigate/api/defs/response/chat_response.py @@ -56,3 +56,12 @@ class ChatCompletionResponse(BaseModel): default_factory=list, description="List of tool calls that were executed during this completion", ) + messages: list[dict[str, Any]] = Field( + default_factory=list, + description=( + "The full conversation chain, including the system message. Persist " + "and replay this verbatim on the next request so the prompt prefix " + "stays byte-identical and the model server's prompt cache keeps " + "hitting." + ), + ) diff --git a/web/.gitignore b/web/.gitignore index 1cac5597ea..ca98c7b96a 100644 --- a/web/.gitignore +++ b/web/.gitignore @@ -12,6 +12,10 @@ dist dist-ssr *.local +# Playwright +playwright-report +test-results + # Editor directories and files .vscode/* !.vscode/extensions.json diff --git a/web/e2e/specs/chat.spec.ts b/web/e2e/specs/chat.spec.ts index fcd10782c1..cab49828ab 100644 --- a/web/e2e/specs/chat.spec.ts +++ b/web/e2e/specs/chat.spec.ts @@ -92,6 +92,15 @@ test.describe("Chat — streaming @medium", () => { await installChatStreamOverride(frigateApp, [ { type: "content", delta: "Hel" }, { type: "content", delta: "lo" }, + { + type: "messages", + messages: [ + { role: "system", content: "sys" }, + { role: "user", content: "hello chat" }, + { role: "assistant", content: "Hello" }, + ], + }, + { type: "done" }, ]); await frigateApp.goto("/chat"); const input = frigateApp.page.getByPlaceholder(/ask/i); @@ -137,6 +146,15 @@ test.describe("Chat — streaming @medium", () => { { type: "content", delta: "Hel" }, { type: "content", delta: "lo, " }, { type: "content", delta: "world!" }, + { + type: "messages", + messages: [ + { role: "system", content: "sys" }, + { role: "user", content: "greet me" }, + { role: "assistant", content: "Hello, world!" }, + ], + }, + { type: "done" }, ], { chunkDelayMs: 50 }, ); @@ -151,19 +169,39 @@ test.describe("Chat — streaming @medium", () => { }); }); - test("tool_calls chunks render a ToolCallsGroup", async ({ frigateApp }) => { - await installChatStreamOverride(frigateApp, [ + test("tool calls in the chain render a ToolCallsGroup", async ({ + frigateApp, + }) => { + const toolTurn = [ + { role: "system", content: "sys" }, + { role: "user", content: "find people" }, { - type: "tool_calls", + role: "assistant", + content: null, tool_calls: [ { id: "call_1", - name: "search_objects", - arguments: { label: "person" }, + type: "function", + function: { + name: "search_objects", + arguments: '{"label":"person"}', + }, }, ], }, + { role: "tool", tool_call_id: "call_1", content: "[]" }, + ]; + await installChatStreamOverride(frigateApp, [ + { type: "messages", messages: toolTurn }, { type: "content", delta: "Searching for people." }, + { + type: "messages", + messages: [ + ...toolTurn, + { role: "assistant", content: "Searching for people." }, + ], + }, + { type: "done" }, ]); await frigateApp.goto("/chat"); const input = frigateApp.page.getByPlaceholder(/ask/i); @@ -253,6 +291,15 @@ test.describe("Chat — attachment chip @medium", () => { // We use the stream override so the first message completes quickly. await installChatStreamOverride(frigateApp, [ { type: "content", delta: "Done." }, + { + type: "messages", + messages: [ + { role: "system", content: "sys" }, + { role: "user", content: "hello" }, + { role: "assistant", content: "Done." }, + ], + }, + { type: "done" }, ]); await frigateApp.goto("/chat"); diff --git a/web/src/pages/Chat.tsx b/web/src/pages/Chat.tsx index 7103a189d1..bd92097873 100644 --- a/web/src/pages/Chat.tsx +++ b/web/src/pages/Chat.tsx @@ -13,6 +13,7 @@ import { ChatComposer } from "@/components/chat/ChatComposer"; import ChatSettings from "@/components/chat/ChatSettings"; import type { ChatMessage, + ChatStats, GenAIModelsResponse, ShowStatsMode, } from "@/types/chat"; @@ -22,12 +23,28 @@ import { getFindSimilarObjectsFromToolCalls, prependAttachment, streamChatCompletion, + toolCallsForMessage, + toolResponsesById, } from "@/utils/chatUtil"; +type StreamingTurn = { + content: string; + reasoning: string; + chain: ChatMessage[]; + stats?: ChatStats; +}; + +const hasText = (content: unknown): content is string => + typeof content === "string" && content.trim().length > 0; + +const toWire = (messages: ChatMessage[]): ChatMessage[] => + messages.map(({ reasoning: _r, stats: _s, ...rest }) => rest); + export default function ChatPage() { const { t } = useTranslation(["views/chat"]); const [input, setInput] = useState(""); const [messages, setMessages] = useState([]); + const [streaming, setStreaming] = useState(null); const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(null); const [attachedEventId, setAttachedEventId] = useState(null); @@ -72,28 +89,19 @@ export default function ChatPage() { if (isNearBottom) { el.scrollTo({ top: el.scrollHeight, behavior: "smooth" }); } - }, [messages, autoScroll]); + }, [messages, streaming, autoScroll]); const submitConversation = useCallback( async (messagesToSend: ChatMessage[]) => { if (isLoading) return; const last = messagesToSend[messagesToSend.length - 1]; - if (!last || last.role !== "user" || !last.content.trim()) return; + if (!last || last.role !== "user" || !hasText(last.content)) return; setError(null); - const assistantPlaceholder: ChatMessage = { - role: "assistant", - content: "", - toolCalls: undefined, - }; - setMessages([...messagesToSend, assistantPlaceholder]); + setMessages(messagesToSend); + setStreaming({ content: "", reasoning: "", chain: [] }); setIsLoading(true); - const apiMessages = messagesToSend.map((m) => ({ - role: m.role, - content: m.content, - })); - const baseURL = axios.defaults.baseURL ?? ""; const url = `${baseURL}chat/completion`; const headers: Record = { @@ -104,16 +112,50 @@ export default function ChatPage() { const controller = new AbortController(); abortRef.current = controller; + let chain: ChatMessage[] = []; + let stats: ChatStats | undefined; + let reasoning = ""; + let hadError = false; + await streamChatCompletion( url, headers, - apiMessages, + toWire(messagesToSend), { - updateMessages: (updater) => setMessages(updater), - onError: (message) => setError(message), + onContentDelta: (delta) => + setStreaming((s) => (s ? { ...s, content: s.content + delta } : s)), + onReasoningDelta: (delta) => { + reasoning += delta; + setStreaming((s) => + s ? { ...s, reasoning: s.reasoning + delta } : s, + ); + }, + onChain: (fullChain) => { + chain = fullChain; + setStreaming((s) => (s ? { ...s, chain: fullChain } : s)); + }, + onStats: (s) => { + stats = s; + setStreaming((cur) => (cur ? { ...cur, stats: s } : cur)); + }, + onError: (message) => { + hadError = true; + setError(message); + }, onDone: () => { abortRef.current = null; setIsLoading(false); + setStreaming(null); + const lastMsg = chain[chain.length - 1]; + if (!hadError && lastMsg?.role === "assistant") { + setMessages( + chain.map((m, i) => + i === chain.length - 1 + ? { ...m, reasoning: reasoning || undefined, stats } + : m, + ), + ); + } }, defaultErrorMessage: t("error"), }, @@ -125,12 +167,14 @@ export default function ChatPage() { ); const recentEventIds = useMemo(() => { + const responses = toolResponsesById(messages); for (let i = messages.length - 1; i >= 0; i--) { const msg = messages[i]; - if (msg.role !== "assistant" || !msg.toolCalls) continue; - const similar = getFindSimilarObjectsFromToolCalls(msg.toolCalls); + if (msg.role !== "assistant" || !msg.tool_calls?.length) continue; + const calls = toolCallsForMessage(msg, responses); + const similar = getFindSimilarObjectsFromToolCalls(calls); if (similar) return similar.results.map((e) => e.id); - const events = getEventIdsFromSearchObjectsToolCalls(msg.toolCalls); + const events = getEventIdsFromSearchObjectsToolCalls(calls); if (events.length > 0) return events.map((e) => e.id); } return []; @@ -154,12 +198,14 @@ export default function ChatPage() { abortRef.current?.abort(); abortRef.current = null; setIsLoading(false); + setStreaming(null); }, []); const startNewChat = useCallback(() => { abortRef.current?.abort(); abortRef.current = null; setIsLoading(false); + setStreaming(null); setMessages([]); setInput(""); setAttachedEventId(null); @@ -181,7 +227,83 @@ export default function ChatPage() { setAttachedEventId(null); }, []); - const hasStarted = messages.length > 0; + const hasStarted = messages.length > 0 || streaming != null; + + // While streaming, the backend's in-flight chain is the source of truth; + // otherwise the committed conversation is. + const renderList = + streaming && streaming.chain.length ? streaming.chain : messages; + const responses = toolResponsesById(renderList); + const renderTail = renderList[renderList.length - 1]; + const finalShown = + renderTail?.role === "assistant" && hasText(renderTail.content); + + const renderMessage = (msg: ChatMessage, i: number) => { + if (msg.role === "system" || msg.role === "tool") return null; + + if (msg.role === "user") { + if (!hasText(msg.content)) return null; + return ( +
+ +
+ ); + } + + const calls = toolCallsForMessage(msg, responses); + const contentText = hasText(msg.content) ? msg.content : ""; + const similar = getFindSimilarObjectsFromToolCalls(calls); + const events = similar ? [] : getEventIdsFromSearchObjectsToolCalls(calls); + + return ( +
+ {calls.length > 0 && } + {hasText(msg.reasoning) && ( + + )} + {contentText && ( + + )} + {similar ? ( + + ) : ( + + )} +
+ ); + }; + + const processingDots = ( +
+ + + +
+ ); return (
@@ -212,102 +334,31 @@ export default function ChatPage() {
{hasStarted ? (
- {messages.map((msg, i) => { - const isLastAssistant = - i === messages.length - 1 && msg.role === "assistant"; - const isComplete = - msg.role === "user" || !isLoading || !isLastAssistant; - const hasToolCalls = - msg.toolCalls && msg.toolCalls.length > 0; - const hasContent = !!msg.content?.trim(); - const hasReasoning = !!msg.reasoning?.trim(); - const showProcessing = - isLastAssistant && - isLoading && - !hasContent && - !hasReasoning; - - // Hide empty placeholder only when there are no tool calls - // and no reasoning streaming yet - if ( - isLastAssistant && - isLoading && - !hasContent && - !hasToolCalls && - !hasReasoning - ) - return ( -
- - - -
- ); - - return ( -
- {msg.role === "assistant" && hasToolCalls && ( - - )} - {msg.role === "assistant" && hasReasoning && ( + {renderList.map((msg, i) => renderMessage(msg, i))} + {streaming && + !finalShown && + (streaming.content || streaming.reasoning ? ( +
+ {hasText(streaming.reasoning) && ( )} - {showProcessing ? ( -
- - - -
- ) : msg.role === "assistant" && - !hasContent && - hasReasoning && - !isComplete ? null : ( + {streaming.content && ( )} - {msg.role === "assistant" && - isComplete && - (() => { - const similar = getFindSimilarObjectsFromToolCalls( - msg.toolCalls, - ); - if (similar) { - return ( - - ); - } - const events = getEventIdsFromSearchObjectsToolCalls( - msg.toolCalls, - ); - return ( - - ); - })()}
- ); - })} + ) : ( + processingDots + ))} {error && (

; response?: string; }; -export type ChatMessage = { - role: "user" | "assistant"; - content: string; - reasoning?: string; - toolCalls?: ToolCall[]; - stats?: ChatStats; -}; - export type StartingRequest = { label: string; prompt: string; diff --git a/web/src/utils/chatUtil.ts b/web/src/utils/chatUtil.ts index 73e5c213b6..b7aeb8088d 100644 --- a/web/src/utils/chatUtil.ts +++ b/web/src/utils/chatUtil.ts @@ -1,16 +1,20 @@ import type { ChatMessage, ChatStats, ToolCall } from "@/types/chat"; export type StreamChatCallbacks = { - /** Update the messages array (e.g. pass to setState). */ - updateMessages: (updater: (prev: ChatMessage[]) => ChatMessage[]) => void; + /** Streamed delta of the assistant's final answer text. */ + onContentDelta: (delta: string) => void; + /** Streamed delta of the assistant's reasoning trace. */ + onReasoningDelta: (delta: string) => void; + /** The full conversation chain so far (system message, history, this turn's + * tool-call turns, tool results, and — on the final emission — the final + * assistant message). */ + onChain: (chain: ChatMessage[]) => void; + /** Token/timing stats for the turn. */ + onStats: (stats: ChatStats) => void; /** Called when the stream sends an error or fetch fails. */ onError: (message: string) => void; /** Called when the stream finishes (success or error). */ onDone: () => void; - /** Called when the stream emits token/timing stats. The stats are also - * attached to the last assistant message in updateMessages, so consumers - * can usually rely on the message itself rather than wiring this up. */ - onStats?: (stats: ChatStats) => void; /** Message used when fetch throws and no server error is available. */ defaultErrorMessage?: string; }; @@ -25,7 +29,7 @@ type StatsChunk = { type StreamChunk = | { type: "error"; error: string } - | { type: "tool_calls"; tool_calls: ToolCall[] } + | { type: "messages"; messages: ChatMessage[] } | { type: "content"; delta: string } | { type: "reasoning"; delta: string } | StatsChunk; @@ -41,16 +45,18 @@ export type StreamChatOptions = { export async function streamChatCompletion( url: string, headers: Record, - apiMessages: { role: string; content: string }[], + apiMessages: ChatMessage[], callbacks: StreamChatCallbacks, signal?: AbortSignal, options: StreamChatOptions = {}, ): Promise { const { - updateMessages, + onContentDelta, + onReasoningDelta, + onChain, + onStats, onError, onDone, - onStats, defaultErrorMessage = "Something went wrong. Please try again.", } = callbacks; @@ -91,65 +97,27 @@ export async function streamChatCompletion( const applyChunk = (data: StreamChunk) => { if (data.type === "error") { onError(data.error); - updateMessages((prev) => - prev.filter((m) => !(m.role === "assistant" && m.content === "")), - ); return "break"; } - if (data.type === "tool_calls" && data.tool_calls?.length) { - updateMessages((prev) => { - const next = [...prev]; - const lastMsg = next[next.length - 1]; - if (lastMsg?.role === "assistant") - next[next.length - 1] = { - ...lastMsg, - toolCalls: data.tool_calls, - }; - return next; - }); + if (data.type === "messages") { + onChain(data.messages ?? []); return "continue"; } if (data.type === "content" && data.delta !== undefined) { - updateMessages((prev) => { - const next = [...prev]; - const lastMsg = next[next.length - 1]; - if (lastMsg?.role === "assistant") - next[next.length - 1] = { - ...lastMsg, - content: lastMsg.content + data.delta, - }; - return next; - }); + onContentDelta(data.delta); return "continue"; } if (data.type === "reasoning" && data.delta !== undefined) { - updateMessages((prev) => { - const next = [...prev]; - const lastMsg = next[next.length - 1]; - if (lastMsg?.role === "assistant") - next[next.length - 1] = { - ...lastMsg, - reasoning: (lastMsg.reasoning ?? "") + data.delta, - }; - return next; - }); + onReasoningDelta(data.delta); return "continue"; } if (data.type === "stats") { - const stats: ChatStats = { + onStats({ promptTokens: data.prompt_tokens, completionTokens: data.completion_tokens, completionDurationMs: data.completion_duration_ms, tokensPerSecond: data.tokens_per_second, - }; - updateMessages((prev) => { - const next = [...prev]; - const lastMsg = next[next.length - 1]; - if (lastMsg?.role === "assistant") - next[next.length - 1] = { ...lastMsg, stats }; - return next; }); - onStats?.(stats); return "continue"; } return "continue"; @@ -165,9 +133,8 @@ export async function streamChatCompletion( const trimmed = line.trim(); if (!trimmed) continue; try { - const data = JSON.parse(trimmed) as StreamChunk & { type: string }; - const result = applyChunk(data as StreamChunk); - if (result === "break") { + const data = JSON.parse(trimmed) as StreamChunk; + if (applyChunk(data) === "break") { hadStreamError = true; break; } @@ -181,50 +148,63 @@ export async function streamChatCompletion( // Flush remaining buffer if (!hadStreamError && buffer.trim()) { try { - const data = JSON.parse(buffer.trim()) as StreamChunk & { - type: string; - delta?: string; - }; - if (data.type === "content" && data.delta !== undefined) { - updateMessages((prev) => { - const next = [...prev]; - const lastMsg = next[next.length - 1]; - if (lastMsg?.role === "assistant") - next[next.length - 1] = { - ...lastMsg, - content: lastMsg.content + data.delta!, - }; - return next; - }); - } + const data = JSON.parse(buffer.trim()) as StreamChunk; + applyChunk(data); } catch { // ignore final malformed chunk } } - - if (!hadStreamError) { - updateMessages((prev) => { - const next = [...prev]; - const lastMsg = next[next.length - 1]; - if (lastMsg?.role === "assistant" && lastMsg.content === "") - next[next.length - 1] = { ...lastMsg, content: " " }; - return next; - }); - } } catch (err) { if (err instanceof DOMException && err.name === "AbortError") { // User stopped generation — not an error } else { onError(defaultErrorMessage); - updateMessages((prev) => - prev.filter((m) => !(m.role === "assistant" && m.content === "")), - ); } } finally { onDone(); } } +/** Map each tool result message to its tool_call_id for response lookup. */ +export function toolResponsesById( + messages: ChatMessage[], +): Map { + const map = new Map(); + for (const m of messages) { + if (m.role === "tool" && typeof m.tool_call_id === "string") { + map.set( + m.tool_call_id, + typeof m.content === "string" ? m.content : JSON.stringify(m.content), + ); + } + } + return map; +} + +/** Derive the display tool calls for one assistant message. */ +export function toolCallsForMessage( + message: ChatMessage, + responses: Map, +): ToolCall[] { + if (!message.tool_calls?.length) return []; + return message.tool_calls.map((tc) => { + let args: Record | undefined; + const raw = tc.function?.arguments; + if (typeof raw === "string") { + try { + args = JSON.parse(raw) as Record; + } catch { + args = undefined; + } + } + return { + name: tc.function?.name ?? "", + arguments: args, + response: responses.get(tc.id), + }; + }); +} + /** * Parse search_objects tool call response(s) into event ids for thumbnails. */