From 264089974ea847111e9e4d9eae11f8e473aa845d Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Sat, 28 Feb 2026 07:38:00 -0700 Subject: [PATCH] Implement streaming for other providers --- frigate/genai/azure-openai.py | 120 +++++++++++++++++ frigate/genai/gemini.py | 237 ++++++++++++++++++++++++++++++++++ frigate/genai/openai.py | 139 ++++++++++++++++++++ 3 files changed, 496 insertions(+) diff --git a/frigate/genai/azure-openai.py b/frigate/genai/azure-openai.py index 21ed5d856..9122ca14e 100644 --- a/frigate/genai/azure-openai.py +++ b/frigate/genai/azure-openai.py @@ -167,3 +167,123 @@ class OpenAIClient(GenAIClient): "tool_calls": None, "finish_reason": "error", } + + async def chat_with_tools_stream( + self, + messages: list[dict[str, Any]], + tools: Optional[list[dict[str, Any]]] = None, + tool_choice: Optional[str] = "auto", + ): + """ + Stream chat with tools; yields content deltas then final message. + + Implements streaming function calling/tool usage for Azure OpenAI models. + """ + try: + openai_tool_choice = None + if tool_choice: + if tool_choice == "none": + openai_tool_choice = "none" + elif tool_choice == "auto": + openai_tool_choice = "auto" + elif tool_choice == "required": + openai_tool_choice = "required" + + request_params = { + "model": self.genai_config.model, + "messages": messages, + "timeout": self.timeout, + "stream": True, + } + + if tools: + request_params["tools"] = tools + if openai_tool_choice is not None: + request_params["tool_choice"] = openai_tool_choice + + # Use streaming API + content_parts: list[str] = [] + tool_calls_by_index: dict[int, dict[str, Any]] = {} + finish_reason = "stop" + + stream = self.provider.chat.completions.create(**request_params) + + for chunk in stream: + if not chunk or not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + + # Check for finish reason + if choice.finish_reason: + finish_reason = choice.finish_reason + + # Extract content deltas + if delta.content: + content_parts.append(delta.content) + yield ("content_delta", delta.content) + + # Extract tool calls + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + fn = tc.function + + if idx not in tool_calls_by_index: + tool_calls_by_index[idx] = { + "id": tc.id or "", + "name": fn.name if fn and fn.name else "", + "arguments": "", + } + + t = tool_calls_by_index[idx] + if tc.id: + t["id"] = tc.id + if fn and fn.name: + t["name"] = fn.name + if fn and fn.arguments: + t["arguments"] += fn.arguments + + # Build final message + full_content = "".join(content_parts).strip() or None + + # Convert tool calls to list format + tool_calls_list = None + if tool_calls_by_index: + tool_calls_list = [] + for tc in tool_calls_by_index.values(): + try: + # Parse accumulated arguments as JSON + parsed_args = json.loads(tc["arguments"]) + except (json.JSONDecodeError, Exception): + parsed_args = tc["arguments"] + + tool_calls_list.append( + { + "id": tc["id"], + "name": tc["name"], + "arguments": parsed_args, + } + ) + finish_reason = "tool_calls" + + yield ( + "message", + { + "content": full_content, + "tool_calls": tool_calls_list, + "finish_reason": finish_reason, + }, + ) + + except Exception as e: + logger.warning("Azure OpenAI streaming returned an error: %s", str(e)) + yield ( + "message", + { + "content": None, + "tool_calls": None, + "finish_reason": "error", + }, + ) diff --git a/frigate/genai/gemini.py b/frigate/genai/gemini.py index fd273faec..418d633b2 100644 --- a/frigate/genai/gemini.py +++ b/frigate/genai/gemini.py @@ -1,5 +1,6 @@ """Gemini Provider for Frigate AI.""" +import json import logging from typing import Any, Optional @@ -273,3 +274,239 @@ class GeminiClient(GenAIClient): "tool_calls": None, "finish_reason": "error", } + + async def chat_with_tools_stream( + self, + messages: list[dict[str, Any]], + tools: Optional[list[dict[str, Any]]] = None, + tool_choice: Optional[str] = "auto", + ): + """ + Stream chat with tools; yields content deltas then final message. + + Implements streaming function calling/tool usage for Gemini models. + """ + try: + # Convert messages to Gemini format + gemini_messages = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + # Map roles to Gemini format + if role == "system": + # Gemini doesn't have system role, prepend to first user message + if gemini_messages and gemini_messages[0].role == "user": + gemini_messages[0].parts[ + 0 + ].text = f"{content}\n\n{gemini_messages[0].parts[0].text}" + else: + gemini_messages.append( + types.Content( + role="user", parts=[types.Part.from_text(text=content)] + ) + ) + elif role == "assistant": + gemini_messages.append( + types.Content( + role="model", parts=[types.Part.from_text(text=content)] + ) + ) + elif role == "tool": + # Handle tool response + function_response = { + "name": msg.get("name", ""), + "response": content, + } + gemini_messages.append( + types.Content( + role="function", + parts=[ + types.Part.from_function_response(function_response) + ], + ) + ) + else: # user + gemini_messages.append( + types.Content( + role="user", parts=[types.Part.from_text(text=content)] + ) + ) + + # Convert tools to Gemini format + gemini_tools = None + if tools: + gemini_tools = [] + for tool in tools: + if tool.get("type") == "function": + func = tool.get("function", {}) + gemini_tools.append( + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name=func.get("name", ""), + description=func.get("description", ""), + parameters=func.get("parameters", {}), + ) + ] + ) + ) + + # Configure tool choice + tool_config = None + if tool_choice: + if tool_choice == "none": + tool_config = types.ToolConfig( + function_calling_config=types.FunctionCallingConfig(mode="NONE") + ) + elif tool_choice == "auto": + tool_config = types.ToolConfig( + function_calling_config=types.FunctionCallingConfig(mode="AUTO") + ) + elif tool_choice == "required": + tool_config = types.ToolConfig( + function_calling_config=types.FunctionCallingConfig(mode="ANY") + ) + + # Build request config + config_params = {"candidate_count": 1} + + if gemini_tools: + config_params["tools"] = gemini_tools + + if tool_config: + config_params["tool_config"] = tool_config + + # Merge runtime_options + if isinstance(self.genai_config.runtime_options, dict): + config_params.update(self.genai_config.runtime_options) + + # Use streaming API + content_parts: list[str] = [] + tool_calls_by_index: dict[int, dict[str, Any]] = {} + finish_reason = "stop" + + response = self.provider.models.generate_content_stream( + model=self.genai_config.model, + contents=gemini_messages, + config=types.GenerateContentConfig(**config_params), + ) + + async for chunk in response: + if not chunk or not chunk.candidates: + continue + + candidate = chunk.candidates[0] + + # Check for finish reason + if hasattr(candidate, "finish_reason") and candidate.finish_reason: + from google.genai.types import FinishReason + + if candidate.finish_reason == FinishReason.STOP: + finish_reason = "stop" + elif candidate.finish_reason == FinishReason.MAX_TOKENS: + finish_reason = "length" + elif candidate.finish_reason in [ + FinishReason.SAFETY, + FinishReason.RECITATION, + ]: + finish_reason = "error" + + # Extract content and tool calls from chunk + if candidate.content and candidate.content.parts: + for part in candidate.content.parts: + if part.text: + content_parts.append(part.text) + yield ("content_delta", part.text) + elif part.function_call: + # Handle function call + try: + arguments = ( + dict(part.function_call.args) + if part.function_call.args + else {} + ) + except Exception: + arguments = {} + + # Store tool call + tool_call_id = part.function_call.name or "" + tool_call_name = part.function_call.name or "" + + # Check if we already have this tool call + found_index = None + for idx, tc in tool_calls_by_index.items(): + if tc["name"] == tool_call_name: + found_index = idx + break + + if found_index is None: + found_index = len(tool_calls_by_index) + tool_calls_by_index[found_index] = { + "id": tool_call_id, + "name": tool_call_name, + "arguments": "", + } + + # Accumulate arguments + if arguments: + tool_calls_by_index[found_index]["arguments"] += ( + json.dumps(arguments) + if isinstance(arguments, dict) + else str(arguments) + ) + + # Build final message + full_content = "".join(content_parts).strip() or None + + # Convert tool calls to list format + tool_calls_list = None + if tool_calls_by_index: + tool_calls_list = [] + for tc in tool_calls_by_index.values(): + try: + # Try to parse accumulated arguments as JSON + parsed_args = json.loads(tc["arguments"]) + except (json.JSONDecodeError, Exception): + parsed_args = tc["arguments"] + + tool_calls_list.append( + { + "id": tc["id"], + "name": tc["name"], + "arguments": parsed_args, + } + ) + finish_reason = "tool_calls" + + yield ( + "message", + { + "content": full_content, + "tool_calls": tool_calls_list, + "finish_reason": finish_reason, + }, + ) + + except errors.APIError as e: + logger.warning("Gemini API error during streaming: %s", str(e)) + yield ( + "message", + { + "content": None, + "tool_calls": None, + "finish_reason": "error", + }, + ) + except Exception as e: + logger.warning( + "Gemini returned an error during chat_with_tools_stream: %s", str(e) + ) + yield ( + "message", + { + "content": None, + "tool_calls": None, + "finish_reason": "error", + }, + ) diff --git a/frigate/genai/openai.py b/frigate/genai/openai.py index c8d9ca7ab..d7bedd19d 100644 --- a/frigate/genai/openai.py +++ b/frigate/genai/openai.py @@ -227,3 +227,142 @@ class OpenAIClient(GenAIClient): "tool_calls": None, "finish_reason": "error", } + + async def chat_with_tools_stream( + self, + messages: list[dict[str, Any]], + tools: Optional[list[dict[str, Any]]] = None, + tool_choice: Optional[str] = "auto", + ): + """ + Stream chat with tools; yields content deltas then final message. + + Implements streaming function calling/tool usage for OpenAI models. + """ + try: + openai_tool_choice = None + if tool_choice: + if tool_choice == "none": + openai_tool_choice = "none" + elif tool_choice == "auto": + openai_tool_choice = "auto" + elif tool_choice == "required": + openai_tool_choice = "required" + + request_params = { + "model": self.genai_config.model, + "messages": messages, + "timeout": self.timeout, + "stream": True, + } + + if tools: + request_params["tools"] = tools + if openai_tool_choice is not None: + request_params["tool_choice"] = openai_tool_choice + + if isinstance(self.genai_config.provider_options, dict): + excluded_options = {"context_size"} + provider_opts = { + k: v + for k, v in self.genai_config.provider_options.items() + if k not in excluded_options + } + request_params.update(provider_opts) + + # Use streaming API + content_parts: list[str] = [] + tool_calls_by_index: dict[int, dict[str, Any]] = {} + finish_reason = "stop" + + stream = self.provider.chat.completions.create(**request_params) + + for chunk in stream: + if not chunk or not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + + # Check for finish reason + if choice.finish_reason: + finish_reason = choice.finish_reason + + # Extract content deltas + if delta.content: + content_parts.append(delta.content) + yield ("content_delta", delta.content) + + # Extract tool calls + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + fn = tc.function + + if idx not in tool_calls_by_index: + tool_calls_by_index[idx] = { + "id": tc.id or "", + "name": fn.name if fn and fn.name else "", + "arguments": "", + } + + t = tool_calls_by_index[idx] + if tc.id: + t["id"] = tc.id + if fn and fn.name: + t["name"] = fn.name + if fn and fn.arguments: + t["arguments"] += fn.arguments + + # Build final message + full_content = "".join(content_parts).strip() or None + + # Convert tool calls to list format + tool_calls_list = None + if tool_calls_by_index: + tool_calls_list = [] + for tc in tool_calls_by_index.values(): + try: + # Parse accumulated arguments as JSON + parsed_args = json.loads(tc["arguments"]) + except (json.JSONDecodeError, Exception): + parsed_args = tc["arguments"] + + tool_calls_list.append( + { + "id": tc["id"], + "name": tc["name"], + "arguments": parsed_args, + } + ) + finish_reason = "tool_calls" + + yield ( + "message", + { + "content": full_content, + "tool_calls": tool_calls_list, + "finish_reason": finish_reason, + }, + ) + + except TimeoutException as e: + logger.warning("OpenAI streaming request timed out: %s", str(e)) + yield ( + "message", + { + "content": None, + "tool_calls": None, + "finish_reason": "error", + }, + ) + except Exception as e: + logger.warning("OpenAI streaming returned an error: %s", str(e)) + yield ( + "message", + { + "content": None, + "tool_calls": None, + "finish_reason": "error", + }, + )