mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-10 10:33:11 +03:00
Implement streaming for other providers
This commit is contained in:
parent
7abf0ab1eb
commit
264089974e
@ -167,3 +167,123 @@ class OpenAIClient(GenAIClient):
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
}
|
||||
|
||||
async def chat_with_tools_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = "auto",
|
||||
):
|
||||
"""
|
||||
Stream chat with tools; yields content deltas then final message.
|
||||
|
||||
Implements streaming function calling/tool usage for Azure OpenAI models.
|
||||
"""
|
||||
try:
|
||||
openai_tool_choice = None
|
||||
if tool_choice:
|
||||
if tool_choice == "none":
|
||||
openai_tool_choice = "none"
|
||||
elif tool_choice == "auto":
|
||||
openai_tool_choice = "auto"
|
||||
elif tool_choice == "required":
|
||||
openai_tool_choice = "required"
|
||||
|
||||
request_params = {
|
||||
"model": self.genai_config.model,
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
if openai_tool_choice is not None:
|
||||
request_params["tool_choice"] = openai_tool_choice
|
||||
|
||||
# Use streaming API
|
||||
content_parts: list[str] = []
|
||||
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
stream = self.provider.chat.completions.create(**request_params)
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
# Check for finish reason
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
# Extract content deltas
|
||||
if delta.content:
|
||||
content_parts.append(delta.content)
|
||||
yield ("content_delta", delta.content)
|
||||
|
||||
# Extract tool calls
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
fn = tc.function
|
||||
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": tc.id or "",
|
||||
"name": fn.name if fn and fn.name else "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
t = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
t["id"] = tc.id
|
||||
if fn and fn.name:
|
||||
t["name"] = fn.name
|
||||
if fn and fn.arguments:
|
||||
t["arguments"] += fn.arguments
|
||||
|
||||
# Build final message
|
||||
full_content = "".join(content_parts).strip() or None
|
||||
|
||||
# Convert tool calls to list format
|
||||
tool_calls_list = None
|
||||
if tool_calls_by_index:
|
||||
tool_calls_list = []
|
||||
for tc in tool_calls_by_index.values():
|
||||
try:
|
||||
# Parse accumulated arguments as JSON
|
||||
parsed_args = json.loads(tc["arguments"])
|
||||
except (json.JSONDecodeError, Exception):
|
||||
parsed_args = tc["arguments"]
|
||||
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"arguments": parsed_args,
|
||||
}
|
||||
)
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": full_content,
|
||||
"tool_calls": tool_calls_list,
|
||||
"finish_reason": finish_reason,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Azure OpenAI streaming returned an error: %s", str(e))
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
},
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Gemini Provider for Frigate AI."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -273,3 +274,239 @@ class GeminiClient(GenAIClient):
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
}
|
||||
|
||||
async def chat_with_tools_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = "auto",
|
||||
):
|
||||
"""
|
||||
Stream chat with tools; yields content deltas then final message.
|
||||
|
||||
Implements streaming function calling/tool usage for Gemini models.
|
||||
"""
|
||||
try:
|
||||
# Convert messages to Gemini format
|
||||
gemini_messages = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Map roles to Gemini format
|
||||
if role == "system":
|
||||
# Gemini doesn't have system role, prepend to first user message
|
||||
if gemini_messages and gemini_messages[0].role == "user":
|
||||
gemini_messages[0].parts[
|
||||
0
|
||||
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}"
|
||||
else:
|
||||
gemini_messages.append(
|
||||
types.Content(
|
||||
role="user", parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
)
|
||||
elif role == "assistant":
|
||||
gemini_messages.append(
|
||||
types.Content(
|
||||
role="model", parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
)
|
||||
elif role == "tool":
|
||||
# Handle tool response
|
||||
function_response = {
|
||||
"name": msg.get("name", ""),
|
||||
"response": content,
|
||||
}
|
||||
gemini_messages.append(
|
||||
types.Content(
|
||||
role="function",
|
||||
parts=[
|
||||
types.Part.from_function_response(function_response)
|
||||
],
|
||||
)
|
||||
)
|
||||
else: # user
|
||||
gemini_messages.append(
|
||||
types.Content(
|
||||
role="user", parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
)
|
||||
|
||||
# Convert tools to Gemini format
|
||||
gemini_tools = None
|
||||
if tools:
|
||||
gemini_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool.get("function", {})
|
||||
gemini_tools.append(
|
||||
types.Tool(
|
||||
function_declarations=[
|
||||
types.FunctionDeclaration(
|
||||
name=func.get("name", ""),
|
||||
description=func.get("description", ""),
|
||||
parameters=func.get("parameters", {}),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Configure tool choice
|
||||
tool_config = None
|
||||
if tool_choice:
|
||||
if tool_choice == "none":
|
||||
tool_config = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(mode="NONE")
|
||||
)
|
||||
elif tool_choice == "auto":
|
||||
tool_config = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(mode="AUTO")
|
||||
)
|
||||
elif tool_choice == "required":
|
||||
tool_config = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(mode="ANY")
|
||||
)
|
||||
|
||||
# Build request config
|
||||
config_params = {"candidate_count": 1}
|
||||
|
||||
if gemini_tools:
|
||||
config_params["tools"] = gemini_tools
|
||||
|
||||
if tool_config:
|
||||
config_params["tool_config"] = tool_config
|
||||
|
||||
# Merge runtime_options
|
||||
if isinstance(self.genai_config.runtime_options, dict):
|
||||
config_params.update(self.genai_config.runtime_options)
|
||||
|
||||
# Use streaming API
|
||||
content_parts: list[str] = []
|
||||
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
response = self.provider.models.generate_content_stream(
|
||||
model=self.genai_config.model,
|
||||
contents=gemini_messages,
|
||||
config=types.GenerateContentConfig(**config_params),
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
if not chunk or not chunk.candidates:
|
||||
continue
|
||||
|
||||
candidate = chunk.candidates[0]
|
||||
|
||||
# Check for finish reason
|
||||
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
|
||||
from google.genai.types import FinishReason
|
||||
|
||||
if candidate.finish_reason == FinishReason.STOP:
|
||||
finish_reason = "stop"
|
||||
elif candidate.finish_reason == FinishReason.MAX_TOKENS:
|
||||
finish_reason = "length"
|
||||
elif candidate.finish_reason in [
|
||||
FinishReason.SAFETY,
|
||||
FinishReason.RECITATION,
|
||||
]:
|
||||
finish_reason = "error"
|
||||
|
||||
# Extract content and tool calls from chunk
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if part.text:
|
||||
content_parts.append(part.text)
|
||||
yield ("content_delta", part.text)
|
||||
elif part.function_call:
|
||||
# Handle function call
|
||||
try:
|
||||
arguments = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {}
|
||||
)
|
||||
except Exception:
|
||||
arguments = {}
|
||||
|
||||
# Store tool call
|
||||
tool_call_id = part.function_call.name or ""
|
||||
tool_call_name = part.function_call.name or ""
|
||||
|
||||
# Check if we already have this tool call
|
||||
found_index = None
|
||||
for idx, tc in tool_calls_by_index.items():
|
||||
if tc["name"] == tool_call_name:
|
||||
found_index = idx
|
||||
break
|
||||
|
||||
if found_index is None:
|
||||
found_index = len(tool_calls_by_index)
|
||||
tool_calls_by_index[found_index] = {
|
||||
"id": tool_call_id,
|
||||
"name": tool_call_name,
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
# Accumulate arguments
|
||||
if arguments:
|
||||
tool_calls_by_index[found_index]["arguments"] += (
|
||||
json.dumps(arguments)
|
||||
if isinstance(arguments, dict)
|
||||
else str(arguments)
|
||||
)
|
||||
|
||||
# Build final message
|
||||
full_content = "".join(content_parts).strip() or None
|
||||
|
||||
# Convert tool calls to list format
|
||||
tool_calls_list = None
|
||||
if tool_calls_by_index:
|
||||
tool_calls_list = []
|
||||
for tc in tool_calls_by_index.values():
|
||||
try:
|
||||
# Try to parse accumulated arguments as JSON
|
||||
parsed_args = json.loads(tc["arguments"])
|
||||
except (json.JSONDecodeError, Exception):
|
||||
parsed_args = tc["arguments"]
|
||||
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"arguments": parsed_args,
|
||||
}
|
||||
)
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": full_content,
|
||||
"tool_calls": tool_calls_list,
|
||||
"finish_reason": finish_reason,
|
||||
},
|
||||
)
|
||||
|
||||
except errors.APIError as e:
|
||||
logger.warning("Gemini API error during streaming: %s", str(e))
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Gemini returned an error during chat_with_tools_stream: %s", str(e)
|
||||
)
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
},
|
||||
)
|
||||
|
||||
@ -227,3 +227,142 @@ class OpenAIClient(GenAIClient):
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
}
|
||||
|
||||
async def chat_with_tools_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = "auto",
|
||||
):
|
||||
"""
|
||||
Stream chat with tools; yields content deltas then final message.
|
||||
|
||||
Implements streaming function calling/tool usage for OpenAI models.
|
||||
"""
|
||||
try:
|
||||
openai_tool_choice = None
|
||||
if tool_choice:
|
||||
if tool_choice == "none":
|
||||
openai_tool_choice = "none"
|
||||
elif tool_choice == "auto":
|
||||
openai_tool_choice = "auto"
|
||||
elif tool_choice == "required":
|
||||
openai_tool_choice = "required"
|
||||
|
||||
request_params = {
|
||||
"model": self.genai_config.model,
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
if openai_tool_choice is not None:
|
||||
request_params["tool_choice"] = openai_tool_choice
|
||||
|
||||
if isinstance(self.genai_config.provider_options, dict):
|
||||
excluded_options = {"context_size"}
|
||||
provider_opts = {
|
||||
k: v
|
||||
for k, v in self.genai_config.provider_options.items()
|
||||
if k not in excluded_options
|
||||
}
|
||||
request_params.update(provider_opts)
|
||||
|
||||
# Use streaming API
|
||||
content_parts: list[str] = []
|
||||
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
stream = self.provider.chat.completions.create(**request_params)
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk or not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
# Check for finish reason
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
# Extract content deltas
|
||||
if delta.content:
|
||||
content_parts.append(delta.content)
|
||||
yield ("content_delta", delta.content)
|
||||
|
||||
# Extract tool calls
|
||||
if delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
fn = tc.function
|
||||
|
||||
if idx not in tool_calls_by_index:
|
||||
tool_calls_by_index[idx] = {
|
||||
"id": tc.id or "",
|
||||
"name": fn.name if fn and fn.name else "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
t = tool_calls_by_index[idx]
|
||||
if tc.id:
|
||||
t["id"] = tc.id
|
||||
if fn and fn.name:
|
||||
t["name"] = fn.name
|
||||
if fn and fn.arguments:
|
||||
t["arguments"] += fn.arguments
|
||||
|
||||
# Build final message
|
||||
full_content = "".join(content_parts).strip() or None
|
||||
|
||||
# Convert tool calls to list format
|
||||
tool_calls_list = None
|
||||
if tool_calls_by_index:
|
||||
tool_calls_list = []
|
||||
for tc in tool_calls_by_index.values():
|
||||
try:
|
||||
# Parse accumulated arguments as JSON
|
||||
parsed_args = json.loads(tc["arguments"])
|
||||
except (json.JSONDecodeError, Exception):
|
||||
parsed_args = tc["arguments"]
|
||||
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"arguments": parsed_args,
|
||||
}
|
||||
)
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": full_content,
|
||||
"tool_calls": tool_calls_list,
|
||||
"finish_reason": finish_reason,
|
||||
},
|
||||
)
|
||||
|
||||
except TimeoutException as e:
|
||||
logger.warning("OpenAI streaming request timed out: %s", str(e))
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("OpenAI streaming returned an error: %s", str(e))
|
||||
yield (
|
||||
"message",
|
||||
{
|
||||
"content": None,
|
||||
"tool_calls": None,
|
||||
"finish_reason": "error",
|
||||
},
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user