Fix chat tool calling and prompt breaking (#23457)
Some checks failed
CI / AMD64 Build (push) Has been cancelled
CI / ARM Build (push) Has been cancelled
CI / Jetson Jetpack 6 (push) Has been cancelled
CI / AMD64 Extra Build (push) Has been cancelled
CI / ARM Extra Build (push) Has been cancelled
CI / Synaptics Build (push) Has been cancelled
CI / Assemble and push default build (push) Has been cancelled

* Implement tool call history keeping

* Refactor to match single message implementation

* Simplify data representation

* Cleanup chat page rendering

* Include system message to not break cache

* Formatting

* Update tests and update .gitignore
This commit is contained in:
Nicolas Mowen 2026-06-12 06:48:43 -06:00 committed by GitHub
parent e6601d50a6
commit d7ad3ba699
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 430 additions and 295 deletions

View File

@ -7,7 +7,7 @@ import operator
import time import time
from datetime import datetime from datetime import datetime
from functools import reduce from functools import reduce
from typing import Any, Dict, List, Optional from typing import Any, Optional
import cv2 import cv2
from fastapi import APIRouter, Body, Depends, HTTPException, Request from fastapi import APIRouter, Body, Depends, HTTPException, Request
@ -59,7 +59,7 @@ class ToolExecuteRequest(BaseModel):
"""Request model for tool execution.""" """Request model for tool execution."""
tool_name: str tool_name: str
arguments: Dict[str, Any] arguments: dict[str, Any]
class VLMMonitorRequest(BaseModel): class VLMMonitorRequest(BaseModel):
@ -68,8 +68,8 @@ class VLMMonitorRequest(BaseModel):
camera: str camera: str
condition: str condition: str
max_duration_minutes: int = 60 max_duration_minutes: int = 60
labels: List[str] = [] labels: list[str] = []
zones: List[str] = [] zones: list[str] = []
@router.get( @router.get(
@ -91,10 +91,10 @@ def get_tools(request: Request) -> JSONResponse:
def _resolve_zones( def _resolve_zones(
zones: List[str], zones: list[str],
config: FrigateConfig, config: FrigateConfig,
target_cameras: List[str], target_cameras: list[str],
) -> List[str]: ) -> list[str]:
"""Map zone names to their canonical config keys, case-insensitively. """Map zone names to their canonical config keys, case-insensitively.
LLMs frequently echo a user's casing ("Front Yard") instead of the LLMs frequently echo a user's casing ("Front Yard") instead of the
@ -107,7 +107,7 @@ def _resolve_zones(
if not zones: if not zones:
return zones return zones
lookup: Dict[str, str] = {} lookup: dict[str, str] = {}
for camera_id in target_cameras: for camera_id in target_cameras:
camera_config = config.cameras.get(camera_id) camera_config = config.cameras.get(camera_id)
if camera_config is None: if camera_config is None:
@ -120,8 +120,8 @@ def _resolve_zones(
async def _execute_search_objects( async def _execute_search_objects(
request: Request, request: Request,
arguments: Dict[str, Any], arguments: dict[str, Any],
allowed_cameras: List[str], allowed_cameras: list[str],
) -> JSONResponse: ) -> JSONResponse:
""" """
Execute the search_objects tool. Execute the search_objects tool.
@ -213,8 +213,8 @@ async def _execute_search_objects(
async def _execute_search_objects_semantic( async def _execute_search_objects_semantic(
request: Request, request: Request,
arguments: Dict[str, Any], arguments: dict[str, Any],
allowed_cameras: List[str], allowed_cameras: list[str],
semantic_query: str, semantic_query: str,
) -> JSONResponse: ) -> JSONResponse:
"""Search objects via fused thumbnail + description embeddings. """Search objects via fused thumbnail + description embeddings.
@ -263,8 +263,8 @@ async def _execute_search_objects_semantic(
limit = int(arguments.get("limit", 25)) limit = int(arguments.get("limit", 25))
limit = max(1, min(limit, 100)) limit = max(1, min(limit, 100))
visual_distances: Dict[str, float] = {} visual_distances: dict[str, float] = {}
description_distances: Dict[str, float] = {} description_distances: dict[str, float] = {}
try: try:
rows = context.search_thumbnail(semantic_query) rows = context.search_thumbnail(semantic_query)
visual_distances = {row[0]: row[1] for row in rows} visual_distances = {row[0]: row[1] for row in rows}
@ -305,7 +305,7 @@ async def _execute_search_objects_semantic(
eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))} eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))}
scored: List[tuple[str, float]] = [] scored: list[tuple[str, float]] = []
for eid in eligible: for eid in eligible:
v_score = ( v_score = (
distance_to_score(visual_distances[eid], context.thumb_stats) distance_to_score(visual_distances[eid], context.thumb_stats)
@ -331,9 +331,9 @@ async def _execute_search_objects_semantic(
async def _execute_find_similar_objects( async def _execute_find_similar_objects(
request: Request, request: Request,
arguments: Dict[str, Any], arguments: dict[str, Any],
allowed_cameras: List[str], allowed_cameras: list[str],
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Execute the find_similar_objects tool. """Execute the find_similar_objects tool.
Returns a plain dict (not JSONResponse) so the chat loop can embed it Returns a plain dict (not JSONResponse) so the chat loop can embed it
@ -403,8 +403,8 @@ async def _execute_find_similar_objects(
# version (see frigate/embeddings/__init__.py). Mirror the pattern used by # version (see frigate/embeddings/__init__.py). Mirror the pattern used by
# frigate/api/event.py events_search: fetch top-k globally, then intersect # frigate/api/event.py events_search: fetch top-k globally, then intersect
# with the structured filters via Peewee. # with the structured filters via Peewee.
visual_distances: Dict[str, float] = {} visual_distances: dict[str, float] = {}
description_distances: Dict[str, float] = {} description_distances: dict[str, float] = {}
try: try:
if similarity_mode in ("visual", "fused"): if similarity_mode in ("visual", "fused"):
@ -462,7 +462,7 @@ async def _execute_find_similar_objects(
eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))} eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))}
# 6. Fuse and rank. # 6. Fuse and rank.
scored: List[tuple[str, float]] = [] scored: list[tuple[str, float]] = []
for eid in eligible: for eid in eligible:
v_score = ( v_score = (
distance_to_score(visual_distances[eid], context.thumb_stats) distance_to_score(visual_distances[eid], context.thumb_stats)
@ -503,7 +503,7 @@ async def _execute_find_similar_objects(
async def execute_tool( async def execute_tool(
request: Request, request: Request,
body: ToolExecuteRequest = Body(...), body: ToolExecuteRequest = Body(...),
allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter), allowed_cameras: list[str] = Depends(get_allowed_cameras_for_filter),
) -> JSONResponse: ) -> JSONResponse:
""" """
Execute a tool function call. Execute a tool function call.
@ -545,8 +545,8 @@ async def execute_tool(
async def _execute_get_live_context( async def _execute_get_live_context(
request: Request, request: Request,
camera: str, camera: str,
allowed_cameras: List[str], allowed_cameras: list[str],
) -> Dict[str, Any]: ) -> dict[str, Any]:
# Reject wildcards explicitly so models retry with a real camera name # Reject wildcards explicitly so models retry with a real camera name
# instead of silently fanning out across every camera. # instead of silently fanning out across every camera.
if camera in ("*", "all"): if camera in ("*", "all"):
@ -593,7 +593,7 @@ async def _execute_get_live_context(
"stationary": obj_dict.get("stationary", False), "stationary": obj_dict.get("stationary", False),
} }
result: Dict[str, Any] = { result: dict[str, Any] = {
"camera": camera, "camera": camera,
"timestamp": frame_time, "timestamp": frame_time,
"detections": list(tracked_objects_dict.values()), "detections": list(tracked_objects_dict.values()),
@ -620,7 +620,7 @@ async def _execute_get_live_context(
async def _get_live_frame_image_url( async def _get_live_frame_image_url(
request: Request, request: Request,
camera: str, camera: str,
allowed_cameras: List[str], allowed_cameras: list[str],
) -> Optional[str]: ) -> Optional[str]:
""" """
Fetch the current live frame for a camera as a base64 data URL. Fetch the current live frame for a camera as a base64 data URL.
@ -659,8 +659,8 @@ async def _get_live_frame_image_url(
async def _execute_set_camera_state( async def _execute_set_camera_state(
request: Request, request: Request,
arguments: Dict[str, Any], arguments: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
role = request.headers.get("remote-role", "") role = request.headers.get("remote-role", "")
if "admin" not in [r.strip() for r in role.split(",")]: if "admin" not in [r.strip() for r in role.split(",")]:
return {"error": "Admin privileges required to change camera settings."} return {"error": "Admin privileges required to change camera settings."}
@ -699,10 +699,10 @@ async def _execute_set_camera_state(
async def _execute_tool_internal( async def _execute_tool_internal(
tool_name: str, tool_name: str,
arguments: Dict[str, Any], arguments: dict[str, Any],
request: Request, request: Request,
allowed_cameras: List[str], allowed_cameras: list[str],
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Internal helper to execute a tool and return the result as a dict. Internal helper to execute a tool and return the result as a dict.
@ -763,8 +763,8 @@ async def _execute_tool_internal(
async def _execute_start_camera_watch( async def _execute_start_camera_watch(
request: Request, request: Request,
arguments: Dict[str, Any], arguments: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
camera = arguments.get("camera", "").strip() camera = arguments.get("camera", "").strip()
condition = arguments.get("condition", "").strip() condition = arguments.get("condition", "").strip()
max_duration_minutes = int(arguments.get("max_duration_minutes", 60)) max_duration_minutes = int(arguments.get("max_duration_minutes", 60))
@ -814,14 +814,14 @@ async def _execute_start_camera_watch(
} }
def _execute_stop_camera_watch() -> Dict[str, Any]: def _execute_stop_camera_watch() -> dict[str, Any]:
cancelled = stop_vlm_watch_job() cancelled = stop_vlm_watch_job()
if cancelled: if cancelled:
return {"success": True, "message": "Watch job cancelled."} return {"success": True, "message": "Watch job cancelled."}
return {"success": False, "message": "No active watch job to cancel."} return {"success": False, "message": "No active watch job to cancel."}
def _execute_get_profile_status(request: Request) -> Dict[str, Any]: def _execute_get_profile_status(request: Request) -> dict[str, Any]:
"""Return profile status including active profile and activation timestamps.""" """Return profile status including active profile and activation timestamps."""
profile_manager = getattr(request.app, "profile_manager", None) profile_manager = getattr(request.app, "profile_manager", None)
if profile_manager is None: if profile_manager is None:
@ -846,9 +846,9 @@ def _execute_get_profile_status(request: Request) -> Dict[str, Any]:
def _execute_get_recap( def _execute_get_recap(
arguments: Dict[str, Any], arguments: dict[str, Any],
allowed_cameras: List[str], allowed_cameras: list[str],
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Fetch review segments with GenAI metadata for a time period.""" """Fetch review segments with GenAI metadata for a time period."""
from functools import reduce from functools import reduce
@ -909,7 +909,7 @@ def _execute_get_recap(
.iterator() .iterator()
) )
events: List[Dict[str, Any]] = [] events: list[dict[str, Any]] = []
for row in rows: for row in rows:
data = row.get("data") or {} data = row.get("data") or {}
@ -920,7 +920,7 @@ def _execute_get_recap(
data = {} data = {}
camera = row["camera"] camera = row["camera"]
event: Dict[str, Any] = { event: dict[str, Any] = {
"camera": camera.replace("_", " ").title(), "camera": camera.replace("_", " ").title(),
"severity": row.get("severity", "detection"), "severity": row.get("severity", "detection"),
} }
@ -984,10 +984,10 @@ def _execute_get_recap(
async def _execute_pending_tools( async def _execute_pending_tools(
pending_tool_calls: List[Dict[str, Any]], pending_tool_calls: list[dict[str, Any]],
request: Request, request: Request,
allowed_cameras: List[str], allowed_cameras: list[str],
) -> tuple[List[ToolCall], List[Dict[str, Any]], List[Dict[str, Any]]]: ) -> tuple[list[ToolCall], list[dict[str, Any]], list[dict[str, Any]]]:
""" """
Execute a list of tool calls. Execute a list of tool calls.
@ -996,9 +996,9 @@ async def _execute_pending_tools(
tool result dicts for conversation, tool result dicts for conversation,
extra messages to inject after tool results e.g. user messages with images) extra messages to inject after tool results e.g. user messages with images)
""" """
tool_calls_out: List[ToolCall] = [] tool_calls_out: list[ToolCall] = []
tool_results: List[Dict[str, Any]] = [] tool_results: list[dict[str, Any]] = []
extra_messages: List[Dict[str, Any]] = [] extra_messages: list[dict[str, Any]] = []
for tool_call in pending_tool_calls: for tool_call in pending_tool_calls:
tool_name = tool_call["name"] tool_name = tool_call["name"]
tool_args = tool_call.get("arguments") or {} tool_args = tool_call.get("arguments") or {}
@ -1106,7 +1106,7 @@ async def _execute_pending_tools(
async def chat_completion( async def chat_completion(
request: Request, request: Request,
body: ChatCompletionRequest = Body(...), body: ChatCompletionRequest = Body(...),
allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter), allowed_cameras: list[str] = Depends(get_allowed_cameras_for_filter),
): ):
""" """
Chat completion endpoint with tool calling support. Chat completion endpoint with tool calling support.
@ -1138,19 +1138,23 @@ async def chat_completion(
) )
conversation = [] conversation = []
system_prompt = build_chat_system_prompt( # Build the system message only when the client hasn't already pinned one.
config=config, # The first turn has no system message; we generate it (with the current
allowed_cameras=allowed_cameras, # timestamp) and return the whole chain so the client persists it. Later
semantic_search_enabled=semantic_search_enabled, # turns send it back verbatim, freezing the timestamp so the prompt prefix
attribute_classifications=attribute_classifications, # stays byte-identical and the model server's prompt cache keeps hitting.
) if not body.messages or body.messages[0].role != "system":
conversation.append(
conversation.append( {
{ "role": "system",
"role": "system", "content": build_chat_system_prompt(
"content": system_prompt, config=config,
} allowed_cameras=allowed_cameras,
) semantic_search_enabled=semantic_search_enabled,
attribute_classifications=attribute_classifications,
),
}
)
for msg in body.messages: for msg in body.messages:
msg_dict = { msg_dict = {
@ -1161,11 +1165,13 @@ async def chat_completion(
msg_dict["tool_call_id"] = msg.tool_call_id msg_dict["tool_call_id"] = msg.tool_call_id
if msg.name: if msg.name:
msg_dict["name"] = msg.name msg_dict["name"] = msg.name
if msg.tool_calls is not None:
msg_dict["tool_calls"] = msg.tool_calls
conversation.append(msg_dict) conversation.append(msg_dict)
tool_iterations = 0 tool_iterations = 0
tool_calls: List[ToolCall] = [] tool_calls: list[ToolCall] = []
max_iterations = body.max_tool_iterations max_iterations = body.max_tool_iterations
logger.debug( logger.debug(
@ -1175,11 +1181,20 @@ async def chat_completion(
# True LLM streaming when client supports it and stream requested # True LLM streaming when client supports it and stream requested
if body.stream and hasattr(genai_client, "chat_with_tools_stream"): if body.stream and hasattr(genai_client, "chat_with_tools_stream"):
stream_tool_calls: List[ToolCall] = []
stream_iterations = 0 stream_iterations = 0
async def stream_body_llm(): async def stream_body_llm():
nonlocal conversation, stream_tool_calls, stream_iterations nonlocal conversation, stream_iterations
def _emit_chain(extra: Optional[list[dict[str, Any]]] = None):
# Return the full conversation (including the system message) so
# the client persists and replays it verbatim next turn.
chain = conversation + (extra or [])
return (
json.dumps({"type": "messages", "messages": chain}).encode("utf-8")
+ b"\n"
)
while stream_iterations < max_iterations: while stream_iterations < max_iterations:
if await request.is_disconnected(): if await request.is_disconnected():
logger.debug("Client disconnected, stopping chat stream") logger.debug("Client disconnected, stopping chat stream")
@ -1244,31 +1259,33 @@ async def chat_completion(
) )
return return
( (
executed_calls, _executed_calls,
tool_results, tool_results,
extra_msgs, extra_msgs,
) = await _execute_pending_tools( ) = await _execute_pending_tools(
pending, request, allowed_cameras pending, request, allowed_cameras
) )
stream_tool_calls.extend(executed_calls)
conversation.extend(tool_results) conversation.extend(tool_results)
conversation.extend(extra_msgs) conversation.extend(extra_msgs)
yield ( # Emit the running chain so the client can render tool
json.dumps( # calls live and replay them verbatim next turn.
{ yield _emit_chain()
"type": "tool_calls",
"tool_calls": [
tc.model_dump() for tc in stream_tool_calls
],
}
).encode("utf-8")
+ b"\n"
)
break break
else: else:
# Streaming never appends the final assistant message
# to the conversation, so add it to the chain.
yield _emit_chain(
extra=[
{
"role": "assistant",
"content": msg.get("content"),
}
]
)
yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n") yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n")
return return
else: else:
yield _emit_chain()
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n" yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
return StreamingResponse( return StreamingResponse(
@ -1315,19 +1332,15 @@ async def chat_completion(
if body.stream: if body.stream:
final_reasoning = response.get("reasoning") final_reasoning = response.get("reasoning")
chain = list(conversation)
async def stream_body() -> Any: async def stream_body() -> Any:
if tool_calls: yield (
yield ( json.dumps({"type": "messages", "messages": chain}).encode(
json.dumps( "utf-8"
{
"type": "tool_calls",
"tool_calls": [
tc.model_dump() for tc in tool_calls
],
}
).encode("utf-8")
+ b"\n"
) )
+ b"\n"
)
# Emit the full reasoning trace up front when the # Emit the full reasoning trace up front when the
# underlying client did not stream it # underlying client did not stream it
if final_reasoning: if final_reasoning:
@ -1363,6 +1376,7 @@ async def chat_completion(
finish_reason=response.get("finish_reason", "stop"), finish_reason=response.get("finish_reason", "stop"),
tool_iterations=tool_iterations, tool_iterations=tool_iterations,
tool_calls=tool_calls, tool_calls=tool_calls,
messages=list(conversation),
).model_dump(), ).model_dump(),
) )
@ -1395,6 +1409,7 @@ async def chat_completion(
finish_reason="length", finish_reason="length",
tool_iterations=tool_iterations, tool_iterations=tool_iterations,
tool_calls=tool_calls, tool_calls=tool_calls,
messages=list(conversation),
).model_dump(), ).model_dump(),
) )

View File

@ -1,6 +1,6 @@
"""Chat API request models.""" """Chat API request models."""
from typing import Optional from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -11,13 +11,29 @@ class ChatMessage(BaseModel):
role: str = Field( role: str = Field(
description="Message role: 'user', 'assistant', 'system', or 'tool'" description="Message role: 'user', 'assistant', 'system', or 'tool'"
) )
content: str = Field(description="Message content") content: Optional[Any] = Field(
default=None,
description=(
"Message content. Usually a string, but may be a multimodal content "
"list (e.g. text + image_url) or null for assistant turns that only "
"request tool calls."
),
)
tool_call_id: Optional[str] = Field( tool_call_id: Optional[str] = Field(
default=None, description="For tool messages, the ID of the tool call" default=None, description="For tool messages, the ID of the tool call"
) )
name: Optional[str] = Field( name: Optional[str] = Field(
default=None, description="For tool messages, the tool name" default=None, description="For tool messages, the tool name"
) )
tool_calls: Optional[list[dict[str, Any]]] = Field(
default=None,
description=(
"For assistant messages replayed from prior turns, the OpenAI-format "
"tool calls the model previously requested. Replaying these verbatim "
"keeps the conversation prefix byte-for-byte identical so the model "
"server's prompt cache hits on follow-up turns."
),
)
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):

View File

@ -56,3 +56,12 @@ class ChatCompletionResponse(BaseModel):
default_factory=list, default_factory=list,
description="List of tool calls that were executed during this completion", description="List of tool calls that were executed during this completion",
) )
messages: list[dict[str, Any]] = Field(
default_factory=list,
description=(
"The full conversation chain, including the system message. Persist "
"and replay this verbatim on the next request so the prompt prefix "
"stays byte-identical and the model server's prompt cache keeps "
"hitting."
),
)

4
web/.gitignore vendored
View File

@ -12,6 +12,10 @@ dist
dist-ssr dist-ssr
*.local *.local
# Playwright
playwright-report
test-results
# Editor directories and files # Editor directories and files
.vscode/* .vscode/*
!.vscode/extensions.json !.vscode/extensions.json

View File

@ -92,6 +92,15 @@ test.describe("Chat — streaming @medium", () => {
await installChatStreamOverride(frigateApp, [ await installChatStreamOverride(frigateApp, [
{ type: "content", delta: "Hel" }, { type: "content", delta: "Hel" },
{ type: "content", delta: "lo" }, { type: "content", delta: "lo" },
{
type: "messages",
messages: [
{ role: "system", content: "sys" },
{ role: "user", content: "hello chat" },
{ role: "assistant", content: "Hello" },
],
},
{ type: "done" },
]); ]);
await frigateApp.goto("/chat"); await frigateApp.goto("/chat");
const input = frigateApp.page.getByPlaceholder(/ask/i); const input = frigateApp.page.getByPlaceholder(/ask/i);
@ -137,6 +146,15 @@ test.describe("Chat — streaming @medium", () => {
{ type: "content", delta: "Hel" }, { type: "content", delta: "Hel" },
{ type: "content", delta: "lo, " }, { type: "content", delta: "lo, " },
{ type: "content", delta: "world!" }, { type: "content", delta: "world!" },
{
type: "messages",
messages: [
{ role: "system", content: "sys" },
{ role: "user", content: "greet me" },
{ role: "assistant", content: "Hello, world!" },
],
},
{ type: "done" },
], ],
{ chunkDelayMs: 50 }, { chunkDelayMs: 50 },
); );
@ -151,19 +169,39 @@ test.describe("Chat — streaming @medium", () => {
}); });
}); });
test("tool_calls chunks render a ToolCallsGroup", async ({ frigateApp }) => { test("tool calls in the chain render a ToolCallsGroup", async ({
await installChatStreamOverride(frigateApp, [ frigateApp,
}) => {
const toolTurn = [
{ role: "system", content: "sys" },
{ role: "user", content: "find people" },
{ {
type: "tool_calls", role: "assistant",
content: null,
tool_calls: [ tool_calls: [
{ {
id: "call_1", id: "call_1",
name: "search_objects", type: "function",
arguments: { label: "person" }, function: {
name: "search_objects",
arguments: '{"label":"person"}',
},
}, },
], ],
}, },
{ role: "tool", tool_call_id: "call_1", content: "[]" },
];
await installChatStreamOverride(frigateApp, [
{ type: "messages", messages: toolTurn },
{ type: "content", delta: "Searching for people." }, { type: "content", delta: "Searching for people." },
{
type: "messages",
messages: [
...toolTurn,
{ role: "assistant", content: "Searching for people." },
],
},
{ type: "done" },
]); ]);
await frigateApp.goto("/chat"); await frigateApp.goto("/chat");
const input = frigateApp.page.getByPlaceholder(/ask/i); const input = frigateApp.page.getByPlaceholder(/ask/i);
@ -253,6 +291,15 @@ test.describe("Chat — attachment chip @medium", () => {
// We use the stream override so the first message completes quickly. // We use the stream override so the first message completes quickly.
await installChatStreamOverride(frigateApp, [ await installChatStreamOverride(frigateApp, [
{ type: "content", delta: "Done." }, { type: "content", delta: "Done." },
{
type: "messages",
messages: [
{ role: "system", content: "sys" },
{ role: "user", content: "hello" },
{ role: "assistant", content: "Done." },
],
},
{ type: "done" },
]); ]);
await frigateApp.goto("/chat"); await frigateApp.goto("/chat");

View File

@ -13,6 +13,7 @@ import { ChatComposer } from "@/components/chat/ChatComposer";
import ChatSettings from "@/components/chat/ChatSettings"; import ChatSettings from "@/components/chat/ChatSettings";
import type { import type {
ChatMessage, ChatMessage,
ChatStats,
GenAIModelsResponse, GenAIModelsResponse,
ShowStatsMode, ShowStatsMode,
} from "@/types/chat"; } from "@/types/chat";
@ -22,12 +23,28 @@ import {
getFindSimilarObjectsFromToolCalls, getFindSimilarObjectsFromToolCalls,
prependAttachment, prependAttachment,
streamChatCompletion, streamChatCompletion,
toolCallsForMessage,
toolResponsesById,
} from "@/utils/chatUtil"; } from "@/utils/chatUtil";
type StreamingTurn = {
content: string;
reasoning: string;
chain: ChatMessage[];
stats?: ChatStats;
};
const hasText = (content: unknown): content is string =>
typeof content === "string" && content.trim().length > 0;
const toWire = (messages: ChatMessage[]): ChatMessage[] =>
messages.map(({ reasoning: _r, stats: _s, ...rest }) => rest);
export default function ChatPage() { export default function ChatPage() {
const { t } = useTranslation(["views/chat"]); const { t } = useTranslation(["views/chat"]);
const [input, setInput] = useState(""); const [input, setInput] = useState("");
const [messages, setMessages] = useState<ChatMessage[]>([]); const [messages, setMessages] = useState<ChatMessage[]>([]);
const [streaming, setStreaming] = useState<StreamingTurn | null>(null);
const [isLoading, setIsLoading] = useState(false); const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const [attachedEventId, setAttachedEventId] = useState<string | null>(null); const [attachedEventId, setAttachedEventId] = useState<string | null>(null);
@ -72,28 +89,19 @@ export default function ChatPage() {
if (isNearBottom) { if (isNearBottom) {
el.scrollTo({ top: el.scrollHeight, behavior: "smooth" }); el.scrollTo({ top: el.scrollHeight, behavior: "smooth" });
} }
}, [messages, autoScroll]); }, [messages, streaming, autoScroll]);
const submitConversation = useCallback( const submitConversation = useCallback(
async (messagesToSend: ChatMessage[]) => { async (messagesToSend: ChatMessage[]) => {
if (isLoading) return; if (isLoading) return;
const last = messagesToSend[messagesToSend.length - 1]; const last = messagesToSend[messagesToSend.length - 1];
if (!last || last.role !== "user" || !last.content.trim()) return; if (!last || last.role !== "user" || !hasText(last.content)) return;
setError(null); setError(null);
const assistantPlaceholder: ChatMessage = { setMessages(messagesToSend);
role: "assistant", setStreaming({ content: "", reasoning: "", chain: [] });
content: "",
toolCalls: undefined,
};
setMessages([...messagesToSend, assistantPlaceholder]);
setIsLoading(true); setIsLoading(true);
const apiMessages = messagesToSend.map((m) => ({
role: m.role,
content: m.content,
}));
const baseURL = axios.defaults.baseURL ?? ""; const baseURL = axios.defaults.baseURL ?? "";
const url = `${baseURL}chat/completion`; const url = `${baseURL}chat/completion`;
const headers: Record<string, string> = { const headers: Record<string, string> = {
@ -104,16 +112,50 @@ export default function ChatPage() {
const controller = new AbortController(); const controller = new AbortController();
abortRef.current = controller; abortRef.current = controller;
let chain: ChatMessage[] = [];
let stats: ChatStats | undefined;
let reasoning = "";
let hadError = false;
await streamChatCompletion( await streamChatCompletion(
url, url,
headers, headers,
apiMessages, toWire(messagesToSend),
{ {
updateMessages: (updater) => setMessages(updater), onContentDelta: (delta) =>
onError: (message) => setError(message), setStreaming((s) => (s ? { ...s, content: s.content + delta } : s)),
onReasoningDelta: (delta) => {
reasoning += delta;
setStreaming((s) =>
s ? { ...s, reasoning: s.reasoning + delta } : s,
);
},
onChain: (fullChain) => {
chain = fullChain;
setStreaming((s) => (s ? { ...s, chain: fullChain } : s));
},
onStats: (s) => {
stats = s;
setStreaming((cur) => (cur ? { ...cur, stats: s } : cur));
},
onError: (message) => {
hadError = true;
setError(message);
},
onDone: () => { onDone: () => {
abortRef.current = null; abortRef.current = null;
setIsLoading(false); setIsLoading(false);
setStreaming(null);
const lastMsg = chain[chain.length - 1];
if (!hadError && lastMsg?.role === "assistant") {
setMessages(
chain.map((m, i) =>
i === chain.length - 1
? { ...m, reasoning: reasoning || undefined, stats }
: m,
),
);
}
}, },
defaultErrorMessage: t("error"), defaultErrorMessage: t("error"),
}, },
@ -125,12 +167,14 @@ export default function ChatPage() {
); );
const recentEventIds = useMemo(() => { const recentEventIds = useMemo(() => {
const responses = toolResponsesById(messages);
for (let i = messages.length - 1; i >= 0; i--) { for (let i = messages.length - 1; i >= 0; i--) {
const msg = messages[i]; const msg = messages[i];
if (msg.role !== "assistant" || !msg.toolCalls) continue; if (msg.role !== "assistant" || !msg.tool_calls?.length) continue;
const similar = getFindSimilarObjectsFromToolCalls(msg.toolCalls); const calls = toolCallsForMessage(msg, responses);
const similar = getFindSimilarObjectsFromToolCalls(calls);
if (similar) return similar.results.map((e) => e.id); if (similar) return similar.results.map((e) => e.id);
const events = getEventIdsFromSearchObjectsToolCalls(msg.toolCalls); const events = getEventIdsFromSearchObjectsToolCalls(calls);
if (events.length > 0) return events.map((e) => e.id); if (events.length > 0) return events.map((e) => e.id);
} }
return []; return [];
@ -154,12 +198,14 @@ export default function ChatPage() {
abortRef.current?.abort(); abortRef.current?.abort();
abortRef.current = null; abortRef.current = null;
setIsLoading(false); setIsLoading(false);
setStreaming(null);
}, []); }, []);
const startNewChat = useCallback(() => { const startNewChat = useCallback(() => {
abortRef.current?.abort(); abortRef.current?.abort();
abortRef.current = null; abortRef.current = null;
setIsLoading(false); setIsLoading(false);
setStreaming(null);
setMessages([]); setMessages([]);
setInput(""); setInput("");
setAttachedEventId(null); setAttachedEventId(null);
@ -181,7 +227,83 @@ export default function ChatPage() {
setAttachedEventId(null); setAttachedEventId(null);
}, []); }, []);
const hasStarted = messages.length > 0; const hasStarted = messages.length > 0 || streaming != null;
// While streaming, the backend's in-flight chain is the source of truth;
// otherwise the committed conversation is.
const renderList =
streaming && streaming.chain.length ? streaming.chain : messages;
const responses = toolResponsesById(renderList);
const renderTail = renderList[renderList.length - 1];
const finalShown =
renderTail?.role === "assistant" && hasText(renderTail.content);
const renderMessage = (msg: ChatMessage, i: number) => {
if (msg.role === "system" || msg.role === "tool") return null;
if (msg.role === "user") {
if (!hasText(msg.content)) return null;
return (
<div key={i} className="flex flex-col gap-2">
<MessageBubble
role="user"
content={msg.content}
messageIndex={i}
onEditSubmit={handleEditSubmit}
isComplete
showStats={showStats}
/>
</div>
);
}
const calls = toolCallsForMessage(msg, responses);
const contentText = hasText(msg.content) ? msg.content : "";
const similar = getFindSimilarObjectsFromToolCalls(calls);
const events = similar ? [] : getEventIdsFromSearchObjectsToolCalls(calls);
return (
<div key={i} className="flex flex-col gap-2">
{calls.length > 0 && <ToolCallsGroup toolCalls={calls} />}
{hasText(msg.reasoning) && (
<ReasoningBubble
reasoning={msg.reasoning}
answerStarted={!!contentText}
/>
)}
{contentText && (
<MessageBubble
role="assistant"
content={contentText}
messageIndex={i}
isComplete
stats={msg.stats}
showStats={showStats}
/>
)}
{similar ? (
<ChatEventThumbnailsRow
events={similar.results}
anchor={similar.anchor}
onAttach={setAttachedEventId}
/>
) : (
<ChatEventThumbnailsRow
events={events}
onAttach={setAttachedEventId}
/>
)}
</div>
);
};
const processingDots = (
<div className="flex items-center gap-2 self-start rounded-2xl bg-muted px-5 py-4">
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.32s]" />
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.16s]" />
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60" />
</div>
);
return ( return (
<div className="flex size-full flex-col"> <div className="flex size-full flex-col">
@ -212,102 +334,31 @@ export default function ChatPage() {
<div className="flex w-full flex-col xl:w-[50%] 3xl:w-[35%]"> <div className="flex w-full flex-col xl:w-[50%] 3xl:w-[35%]">
{hasStarted ? ( {hasStarted ? (
<div className="flex w-full flex-1 flex-col gap-3 pb-3"> <div className="flex w-full flex-1 flex-col gap-3 pb-3">
{messages.map((msg, i) => { {renderList.map((msg, i) => renderMessage(msg, i))}
const isLastAssistant = {streaming &&
i === messages.length - 1 && msg.role === "assistant"; !finalShown &&
const isComplete = (streaming.content || streaming.reasoning ? (
msg.role === "user" || !isLoading || !isLastAssistant; <div className="flex flex-col gap-2">
const hasToolCalls = {hasText(streaming.reasoning) && (
msg.toolCalls && msg.toolCalls.length > 0;
const hasContent = !!msg.content?.trim();
const hasReasoning = !!msg.reasoning?.trim();
const showProcessing =
isLastAssistant &&
isLoading &&
!hasContent &&
!hasReasoning;
// Hide empty placeholder only when there are no tool calls
// and no reasoning streaming yet
if (
isLastAssistant &&
isLoading &&
!hasContent &&
!hasToolCalls &&
!hasReasoning
)
return (
<div
key={i}
className="flex items-center gap-2 self-start rounded-2xl bg-muted px-5 py-4"
>
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.32s]" />
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.16s]" />
<span className="size-2.5 animate-bounce rounded-full bg-muted-foreground/60" />
</div>
);
return (
<div key={i} className="flex flex-col gap-2">
{msg.role === "assistant" && hasToolCalls && (
<ToolCallsGroup toolCalls={msg.toolCalls!} />
)}
{msg.role === "assistant" && hasReasoning && (
<ReasoningBubble <ReasoningBubble
reasoning={msg.reasoning!} reasoning={streaming.reasoning}
answerStarted={hasContent} answerStarted={!!streaming.content}
/> />
)} )}
{showProcessing ? ( {streaming.content && (
<div className="flex items-center gap-2 self-start rounded-2xl bg-muted px-5 py-4">
<span className="size-2 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.3s]" />
<span className="size-2 animate-bounce rounded-full bg-muted-foreground/60 [animation-delay:-0.15s]" />
<span className="size-2 animate-bounce rounded-full bg-muted-foreground/60" />
</div>
) : msg.role === "assistant" &&
!hasContent &&
hasReasoning &&
!isComplete ? null : (
<MessageBubble <MessageBubble
role={msg.role} role="assistant"
content={msg.content} content={streaming.content}
messageIndex={i} messageIndex={-1}
onEditSubmit={ isComplete={false}
msg.role === "user" ? handleEditSubmit : undefined stats={streaming.stats}
}
isComplete={isComplete}
stats={msg.stats}
showStats={showStats} showStats={showStats}
/> />
)} )}
{msg.role === "assistant" &&
isComplete &&
(() => {
const similar = getFindSimilarObjectsFromToolCalls(
msg.toolCalls,
);
if (similar) {
return (
<ChatEventThumbnailsRow
events={similar.results}
anchor={similar.anchor}
onAttach={setAttachedEventId}
/>
);
}
const events = getEventIdsFromSearchObjectsToolCalls(
msg.toolCalls,
);
return (
<ChatEventThumbnailsRow
events={events}
onAttach={setAttachedEventId}
/>
);
})()}
</div> </div>
); ) : (
})} processingDots
))}
{error && ( {error && (
<p <p
className="flex items-center gap-1.5 self-start text-sm text-destructive" className="flex items-center gap-1.5 self-start text-sm text-destructive"

View File

@ -1,17 +1,30 @@
export type ToolCallFunction = {
name: string;
arguments: string;
};
export type WireToolCall = {
id: string;
type?: string;
function: ToolCallFunction;
};
export type ChatMessage = {
role: "system" | "user" | "assistant" | "tool";
content: unknown;
tool_call_id?: string;
name?: string;
tool_calls?: WireToolCall[];
reasoning?: string;
stats?: ChatStats;
};
export type ToolCall = { export type ToolCall = {
name: string; name: string;
arguments?: Record<string, unknown>; arguments?: Record<string, unknown>;
response?: string; response?: string;
}; };
export type ChatMessage = {
role: "user" | "assistant";
content: string;
reasoning?: string;
toolCalls?: ToolCall[];
stats?: ChatStats;
};
export type StartingRequest = { export type StartingRequest = {
label: string; label: string;
prompt: string; prompt: string;

View File

@ -1,16 +1,20 @@
import type { ChatMessage, ChatStats, ToolCall } from "@/types/chat"; import type { ChatMessage, ChatStats, ToolCall } from "@/types/chat";
export type StreamChatCallbacks = { export type StreamChatCallbacks = {
/** Update the messages array (e.g. pass to setState). */ /** Streamed delta of the assistant's final answer text. */
updateMessages: (updater: (prev: ChatMessage[]) => ChatMessage[]) => void; onContentDelta: (delta: string) => void;
/** Streamed delta of the assistant's reasoning trace. */
onReasoningDelta: (delta: string) => void;
/** The full conversation chain so far (system message, history, this turn's
* tool-call turns, tool results, and on the final emission the final
* assistant message). */
onChain: (chain: ChatMessage[]) => void;
/** Token/timing stats for the turn. */
onStats: (stats: ChatStats) => void;
/** Called when the stream sends an error or fetch fails. */ /** Called when the stream sends an error or fetch fails. */
onError: (message: string) => void; onError: (message: string) => void;
/** Called when the stream finishes (success or error). */ /** Called when the stream finishes (success or error). */
onDone: () => void; onDone: () => void;
/** Called when the stream emits token/timing stats. The stats are also
* attached to the last assistant message in updateMessages, so consumers
* can usually rely on the message itself rather than wiring this up. */
onStats?: (stats: ChatStats) => void;
/** Message used when fetch throws and no server error is available. */ /** Message used when fetch throws and no server error is available. */
defaultErrorMessage?: string; defaultErrorMessage?: string;
}; };
@ -25,7 +29,7 @@ type StatsChunk = {
type StreamChunk = type StreamChunk =
| { type: "error"; error: string } | { type: "error"; error: string }
| { type: "tool_calls"; tool_calls: ToolCall[] } | { type: "messages"; messages: ChatMessage[] }
| { type: "content"; delta: string } | { type: "content"; delta: string }
| { type: "reasoning"; delta: string } | { type: "reasoning"; delta: string }
| StatsChunk; | StatsChunk;
@ -41,16 +45,18 @@ export type StreamChatOptions = {
export async function streamChatCompletion( export async function streamChatCompletion(
url: string, url: string,
headers: Record<string, string>, headers: Record<string, string>,
apiMessages: { role: string; content: string }[], apiMessages: ChatMessage[],
callbacks: StreamChatCallbacks, callbacks: StreamChatCallbacks,
signal?: AbortSignal, signal?: AbortSignal,
options: StreamChatOptions = {}, options: StreamChatOptions = {},
): Promise<void> { ): Promise<void> {
const { const {
updateMessages, onContentDelta,
onReasoningDelta,
onChain,
onStats,
onError, onError,
onDone, onDone,
onStats,
defaultErrorMessage = "Something went wrong. Please try again.", defaultErrorMessage = "Something went wrong. Please try again.",
} = callbacks; } = callbacks;
@ -91,65 +97,27 @@ export async function streamChatCompletion(
const applyChunk = (data: StreamChunk) => { const applyChunk = (data: StreamChunk) => {
if (data.type === "error") { if (data.type === "error") {
onError(data.error); onError(data.error);
updateMessages((prev) =>
prev.filter((m) => !(m.role === "assistant" && m.content === "")),
);
return "break"; return "break";
} }
if (data.type === "tool_calls" && data.tool_calls?.length) { if (data.type === "messages") {
updateMessages((prev) => { onChain(data.messages ?? []);
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant")
next[next.length - 1] = {
...lastMsg,
toolCalls: data.tool_calls,
};
return next;
});
return "continue"; return "continue";
} }
if (data.type === "content" && data.delta !== undefined) { if (data.type === "content" && data.delta !== undefined) {
updateMessages((prev) => { onContentDelta(data.delta);
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant")
next[next.length - 1] = {
...lastMsg,
content: lastMsg.content + data.delta,
};
return next;
});
return "continue"; return "continue";
} }
if (data.type === "reasoning" && data.delta !== undefined) { if (data.type === "reasoning" && data.delta !== undefined) {
updateMessages((prev) => { onReasoningDelta(data.delta);
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant")
next[next.length - 1] = {
...lastMsg,
reasoning: (lastMsg.reasoning ?? "") + data.delta,
};
return next;
});
return "continue"; return "continue";
} }
if (data.type === "stats") { if (data.type === "stats") {
const stats: ChatStats = { onStats({
promptTokens: data.prompt_tokens, promptTokens: data.prompt_tokens,
completionTokens: data.completion_tokens, completionTokens: data.completion_tokens,
completionDurationMs: data.completion_duration_ms, completionDurationMs: data.completion_duration_ms,
tokensPerSecond: data.tokens_per_second, tokensPerSecond: data.tokens_per_second,
};
updateMessages((prev) => {
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant")
next[next.length - 1] = { ...lastMsg, stats };
return next;
}); });
onStats?.(stats);
return "continue"; return "continue";
} }
return "continue"; return "continue";
@ -165,9 +133,8 @@ export async function streamChatCompletion(
const trimmed = line.trim(); const trimmed = line.trim();
if (!trimmed) continue; if (!trimmed) continue;
try { try {
const data = JSON.parse(trimmed) as StreamChunk & { type: string }; const data = JSON.parse(trimmed) as StreamChunk;
const result = applyChunk(data as StreamChunk); if (applyChunk(data) === "break") {
if (result === "break") {
hadStreamError = true; hadStreamError = true;
break; break;
} }
@ -181,50 +148,63 @@ export async function streamChatCompletion(
// Flush remaining buffer // Flush remaining buffer
if (!hadStreamError && buffer.trim()) { if (!hadStreamError && buffer.trim()) {
try { try {
const data = JSON.parse(buffer.trim()) as StreamChunk & { const data = JSON.parse(buffer.trim()) as StreamChunk;
type: string; applyChunk(data);
delta?: string;
};
if (data.type === "content" && data.delta !== undefined) {
updateMessages((prev) => {
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant")
next[next.length - 1] = {
...lastMsg,
content: lastMsg.content + data.delta!,
};
return next;
});
}
} catch { } catch {
// ignore final malformed chunk // ignore final malformed chunk
} }
} }
if (!hadStreamError) {
updateMessages((prev) => {
const next = [...prev];
const lastMsg = next[next.length - 1];
if (lastMsg?.role === "assistant" && lastMsg.content === "")
next[next.length - 1] = { ...lastMsg, content: " " };
return next;
});
}
} catch (err) { } catch (err) {
if (err instanceof DOMException && err.name === "AbortError") { if (err instanceof DOMException && err.name === "AbortError") {
// User stopped generation — not an error // User stopped generation — not an error
} else { } else {
onError(defaultErrorMessage); onError(defaultErrorMessage);
updateMessages((prev) =>
prev.filter((m) => !(m.role === "assistant" && m.content === "")),
);
} }
} finally { } finally {
onDone(); onDone();
} }
} }
/** Map each tool result message to its tool_call_id for response lookup. */
export function toolResponsesById(
messages: ChatMessage[],
): Map<string, string> {
const map = new Map<string, string>();
for (const m of messages) {
if (m.role === "tool" && typeof m.tool_call_id === "string") {
map.set(
m.tool_call_id,
typeof m.content === "string" ? m.content : JSON.stringify(m.content),
);
}
}
return map;
}
/** Derive the display tool calls for one assistant message. */
export function toolCallsForMessage(
message: ChatMessage,
responses: Map<string, string>,
): ToolCall[] {
if (!message.tool_calls?.length) return [];
return message.tool_calls.map((tc) => {
let args: Record<string, unknown> | undefined;
const raw = tc.function?.arguments;
if (typeof raw === "string") {
try {
args = JSON.parse(raw) as Record<string, unknown>;
} catch {
args = undefined;
}
}
return {
name: tc.function?.name ?? "",
arguments: args,
response: responses.get(tc.id),
};
});
}
/** /**
* Parse search_objects tool call response(s) into event ids for thumbnails. * Parse search_objects tool call response(s) into event ids for thumbnails.
*/ */