Fix chat tool calling and prompt breaking (#23457)
Some checks failed
CI / AMD64 Build (push) Has been cancelled
CI / ARM Build (push) Has been cancelled
CI / Jetson Jetpack 6 (push) Has been cancelled
CI / AMD64 Extra Build (push) Has been cancelled
CI / ARM Extra Build (push) Has been cancelled
CI / Synaptics Build (push) Has been cancelled
CI / Assemble and push default build (push) Has been cancelled

* Implement tool call history keeping

* Refactor to match single message implementation

* Simplify data representation

* Cleanup chat page rendering

* Include system message to not break cache

* Formatting

* Update tests and update .gitignore
This commit is contained in:
Nicolas Mowen 2026-06-12 06:48:43 -06:00 committed by GitHub
parent e6601d50a6
commit d7ad3ba699
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 430 additions and 295 deletions

View File

@ -7,7 +7,7 @@ import operator
import time
from datetime import datetime
from functools import reduce
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import cv2
from fastapi import APIRouter, Body, Depends, HTTPException, Request
@ -59,7 +59,7 @@ class ToolExecuteRequest(BaseModel):
"""Request model for tool execution."""
tool_name: str
arguments: Dict[str, Any]
arguments: dict[str, Any]
class VLMMonitorRequest(BaseModel):
@ -68,8 +68,8 @@ class VLMMonitorRequest(BaseModel):
camera: str
condition: str
max_duration_minutes: int = 60
labels: List[str] = []
zones: List[str] = []
labels: list[str] = []
zones: list[str] = []
@router.get(
@ -91,10 +91,10 @@ def get_tools(request: Request) -> JSONResponse:
def _resolve_zones(
zones: List[str],
zones: list[str],
config: FrigateConfig,
target_cameras: List[str],
) -> List[str]:
target_cameras: list[str],
) -> list[str]:
"""Map zone names to their canonical config keys, case-insensitively.
LLMs frequently echo a user's casing ("Front Yard") instead of the
@ -107,7 +107,7 @@ def _resolve_zones(
if not zones:
return zones
lookup: Dict[str, str] = {}
lookup: dict[str, str] = {}
for camera_id in target_cameras:
camera_config = config.cameras.get(camera_id)
if camera_config is None:
@ -120,8 +120,8 @@ def _resolve_zones(
async def _execute_search_objects(
request: Request,
arguments: Dict[str, Any],
allowed_cameras: List[str],
arguments: dict[str, Any],
allowed_cameras: list[str],
) -> JSONResponse:
"""
Execute the search_objects tool.
@ -213,8 +213,8 @@ async def _execute_search_objects(
async def _execute_search_objects_semantic(
request: Request,
arguments: Dict[str, Any],
allowed_cameras: List[str],
arguments: dict[str, Any],
allowed_cameras: list[str],
semantic_query: str,
) -> JSONResponse:
"""Search objects via fused thumbnail + description embeddings.
@ -263,8 +263,8 @@ async def _execute_search_objects_semantic(
limit = int(arguments.get("limit", 25))
limit = max(1, min(limit, 100))
visual_distances: Dict[str, float] = {}
description_distances: Dict[str, float] = {}
visual_distances: dict[str, float] = {}
description_distances: dict[str, float] = {}
try:
rows = context.search_thumbnail(semantic_query)
visual_distances = {row[0]: row[1] for row in rows}
@ -305,7 +305,7 @@ async def _execute_search_objects_semantic(
eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))}
scored: List[tuple[str, float]] = []
scored: list[tuple[str, float]] = []
for eid in eligible:
v_score = (
distance_to_score(visual_distances[eid], context.thumb_stats)
@ -331,9 +331,9 @@ async def _execute_search_objects_semantic(
async def _execute_find_similar_objects(
request: Request,
arguments: Dict[str, Any],
allowed_cameras: List[str],
) -> Dict[str, Any]:
arguments: dict[str, Any],
allowed_cameras: list[str],
) -> dict[str, Any]:
"""Execute the find_similar_objects tool.
Returns a plain dict (not JSONResponse) so the chat loop can embed it
@ -403,8 +403,8 @@ async def _execute_find_similar_objects(
# version (see frigate/embeddings/__init__.py). Mirror the pattern used by
# frigate/api/event.py events_search: fetch top-k globally, then intersect
# with the structured filters via Peewee.
visual_distances: Dict[str, float] = {}
description_distances: Dict[str, float] = {}
visual_distances: dict[str, float] = {}
description_distances: dict[str, float] = {}
try:
if similarity_mode in ("visual", "fused"):
@ -462,7 +462,7 @@ async def _execute_find_similar_objects(
eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))}
# 6. Fuse and rank.
scored: List[tuple[str, float]] = []
scored: list[tuple[str, float]] = []
for eid in eligible:
v_score = (
distance_to_score(visual_distances[eid], context.thumb_stats)
@ -503,7 +503,7 @@ async def _execute_find_similar_objects(
async def execute_tool(
request: Request,
body: ToolExecuteRequest = Body(...),
allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter),
allowed_cameras: list[str] = Depends(get_allowed_cameras_for_filter),
) -> JSONResponse:
"""
Execute a tool function call.
@ -545,8 +545,8 @@ async def execute_tool(
async def _execute_get_live_context(
request: Request,
camera: str,
allowed_cameras: List[str],
) -> Dict[str, Any]:
allowed_cameras: list[str],
) -> dict[str, Any]:
# Reject wildcards explicitly so models retry with a real camera name
# instead of silently fanning out across every camera.
if camera in ("*", "all"):
@ -593,7 +593,7 @@ async def _execute_get_live_context(
"stationary": obj_dict.get("stationary", False),
}
result: Dict[str, Any] = {
result: dict[str, Any] = {
"camera": camera,
"timestamp": frame_time,
"detections": list(tracked_objects_dict.values()),
@ -620,7 +620,7 @@ async def _execute_get_live_context(
async def _get_live_frame_image_url(
request: Request,
camera: str,
allowed_cameras: List[str],
allowed_cameras: list[str],
) -> Optional[str]:
"""
Fetch the current live frame for a camera as a base64 data URL.
@ -659,8 +659,8 @@ async def _get_live_frame_image_url(
async def _execute_set_camera_state(
request: Request,
arguments: Dict[str, Any],
) -> Dict[str, Any]:
arguments: dict[str, Any],
) -> dict[str, Any]:
role = request.headers.get("remote-role", "")
if "admin" not in [r.strip() for r in role.split(",")]:
return {"error": "Admin privileges required to change camera settings."}
@ -699,10 +699,10 @@ async def _execute_set_camera_state(
async def _execute_tool_internal(
tool_name: str,
arguments: Dict[str, Any],
arguments: dict[str, Any],
request: Request,
allowed_cameras: List[str],
) -> Dict[str, Any]:
allowed_cameras: list[str],
) -> dict[str, Any]:
"""
Internal helper to execute a tool and return the result as a dict.
@ -763,8 +763,8 @@ async def _execute_tool_internal(
async def _execute_start_camera_watch(
request: Request,
arguments: Dict[str, Any],
) -> Dict[str, Any]:
arguments: dict[str, Any],
) -> dict[str, Any]:
camera = arguments.get("camera", "").strip()
condition = arguments.get("condition", "").strip()
max_duration_minutes = int(arguments.get("max_duration_minutes", 60))
@ -814,14 +814,14 @@ async def _execute_start_camera_watch(
}
def _execute_stop_camera_watch() -> Dict[str, Any]:
def _execute_stop_camera_watch() -> dict[str, Any]:
cancelled = stop_vlm_watch_job()
if cancelled:
return {"success": True, "message": "Watch job cancelled."}
return {"success": False, "message": "No active watch job to cancel."}
def _execute_get_profile_status(request: Request) -> Dict[str, Any]:
def _execute_get_profile_status(request: Request) -> dict[str, Any]:
"""Return profile status including active profile and activation timestamps."""
profile_manager = getattr(request.app, "profile_manager", None)
if profile_manager is None:
@ -846,9 +846,9 @@ def _execute_get_profile_status(request: Request) -> Dict[str, Any]:
def _execute_get_recap(
arguments: Dict[str, Any],
allowed_cameras: List[str],
) -> Dict[str, Any]:
arguments: dict[str, Any],
allowed_cameras: list[str],
) -> dict[str, Any]:
"""Fetch review segments with GenAI metadata for a time period."""
from functools import reduce
@ -909,7 +909,7 @@ def _execute_get_recap(
.iterator()
)
events: List[Dict[str, Any]] = []
events: list[dict[str, Any]] = []
for row in rows:
data = row.get("data") or {}
@ -920,7 +920,7 @@ def _execute_get_recap(
data = {}
camera = row["camera"]
event: Dict[str, Any] = {
event: dict[str, Any] = {
"camera": camera.replace("_", " ").title(),
"severity": row.get("severity", "detection"),
}
@ -984,10 +984,10 @@ def _execute_get_recap(
async def _execute_pending_tools(
pending_tool_calls: List[Dict[str, Any]],
pending_tool_calls: list[dict[str, Any]],
request: Request,
allowed_cameras: List[str],
) -> tuple[List[ToolCall], List[Dict[str, Any]], List[Dict[str, Any]]]:
allowed_cameras: list[str],
) -> tuple[list[ToolCall], list[dict[str, Any]], list[dict[str, Any]]]:
"""
Execute a list of tool calls.
@ -996,9 +996,9 @@ async def _execute_pending_tools(
tool result dicts for conversation,
extra messages to inject after tool results e.g. user messages with images)
"""
tool_calls_out: List[ToolCall] = []
tool_results: List[Dict[str, Any]] = []
extra_messages: List[Dict[str, Any]] = []
tool_calls_out: list[ToolCall] = []
tool_results: list[dict[str, Any]] = []
extra_messages: list[dict[str, Any]] = []
for tool_call in pending_tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call.get("arguments") or {}
@ -1106,7 +1106,7 @@ async def _execute_pending_tools(
async def chat_completion(
request: Request,
body: ChatCompletionRequest = Body(...),
allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter),
allowed_cameras: list[str] = Depends(get_allowed_cameras_for_filter),
):
"""
Chat completion endpoint with tool calling support.
@ -1138,19 +1138,23 @@ async def chat_completion(
)
conversation = []
system_prompt = build_chat_system_prompt(
config=config,
allowed_cameras=allowed_cameras,
semantic_search_enabled=semantic_search_enabled,
attribute_classifications=attribute_classifications,
)
conversation.append(
{
"role": "system",
"content": system_prompt,
}
)
# Build the system message only when the client hasn't already pinned one.
# The first turn has no system message; we generate it (with the current
# timestamp) and return the whole chain so the client persists it. Later
# turns send it back verbatim, freezing the timestamp so the prompt prefix
# stays byte-identical and the model server's prompt cache keeps hitting.
if not body.messages or body.messages[0].role != "system":
conversation.append(
{
"role": "system",
"content": build_chat_system_prompt(
config=config,
allowed_cameras=allowed_cameras,
semantic_search_enabled=semantic_search_enabled,
attribute_classifications=attribute_classifications,
),
}
)
for msg in body.messages:
msg_dict = {
@ -1161,11 +1165,13 @@ async def chat_completion(
msg_dict["tool_call_id"] = msg.tool_call_id
if msg.name:
msg_dict["name"] = msg.name
if msg.tool_calls is not None:
msg_dict["tool_calls"] = msg.tool_calls
conversation.append(msg_dict)
tool_iterations = 0
tool_calls: List[ToolCall] = []
tool_calls: list[ToolCall] = []
max_iterations = body.max_tool_iterations
logger.debug(
@ -1175,11 +1181,20 @@ async def chat_completion(
# True LLM streaming when client supports it and stream requested
if body.stream and hasattr(genai_client, "chat_with_tools_stream"):
stream_tool_calls: List[ToolCall] = []
stream_iterations = 0
async def stream_body_llm():
nonlocal conversation, stream_tool_calls, stream_iterations
nonlocal conversation, stream_iterations
def _emit_chain(extra: Optional[list[dict[str, Any]]] = None):
# Return the full conversation (including the system message) so
# the client persists and replays it verbatim next turn.
chain = conversation + (extra or [])
return (
json.dumps({"type": "messages", "messages": chain}).encode("utf-8")
+ b"\n"
)
while stream_iterations < max_iterations:
if await request.is_disconnected():
logger.debug("Client disconnected, stopping chat stream")
@ -1244,31 +1259,33 @@ async def chat_completion(
)
return
(
executed_calls,
_executed_calls,
tool_results,
extra_msgs,
) = await _execute_pending_tools(
pending, request, allowed_cameras
)
stream_tool_calls.extend(executed_calls)
conversation.extend(tool_results)
conversation.extend(extra_msgs)
yield (
json.dumps(
{
"type": "tool_calls",
"tool_calls": [
tc.model_dump() for tc in stream_tool_calls
],
}
).encode("utf-8")
+ b"\n"
)
# Emit the running chain so the client can render tool
# calls live and replay them verbatim next turn.
yield _emit_chain()
break
else:
# Streaming never appends the final assistant message
# to the conversation, so add it to the chain.
yield _emit_chain(
extra=[
{
"role": "assistant",
"content": msg.get("content"),
}
]
)
yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n")
return
else:
yield _emit_chain()
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
return StreamingResponse(
@ -1315,19 +1332,15 @@ async def chat_completion(
if body.stream:
final_reasoning = response.get("reasoning")
chain = list(conversation)
async def stream_body() -> Any:
if tool_calls:
yield (
json.dumps(
{
"type": "tool_calls",
"tool_calls": [
tc.model_dump() for tc in tool_calls
],
}
).encode("utf-8")
+ b"\n"
yield (
json.dumps({"type": "messages", "messages": chain}).encode(
"utf-8"
)
+ b"\n"
)
# Emit the full reasoning trace up front when the
# underlying client did not stream it
if final_reasoning:
@ -1363,6 +1376,7 @@ async def chat_completion(
finish_reason=response.get("finish_reason", "stop"),
tool_iterations=tool_iterations,
tool_calls=tool_calls,
messages=list(conversation),
).model_dump(),
)
@ -1395,6 +1409,7 @@ async def chat_completion(
finish_reason="length",
tool_iterations=tool_iterations,
tool_calls=tool_calls,
messages=list(conversation),
).model_dump(),
)

View File

@ -1,6 +1,6 @@
"""Chat API request models."""
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel, Field
@ -11,13 +11,29 @@ class ChatMessage(BaseModel):
role: str = Field(
description="Message role: 'user', 'assistant', 'system', or 'tool'"
)
content: str = Field(description="Message content")
content: Optional[Any] = Field(
default=None,
description=(
"Message content. Usually a string, but may be a multimodal content "
"list (e.g. text + image_url) or null for assistant turns that only "
"request tool calls."
),
)
tool_call_id: Optional[str] = Field(
default=None, description="For tool messages, the ID of the tool call"
)
name: Optional[str] = Field(
default=None, description="For tool messages, the tool name"
)
tool_calls: Optional[list[dict[str, Any]]] = Field(
default=None,
description=(
"For assistant messages replayed from prior turns, the OpenAI-format "
"tool calls the model previously requested. Replaying these verbatim "
"keeps the conversation prefix byte-for-byte identical so the model "
"server's prompt cache hits on follow-up turns."
),
)
class ChatCompletionRequest(BaseModel):

View File

@ -56,3 +56,12 @@ class ChatCompletionResponse(BaseModel):
default_factory=list,
description="List of tool calls that were executed during this completion",
)
messages: list[dict[str, Any]] = Field(
default_factory=list,
description=(
"The full conversation chain, including the system message. Persist "
"and replay this verbatim on the next request so the prompt prefix "
"stays byte-identical and the model server's prompt cache keeps "
"hitting."
),
)

4
web/.gitignore vendored
View File

@ -12,6 +12,10 @@ dist
dist-ssr
*.local
# Playwright
playwright-report
test-results
# Editor directories and files
.vscode/*
!.vscode/extensions.json

View File

@ -92,6 +92,15 @@ test.describe("Chat — streaming @medium", () => {
await installChatStreamOverride(frigateApp, [
{ type: "content", delta: "Hel" },
{ type: "content", delta: "lo" },
{
type: "messages",
messages: [
{ role: "system", content: "sys" },
{ role: "user", content: "hello chat" },
{ role: "assistant", content: "Hello" },
],
},
{ type: "done" },
]);
await frigateApp.goto("/chat");
const input = frigateApp.page.getByPlaceholder(/ask/i);
@ -137,6 +146,15 @@ test.describe("Chat — streaming @medium", () => {
{ type: "content", delta: "Hel" },
{ type: "content", delta: "lo, " },
{ type: "content", delta: "world!" },
{
type: "messages",
messages: [
{ role: "system", content: "sys" },
{ role: "user", content: "greet me" },
{ role: "assistant", content: "Hello, world!" },
],
},
{ type: "done" },
],
{ chunkDelayMs: 50 },
);
@ -151,19 +169,39 @@ test.describe("Chat — streaming @medium", () => {
});
});
test("tool_calls chunks render a ToolCallsGroup", async ({ frigateApp }) => {
await installChatStreamOverride(frigateApp, [
test("tool calls in the chain render a ToolCallsGroup", async ({
frigateApp,
}) => {
const toolTurn = [
{ role: "system", content: "sys" },
{ role: "user", content: "find people" },
{
type: "tool_calls",
role: "assistant",
content: null,
tool_calls: [
{
id: "call_1",
name: "search_objects",
arguments: { label: "person" },
type: "function",
function: {
name: "search_objects",
arguments: '{"label":"person"}',
},
},
],
},
{ role: "tool", tool_call_id: "call_1", content: "[]" },
];
await installChatStreamOverride(frigateApp, [
{ type: "messages", messages: toolTurn },
{ type: "content", delta: "Searching for people." },
{
type: "messages",
messages: [
...toolTurn,
{ role: "assistant", content: "Searching for people." },
],
},
{ type: "done" },
]);
await frigateApp.goto("/chat");
const input = frigateApp.page.getByPlaceholder(/ask/i);
@ -253,6 +291,15 @@ test.describe("Chat — attachment chip @medium", () => {
// We use the stream override so the first message completes quickly.
await installChatStreamOverride(frigateApp, [
{ type: "content", delta: "Done." },
{
type: "messages",
messages: [
{ role: "system", content: "sys" },
{ role: "user", content: "hello" },
{ role: "assistant", content: "Done." },
],
},
{ type: "done" },
]);
await frigateApp.goto("/chat");

View File

@ -13,6 +13,7 @@ import { ChatComposer } from "@/components/chat/ChatComposer";
import ChatSettings from "@/components/chat/ChatSettings";
import type {
ChatMessage,
ChatStats,
GenAIModelsResponse,
ShowStatsMode,
} from "@/types/chat";
@ -22,12 +23,28 @@ import {
getFindSimilarObjectsFromToolCalls,
prependAttachment,
streamChatCompletion,
toolCallsForMessage,
toolResponsesById,
} from "@/utils/chatUtil";
type StreamingTurn = {
content: string;
reasoning: string;
chain: ChatMessage[];
stats?: ChatStats;
};
const hasText = (content: unknown): content is string =>
typeof content === "string" && content.trim().length > 0;
const toWire = (messages: ChatMessage[]): ChatMessage[] =>
messages.map(({ reasoning: _r, stats: _s, ...rest }) => rest);
export default function ChatPage() {
const { t } = useTranslation(["views/chat"]);
const [input, setInput] = useState("");
const [messages, setMessages] = useState<ChatMessage[]>([]);
const [streaming, setStreaming] = useState<StreamingTurn | null>(null);
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
const [attachedEventId, setAttachedEventId] = useState<string | null>(null);
@ -72,28 +89,19 @@ export default function ChatPage() {
if (isNearBottom) {
el.scrollTo({ top: el.scrollHeight, behavior: "smooth" });
}
}, [messages, autoScroll]);
}, [messages, streaming, autoScroll]);
const submitConversation = useCallback(
async (messagesToSend: ChatMessage[]) => {
if (isLoading) return;
const last = messagesToSend[messagesToSend.length - 1];
if (!last || last.role !== "user" || !last.content.trim()) return;
if (!last || last.role !== "user" || !hasText(last.content)) return;
setError(null);
const assistantPlaceholder: ChatMessage = {
role: "assistant",
content: "",
toolCalls: undefined,
};
setMessages([...messagesToSend, assistantPlaceholder]);
setMessages(messagesToSend);
setStreaming({ content: "", reasoning: "", chain: [] });
setIsLoading(true);
const apiMessages = messagesToSend.map((m) => ({
role: m.role,
content: m.content,
}));
const baseURL = axios.defaults.baseURL ?? "";
const url = `${baseURL}chat/completion`;
const headers: Record<string, string> = {
@ -104,16 +112,50 @@ export default function ChatPage() {
const controller = new AbortController();
abortRef.current = controller;
let chain: ChatMessage[] = [];
let stats: ChatStats | undefined;
let reasoning = "";
let hadError = false;
await streamChatCompletion(
url,
headers,
apiMessages,
toWire(messagesToSend),
{
updateMessages: (updater) => setMessages(updater),
onError: (message) => setError(message),
onContentDelta: (delta) =>
setStreaming((s) => (s ? { ...s, content: s.content + delta } : s)),
onReasoningDelta: (delta) => {
reasoning += delta;
setStreaming((s) =>
s ? { ...s, reasoning: s.reasoning + delta } : s,
);
},
onChain: (fullChain) => {
chain = fullChain;
setStreaming((s) => (s ? { ...s, chain: fullChain } : s));
},
onStats: (s) => {
stats = s;
setStreaming((cur) => (cur ? { ...cur, stats: s } : cur));
},
onError: (message) => {
hadError = true;
setError(message);
},
onDone: () => {
abortRef.current = null;
setIsLoading(false);
setStreaming(null);
const lastMsg = chain[chain.length - 1];
if (!hadError && lastMsg?.role === "assistant") {
setMessages(
chain.map((m, i) =>
i === chain.length - 1
? { ...m, reasoning: reasoning || undefined, stats }
: m,
),
);
}
},
defaultErrorMessage: t("error"),
},
@ -125,12 +167,14 @@ export default function ChatPage() {
);
const recentEventIds = useMemo(() => {
const responses = toolResponsesById(messages);
for (let i = messages.length - 1; i >= 0; i--) {
const msg = messages[i];
if (msg.role !== "assistant" || !msg.toolCalls) continue;
const similar = getFindSimilarObjectsFromToolCalls(msg.toolCalls);
if (msg.role !== "assistant" || !msg.tool_calls?.length) continue;
const calls = toolCallsForMessage(msg, responses);
const similar = getFindSimilarObjectsFromToolCalls(calls);
if (similar) return similar.results.map((e) => e.id);
const events = getEventIdsFromSearchObjectsToolCalls(msg.toolCalls);
const events = getEventIdsFromSearchObjectsToolCalls(calls);
if (events.length > 0) return events.map((e) => e.id);
}
return [];
@ -154,12 +198,14 @@ export default function ChatPage() {
abortRef.current?.abort();
abortRef.current = null;
setIsLoading(false);
setStreaming(null);
}, []);
const startNewChat = useCallback(() => {
abortRef.current?.abort();
abortRef.current = null;
setIsLoading(false);
setStreaming(null);
setMessages([]);
setInput("");
setAttachedEventId(null);
@ -181,7 +227,83 @@ export default function ChatPage() {
setAttachedEventId(null);
}, []);
const hasStarted = messages.length > 0;
const hasStarted = messages.length > 0 || streaming != null;
// While streaming, the backend's in-flight chain is the source of truth;
// otherwise the committed conversation is.
const renderList =
streaming && streaming.chain.length ? streaming.chain : messages;
const responses = toolResponsesById(renderList);
const renderTail = renderList[renderList.length - 1];
const finalShown =
renderTail?.role === "assistant" && hasText(renderTail.content);
const renderMessage = (msg: ChatMessage, i: number) => {
if (msg.role === "system" || msg.role === "tool") return null;
if (msg.role === "user") {
if (!hasText(msg.content)) return null;
return (
<div key={i} className="flex flex-col gap-2">
<MessageBubble
role="user"
content={msg.content}
messageIndex={i}
onEditSubmit={handleEditSubmit}
isComplete
showStats={showStats}
/>
</div>
);
}
const calls = toolCallsForMessage(msg, responses);
const contentText = hasText(msg.content) ? msg.content : "";
const similar = getFindSimilarObjectsFromToolCalls(calls);
const events = similar ? [] : getEventIdsFromSearchObjectsToolCalls(calls);
return (
<div key={i} className="flex flex-col gap-2">
{calls.length > 0 && <ToolCallsGroup toolCalls={calls} />}
{hasText(msg.reasoning) && (
<ReasoningBubble
reasoning={msg.reasoning}
answerStarted={!!contentText}
/>
)}
{contentText && (
<MessageBubble
role="assistant"
content={contentText}
messageIndex={i}
isComplete
stats={msg.stats}
showStats={showStats}
/>
)}
{similar ? (
<ChatEventThumbnailsRow
events={similar.results}
anchor={similar.anchor}
onAttach={setAttachedEventId}
/>
) : (
<ChatEventThumbnailsRow
events={events}
onAttach={setAttachedEventId}
/>
)}
</div>
);
};
const processingDots = (
<div className="flex items-center gap-2 self-start rounded-2xl bg-muted px-5 py-4">
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.32s]" />
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.16s]" />
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60" />
</div>
);
return (
<div className="flex size-full flex-col">
@ -212,102 +334,31 @@ export default function ChatPage() {
<div className="flex w-full flex-col xl:w-[50%] 3xl:w-[35%]">
{hasStarted ? (
<div className="flex w-full flex-1 flex-col gap-3 pb-3">
{messages.map((msg, i) => {
const isLastAssistant =
i === messages.length - 1 && msg.role === "assistant";
const isComplete =
msg.role === "user" || !isLoading || !isLastAssistant;
const hasToolCalls =
msg.toolCalls && msg.toolCalls.length > 0;
const hasContent = !!msg.content?.trim();
const hasReasoning = !!msg.reasoning?.trim();
const showProcessing =
isLastAssistant &&
isLoading &&
!hasContent &&
!hasReasoning;
// Hide empty placeholder only when there are no tool calls
// and no reasoning streaming yet
if (
isLastAssistant &&
isLoading &&
!hasContent &&
!hasToolCalls &&
!hasReasoning
)
return (
<div
key={i}
className="flex items-center gap-2 self-start rounded-2xl bg-muted px-5 py-4"
>
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.32s]" />
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.16s]" />
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60" />
</div>
);
return (
<div key={i} className="flex flex-col gap-2">
{msg.role === "assistant" && hasToolCalls && (
<ToolCallsGroup toolCalls={msg.toolCalls!} />
)}
{msg.role === "assistant" && hasReasoning && (
{renderList.map((msg, i) => renderMessage(msg, i))}
{streaming &&
!finalShown &&
(streaming.content || streaming.reasoning ? (
<div className="flex flex-col gap-2">
{hasText(streaming.reasoning) && (
<ReasoningBubble
reasoning={msg.reasoning!}
answerStarted={hasContent}
reasoning={streaming.reasoning}
answerStarted={!!streaming.content}
/>
)}
{showProcessing ? (
<div className="flex items-center gap-2 self-start rounded-2xl bg-muted px-5 py-4">
<span className="size-2 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.3s]" />
<span className="size-2 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.15s]" />
<span className="size-2 animate-bounce rounded-full bg-muted-foreground/60" />
</div>
) : msg.role === "assistant" &&
!hasContent &&
hasReasoning &&
!isComplete ? null : (
{streaming.content && (
<MessageBubble
role={msg.role}
content={msg.content}
messageIndex={i}
onEditSubmit={
msg.role === "user" ? handleEditSubmit : undefined
}
isComplete={isComplete}
stats={msg.stats}
role="assistant"
content={streaming.content}
messageIndex={-1}
isComplete={false}
stats={streaming.stats}
showStats={showStats}
/>
)}
{msg.role === "assistant" &&
isComplete &&
(() => {
const similar = getFindSimilarObjectsFromToolCalls(
msg.toolCalls,
);
if (similar) {
return (
<ChatEventThumbnailsRow
events={similar.results}
anchor={similar.anchor}
onAttach={setAttachedEventId}
/>
);
}
const events = getEventIdsFromSearchObjectsToolCalls(
msg.toolCalls,
);
return (
<ChatEventThumbnailsRow
events={events}
onAttach={setAttachedEventId}
/>
);
})()}
</div>
);
})}
) : (
processingDots
))}
{error && (
<p
className="flex items-center gap-1.5 self-start text-sm text-destructive"

View File

@ -1,17 +1,30 @@
export type ToolCallFunction = {
name: string;
arguments: string;
};
export type WireToolCall = {
id: string;
type?: string;
function: ToolCallFunction;
};
export type ChatMessage = {
role: "system" | "user" | "assistant" | "tool";
content: unknown;
tool_call_id?: string;
name?: string;
tool_calls?: WireToolCall[];
reasoning?: string;
stats?: ChatStats;
};
export type ToolCall = {
name: string;
arguments?: Record<string, unknown>;
response?: string;
};
export type ChatMessage = {
role: "user" | "assistant";
content: string;
reasoning?: string;
toolCalls?: ToolCall[];
stats?: ChatStats;
};
export type StartingRequest = {
label: string;
prompt: string;

View File

@ -1,16 +1,20 @@
import type { ChatMessage, ChatStats, ToolCall } from "@/types/chat";
export type StreamChatCallbacks = {
/** Update the messages array (e.g. pass to setState). */
updateMessages: (updater: (prev: ChatMessage[]) => ChatMessage[]) => void;
/** Streamed delta of the assistant's final answer text. */
onContentDelta: (delta: string) => void;
/** Streamed delta of the assistant's reasoning trace. */
onReasoningDelta: (delta: string) => void;
/** The full conversation chain so far (system message, history, this turn's
* tool-call turns, tool results, and on the final emission the final
* assistant message). */
onChain: (chain: ChatMessage[]) => void;
/** Token/timing stats for the turn. */
onStats: (stats: ChatStats) => void;
/** Called when the stream sends an error or fetch fails. */
onError: (message: string) => void;
/** Called when the stream finishes (success or error). */
onDone: () => void;
/** Called when the stream emits token/timing stats. The stats are also
* attached to the last assistant message in updateMessages, so consumers
* can usually rely on the message itself rather than wiring this up. */
onStats?: (stats: ChatStats) => void;
/** Message used when fetch throws and no server error is available. */
defaultErrorMessage?: string;
};
@ -25,7 +29,7 @@ type StatsChunk = {
type StreamChunk =
| { type: "error"; error: string }
| { type: "tool_calls"; tool_calls: ToolCall[] }
| { type: "messages"; messages: ChatMessage[] }
| { type: "content"; delta: string }
| { type: "reasoning"; delta: string }
| StatsChunk;
@ -41,16 +45,18 @@ export type StreamChatOptions = {
export async function streamChatCompletion(
url: string,
headers: Record<string, string>,
apiMessages: { role: string; content: string }[],
apiMessages: ChatMessage[],
callbacks: StreamChatCallbacks,
signal?: AbortSignal,
options: StreamChatOptions = {},
): Promise<void> {
const {
updateMessages,
onContentDelta,
onReasoningDelta,
onChain,
onStats,
onError,
onDone,
onStats,
defaultErrorMessage = "Something went wrong. Please try again.",
} = callbacks;
@ -91,65 +97,27 @@ export async function streamChatCompletion(
const applyChunk = (data: StreamChunk) => {
if (data.type === "error") {
onError(data.error);
updateMessages((prev) =>
prev.filter((m) => !(m.role === "assistant" && m.content === "")),
);
return "break";
}
if (data.type === "tool_calls" && data.tool_calls?.length) {
updateMessages((prev) => {
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant")
next[next.length - 1] = {
...lastMsg,
toolCalls: data.tool_calls,
};
return next;
});
if (data.type === "messages") {
onChain(data.messages ?? []);
return "continue";
}
if (data.type === "content" && data.delta !== undefined) {
updateMessages((prev) => {
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant")
next[next.length - 1] = {
...lastMsg,
content: lastMsg.content + data.delta,
};
return next;
});
onContentDelta(data.delta);
return "continue";
}
if (data.type === "reasoning" && data.delta !== undefined) {
updateMessages((prev) => {
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant")
next[next.length - 1] = {
...lastMsg,
reasoning: (lastMsg.reasoning ?? "") + data.delta,
};
return next;
});
onReasoningDelta(data.delta);
return "continue";
}
if (data.type === "stats") {
const stats: ChatStats = {
onStats({
promptTokens: data.prompt_tokens,
completionTokens: data.completion_tokens,
completionDurationMs: data.completion_duration_ms,
tokensPerSecond: data.tokens_per_second,
};
updateMessages((prev) => {
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant")
next[next.length - 1] = { ...lastMsg, stats };
return next;
});
onStats?.(stats);
return "continue";
}
return "continue";
@ -165,9 +133,8 @@ export async function streamChatCompletion(
const trimmed = line.trim();
if (!trimmed) continue;
try {
const data = JSON.parse(trimmed) as StreamChunk & { type: string };
const result = applyChunk(data as StreamChunk);
if (result === "break") {
const data = JSON.parse(trimmed) as StreamChunk;
if (applyChunk(data) === "break") {
hadStreamError = true;
break;
}
@ -181,50 +148,63 @@ export async function streamChatCompletion(
// Flush remaining buffer
if (!hadStreamError && buffer.trim()) {
try {
const data = JSON.parse(buffer.trim()) as StreamChunk & {
type: string;
delta?: string;
};
if (data.type === "content" && data.delta !== undefined) {
updateMessages((prev) => {
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant")
next[next.length - 1] = {
...lastMsg,
content: lastMsg.content + data.delta!,
};
return next;
});
}
const data = JSON.parse(buffer.trim()) as StreamChunk;
applyChunk(data);
} catch {
// ignore final malformed chunk
}
}
if (!hadStreamError) {
updateMessages((prev) => {
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant" && lastMsg.content === "")
next[next.length - 1] = { ...lastMsg, content: " " };
return next;
});
}
} catch (err) {
if (err instanceof DOMException && err.name === "AbortError") {
// User stopped generation — not an error
} else {
onError(defaultErrorMessage);
updateMessages((prev) =>
prev.filter((m) => !(m.role === "assistant" && m.content === "")),
);
}
} finally {
onDone();
}
}
/** Map each tool result message to its tool_call_id for response lookup. */
export function toolResponsesById(
messages: ChatMessage[],
): Map<string, string> {
const map = new Map<string, string>();
for (const m of messages) {
if (m.role === "tool" && typeof m.tool_call_id === "string") {
map.set(
m.tool_call_id,
typeof m.content === "string" ? m.content : JSON.stringify(m.content),
);
}
}
return map;
}
/** Derive the display tool calls for one assistant message. */
export function toolCallsForMessage(
message: ChatMessage,
responses: Map<string, string>,
): ToolCall[] {
if (!message.tool_calls?.length) return [];
return message.tool_calls.map((tc) => {
let args: Record<string, unknown> | undefined;
const raw = tc.function?.arguments;
if (typeof raw === "string") {
try {
args = JSON.parse(raw) as Record<string, unknown>;
} catch {
args = undefined;
}
}
return {
name: tc.function?.name ?? "",
arguments: args,
response: responses.get(tc.id),
};
});
}
/**
* Parse search_objects tool call response(s) into event ids for thumbnails.
*/