mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-05-01 19:17:41 +03:00
Fix gemini tool calling
This commit is contained in:
parent
33abaaa9f8
commit
b47a47c44a
@ -136,11 +136,29 @@ class GeminiClient(GenAIClient):
|
||||
)
|
||||
)
|
||||
elif role == "assistant":
|
||||
gemini_messages.append(
|
||||
types.Content(
|
||||
role="model", parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
)
|
||||
parts: list[types.Part] = []
|
||||
if content:
|
||||
parts.append(types.Part.from_text(text=content))
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
func = tc.get("function") or {}
|
||||
tc_name = func.get("name") or ""
|
||||
tc_args: Any = func.get("arguments")
|
||||
if isinstance(tc_args, str):
|
||||
try:
|
||||
tc_args = json.loads(tc_args)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
tc_args = {}
|
||||
if not isinstance(tc_args, dict):
|
||||
tc_args = {}
|
||||
if tc_name:
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=tc_name, args=tc_args
|
||||
)
|
||||
)
|
||||
if not parts:
|
||||
parts.append(types.Part.from_text(text=" "))
|
||||
gemini_messages.append(types.Content(role="model", parts=parts))
|
||||
elif role == "tool":
|
||||
# Handle tool response
|
||||
response_payload = (
|
||||
@ -151,7 +169,9 @@ class GeminiClient(GenAIClient):
|
||||
role="function",
|
||||
parts=[
|
||||
types.Part.from_function_response(
|
||||
name=msg.get("name", ""),
|
||||
name=msg.get("name")
|
||||
or msg.get("tool_call_id")
|
||||
or "",
|
||||
response=response_payload,
|
||||
)
|
||||
],
|
||||
@ -345,11 +365,29 @@ class GeminiClient(GenAIClient):
|
||||
)
|
||||
)
|
||||
elif role == "assistant":
|
||||
gemini_messages.append(
|
||||
types.Content(
|
||||
role="model", parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
)
|
||||
parts: list[types.Part] = []
|
||||
if content:
|
||||
parts.append(types.Part.from_text(text=content))
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
func = tc.get("function") or {}
|
||||
tc_name = func.get("name") or ""
|
||||
tc_args: Any = func.get("arguments")
|
||||
if isinstance(tc_args, str):
|
||||
try:
|
||||
tc_args = json.loads(tc_args)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
tc_args = {}
|
||||
if not isinstance(tc_args, dict):
|
||||
tc_args = {}
|
||||
if tc_name:
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=tc_name, args=tc_args
|
||||
)
|
||||
)
|
||||
if not parts:
|
||||
parts.append(types.Part.from_text(text=" "))
|
||||
gemini_messages.append(types.Content(role="model", parts=parts))
|
||||
elif role == "tool":
|
||||
# Handle tool response
|
||||
response_payload = (
|
||||
@ -360,7 +398,9 @@ class GeminiClient(GenAIClient):
|
||||
role="function",
|
||||
parts=[
|
||||
types.Part.from_function_response(
|
||||
name=msg.get("name", ""),
|
||||
name=msg.get("name")
|
||||
or msg.get("tool_call_id")
|
||||
or "",
|
||||
response=response_payload,
|
||||
)
|
||||
],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user