Adapt to new Gemini format

This commit is contained in:
Nicolas Mowen 2026-02-25 09:19:56 -07:00
parent 84760c42cb
commit 5f02e33e55

View File

@ -1,6 +1,5 @@
"""Gemini Provider for Frigate AI.""" """Gemini Provider for Frigate AI."""
import json
import logging import logging
from typing import Any, Optional from typing import Any, Optional
@ -84,147 +83,169 @@ class GeminiClient(GenAIClient):
tools: Optional[list[dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto", tool_choice: Optional[str] = "auto",
) -> dict[str, Any]: ) -> dict[str, Any]:
"""
Send chat messages to Gemini with optional tool definitions.
Implements function calling/tool usage for Gemini models.
"""
try: try:
# Convert messages to Gemini format
gemini_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
# Map roles to Gemini format
if role == "system":
# Gemini doesn't have system role, prepend to first user message
if gemini_messages and gemini_messages[0].role == "user":
gemini_messages[0].parts[
0
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}"
else:
gemini_messages.append(
types.Content(
role="user", parts=[types.Part.from_text(text=content)]
)
)
elif role == "assistant":
gemini_messages.append(
types.Content(
role="model", parts=[types.Part.from_text(text=content)]
)
)
elif role == "tool":
# Handle tool response
function_response = {
"name": msg.get("name", ""),
"response": content,
}
gemini_messages.append(
types.Content(
role="function",
parts=[
types.Part.from_function_response(function_response)
],
)
)
else: # user
gemini_messages.append(
types.Content(
role="user", parts=[types.Part.from_text(text=content)]
)
)
# Convert tools to Gemini format
gemini_tools = None
if tools: if tools:
function_declarations = [] gemini_tools = []
for tool in tools: for tool in tools:
if tool.get("type") == "function": if tool.get("type") == "function":
func_def = tool.get("function", {}) func = tool.get("function", {})
function_declarations.append( gemini_tools.append(
genai.protos.FunctionDeclaration( types.Tool(
name=func_def.get("name"), function_declarations=[
description=func_def.get("description"), types.FunctionDeclaration(
parameters=genai.protos.Schema( name=func.get("name", ""),
type=genai.protos.Type.OBJECT, description=func.get("description", ""),
properties={ parameters=func.get("parameters", {}),
prop_name: genai.protos.Schema( )
type=_convert_json_type_to_gemini( ]
prop.get("type")
),
description=prop.get("description"),
)
for prop_name, prop in func_def.get(
"parameters", {}
)
.get("properties", {})
.items()
},
required=func_def.get("parameters", {}).get(
"required", []
),
),
) )
) )
tool_config = genai.protos.Tool( # Configure tool choice
function_declarations=function_declarations tool_config = None
) if tool_choice:
if tool_choice == "none": if tool_choice == "none":
function_calling_config = genai.protos.FunctionCallingConfig( tool_config = types.ToolConfig(
mode=genai.protos.FunctionCallingConfig.Mode.NONE function_calling_config=types.FunctionCallingConfig(mode="NONE")
)
elif tool_choice == "auto":
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(mode="AUTO")
) )
elif tool_choice == "required": elif tool_choice == "required":
function_calling_config = genai.protos.FunctionCallingConfig( tool_config = types.ToolConfig(
mode=genai.protos.FunctionCallingConfig.Mode.ANY function_calling_config=types.FunctionCallingConfig(mode="ANY")
)
else:
function_calling_config = genai.protos.FunctionCallingConfig(
mode=genai.protos.FunctionCallingConfig.Mode.AUTO
)
else:
tool_config = None
function_calling_config = None
contents = []
for msg in messages:
role = msg.get("role")
content = msg.get("content", "")
if role == "system":
continue
elif role == "user":
contents.append({"role": "user", "parts": [content]})
elif role == "assistant":
parts = [content] if content else []
if "tool_calls" in msg:
for tc in msg["tool_calls"]:
parts.append(
genai.protos.FunctionCall(
name=tc["function"]["name"],
args=json.loads(tc["function"]["arguments"]),
)
)
contents.append({"role": "model", "parts": parts})
elif role == "tool":
tool_name = msg.get("name", "")
tool_result = (
json.loads(content) if isinstance(content, str) else content
)
contents.append(
{
"role": "function",
"parts": [
genai.protos.FunctionResponse(
name=tool_name,
response=tool_result,
)
],
}
) )
generation_config = genai.types.GenerationConfig( # Build request config
candidate_count=1, config_params = {"candidate_count": 1}
)
if function_calling_config:
generation_config.function_calling_config = function_calling_config
response = self.provider.generate_content( if gemini_tools:
contents, config_params["tools"] = gemini_tools
tools=[tool_config] if tool_config else None,
generation_config=generation_config, if tool_config:
request_options=genai.types.RequestOptions(timeout=self.timeout), config_params["tool_config"] = tool_config
# Merge runtime_options
if isinstance(self.genai_config.runtime_options, dict):
config_params.update(self.genai_config.runtime_options)
response = self.provider.models.generate_content(
model=self.genai_config.model,
contents=gemini_messages,
config=types.GenerateContentConfig(**config_params),
) )
# Check if response is valid
if not response or not response.candidates:
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
candidate = response.candidates[0]
content = None content = None
tool_calls = None tool_calls = None
if response.candidates and response.candidates[0].content: # Extract content and tool calls from response
parts = response.candidates[0].content.parts if candidate.content and candidate.content.parts:
text_parts = [p.text for p in parts if hasattr(p, "text") and p.text] for part in candidate.content.parts:
if text_parts: if part.text:
content = " ".join(text_parts).strip() content = part.text.strip()
elif part.function_call:
# Handle function call
if tool_calls is None:
tool_calls = []
try:
arguments = (
dict(part.function_call.args)
if part.function_call.args
else {}
)
except Exception:
arguments = {}
function_calls = [
p.function_call
for p in parts
if hasattr(p, "function_call") and p.function_call
]
if function_calls:
tool_calls = []
for fc in function_calls:
tool_calls.append( tool_calls.append(
{ {
"id": f"call_{hash(fc.name)}", "id": part.function_call.name or "",
"name": fc.name, "name": part.function_call.name or "",
"arguments": dict(fc.args) "arguments": arguments,
if hasattr(fc, "args")
else {},
} }
) )
# Determine finish reason
finish_reason = "error" finish_reason = "error"
if response.candidates: if hasattr(candidate, "finish_reason") and candidate.finish_reason:
finish_reason_map = { from google.genai.types import FinishReason
genai.types.FinishReason.STOP: "stop",
genai.types.FinishReason.MAX_TOKENS: "length", if candidate.finish_reason == FinishReason.STOP:
genai.types.FinishReason.SAFETY: "stop", finish_reason = "stop"
genai.types.FinishReason.RECITATION: "stop", elif candidate.finish_reason == FinishReason.MAX_TOKENS:
genai.types.FinishReason.OTHER: "error", finish_reason = "length"
} elif candidate.finish_reason in [
finish_reason = finish_reason_map.get( FinishReason.SAFETY,
response.candidates[0].finish_reason, "error" FinishReason.RECITATION,
) ]:
finish_reason = "error"
elif tool_calls:
finish_reason = "tool_calls"
elif content:
finish_reason = "stop"
elif tool_calls: elif tool_calls:
finish_reason = "tool_calls" finish_reason = "tool_calls"
elif content: elif content:
@ -236,29 +257,19 @@ class GeminiClient(GenAIClient):
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
except GoogleAPICallError as e: except errors.APIError as e:
logger.warning("Gemini returned an error: %s", str(e)) logger.warning("Gemini API error during chat_with_tools: %s", str(e))
return { return {
"content": None, "content": None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "error", "finish_reason": "error",
} }
except Exception as e: except Exception as e:
logger.warning("Unexpected error in Gemini chat_with_tools: %s", str(e)) logger.warning(
"Gemini returned an error during chat_with_tools: %s", str(e)
)
return { return {
"content": None, "content": None,
"tool_calls": None, "tool_calls": None,
"finish_reason": "error", "finish_reason": "error",
} }
def _convert_json_type_to_gemini(json_type: str) -> genai.protos.Type:
type_map = {
"string": genai.protos.Type.STRING,
"integer": genai.protos.Type.INTEGER,
"number": genai.protos.Type.NUMBER,
"boolean": genai.protos.Type.BOOLEAN,
"array": genai.protos.Type.ARRAY,
"object": genai.protos.Type.OBJECT,
}
return type_map.get(json_type, genai.protos.Type.STRING)