diff --git a/frigate/api/chat.py b/frigate/api/chat.py index a54af2ac66..2e6283a525 100644 --- a/frigate/api/chat.py +++ b/frigate/api/chat.py @@ -7,7 +7,7 @@ import operator import time from datetime import datetime from functools import reduce -from typing import Any, Dict, List, Optional +from typing import Any, Optional import cv2 from fastapi import APIRouter, Body, Depends, HTTPException, Request @@ -59,7 +59,7 @@ class ToolExecuteRequest(BaseModel): """Request model for tool execution.""" tool_name: str - arguments: Dict[str, Any] + arguments: dict[str, Any] class VLMMonitorRequest(BaseModel): @@ -68,8 +68,8 @@ class VLMMonitorRequest(BaseModel): camera: str condition: str max_duration_minutes: int = 60 - labels: List[str] = [] - zones: List[str] = [] + labels: list[str] = [] + zones: list[str] = [] @router.get( @@ -91,10 +91,10 @@ def get_tools(request: Request) -> JSONResponse: def _resolve_zones( - zones: List[str], + zones: list[str], config: FrigateConfig, - target_cameras: List[str], -) -> List[str]: + target_cameras: list[str], +) -> list[str]: """Map zone names to their canonical config keys, case-insensitively. LLMs frequently echo a user's casing ("Front Yard") instead of the @@ -107,7 +107,7 @@ def _resolve_zones( if not zones: return zones - lookup: Dict[str, str] = {} + lookup: dict[str, str] = {} for camera_id in target_cameras: camera_config = config.cameras.get(camera_id) if camera_config is None: @@ -120,8 +120,8 @@ def _resolve_zones( async def _execute_search_objects( request: Request, - arguments: Dict[str, Any], - allowed_cameras: List[str], + arguments: dict[str, Any], + allowed_cameras: list[str], ) -> JSONResponse: """ Execute the search_objects tool. @@ -213,8 +213,8 @@ async def _execute_search_objects( async def _execute_search_objects_semantic( request: Request, - arguments: Dict[str, Any], - allowed_cameras: List[str], + arguments: dict[str, Any], + allowed_cameras: list[str], semantic_query: str, ) -> JSONResponse: """Search objects via fused thumbnail + description embeddings. @@ -263,8 +263,8 @@ async def _execute_search_objects_semantic( limit = int(arguments.get("limit", 25)) limit = max(1, min(limit, 100)) - visual_distances: Dict[str, float] = {} - description_distances: Dict[str, float] = {} + visual_distances: dict[str, float] = {} + description_distances: dict[str, float] = {} try: rows = context.search_thumbnail(semantic_query) visual_distances = {row[0]: row[1] for row in rows} @@ -305,7 +305,7 @@ async def _execute_search_objects_semantic( eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))} - scored: List[tuple[str, float]] = [] + scored: list[tuple[str, float]] = [] for eid in eligible: v_score = ( distance_to_score(visual_distances[eid], context.thumb_stats) @@ -331,9 +331,9 @@ async def _execute_search_objects_semantic( async def _execute_find_similar_objects( request: Request, - arguments: Dict[str, Any], - allowed_cameras: List[str], -) -> Dict[str, Any]: + arguments: dict[str, Any], + allowed_cameras: list[str], +) -> dict[str, Any]: """Execute the find_similar_objects tool. Returns a plain dict (not JSONResponse) so the chat loop can embed it @@ -403,8 +403,8 @@ async def _execute_find_similar_objects( # version (see frigate/embeddings/__init__.py). Mirror the pattern used by # frigate/api/event.py events_search: fetch top-k globally, then intersect # with the structured filters via Peewee. - visual_distances: Dict[str, float] = {} - description_distances: Dict[str, float] = {} + visual_distances: dict[str, float] = {} + description_distances: dict[str, float] = {} try: if similarity_mode in ("visual", "fused"): @@ -462,7 +462,7 @@ async def _execute_find_similar_objects( eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))} # 6. Fuse and rank. - scored: List[tuple[str, float]] = [] + scored: list[tuple[str, float]] = [] for eid in eligible: v_score = ( distance_to_score(visual_distances[eid], context.thumb_stats) @@ -503,7 +503,7 @@ async def _execute_find_similar_objects( async def execute_tool( request: Request, body: ToolExecuteRequest = Body(...), - allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter), + allowed_cameras: list[str] = Depends(get_allowed_cameras_for_filter), ) -> JSONResponse: """ Execute a tool function call. @@ -545,8 +545,8 @@ async def execute_tool( async def _execute_get_live_context( request: Request, camera: str, - allowed_cameras: List[str], -) -> Dict[str, Any]: + allowed_cameras: list[str], +) -> dict[str, Any]: # Reject wildcards explicitly so models retry with a real camera name # instead of silently fanning out across every camera. if camera in ("*", "all"): @@ -593,7 +593,7 @@ async def _execute_get_live_context( "stationary": obj_dict.get("stationary", False), } - result: Dict[str, Any] = { + result: dict[str, Any] = { "camera": camera, "timestamp": frame_time, "detections": list(tracked_objects_dict.values()), @@ -620,7 +620,7 @@ async def _execute_get_live_context( async def _get_live_frame_image_url( request: Request, camera: str, - allowed_cameras: List[str], + allowed_cameras: list[str], ) -> Optional[str]: """ Fetch the current live frame for a camera as a base64 data URL. @@ -659,8 +659,8 @@ async def _get_live_frame_image_url( async def _execute_set_camera_state( request: Request, - arguments: Dict[str, Any], -) -> Dict[str, Any]: + arguments: dict[str, Any], +) -> dict[str, Any]: role = request.headers.get("remote-role", "") if "admin" not in [r.strip() for r in role.split(",")]: return {"error": "Admin privileges required to change camera settings."} @@ -699,10 +699,10 @@ async def _execute_set_camera_state( async def _execute_tool_internal( tool_name: str, - arguments: Dict[str, Any], + arguments: dict[str, Any], request: Request, - allowed_cameras: List[str], -) -> Dict[str, Any]: + allowed_cameras: list[str], +) -> dict[str, Any]: """ Internal helper to execute a tool and return the result as a dict. @@ -763,8 +763,8 @@ async def _execute_tool_internal( async def _execute_start_camera_watch( request: Request, - arguments: Dict[str, Any], -) -> Dict[str, Any]: + arguments: dict[str, Any], +) -> dict[str, Any]: camera = arguments.get("camera", "").strip() condition = arguments.get("condition", "").strip() max_duration_minutes = int(arguments.get("max_duration_minutes", 60)) @@ -814,14 +814,14 @@ async def _execute_start_camera_watch( } -def _execute_stop_camera_watch() -> Dict[str, Any]: +def _execute_stop_camera_watch() -> dict[str, Any]: cancelled = stop_vlm_watch_job() if cancelled: return {"success": True, "message": "Watch job cancelled."} return {"success": False, "message": "No active watch job to cancel."} -def _execute_get_profile_status(request: Request) -> Dict[str, Any]: +def _execute_get_profile_status(request: Request) -> dict[str, Any]: """Return profile status including active profile and activation timestamps.""" profile_manager = getattr(request.app, "profile_manager", None) if profile_manager is None: @@ -846,9 +846,9 @@ def _execute_get_profile_status(request: Request) -> Dict[str, Any]: def _execute_get_recap( - arguments: Dict[str, Any], - allowed_cameras: List[str], -) -> Dict[str, Any]: + arguments: dict[str, Any], + allowed_cameras: list[str], +) -> dict[str, Any]: """Fetch review segments with GenAI metadata for a time period.""" from functools import reduce @@ -909,7 +909,7 @@ def _execute_get_recap( .iterator() ) - events: List[Dict[str, Any]] = [] + events: list[dict[str, Any]] = [] for row in rows: data = row.get("data") or {} @@ -920,7 +920,7 @@ def _execute_get_recap( data = {} camera = row["camera"] - event: Dict[str, Any] = { + event: dict[str, Any] = { "camera": camera.replace("_", " ").title(), "severity": row.get("severity", "detection"), } @@ -984,10 +984,10 @@ def _execute_get_recap( async def _execute_pending_tools( - pending_tool_calls: List[Dict[str, Any]], + pending_tool_calls: list[dict[str, Any]], request: Request, - allowed_cameras: List[str], -) -> tuple[List[ToolCall], List[Dict[str, Any]], List[Dict[str, Any]]]: + allowed_cameras: list[str], +) -> tuple[list[ToolCall], list[dict[str, Any]], list[dict[str, Any]]]: """ Execute a list of tool calls. @@ -996,9 +996,9 @@ async def _execute_pending_tools( tool result dicts for conversation, extra messages to inject after tool results — e.g. user messages with images) """ - tool_calls_out: List[ToolCall] = [] - tool_results: List[Dict[str, Any]] = [] - extra_messages: List[Dict[str, Any]] = [] + tool_calls_out: list[ToolCall] = [] + tool_results: list[dict[str, Any]] = [] + extra_messages: list[dict[str, Any]] = [] for tool_call in pending_tool_calls: tool_name = tool_call["name"] tool_args = tool_call.get("arguments") or {} @@ -1106,7 +1106,7 @@ async def _execute_pending_tools( async def chat_completion( request: Request, body: ChatCompletionRequest = Body(...), - allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter), + allowed_cameras: list[str] = Depends(get_allowed_cameras_for_filter), ): """ Chat completion endpoint with tool calling support. @@ -1153,7 +1153,7 @@ async def chat_completion( ) for msg in body.messages: - msg_dict: Dict[str, Any] = { + msg_dict = { "role": msg.role, "content": msg.content, } @@ -1161,20 +1161,16 @@ async def chat_completion( msg_dict["tool_call_id"] = msg.tool_call_id if msg.name: msg_dict["name"] = msg.name - # Replayed assistant turns carry their original tool_calls so the - # rendered prefix matches the prior turn exactly (prompt caching). if msg.tool_calls is not None: msg_dict["tool_calls"] = msg.tool_calls conversation.append(msg_dict) - # Everything appended from here on belongs to the assistant turn we are - # about to generate. We hand this slice back to the client so it can replay - # it verbatim on the next turn, keeping the cached prompt prefix intact. + # Messages appended past this point form this turn's replay record. turn_start_len = len(conversation) tool_iterations = 0 - tool_calls: List[ToolCall] = [] + tool_calls: list[ToolCall] = [] max_iterations = body.max_tool_iterations logger.debug( @@ -1184,17 +1180,12 @@ async def chat_completion( # True LLM streaming when client supports it and stream requested if body.stream and hasattr(genai_client, "chat_with_tools_stream"): - stream_tool_calls: List[ToolCall] = [] stream_iterations = 0 async def stream_body_llm(): - nonlocal conversation, stream_tool_calls, stream_iterations + nonlocal conversation, stream_iterations - def _emit_replay_messages(extra: Optional[List[Dict[str, Any]]] = None): - # Hand the client the exact messages appended for this assistant - # turn (assistant tool-call turns, tool results, injected image - # messages, and the final assistant message) so it can replay - # them verbatim next turn and keep the prompt cache warm. + def _emit_replay_messages(extra: Optional[list[dict[str, Any]]] = None): turn_messages = conversation[turn_start_len:] + (extra or []) return ( json.dumps({"type": "messages", "messages": turn_messages}).encode( @@ -1267,41 +1258,32 @@ async def chat_completion( ) return ( - executed_calls, + _executed_calls, tool_results, extra_msgs, ) = await _execute_pending_tools( pending, request, allowed_cameras ) - stream_tool_calls.extend(executed_calls) conversation.extend(tool_results) conversation.extend(extra_msgs) - yield ( - json.dumps( - { - "type": "tool_calls", - "tool_calls": [ - tc.model_dump() for tc in stream_tool_calls - ], - } - ).encode("utf-8") - + b"\n" - ) + # Running turn slice: lets the client render tool + # calls live and replay them verbatim next turn. + yield _emit_replay_messages() break else: - # Final answer: the streaming loop never appends the - # last assistant message to `conversation`, so add it - # to the replay slice explicitly. - final_assistant = { - "role": "assistant", - "content": msg.get("content"), - } - yield _emit_replay_messages(extra=[final_assistant]) + # Streaming never appends the final assistant message + # to the conversation, so add it to the replay slice. + yield _emit_replay_messages( + extra=[ + { + "role": "assistant", + "content": msg.get("content"), + } + ] + ) yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n") return else: - # Max iterations reached: replay whatever we accumulated so the - # next turn still starts from a cache-friendly prefix. yield _emit_replay_messages() yield json.dumps({"type": "done"}).encode("utf-8") + b"\n" @@ -1349,19 +1331,15 @@ async def chat_completion( if body.stream: final_reasoning = response.get("reasoning") + turn_messages = conversation[turn_start_len:] + async def stream_body() -> Any: - if tool_calls: - yield ( - json.dumps( - { - "type": "tool_calls", - "tool_calls": [ - tc.model_dump() for tc in tool_calls - ], - } - ).encode("utf-8") - + b"\n" - ) + yield ( + json.dumps( + {"type": "messages", "messages": turn_messages} + ).encode("utf-8") + + b"\n" + ) # Emit the full reasoning trace up front when the # underlying client did not stream it if final_reasoning: diff --git a/frigate/api/defs/response/chat_response.py b/frigate/api/defs/response/chat_response.py index c2b3e6b1f2..105104baa4 100644 --- a/frigate/api/defs/response/chat_response.py +++ b/frigate/api/defs/response/chat_response.py @@ -56,3 +56,12 @@ class ChatCompletionResponse(BaseModel): default_factory=list, description="List of tool calls that were executed during this completion", ) + messages: list[dict[str, Any]] = Field( + default_factory=list, + description=( + "The exact conversation messages appended for this assistant turn " + "(assistant tool-call turns, tool results, and the final assistant " + "message). Replay these verbatim as conversation history on the next " + "request to keep the model server's prompt cache prefix intact." + ), + )