Implement other providers

This commit is contained in:
Nicolas Mowen 2026-01-20 07:57:11 -07:00
parent 3acd12bc56
commit 4eeac987b8
4 changed files with 407 additions and 15 deletions

View File

@ -306,7 +306,7 @@ Always be accurate with time calculations based on the current date provided."""
tool_iterations = 0
max_iterations = body.max_tool_iterations
logger.info(
logger.debug(
f"Starting chat completion with {len(conversation)} message(s), "
f"{len(tools)} tool(s) available, max_iterations={max_iterations}"
)
@ -352,7 +352,7 @@ Always be accurate with time calculations based on the current date provided."""
tool_calls = response.get("tool_calls")
if not tool_calls:
logger.info(
logger.debug(
f"Chat completion finished with final answer (iterations: {tool_iterations})"
)
return JSONResponse(
@ -369,7 +369,7 @@ Always be accurate with time calculations based on the current date provided."""
# Execute tools
tool_iterations += 1
logger.info(
logger.debug(
f"Tool calls detected (iteration {tool_iterations}/{max_iterations}): "
f"{len(tool_calls)} tool(s) to execute"
)
@ -380,7 +380,7 @@ Always be accurate with time calculations based on the current date provided."""
tool_args = tool_call["arguments"]
tool_call_id = tool_call["id"]
logger.info(
logger.debug(
f"Executing tool: {tool_name} (id: {tool_call_id}) with arguments: {json.dumps(tool_args, indent=2)}"
)
@ -402,19 +402,19 @@ Always be accurate with time calculations based on the current date provided."""
if result_count > 0
else [],
}
logger.info(
logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
f"Result: {json.dumps(result_summary, indent=2)}"
)
elif isinstance(tool_result, str):
result_content = tool_result
logger.info(
logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
f"Result length: {len(result_content)} characters"
)
else:
result_content = str(tool_result)
logger.info(
logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) completed successfully. "
f"Result type: {type(tool_result).__name__}"
)
@ -441,16 +441,12 @@ Always be accurate with time calculations based on the current date provided."""
"content": error_content,
}
)
logger.info(
logger.debug(
f"Tool {tool_name} (id: {tool_call_id}) failed. Error result added to conversation."
)
conversation.extend(tool_results)
logger.info(
f"Added {len(tool_results)} tool result(s) to conversation. "
f"Continuing with next LLM call..."
)
logger.info(
logger.debug(
f"Added {len(tool_results)} tool result(s) to conversation. "
f"Continuing with next LLM call..."
)

View File

@ -1,8 +1,9 @@
"""Azure OpenAI Provider for Frigate AI."""
import base64
import json
import logging
from typing import Optional
from typing import Any, Optional
from urllib.parse import parse_qs, urlparse
from openai import AzureOpenAI
@ -75,3 +76,93 @@ class OpenAIClient(GenAIClient):
def get_context_size(self) -> int:
"""Get the context window size for Azure OpenAI."""
return 128000
def chat_with_tools(
self,
messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto",
) -> dict[str, Any]:
try:
openai_tool_choice = None
if tool_choice:
if tool_choice == "none":
openai_tool_choice = "none"
elif tool_choice == "auto":
openai_tool_choice = "auto"
elif tool_choice == "required":
openai_tool_choice = "required"
request_params = {
"model": self.genai_config.model,
"messages": messages,
"timeout": self.timeout,
}
if tools:
request_params["tools"] = tools
if openai_tool_choice is not None:
request_params["tool_choice"] = openai_tool_choice
result = self.provider.chat.completions.create(**request_params)
if (
result is None
or not hasattr(result, "choices")
or len(result.choices) == 0
):
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
choice = result.choices[0]
message = choice.message
content = message.content.strip() if message.content else None
tool_calls = None
if message.tool_calls:
tool_calls = []
for tool_call in message.tool_calls:
try:
arguments = json.loads(tool_call.function.arguments)
except (json.JSONDecodeError, AttributeError) as e:
logger.warning(
f"Failed to parse tool call arguments: {e}, "
f"tool: {tool_call.function.name if hasattr(tool_call.function, 'name') else 'unknown'}"
)
arguments = {}
tool_calls.append(
{
"id": tool_call.id if hasattr(tool_call, "id") else "",
"name": tool_call.function.name
if hasattr(tool_call.function, "name")
else "",
"arguments": arguments,
}
)
finish_reason = "error"
if hasattr(choice, "finish_reason") and choice.finish_reason:
finish_reason = choice.finish_reason
elif tool_calls:
finish_reason = "tool_calls"
elif content:
finish_reason = "stop"
return {
"content": content,
"tool_calls": tool_calls,
"finish_reason": finish_reason,
}
except Exception as e:
logger.warning("Azure OpenAI returned an error: %s", str(e))
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}

View File

@ -1,7 +1,8 @@
"""Gemini Provider for Frigate AI."""
import json
import logging
from typing import Optional
from typing import Any, Optional
import google.generativeai as genai
from google.api_core.exceptions import GoogleAPICallError
@ -58,3 +59,189 @@ class GeminiClient(GenAIClient):
"""Get the context window size for Gemini."""
# Gemini Pro Vision has a 1M token context window
return 1000000
def chat_with_tools(
self,
messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto",
) -> dict[str, Any]:
try:
if tools:
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func_def = tool.get("function", {})
function_declarations.append(
genai.protos.FunctionDeclaration(
name=func_def.get("name"),
description=func_def.get("description"),
parameters=genai.protos.Schema(
type=genai.protos.Type.OBJECT,
properties={
prop_name: genai.protos.Schema(
type=_convert_json_type_to_gemini(
prop.get("type")
),
description=prop.get("description"),
)
for prop_name, prop in func_def.get(
"parameters", {}
)
.get("properties", {})
.items()
},
required=func_def.get("parameters", {}).get(
"required", []
),
),
)
)
tool_config = genai.protos.Tool(
function_declarations=function_declarations
)
if tool_choice == "none":
function_calling_config = genai.protos.FunctionCallingConfig(
mode=genai.protos.FunctionCallingConfig.Mode.NONE
)
elif tool_choice == "required":
function_calling_config = genai.protos.FunctionCallingConfig(
mode=genai.protos.FunctionCallingConfig.Mode.ANY
)
else:
function_calling_config = genai.protos.FunctionCallingConfig(
mode=genai.protos.FunctionCallingConfig.Mode.AUTO
)
else:
tool_config = None
function_calling_config = None
contents = []
for msg in messages:
role = msg.get("role")
content = msg.get("content", "")
if role == "system":
continue
elif role == "user":
contents.append({"role": "user", "parts": [content]})
elif role == "assistant":
parts = [content] if content else []
if "tool_calls" in msg:
for tc in msg["tool_calls"]:
parts.append(
genai.protos.FunctionCall(
name=tc["function"]["name"],
args=json.loads(tc["function"]["arguments"]),
)
)
contents.append({"role": "model", "parts": parts})
elif role == "tool":
tool_call_id = msg.get("tool_call_id")
tool_name = msg.get("name", "")
tool_result = (
json.loads(content) if isinstance(content, str) else content
)
contents.append(
{
"role": "function",
"parts": [
genai.protos.FunctionResponse(
name=tool_name,
response=tool_result,
)
],
}
)
generation_config = genai.types.GenerationConfig(
candidate_count=1,
)
if function_calling_config:
generation_config.function_calling_config = function_calling_config
response = self.provider.generate_content(
contents,
tools=[tool_config] if tool_config else None,
generation_config=generation_config,
request_options=genai.types.RequestOptions(timeout=self.timeout),
)
content = None
tool_calls = None
if response.candidates and response.candidates[0].content:
parts = response.candidates[0].content.parts
text_parts = [p.text for p in parts if hasattr(p, "text") and p.text]
if text_parts:
content = " ".join(text_parts).strip()
function_calls = [
p.function_call
for p in parts
if hasattr(p, "function_call") and p.function_call
]
if function_calls:
tool_calls = []
for fc in function_calls:
tool_calls.append(
{
"id": f"call_{hash(fc.name)}",
"name": fc.name,
"arguments": dict(fc.args)
if hasattr(fc, "args")
else {},
}
)
finish_reason = "error"
if response.candidates:
finish_reason_map = {
genai.types.FinishReason.STOP: "stop",
genai.types.FinishReason.MAX_TOKENS: "length",
genai.types.FinishReason.SAFETY: "stop",
genai.types.FinishReason.RECITATION: "stop",
genai.types.FinishReason.OTHER: "error",
}
finish_reason = finish_reason_map.get(
response.candidates[0].finish_reason, "error"
)
elif tool_calls:
finish_reason = "tool_calls"
elif content:
finish_reason = "stop"
return {
"content": content,
"tool_calls": tool_calls,
"finish_reason": finish_reason,
}
except GoogleAPICallError as e:
logger.warning("Gemini returned an error: %s", str(e))
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
except Exception as e:
logger.warning("Unexpected error in Gemini chat_with_tools: %s", str(e))
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
def _convert_json_type_to_gemini(json_type: str) -> genai.protos.Type:
type_map = {
"string": genai.protos.Type.STRING,
"integer": genai.protos.Type.INTEGER,
"number": genai.protos.Type.NUMBER,
"boolean": genai.protos.Type.BOOLEAN,
"array": genai.protos.Type.ARRAY,
"object": genai.protos.Type.OBJECT,
}
return type_map.get(json_type, genai.protos.Type.STRING)

View File

@ -1,5 +1,6 @@
"""Ollama Provider for Frigate AI."""
import json
import logging
from typing import Any, Optional
@ -77,3 +78,120 @@ class OllamaClient(GenAIClient):
return self.genai_config.provider_options.get("options", {}).get(
"num_ctx", 4096
)
def chat_with_tools(
self,
messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto",
) -> dict[str, Any]:
if self.provider is None:
logger.warning(
"Ollama provider has not been initialized. Check your Ollama configuration."
)
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
try:
request_messages = []
for msg in messages:
msg_dict = {
"role": msg.get("role"),
"content": msg.get("content", ""),
}
if msg.get("tool_call_id"):
msg_dict["tool_call_id"] = msg["tool_call_id"]
if msg.get("name"):
msg_dict["name"] = msg["name"]
if msg.get("tool_calls"):
msg_dict["tool_calls"] = msg["tool_calls"]
request_messages.append(msg_dict)
request_params = {
"model": self.genai_config.model,
"messages": request_messages,
}
if tools:
request_params["tools"] = tools
if tool_choice:
if tool_choice == "none":
request_params["tool_choice"] = "none"
elif tool_choice == "required":
request_params["tool_choice"] = "required"
elif tool_choice == "auto":
request_params["tool_choice"] = "auto"
request_params.update(self.provider_options)
response = self.provider.chat(**request_params)
if not response or "message" not in response:
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
message = response["message"]
content = (
message.get("content", "").strip() if message.get("content") else None
)
tool_calls = None
if "tool_calls" in message and message["tool_calls"]:
tool_calls = []
for tool_call in message["tool_calls"]:
try:
function_data = tool_call.get("function", {})
arguments_str = function_data.get("arguments", "{}")
arguments = json.loads(arguments_str)
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.warning(
f"Failed to parse tool call arguments: {e}, "
f"tool: {function_data.get('name', 'unknown')}"
)
arguments = {}
tool_calls.append(
{
"id": tool_call.get("id", ""),
"name": function_data.get("name", ""),
"arguments": arguments,
}
)
finish_reason = "error"
if "done" in response and response["done"]:
if tool_calls:
finish_reason = "tool_calls"
elif content:
finish_reason = "stop"
elif tool_calls:
finish_reason = "tool_calls"
elif content:
finish_reason = "stop"
return {
"content": content,
"tool_calls": tool_calls,
"finish_reason": finish_reason,
}
except (TimeoutException, ResponseError, ConnectionError) as e:
logger.warning("Ollama returned an error: %s", str(e))
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
except Exception as e:
logger.warning("Unexpected error in Ollama chat_with_tools: %s", str(e))
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}