mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-23 16:48:23 +03:00
Implement streaming for other providers
This commit is contained in:
parent
7abf0ab1eb
commit
264089974e
@ -167,3 +167,123 @@ class OpenAIClient(GenAIClient):
|
|||||||
"tool_calls": None,
|
"tool_calls": None,
|
||||||
"finish_reason": "error",
|
"finish_reason": "error",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def chat_with_tools_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: Optional[list[dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = "auto",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream chat with tools; yields content deltas then final message.
|
||||||
|
|
||||||
|
Implements streaming function calling/tool usage for Azure OpenAI models.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
openai_tool_choice = None
|
||||||
|
if tool_choice:
|
||||||
|
if tool_choice == "none":
|
||||||
|
openai_tool_choice = "none"
|
||||||
|
elif tool_choice == "auto":
|
||||||
|
openai_tool_choice = "auto"
|
||||||
|
elif tool_choice == "required":
|
||||||
|
openai_tool_choice = "required"
|
||||||
|
|
||||||
|
request_params = {
|
||||||
|
"model": self.genai_config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
request_params["tools"] = tools
|
||||||
|
if openai_tool_choice is not None:
|
||||||
|
request_params["tool_choice"] = openai_tool_choice
|
||||||
|
|
||||||
|
# Use streaming API
|
||||||
|
content_parts: list[str] = []
|
||||||
|
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
stream = self.provider.chat.completions.create(**request_params)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if not chunk or not chunk.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
delta = choice.delta
|
||||||
|
|
||||||
|
# Check for finish reason
|
||||||
|
if choice.finish_reason:
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
|
||||||
|
# Extract content deltas
|
||||||
|
if delta.content:
|
||||||
|
content_parts.append(delta.content)
|
||||||
|
yield ("content_delta", delta.content)
|
||||||
|
|
||||||
|
# Extract tool calls
|
||||||
|
if delta.tool_calls:
|
||||||
|
for tc in delta.tool_calls:
|
||||||
|
idx = tc.index
|
||||||
|
fn = tc.function
|
||||||
|
|
||||||
|
if idx not in tool_calls_by_index:
|
||||||
|
tool_calls_by_index[idx] = {
|
||||||
|
"id": tc.id or "",
|
||||||
|
"name": fn.name if fn and fn.name else "",
|
||||||
|
"arguments": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
t = tool_calls_by_index[idx]
|
||||||
|
if tc.id:
|
||||||
|
t["id"] = tc.id
|
||||||
|
if fn and fn.name:
|
||||||
|
t["name"] = fn.name
|
||||||
|
if fn and fn.arguments:
|
||||||
|
t["arguments"] += fn.arguments
|
||||||
|
|
||||||
|
# Build final message
|
||||||
|
full_content = "".join(content_parts).strip() or None
|
||||||
|
|
||||||
|
# Convert tool calls to list format
|
||||||
|
tool_calls_list = None
|
||||||
|
if tool_calls_by_index:
|
||||||
|
tool_calls_list = []
|
||||||
|
for tc in tool_calls_by_index.values():
|
||||||
|
try:
|
||||||
|
# Parse accumulated arguments as JSON
|
||||||
|
parsed_args = json.loads(tc["arguments"])
|
||||||
|
except (json.JSONDecodeError, Exception):
|
||||||
|
parsed_args = tc["arguments"]
|
||||||
|
|
||||||
|
tool_calls_list.append(
|
||||||
|
{
|
||||||
|
"id": tc["id"],
|
||||||
|
"name": tc["name"],
|
||||||
|
"arguments": parsed_args,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
|
||||||
|
yield (
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": full_content,
|
||||||
|
"tool_calls": tool_calls_list,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Azure OpenAI streaming returned an error: %s", str(e))
|
||||||
|
yield (
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": "error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
"""Gemini Provider for Frigate AI."""
|
"""Gemini Provider for Frigate AI."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -273,3 +274,239 @@ class GeminiClient(GenAIClient):
|
|||||||
"tool_calls": None,
|
"tool_calls": None,
|
||||||
"finish_reason": "error",
|
"finish_reason": "error",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def chat_with_tools_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: Optional[list[dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = "auto",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream chat with tools; yields content deltas then final message.
|
||||||
|
|
||||||
|
Implements streaming function calling/tool usage for Gemini models.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert messages to Gemini format
|
||||||
|
gemini_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "user")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
|
||||||
|
# Map roles to Gemini format
|
||||||
|
if role == "system":
|
||||||
|
# Gemini doesn't have system role, prepend to first user message
|
||||||
|
if gemini_messages and gemini_messages[0].role == "user":
|
||||||
|
gemini_messages[0].parts[
|
||||||
|
0
|
||||||
|
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}"
|
||||||
|
else:
|
||||||
|
gemini_messages.append(
|
||||||
|
types.Content(
|
||||||
|
role="user", parts=[types.Part.from_text(text=content)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif role == "assistant":
|
||||||
|
gemini_messages.append(
|
||||||
|
types.Content(
|
||||||
|
role="model", parts=[types.Part.from_text(text=content)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif role == "tool":
|
||||||
|
# Handle tool response
|
||||||
|
function_response = {
|
||||||
|
"name": msg.get("name", ""),
|
||||||
|
"response": content,
|
||||||
|
}
|
||||||
|
gemini_messages.append(
|
||||||
|
types.Content(
|
||||||
|
role="function",
|
||||||
|
parts=[
|
||||||
|
types.Part.from_function_response(function_response)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else: # user
|
||||||
|
gemini_messages.append(
|
||||||
|
types.Content(
|
||||||
|
role="user", parts=[types.Part.from_text(text=content)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert tools to Gemini format
|
||||||
|
gemini_tools = None
|
||||||
|
if tools:
|
||||||
|
gemini_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
func = tool.get("function", {})
|
||||||
|
gemini_tools.append(
|
||||||
|
types.Tool(
|
||||||
|
function_declarations=[
|
||||||
|
types.FunctionDeclaration(
|
||||||
|
name=func.get("name", ""),
|
||||||
|
description=func.get("description", ""),
|
||||||
|
parameters=func.get("parameters", {}),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure tool choice
|
||||||
|
tool_config = None
|
||||||
|
if tool_choice:
|
||||||
|
if tool_choice == "none":
|
||||||
|
tool_config = types.ToolConfig(
|
||||||
|
function_calling_config=types.FunctionCallingConfig(mode="NONE")
|
||||||
|
)
|
||||||
|
elif tool_choice == "auto":
|
||||||
|
tool_config = types.ToolConfig(
|
||||||
|
function_calling_config=types.FunctionCallingConfig(mode="AUTO")
|
||||||
|
)
|
||||||
|
elif tool_choice == "required":
|
||||||
|
tool_config = types.ToolConfig(
|
||||||
|
function_calling_config=types.FunctionCallingConfig(mode="ANY")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build request config
|
||||||
|
config_params = {"candidate_count": 1}
|
||||||
|
|
||||||
|
if gemini_tools:
|
||||||
|
config_params["tools"] = gemini_tools
|
||||||
|
|
||||||
|
if tool_config:
|
||||||
|
config_params["tool_config"] = tool_config
|
||||||
|
|
||||||
|
# Merge runtime_options
|
||||||
|
if isinstance(self.genai_config.runtime_options, dict):
|
||||||
|
config_params.update(self.genai_config.runtime_options)
|
||||||
|
|
||||||
|
# Use streaming API
|
||||||
|
content_parts: list[str] = []
|
||||||
|
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
response = self.provider.models.generate_content_stream(
|
||||||
|
model=self.genai_config.model,
|
||||||
|
contents=gemini_messages,
|
||||||
|
config=types.GenerateContentConfig(**config_params),
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
if not chunk or not chunk.candidates:
|
||||||
|
continue
|
||||||
|
|
||||||
|
candidate = chunk.candidates[0]
|
||||||
|
|
||||||
|
# Check for finish reason
|
||||||
|
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
|
||||||
|
from google.genai.types import FinishReason
|
||||||
|
|
||||||
|
if candidate.finish_reason == FinishReason.STOP:
|
||||||
|
finish_reason = "stop"
|
||||||
|
elif candidate.finish_reason == FinishReason.MAX_TOKENS:
|
||||||
|
finish_reason = "length"
|
||||||
|
elif candidate.finish_reason in [
|
||||||
|
FinishReason.SAFETY,
|
||||||
|
FinishReason.RECITATION,
|
||||||
|
]:
|
||||||
|
finish_reason = "error"
|
||||||
|
|
||||||
|
# Extract content and tool calls from chunk
|
||||||
|
if candidate.content and candidate.content.parts:
|
||||||
|
for part in candidate.content.parts:
|
||||||
|
if part.text:
|
||||||
|
content_parts.append(part.text)
|
||||||
|
yield ("content_delta", part.text)
|
||||||
|
elif part.function_call:
|
||||||
|
# Handle function call
|
||||||
|
try:
|
||||||
|
arguments = (
|
||||||
|
dict(part.function_call.args)
|
||||||
|
if part.function_call.args
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
arguments = {}
|
||||||
|
|
||||||
|
# Store tool call
|
||||||
|
tool_call_id = part.function_call.name or ""
|
||||||
|
tool_call_name = part.function_call.name or ""
|
||||||
|
|
||||||
|
# Check if we already have this tool call
|
||||||
|
found_index = None
|
||||||
|
for idx, tc in tool_calls_by_index.items():
|
||||||
|
if tc["name"] == tool_call_name:
|
||||||
|
found_index = idx
|
||||||
|
break
|
||||||
|
|
||||||
|
if found_index is None:
|
||||||
|
found_index = len(tool_calls_by_index)
|
||||||
|
tool_calls_by_index[found_index] = {
|
||||||
|
"id": tool_call_id,
|
||||||
|
"name": tool_call_name,
|
||||||
|
"arguments": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Accumulate arguments
|
||||||
|
if arguments:
|
||||||
|
tool_calls_by_index[found_index]["arguments"] += (
|
||||||
|
json.dumps(arguments)
|
||||||
|
if isinstance(arguments, dict)
|
||||||
|
else str(arguments)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build final message
|
||||||
|
full_content = "".join(content_parts).strip() or None
|
||||||
|
|
||||||
|
# Convert tool calls to list format
|
||||||
|
tool_calls_list = None
|
||||||
|
if tool_calls_by_index:
|
||||||
|
tool_calls_list = []
|
||||||
|
for tc in tool_calls_by_index.values():
|
||||||
|
try:
|
||||||
|
# Try to parse accumulated arguments as JSON
|
||||||
|
parsed_args = json.loads(tc["arguments"])
|
||||||
|
except (json.JSONDecodeError, Exception):
|
||||||
|
parsed_args = tc["arguments"]
|
||||||
|
|
||||||
|
tool_calls_list.append(
|
||||||
|
{
|
||||||
|
"id": tc["id"],
|
||||||
|
"name": tc["name"],
|
||||||
|
"arguments": parsed_args,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
|
||||||
|
yield (
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": full_content,
|
||||||
|
"tool_calls": tool_calls_list,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except errors.APIError as e:
|
||||||
|
logger.warning("Gemini API error during streaming: %s", str(e))
|
||||||
|
yield (
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": "error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Gemini returned an error during chat_with_tools_stream: %s", str(e)
|
||||||
|
)
|
||||||
|
yield (
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": "error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
@ -227,3 +227,142 @@ class OpenAIClient(GenAIClient):
|
|||||||
"tool_calls": None,
|
"tool_calls": None,
|
||||||
"finish_reason": "error",
|
"finish_reason": "error",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def chat_with_tools_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: Optional[list[dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = "auto",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream chat with tools; yields content deltas then final message.
|
||||||
|
|
||||||
|
Implements streaming function calling/tool usage for OpenAI models.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
openai_tool_choice = None
|
||||||
|
if tool_choice:
|
||||||
|
if tool_choice == "none":
|
||||||
|
openai_tool_choice = "none"
|
||||||
|
elif tool_choice == "auto":
|
||||||
|
openai_tool_choice = "auto"
|
||||||
|
elif tool_choice == "required":
|
||||||
|
openai_tool_choice = "required"
|
||||||
|
|
||||||
|
request_params = {
|
||||||
|
"model": self.genai_config.model,
|
||||||
|
"messages": messages,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
request_params["tools"] = tools
|
||||||
|
if openai_tool_choice is not None:
|
||||||
|
request_params["tool_choice"] = openai_tool_choice
|
||||||
|
|
||||||
|
if isinstance(self.genai_config.provider_options, dict):
|
||||||
|
excluded_options = {"context_size"}
|
||||||
|
provider_opts = {
|
||||||
|
k: v
|
||||||
|
for k, v in self.genai_config.provider_options.items()
|
||||||
|
if k not in excluded_options
|
||||||
|
}
|
||||||
|
request_params.update(provider_opts)
|
||||||
|
|
||||||
|
# Use streaming API
|
||||||
|
content_parts: list[str] = []
|
||||||
|
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
stream = self.provider.chat.completions.create(**request_params)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if not chunk or not chunk.choices:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
delta = choice.delta
|
||||||
|
|
||||||
|
# Check for finish reason
|
||||||
|
if choice.finish_reason:
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
|
||||||
|
# Extract content deltas
|
||||||
|
if delta.content:
|
||||||
|
content_parts.append(delta.content)
|
||||||
|
yield ("content_delta", delta.content)
|
||||||
|
|
||||||
|
# Extract tool calls
|
||||||
|
if delta.tool_calls:
|
||||||
|
for tc in delta.tool_calls:
|
||||||
|
idx = tc.index
|
||||||
|
fn = tc.function
|
||||||
|
|
||||||
|
if idx not in tool_calls_by_index:
|
||||||
|
tool_calls_by_index[idx] = {
|
||||||
|
"id": tc.id or "",
|
||||||
|
"name": fn.name if fn and fn.name else "",
|
||||||
|
"arguments": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
t = tool_calls_by_index[idx]
|
||||||
|
if tc.id:
|
||||||
|
t["id"] = tc.id
|
||||||
|
if fn and fn.name:
|
||||||
|
t["name"] = fn.name
|
||||||
|
if fn and fn.arguments:
|
||||||
|
t["arguments"] += fn.arguments
|
||||||
|
|
||||||
|
# Build final message
|
||||||
|
full_content = "".join(content_parts).strip() or None
|
||||||
|
|
||||||
|
# Convert tool calls to list format
|
||||||
|
tool_calls_list = None
|
||||||
|
if tool_calls_by_index:
|
||||||
|
tool_calls_list = []
|
||||||
|
for tc in tool_calls_by_index.values():
|
||||||
|
try:
|
||||||
|
# Parse accumulated arguments as JSON
|
||||||
|
parsed_args = json.loads(tc["arguments"])
|
||||||
|
except (json.JSONDecodeError, Exception):
|
||||||
|
parsed_args = tc["arguments"]
|
||||||
|
|
||||||
|
tool_calls_list.append(
|
||||||
|
{
|
||||||
|
"id": tc["id"],
|
||||||
|
"name": tc["name"],
|
||||||
|
"arguments": parsed_args,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
|
||||||
|
yield (
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": full_content,
|
||||||
|
"tool_calls": tool_calls_list,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except TimeoutException as e:
|
||||||
|
logger.warning("OpenAI streaming request timed out: %s", str(e))
|
||||||
|
yield (
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": "error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("OpenAI streaming returned an error: %s", str(e))
|
||||||
|
yield (
|
||||||
|
"message",
|
||||||
|
{
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
"finish_reason": "error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user