Refactor to match single message implementation

This commit is contained in:
Nicolas Mowen 2026-06-11 16:22:46 -06:00
parent 7039dc5cb4
commit 222a26f720
2 changed files with 83 additions and 96 deletions

View File

@ -7,7 +7,7 @@ import operator
import time
from datetime import datetime
from functools import reduce
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import cv2
from fastapi import APIRouter, Body, Depends, HTTPException, Request
@ -59,7 +59,7 @@ class ToolExecuteRequest(BaseModel):
"""Request model for tool execution."""
tool_name: str
arguments: Dict[str, Any]
arguments: dict[str, Any]
class VLMMonitorRequest(BaseModel):
@ -68,8 +68,8 @@ class VLMMonitorRequest(BaseModel):
camera: str
condition: str
max_duration_minutes: int = 60
labels: List[str] = []
zones: List[str] = []
labels: list[str] = []
zones: list[str] = []
@router.get(
@ -91,10 +91,10 @@ def get_tools(request: Request) -> JSONResponse:
def _resolve_zones(
zones: List[str],
zones: list[str],
config: FrigateConfig,
target_cameras: List[str],
) -> List[str]:
target_cameras: list[str],
) -> list[str]:
"""Map zone names to their canonical config keys, case-insensitively.
LLMs frequently echo a user's casing ("Front Yard") instead of the
@ -107,7 +107,7 @@ def _resolve_zones(
if not zones:
return zones
lookup: Dict[str, str] = {}
lookup: dict[str, str] = {}
for camera_id in target_cameras:
camera_config = config.cameras.get(camera_id)
if camera_config is None:
@ -120,8 +120,8 @@ def _resolve_zones(
async def _execute_search_objects(
request: Request,
arguments: Dict[str, Any],
allowed_cameras: List[str],
arguments: dict[str, Any],
allowed_cameras: list[str],
) -> JSONResponse:
"""
Execute the search_objects tool.
@ -213,8 +213,8 @@ async def _execute_search_objects(
async def _execute_search_objects_semantic(
request: Request,
arguments: Dict[str, Any],
allowed_cameras: List[str],
arguments: dict[str, Any],
allowed_cameras: list[str],
semantic_query: str,
) -> JSONResponse:
"""Search objects via fused thumbnail + description embeddings.
@ -263,8 +263,8 @@ async def _execute_search_objects_semantic(
limit = int(arguments.get("limit", 25))
limit = max(1, min(limit, 100))
visual_distances: Dict[str, float] = {}
description_distances: Dict[str, float] = {}
visual_distances: dict[str, float] = {}
description_distances: dict[str, float] = {}
try:
rows = context.search_thumbnail(semantic_query)
visual_distances = {row[0]: row[1] for row in rows}
@ -305,7 +305,7 @@ async def _execute_search_objects_semantic(
eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))}
scored: List[tuple[str, float]] = []
scored: list[tuple[str, float]] = []
for eid in eligible:
v_score = (
distance_to_score(visual_distances[eid], context.thumb_stats)
@ -331,9 +331,9 @@ async def _execute_search_objects_semantic(
async def _execute_find_similar_objects(
request: Request,
arguments: Dict[str, Any],
allowed_cameras: List[str],
) -> Dict[str, Any]:
arguments: dict[str, Any],
allowed_cameras: list[str],
) -> dict[str, Any]:
"""Execute the find_similar_objects tool.
Returns a plain dict (not JSONResponse) so the chat loop can embed it
@ -403,8 +403,8 @@ async def _execute_find_similar_objects(
# version (see frigate/embeddings/__init__.py). Mirror the pattern used by
# frigate/api/event.py events_search: fetch top-k globally, then intersect
# with the structured filters via Peewee.
visual_distances: Dict[str, float] = {}
description_distances: Dict[str, float] = {}
visual_distances: dict[str, float] = {}
description_distances: dict[str, float] = {}
try:
if similarity_mode in ("visual", "fused"):
@ -462,7 +462,7 @@ async def _execute_find_similar_objects(
eligible = {e.id: e for e in Event.select().where(reduce(operator.and_, clauses))}
# 6. Fuse and rank.
scored: List[tuple[str, float]] = []
scored: list[tuple[str, float]] = []
for eid in eligible:
v_score = (
distance_to_score(visual_distances[eid], context.thumb_stats)
@ -503,7 +503,7 @@ async def _execute_find_similar_objects(
async def execute_tool(
request: Request,
body: ToolExecuteRequest = Body(...),
allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter),
allowed_cameras: list[str] = Depends(get_allowed_cameras_for_filter),
) -> JSONResponse:
"""
Execute a tool function call.
@ -545,8 +545,8 @@ async def execute_tool(
async def _execute_get_live_context(
request: Request,
camera: str,
allowed_cameras: List[str],
) -> Dict[str, Any]:
allowed_cameras: list[str],
) -> dict[str, Any]:
# Reject wildcards explicitly so models retry with a real camera name
# instead of silently fanning out across every camera.
if camera in ("*", "all"):
@ -593,7 +593,7 @@ async def _execute_get_live_context(
"stationary": obj_dict.get("stationary", False),
}
result: Dict[str, Any] = {
result: dict[str, Any] = {
"camera": camera,
"timestamp": frame_time,
"detections": list(tracked_objects_dict.values()),
@ -620,7 +620,7 @@ async def _execute_get_live_context(
async def _get_live_frame_image_url(
request: Request,
camera: str,
allowed_cameras: List[str],
allowed_cameras: list[str],
) -> Optional[str]:
"""
Fetch the current live frame for a camera as a base64 data URL.
@ -659,8 +659,8 @@ async def _get_live_frame_image_url(
async def _execute_set_camera_state(
request: Request,
arguments: Dict[str, Any],
) -> Dict[str, Any]:
arguments: dict[str, Any],
) -> dict[str, Any]:
role = request.headers.get("remote-role", "")
if "admin" not in [r.strip() for r in role.split(",")]:
return {"error": "Admin privileges required to change camera settings."}
@ -699,10 +699,10 @@ async def _execute_set_camera_state(
async def _execute_tool_internal(
tool_name: str,
arguments: Dict[str, Any],
arguments: dict[str, Any],
request: Request,
allowed_cameras: List[str],
) -> Dict[str, Any]:
allowed_cameras: list[str],
) -> dict[str, Any]:
"""
Internal helper to execute a tool and return the result as a dict.
@ -763,8 +763,8 @@ async def _execute_tool_internal(
async def _execute_start_camera_watch(
request: Request,
arguments: Dict[str, Any],
) -> Dict[str, Any]:
arguments: dict[str, Any],
) -> dict[str, Any]:
camera = arguments.get("camera", "").strip()
condition = arguments.get("condition", "").strip()
max_duration_minutes = int(arguments.get("max_duration_minutes", 60))
@ -814,14 +814,14 @@ async def _execute_start_camera_watch(
}
def _execute_stop_camera_watch() -> Dict[str, Any]:
def _execute_stop_camera_watch() -> dict[str, Any]:
cancelled = stop_vlm_watch_job()
if cancelled:
return {"success": True, "message": "Watch job cancelled."}
return {"success": False, "message": "No active watch job to cancel."}
def _execute_get_profile_status(request: Request) -> Dict[str, Any]:
def _execute_get_profile_status(request: Request) -> dict[str, Any]:
"""Return profile status including active profile and activation timestamps."""
profile_manager = getattr(request.app, "profile_manager", None)
if profile_manager is None:
@ -846,9 +846,9 @@ def _execute_get_profile_status(request: Request) -> Dict[str, Any]:
def _execute_get_recap(
arguments: Dict[str, Any],
allowed_cameras: List[str],
) -> Dict[str, Any]:
arguments: dict[str, Any],
allowed_cameras: list[str],
) -> dict[str, Any]:
"""Fetch review segments with GenAI metadata for a time period."""
from functools import reduce
@ -909,7 +909,7 @@ def _execute_get_recap(
.iterator()
)
events: List[Dict[str, Any]] = []
events: list[dict[str, Any]] = []
for row in rows:
data = row.get("data") or {}
@ -920,7 +920,7 @@ def _execute_get_recap(
data = {}
camera = row["camera"]
event: Dict[str, Any] = {
event: dict[str, Any] = {
"camera": camera.replace("_", " ").title(),
"severity": row.get("severity", "detection"),
}
@ -984,10 +984,10 @@ def _execute_get_recap(
async def _execute_pending_tools(
pending_tool_calls: List[Dict[str, Any]],
pending_tool_calls: list[dict[str, Any]],
request: Request,
allowed_cameras: List[str],
) -> tuple[List[ToolCall], List[Dict[str, Any]], List[Dict[str, Any]]]:
allowed_cameras: list[str],
) -> tuple[list[ToolCall], list[dict[str, Any]], list[dict[str, Any]]]:
"""
Execute a list of tool calls.
@ -996,9 +996,9 @@ async def _execute_pending_tools(
tool result dicts for conversation,
extra messages to inject after tool results e.g. user messages with images)
"""
tool_calls_out: List[ToolCall] = []
tool_results: List[Dict[str, Any]] = []
extra_messages: List[Dict[str, Any]] = []
tool_calls_out: list[ToolCall] = []
tool_results: list[dict[str, Any]] = []
extra_messages: list[dict[str, Any]] = []
for tool_call in pending_tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call.get("arguments") or {}
@ -1106,7 +1106,7 @@ async def _execute_pending_tools(
async def chat_completion(
request: Request,
body: ChatCompletionRequest = Body(...),
allowed_cameras: List[str] = Depends(get_allowed_cameras_for_filter),
allowed_cameras: list[str] = Depends(get_allowed_cameras_for_filter),
):
"""
Chat completion endpoint with tool calling support.
@ -1153,7 +1153,7 @@ async def chat_completion(
)
for msg in body.messages:
msg_dict: Dict[str, Any] = {
msg_dict = {
"role": msg.role,
"content": msg.content,
}
@ -1161,20 +1161,16 @@ async def chat_completion(
msg_dict["tool_call_id"] = msg.tool_call_id
if msg.name:
msg_dict["name"] = msg.name
# Replayed assistant turns carry their original tool_calls so the
# rendered prefix matches the prior turn exactly (prompt caching).
if msg.tool_calls is not None:
msg_dict["tool_calls"] = msg.tool_calls
conversation.append(msg_dict)
# Everything appended from here on belongs to the assistant turn we are
# about to generate. We hand this slice back to the client so it can replay
# it verbatim on the next turn, keeping the cached prompt prefix intact.
# Messages appended past this point form this turn's replay record.
turn_start_len = len(conversation)
tool_iterations = 0
tool_calls: List[ToolCall] = []
tool_calls: list[ToolCall] = []
max_iterations = body.max_tool_iterations
logger.debug(
@ -1184,17 +1180,12 @@ async def chat_completion(
# True LLM streaming when client supports it and stream requested
if body.stream and hasattr(genai_client, "chat_with_tools_stream"):
stream_tool_calls: List[ToolCall] = []
stream_iterations = 0
async def stream_body_llm():
nonlocal conversation, stream_tool_calls, stream_iterations
nonlocal conversation, stream_iterations
def _emit_replay_messages(extra: Optional[List[Dict[str, Any]]] = None):
# Hand the client the exact messages appended for this assistant
# turn (assistant tool-call turns, tool results, injected image
# messages, and the final assistant message) so it can replay
# them verbatim next turn and keep the prompt cache warm.
def _emit_replay_messages(extra: Optional[list[dict[str, Any]]] = None):
turn_messages = conversation[turn_start_len:] + (extra or [])
return (
json.dumps({"type": "messages", "messages": turn_messages}).encode(
@ -1267,41 +1258,32 @@ async def chat_completion(
)
return
(
executed_calls,
_executed_calls,
tool_results,
extra_msgs,
) = await _execute_pending_tools(
pending, request, allowed_cameras
)
stream_tool_calls.extend(executed_calls)
conversation.extend(tool_results)
conversation.extend(extra_msgs)
yield (
json.dumps(
{
"type": "tool_calls",
"tool_calls": [
tc.model_dump() for tc in stream_tool_calls
],
}
).encode("utf-8")
+ b"\n"
)
# Running turn slice: lets the client render tool
# calls live and replay them verbatim next turn.
yield _emit_replay_messages()
break
else:
# Final answer: the streaming loop never appends the
# last assistant message to `conversation`, so add it
# to the replay slice explicitly.
final_assistant = {
"role": "assistant",
"content": msg.get("content"),
}
yield _emit_replay_messages(extra=[final_assistant])
# Streaming never appends the final assistant message
# to the conversation, so add it to the replay slice.
yield _emit_replay_messages(
extra=[
{
"role": "assistant",
"content": msg.get("content"),
}
]
)
yield (json.dumps({"type": "done"}).encode("utf-8") + b"\n")
return
else:
# Max iterations reached: replay whatever we accumulated so the
# next turn still starts from a cache-friendly prefix.
yield _emit_replay_messages()
yield json.dumps({"type": "done"}).encode("utf-8") + b"\n"
@ -1349,19 +1331,15 @@ async def chat_completion(
if body.stream:
final_reasoning = response.get("reasoning")
turn_messages = conversation[turn_start_len:]
async def stream_body() -> Any:
if tool_calls:
yield (
json.dumps(
{
"type": "tool_calls",
"tool_calls": [
tc.model_dump() for tc in tool_calls
],
}
).encode("utf-8")
+ b"\n"
)
yield (
json.dumps(
{"type": "messages", "messages": turn_messages}
).encode("utf-8")
+ b"\n"
)
# Emit the full reasoning trace up front when the
# underlying client did not stream it
if final_reasoning:

View File

@ -56,3 +56,12 @@ class ChatCompletionResponse(BaseModel):
default_factory=list,
description="List of tool calls that were executed during this completion",
)
messages: list[dict[str, Any]] = Field(
default_factory=list,
description=(
"The exact conversation messages appended for this assistant turn "
"(assistant tool-call turns, tool results, and the final assistant "
"message). Replay these verbatim as conversation history on the next "
"request to keep the model server's prompt cache prefix intact."
),
)