Implement other providers

This commit is contained in:
Nicolas Mowen 2026-01-20 07:57:11 -07:00
parent 3acd12bc56
commit 4eeac987b8
4 changed files with 407 additions and 15 deletions

View File

@ -306,7 +306,7 @@ Always be accurate with time calculations based on the current date provided."""
tool_iterations = 0 tool_iterations = 0
max_iterations = body.max_tool_iterations max_iterations = body.max_tool_iterations
logger.info( logger.debug(
f"Starting chat completion with {len(conversation)} message(s), " f"Starting chat completion with {len(conversation)} message(s), "
f"{len(tools)} tool(s) available, max_iterations={max_iterations}" f"{len(tools)} tool(s) available, max_iterations={max_iterations}"
) )
@ -352,7 +352,7 @@ Always be accurate with time calculations based on the current date provided."""
tool_calls = response.get("tool_calls") tool_calls = response.get("tool_calls")
if not tool_calls: if not tool_calls:
logger.info( logger.debug(
f"Chat completion finished with final answer (iterations: {tool_iterations})" f"Chat completion finished with final answer (iterations: {tool_iterations})"
) )
return JSONResponse( return JSONResponse(
@ -369,7 +369,7 @@ Always be accurate with time calculations based on the current date provided."""
# Execute tools # Execute tools
tool_iterations += 1 tool_iterations += 1
logger.info( logger.debug(
f"Tool calls detected (iteration {tool_iterations}/{max_iterations}): " f"Tool calls detected (iteration {tool_iterations}/{max_iterations}): "
f"{len(tool_calls)} tool(s) to execute" f"{len(tool_calls)} tool(s) to execute"
) )
@ -380,7 +380,7 @@ Always be accurate with time calculations based on the current date provided."""
tool_args = tool_call["arguments"] tool_args = tool_call["arguments"]
tool_call_id = tool_call["id"] tool_call_id = tool_call["id"]
logger.info( logger.debug(
f"Executing tool: {tool_name} (id: {tool_call_id}) with arguments: {json.dumps(tool_args, indent=2)}" f"Executing tool: {tool_name} (id: {tool_call_id}) with arguments: {json.dumps(tool_args, indent=2)}"
) )
@ -402,19 +402,19 @@ Always be accurate with time calculations based on the current date provided."""
if result_count > 0 if result_count > 0
else [], else [],
} }
logger.info( logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. " f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
f"Result: {json.dumps(result_summary, indent=2)}" f"Result: {json.dumps(result_summary, indent=2)}"
) )
elif isinstance(tool_result, str): elif isinstance(tool_result, str):
result_content = tool_result result_content = tool_result
logger.info( logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. " f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
f"Result length: {len(result_content)} characters" f"Result length: {len(result_content)} characters"
) )
else: else:
result_content = str(tool_result) result_content = str(tool_result)
logger.info( logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. " f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
f"Result type: {type(tool_result).__name__}" f"Result type: {type(tool_result).__name__}"
) )
@ -441,16 +441,12 @@ Always be accurate with time calculations based on the current date provided."""
"content": error_content, "content": error_content,
} }
) )
logger.info( logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) failed. Error result added to conversation." f"Tool {tool_name} (id: {tool_call_id}) failed. Error result added to conversation."
) )
conversation.extend(tool_results) conversation.extend(tool_results)
logger.info( logger.debug(
f"Added {len(tool_results)} tool result(s) to conversation. "
f"Continuing with next LLM call..."
)
logger.info(
f"Added {len(tool_results)} tool result(s) to conversation. " f"Added {len(tool_results)} tool result(s) to conversation. "
f"Continuing with next LLM call..." f"Continuing with next LLM call..."
) )

View File

@ -1,8 +1,9 @@
"""Azure OpenAI Provider for Frigate AI.""" """Azure OpenAI Provider for Frigate AI."""
import base64 import base64
import json
import logging import logging
from typing import Optional from typing import Any, Optional
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from openai import AzureOpenAI from openai import AzureOpenAI
@ -75,3 +76,93 @@ class OpenAIClient(GenAIClient):
def get_context_size(self) -> int: def get_context_size(self) -> int:
"""Get the context window size for Azure OpenAI.""" """Get the context window size for Azure OpenAI."""
return 128000 return 128000
def chat_with_tools(
self,
messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto",
) -> dict[str, Any]:
try:
openai_tool_choice = None
if tool_choice:
if tool_choice == "none":
openai_tool_choice = "none"
elif tool_choice == "auto":
openai_tool_choice = "auto"
elif tool_choice == "required":
openai_tool_choice = "required"
request_params = {
"model": self.genai_config.model,
"messages": messages,
"timeout": self.timeout,
}
if tools:
request_params["tools"] = tools
if openai_tool_choice is not None:
request_params["tool_choice"] = openai_tool_choice
result = self.provider.chat.completions.create(**request_params)
if (
result is None
or not hasattr(result, "choices")
or len(result.choices) == 0
):
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
choice = result.choices[0]
message = choice.message
content = message.content.strip() if message.content else None
tool_calls = None
if message.tool_calls:
tool_calls = []
for tool_call in message.tool_calls:
try:
arguments = json.loads(tool_call.function.arguments)
except (json.JSONDecodeError, AttributeError) as e:
logger.warning(
f"Failed to parse tool call arguments: {e}, "
f"tool: {tool_call.function.name if hasattr(tool_call.function, 'name') else 'unknown'}"
)
arguments = {}
tool_calls.append(
{
"id": tool_call.id if hasattr(tool_call, "id") else "",
"name": tool_call.function.name
if hasattr(tool_call.function, "name")
else "",
"arguments": arguments,
}
)
finish_reason = "error"
if hasattr(choice, "finish_reason") and choice.finish_reason:
finish_reason = choice.finish_reason
elif tool_calls:
finish_reason = "tool_calls"
elif content:
finish_reason = "stop"
return {
"content": content,
"tool_calls": tool_calls,
"finish_reason": finish_reason,
}
except Exception as e:
logger.warning("Azure OpenAI returned an error: %s", str(e))
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}

View File

@ -1,7 +1,8 @@
"""Gemini Provider for Frigate AI.""" """Gemini Provider for Frigate AI."""
import json
import logging import logging
from typing import Optional from typing import Any, Optional
import google.generativeai as genai import google.generativeai as genai
from google.api_core.exceptions import GoogleAPICallError from google.api_core.exceptions import GoogleAPICallError
@ -58,3 +59,189 @@ class GeminiClient(GenAIClient):
"""Get the context window size for Gemini.""" """Get the context window size for Gemini."""
# Gemini Pro Vision has a 1M token context window # Gemini Pro Vision has a 1M token context window
return 1000000 return 1000000
def chat_with_tools(
self,
messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto",
) -> dict[str, Any]:
try:
if tools:
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func_def = tool.get("function", {})
function_declarations.append(
genai.protos.FunctionDeclaration(
name=func_def.get("name"),
description=func_def.get("description"),
parameters=genai.protos.Schema(
type=genai.protos.Type.OBJECT,
properties={
prop_name: genai.protos.Schema(
type=_convert_json_type_to_gemini(
prop.get("type")
),
description=prop.get("description"),
)
for prop_name, prop in func_def.get(
"parameters", {}
)
.get("properties", {})
.items()
},
required=func_def.get("parameters", {}).get(
"required", []
),
),
)
)
tool_config = genai.protos.Tool(
function_declarations=function_declarations
)
if tool_choice == "none":
function_calling_config = genai.protos.FunctionCallingConfig(
mode=genai.protos.FunctionCallingConfig.Mode.NONE
)
elif tool_choice == "required":
function_calling_config = genai.protos.FunctionCallingConfig(
mode=genai.protos.FunctionCallingConfig.Mode.ANY
)
else:
function_calling_config = genai.protos.FunctionCallingConfig(
mode=genai.protos.FunctionCallingConfig.Mode.AUTO
)
else:
tool_config = None
function_calling_config = None
contents = []
for msg in messages:
role = msg.get("role")
content = msg.get("content", "")
if role == "system":
continue
elif role == "user":
contents.append({"role": "user", "parts": [content]})
elif role == "assistant":
parts = [content] if content else []
if "tool_calls" in msg:
for tc in msg["tool_calls"]:
parts.append(
genai.protos.FunctionCall(
name=tc["function"]["name"],
args=json.loads(tc["function"]["arguments"]),
)
)
contents.append({"role": "model", "parts": parts})
elif role == "tool":
tool_call_id = msg.get("tool_call_id")
tool_name = msg.get("name", "")
tool_result = (
json.loads(content) if isinstance(content, str) else content
)
contents.append(
{
"role": "function",
"parts": [
genai.protos.FunctionResponse(
name=tool_name,
response=tool_result,
)
],
}
)
generation_config = genai.types.GenerationConfig(
candidate_count=1,
)
if function_calling_config:
generation_config.function_calling_config = function_calling_config
response = self.provider.generate_content(
contents,
tools=[tool_config] if tool_config else None,
generation_config=generation_config,
request_options=genai.types.RequestOptions(timeout=self.timeout),
)
content = None
tool_calls = None
if response.candidates and response.candidates[0].content:
parts = response.candidates[0].content.parts
text_parts = [p.text for p in parts if hasattr(p, "text") and p.text]
if text_parts:
content = " ".join(text_parts).strip()
function_calls = [
p.function_call
for p in parts
if hasattr(p, "function_call") and p.function_call
]
if function_calls:
tool_calls = []
for fc in function_calls:
tool_calls.append(
{
"id": f"call_{hash(fc.name)}",
"name": fc.name,
"arguments": dict(fc.args)
if hasattr(fc, "args")
else {},
}
)
finish_reason = "error"
if response.candidates:
finish_reason_map = {
genai.types.FinishReason.STOP: "stop",
genai.types.FinishReason.MAX_TOKENS: "length",
genai.types.FinishReason.SAFETY: "stop",
genai.types.FinishReason.RECITATION: "stop",
genai.types.FinishReason.OTHER: "error",
}
finish_reason = finish_reason_map.get(
response.candidates[0].finish_reason, "error"
)
elif tool_calls:
finish_reason = "tool_calls"
elif content:
finish_reason = "stop"
return {
"content": content,
"tool_calls": tool_calls,
"finish_reason": finish_reason,
}
except GoogleAPICallError as e:
logger.warning("Gemini returned an error: %s", str(e))
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
except Exception as e:
logger.warning("Unexpected error in Gemini chat_with_tools: %s", str(e))
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
def _convert_json_type_to_gemini(json_type: str) -> genai.protos.Type:
type_map = {
"string": genai.protos.Type.STRING,
"integer": genai.protos.Type.INTEGER,
"number": genai.protos.Type.NUMBER,
"boolean": genai.protos.Type.BOOLEAN,
"array": genai.protos.Type.ARRAY,
"object": genai.protos.Type.OBJECT,
}
return type_map.get(json_type, genai.protos.Type.STRING)

View File

@ -1,5 +1,6 @@
"""Ollama Provider for Frigate AI.""" """Ollama Provider for Frigate AI."""
import json
import logging import logging
from typing import Any, Optional from typing import Any, Optional
@ -77,3 +78,120 @@ class OllamaClient(GenAIClient):
return self.genai_config.provider_options.get("options", {}).get( return self.genai_config.provider_options.get("options", {}).get(
"num_ctx", 4096 "num_ctx", 4096
) )
def chat_with_tools(
self,
messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto",
) -> dict[str, Any]:
if self.provider is None:
logger.warning(
"Ollama provider has not been initialized. Check your Ollama configuration."
)
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
try:
request_messages = []
for msg in messages:
msg_dict = {
"role": msg.get("role"),
"content": msg.get("content", ""),
}
if msg.get("tool_call_id"):
msg_dict["tool_call_id"] = msg["tool_call_id"]
if msg.get("name"):
msg_dict["name"] = msg["name"]
if msg.get("tool_calls"):
msg_dict["tool_calls"] = msg["tool_calls"]
request_messages.append(msg_dict)
request_params = {
"model": self.genai_config.model,
"messages": request_messages,
}
if tools:
request_params["tools"] = tools
if tool_choice:
if tool_choice == "none":
request_params["tool_choice"] = "none"
elif tool_choice == "required":
request_params["tool_choice"] = "required"
elif tool_choice == "auto":
request_params["tool_choice"] = "auto"
request_params.update(self.provider_options)
response = self.provider.chat(**request_params)
if not response or "message" not in response:
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
message = response["message"]
content = (
message.get("content", "").strip() if message.get("content") else None
)
tool_calls = None
if "tool_calls" in message and message["tool_calls"]:
tool_calls = []
for tool_call in message["tool_calls"]:
try:
function_data = tool_call.get("function", {})
arguments_str = function_data.get("arguments", "{}")
arguments = json.loads(arguments_str)
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.warning(
f"Failed to parse tool call arguments: {e}, "
f"tool: {function_data.get('name', 'unknown')}"
)
arguments = {}
tool_calls.append(
{
"id": tool_call.get("id", ""),
"name": function_data.get("name", ""),
"arguments": arguments,
}
)
finish_reason = "error"
if "done" in response and response["done"]:
if tool_calls:
finish_reason = "tool_calls"
elif content:
finish_reason = "stop"
elif tool_calls:
finish_reason = "tool_calls"
elif content:
finish_reason = "stop"
return {
"content": content,
"tool_calls": tool_calls,
"finish_reason": finish_reason,
}
except (TimeoutException, ResponseError, ConnectionError) as e:
logger.warning("Ollama returned an error: %s", str(e))
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
except Exception as e:
logger.warning("Unexpected error in Ollama chat_with_tools: %s", str(e))
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}