Fix gemini tool calling

This commit is contained in:
Nicolas Mowen 2026-04-29 17:03:11 -06:00
parent 33abaaa9f8
commit b47a47c44a

View File

@ -136,11 +136,29 @@ class GeminiClient(GenAIClient):
)
)
elif role == "assistant":
gemini_messages.append(
types.Content(
role="model", parts=[types.Part.from_text(text=content)]
)
)
parts: list[types.Part] = []
if content:
parts.append(types.Part.from_text(text=content))
for tc in msg.get("tool_calls") or []:
func = tc.get("function") or {}
tc_name = func.get("name") or ""
tc_args: Any = func.get("arguments")
if isinstance(tc_args, str):
try:
tc_args = json.loads(tc_args)
except (json.JSONDecodeError, TypeError):
tc_args = {}
if not isinstance(tc_args, dict):
tc_args = {}
if tc_name:
parts.append(
types.Part.from_function_call(
name=tc_name, args=tc_args
)
)
if not parts:
parts.append(types.Part.from_text(text=" "))
gemini_messages.append(types.Content(role="model", parts=parts))
elif role == "tool":
# Handle tool response
response_payload = (
@ -151,7 +169,9 @@ class GeminiClient(GenAIClient):
role="function",
parts=[
types.Part.from_function_response(
name=msg.get("name", ""),
name=msg.get("name")
or msg.get("tool_call_id")
or "",
response=response_payload,
)
],
@ -345,11 +365,29 @@ class GeminiClient(GenAIClient):
)
)
elif role == "assistant":
gemini_messages.append(
types.Content(
role="model", parts=[types.Part.from_text(text=content)]
)
)
parts: list[types.Part] = []
if content:
parts.append(types.Part.from_text(text=content))
for tc in msg.get("tool_calls") or []:
func = tc.get("function") or {}
tc_name = func.get("name") or ""
tc_args: Any = func.get("arguments")
if isinstance(tc_args, str):
try:
tc_args = json.loads(tc_args)
except (json.JSONDecodeError, TypeError):
tc_args = {}
if not isinstance(tc_args, dict):
tc_args = {}
if tc_name:
parts.append(
types.Part.from_function_call(
name=tc_name, args=tc_args
)
)
if not parts:
parts.append(types.Part.from_text(text=" "))
gemini_messages.append(types.Content(role="model", parts=parts))
elif role == "tool":
# Handle tool response
response_payload = (
@ -360,7 +398,9 @@ class GeminiClient(GenAIClient):
role="function",
parts=[
types.Part.from_function_response(
name=msg.get("name", ""),
name=msg.get("name")
or msg.get("tool_call_id")
or "",
response=response_payload,
)
],