From 4eeac987b8cbec020ae80f55f6a0a4580c2863ea Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Tue, 20 Jan 2026 07:57:11 -0700 Subject: [PATCH] Implement other providers --- frigate/api/chat.py | 22 ++-- frigate/genai/azure-openai.py | 93 ++++++++++++++++- frigate/genai/gemini.py | 189 +++++++++++++++++++++++++++++++++- frigate/genai/ollama.py | 118 +++++++++++++++++++++ 4 files changed, 407 insertions(+), 15 deletions(-) diff --git a/frigate/api/chat.py b/frigate/api/chat.py index bc092a9b1..eeff3ab6d 100644 --- a/frigate/api/chat.py +++ b/frigate/api/chat.py @@ -306,7 +306,7 @@ Always be accurate with time calculations based on the current date provided.""" tool_iterations = 0 max_iterations = body.max_tool_iterations - logger.info( + logger.debug( f"Starting chat completion with {len(conversation)} message(s), " f"{len(tools)} tool(s) available, max_iterations={max_iterations}" ) @@ -352,7 +352,7 @@ Always be accurate with time calculations based on the current date provided.""" tool_calls = response.get("tool_calls") if not tool_calls: - logger.info( + logger.debug( f"Chat completion finished with final answer (iterations: {tool_iterations})" ) return JSONResponse( @@ -369,7 +369,7 @@ Always be accurate with time calculations based on the current date provided.""" # Execute tools tool_iterations += 1 - logger.info( + logger.debug( f"Tool calls detected (iteration {tool_iterations}/{max_iterations}): " f"{len(tool_calls)} tool(s) to execute" ) @@ -380,7 +380,7 @@ Always be accurate with time calculations based on the current date provided.""" tool_args = tool_call["arguments"] tool_call_id = tool_call["id"] - logger.info( + logger.debug( f"Executing tool: {tool_name} (id: {tool_call_id}) with arguments: {json.dumps(tool_args, indent=2)}" ) @@ -402,19 +402,19 @@ Always be accurate with time calculations based on the current date provided.""" if result_count > 0 else [], } - logger.info( + logger.debug( f"Tool {tool_name} (id: {tool_call_id}) completed successfully. " f"Result: {json.dumps(result_summary, indent=2)}" ) elif isinstance(tool_result, str): result_content = tool_result - logger.info( + logger.debug( f"Tool {tool_name} (id: {tool_call_id}) completed successfully. " f"Result length: {len(result_content)} characters" ) else: result_content = str(tool_result) - logger.info( + logger.debug( f"Tool {tool_name} (id: {tool_call_id}) completed successfully. " f"Result type: {type(tool_result).__name__}" ) @@ -441,16 +441,12 @@ Always be accurate with time calculations based on the current date provided.""" "content": error_content, } ) - logger.info( + logger.debug( f"Tool {tool_name} (id: {tool_call_id}) failed. Error result added to conversation." ) conversation.extend(tool_results) - logger.info( - f"Added {len(tool_results)} tool result(s) to conversation. " - f"Continuing with next LLM call..." - ) - logger.info( + logger.debug( f"Added {len(tool_results)} tool result(s) to conversation. " f"Continuing with next LLM call..." ) diff --git a/frigate/genai/azure-openai.py b/frigate/genai/azure-openai.py index eba8b47c0..78ed376e5 100644 --- a/frigate/genai/azure-openai.py +++ b/frigate/genai/azure-openai.py @@ -1,8 +1,9 @@ """Azure OpenAI Provider for Frigate AI.""" import base64 +import json import logging -from typing import Optional +from typing import Any, Optional from urllib.parse import parse_qs, urlparse from openai import AzureOpenAI @@ -75,3 +76,93 @@ class OpenAIClient(GenAIClient): def get_context_size(self) -> int: """Get the context window size for Azure OpenAI.""" return 128000 + + def chat_with_tools( + self, + messages: list[dict[str, Any]], + tools: Optional[list[dict[str, Any]]] = None, + tool_choice: Optional[str] = "auto", + ) -> dict[str, Any]: + try: + openai_tool_choice = None + if tool_choice: + if tool_choice == "none": + openai_tool_choice = "none" + elif tool_choice == "auto": + openai_tool_choice = "auto" + elif tool_choice == "required": + openai_tool_choice = "required" + + request_params = { + "model": self.genai_config.model, + "messages": messages, + "timeout": self.timeout, + } + + if tools: + request_params["tools"] = tools + if openai_tool_choice is not None: + request_params["tool_choice"] = openai_tool_choice + + result = self.provider.chat.completions.create(**request_params) + + if ( + result is None + or not hasattr(result, "choices") + or len(result.choices) == 0 + ): + return { + "content": None, + "tool_calls": None, + "finish_reason": "error", + } + + choice = result.choices[0] + message = choice.message + + content = message.content.strip() if message.content else None + + tool_calls = None + if message.tool_calls: + tool_calls = [] + for tool_call in message.tool_calls: + try: + arguments = json.loads(tool_call.function.arguments) + except (json.JSONDecodeError, AttributeError) as e: + logger.warning( + f"Failed to parse tool call arguments: {e}, " + f"tool: {tool_call.function.name if hasattr(tool_call.function, 'name') else 'unknown'}" + ) + arguments = {} + + tool_calls.append( + { + "id": tool_call.id if hasattr(tool_call, "id") else "", + "name": tool_call.function.name + if hasattr(tool_call.function, "name") + else "", + "arguments": arguments, + } + ) + + finish_reason = "error" + if hasattr(choice, "finish_reason") and choice.finish_reason: + finish_reason = choice.finish_reason + elif tool_calls: + finish_reason = "tool_calls" + elif content: + finish_reason = "stop" + + return { + "content": content, + "tool_calls": tool_calls, + "finish_reason": finish_reason, + } + + except Exception as e: + logger.warning("Azure OpenAI returned an error: %s", str(e)) + return { + "content": None, + "tool_calls": None, + "finish_reason": "error", + } diff --git a/frigate/genai/gemini.py b/frigate/genai/gemini.py index f94448d75..fb462f003 100644 --- a/frigate/genai/gemini.py +++ b/frigate/genai/gemini.py @@ -1,7 +1,8 @@ """Gemini Provider for Frigate AI.""" +import json import logging -from typing import Optional +from typing import Any, Optional import google.generativeai as genai from google.api_core.exceptions import GoogleAPICallError @@ -58,3 +59,189 @@ class GeminiClient(GenAIClient): """Get the context window size for Gemini.""" # Gemini Pro Vision has a 1M token context window return 1000000 + + def chat_with_tools( + self, + messages: list[dict[str, Any]], + tools: Optional[list[dict[str, Any]]] = None, + tool_choice: Optional[str] = "auto", + ) -> dict[str, Any]: + try: + if tools: + function_declarations = [] + for tool in tools: + if tool.get("type") == "function": + func_def = tool.get("function", {}) + function_declarations.append( + genai.protos.FunctionDeclaration( + name=func_def.get("name"), + description=func_def.get("description"), + parameters=genai.protos.Schema( + type=genai.protos.Type.OBJECT, + properties={ + prop_name: genai.protos.Schema( + type=_convert_json_type_to_gemini( + prop.get("type") + ), + description=prop.get("description"), + ) + for prop_name, prop in func_def.get( + "parameters", {} + ) + .get("properties", {}) + .items() + }, + required=func_def.get("parameters", {}).get( + "required", [] + ), + ), + ) + ) + + tool_config = genai.protos.Tool( + function_declarations=function_declarations + ) + + if tool_choice == "none": + function_calling_config = genai.protos.FunctionCallingConfig( + mode=genai.protos.FunctionCallingConfig.Mode.NONE + ) + elif tool_choice == "required": + function_calling_config = genai.protos.FunctionCallingConfig( + mode=genai.protos.FunctionCallingConfig.Mode.ANY + ) + else: + function_calling_config = genai.protos.FunctionCallingConfig( + mode=genai.protos.FunctionCallingConfig.Mode.AUTO + ) + else: + tool_config = None + function_calling_config = None + + contents = [] + for msg in messages: + role = msg.get("role") + content = msg.get("content", "") + + if role == "system": + continue + elif role == "user": + contents.append({"role": "user", "parts": [content]}) + elif role == "assistant": + parts = [content] if content else [] + if "tool_calls" in msg: + for tc in msg["tool_calls"]: + parts.append( + genai.protos.FunctionCall( + name=tc["function"]["name"], + args=json.loads(tc["function"]["arguments"]), + ) + ) + contents.append({"role": "model", "parts": parts}) + elif role == "tool": + tool_call_id = msg.get("tool_call_id") + tool_name = msg.get("name", "") + tool_result = ( + json.loads(content) if isinstance(content, str) else content + ) + contents.append( + { + "role": "function", + "parts": [ + genai.protos.FunctionResponse( + name=tool_name, + response=tool_result, + ) + ], + } + ) + + generation_config = genai.types.GenerationConfig( + candidate_count=1, + ) + if function_calling_config: + generation_config.function_calling_config = function_calling_config + + response = self.provider.generate_content( + contents, + tools=[tool_config] if tool_config else None, + generation_config=generation_config, + request_options=genai.types.RequestOptions(timeout=self.timeout), + ) + + content = None + tool_calls = None + + if response.candidates and response.candidates[0].content: + parts = response.candidates[0].content.parts + text_parts = [p.text for p in parts if hasattr(p, "text") and p.text] + if text_parts: + content = " ".join(text_parts).strip() + + function_calls = [ + p.function_call + for p in parts + if hasattr(p, "function_call") and p.function_call + ] + if function_calls: + tool_calls = [] + for fc in function_calls: + tool_calls.append( + { + "id": f"call_{hash(fc.name)}", + "name": fc.name, + "arguments": dict(fc.args) + if hasattr(fc, "args") + else {}, + } + ) + + finish_reason = "error" + if response.candidates: + finish_reason_map = { + genai.types.FinishReason.STOP: "stop", + genai.types.FinishReason.MAX_TOKENS: "length", + genai.types.FinishReason.SAFETY: "stop", + genai.types.FinishReason.RECITATION: "stop", + genai.types.FinishReason.OTHER: "error", + } + finish_reason = finish_reason_map.get( + response.candidates[0].finish_reason, "error" + ) + elif tool_calls: + finish_reason = "tool_calls" + elif content: + finish_reason = "stop" + + return { + "content": content, + "tool_calls": tool_calls, + "finish_reason": finish_reason, + } + + except GoogleAPICallError as e: + logger.warning("Gemini returned an error: %s", str(e)) + return { + "content": None, + "tool_calls": None, + "finish_reason": "error", + } + except Exception as e: + logger.warning("Unexpected error in Gemini chat_with_tools: %s", str(e)) + return { + "content": None, + "tool_calls": None, + "finish_reason": "error", + } + + +def _convert_json_type_to_gemini(json_type: str) -> genai.protos.Type: + type_map = { + "string": genai.protos.Type.STRING, + "integer": genai.protos.Type.INTEGER, + "number": genai.protos.Type.NUMBER, + "boolean": genai.protos.Type.BOOLEAN, + "array": genai.protos.Type.ARRAY, + "object": genai.protos.Type.OBJECT, + } + return type_map.get(json_type, genai.protos.Type.STRING) diff --git a/frigate/genai/ollama.py b/frigate/genai/ollama.py index 9f9c8a750..84f5148e5 100644 --- a/frigate/genai/ollama.py +++ b/frigate/genai/ollama.py @@ -1,5 +1,6 @@ """Ollama Provider for Frigate AI.""" +import json import logging from typing import Any, Optional @@ -77,3 +78,120 @@ class OllamaClient(GenAIClient): return self.genai_config.provider_options.get("options", {}).get( "num_ctx", 4096 ) + + def chat_with_tools( + self, + messages: list[dict[str, Any]], + tools: Optional[list[dict[str, Any]]] = None, + tool_choice: Optional[str] = "auto", + ) -> dict[str, Any]: + if self.provider is None: + logger.warning( + "Ollama provider has not been initialized. Check your Ollama configuration." + ) + return { + "content": None, + "tool_calls": None, + "finish_reason": "error", + } + + try: + request_messages = [] + for msg in messages: + msg_dict = { + "role": msg.get("role"), + "content": msg.get("content", ""), + } + if msg.get("tool_call_id"): + msg_dict["tool_call_id"] = msg["tool_call_id"] + if msg.get("name"): + msg_dict["name"] = msg["name"] + if msg.get("tool_calls"): + msg_dict["tool_calls"] = msg["tool_calls"] + request_messages.append(msg_dict) + + request_params = { + "model": self.genai_config.model, + "messages": request_messages, + } + + if tools: + request_params["tools"] = tools + if tool_choice: + if tool_choice == "none": + request_params["tool_choice"] = "none" + elif tool_choice == "required": + request_params["tool_choice"] = "required" + elif tool_choice == "auto": + request_params["tool_choice"] = "auto" + + request_params.update(self.provider_options) + + response = self.provider.chat(**request_params) + + if not response or "message" not in response: + return { + "content": None, + "tool_calls": None, + "finish_reason": "error", + } + + message = response["message"] + content = ( + message.get("content", "").strip() if message.get("content") else None + ) + + tool_calls = None + if "tool_calls" in message and message["tool_calls"]: + tool_calls = [] + for tool_call in message["tool_calls"]: + try: + function_data = tool_call.get("function", {}) + arguments_str = function_data.get("arguments", "{}") + arguments = json.loads(arguments_str) + except (json.JSONDecodeError, KeyError, TypeError) as e: + logger.warning( + f"Failed to parse tool call arguments: {e}, " + f"tool: {function_data.get('name', 'unknown')}" + ) + arguments = {} + + tool_calls.append( + { + "id": tool_call.get("id", ""), + "name": function_data.get("name", ""), + "arguments": arguments, + } + ) + + finish_reason = "error" + if "done" in response and response["done"]: + if tool_calls: + finish_reason = "tool_calls" + elif content: + finish_reason = "stop" + elif tool_calls: + finish_reason = "tool_calls" + elif content: + finish_reason = "stop" + + return { + "content": content, + "tool_calls": tool_calls, + "finish_reason": finish_reason, + } + + except (TimeoutException, ResponseError, ConnectionError) as e: + logger.warning("Ollama returned an error: %s", str(e)) + return { + "content": None, + "tool_calls": None, + "finish_reason": "error", + } + except Exception as e: + logger.warning("Unexpected error in Ollama chat_with_tools: %s", str(e)) + return { + "content": None, + "tool_calls": None, + "finish_reason": "error", + }