Implement streaming for other providers

This commit is contained in:
Nicolas Mowen 2026-02-28 07:38:00 -07:00
parent 7abf0ab1eb
commit 264089974e
3 changed files with 496 additions and 0 deletions

View File

@ -167,3 +167,123 @@ class OpenAIClient(GenAIClient):
"tool_calls": None,
"finish_reason": "error",
}
async def chat_with_tools_stream(
self,
messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto",
):
"""
Stream chat with tools; yields content deltas then final message.
Implements streaming function calling/tool usage for Azure OpenAI models.
"""
try:
openai_tool_choice = None
if tool_choice:
if tool_choice == "none":
openai_tool_choice = "none"
elif tool_choice == "auto":
openai_tool_choice = "auto"
elif tool_choice == "required":
openai_tool_choice = "required"
request_params = {
"model": self.genai_config.model,
"messages": messages,
"timeout": self.timeout,
"stream": True,
}
if tools:
request_params["tools"] = tools
if openai_tool_choice is not None:
request_params["tool_choice"] = openai_tool_choice
# Use streaming API
content_parts: list[str] = []
tool_calls_by_index: dict[int, dict[str, Any]] = {}
finish_reason = "stop"
stream = self.provider.chat.completions.create(**request_params)
for chunk in stream:
if not chunk or not chunk.choices:
continue
choice = chunk.choices[0]
delta = choice.delta
# Check for finish reason
if choice.finish_reason:
finish_reason = choice.finish_reason
# Extract content deltas
if delta.content:
content_parts.append(delta.content)
yield ("content_delta", delta.content)
# Extract tool calls
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
fn = tc.function
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": tc.id or "",
"name": fn.name if fn and fn.name else "",
"arguments": "",
}
t = tool_calls_by_index[idx]
if tc.id:
t["id"] = tc.id
if fn and fn.name:
t["name"] = fn.name
if fn and fn.arguments:
t["arguments"] += fn.arguments
# Build final message
full_content = "".join(content_parts).strip() or None
# Convert tool calls to list format
tool_calls_list = None
if tool_calls_by_index:
tool_calls_list = []
for tc in tool_calls_by_index.values():
try:
# Parse accumulated arguments as JSON
parsed_args = json.loads(tc["arguments"])
except (json.JSONDecodeError, Exception):
parsed_args = tc["arguments"]
tool_calls_list.append(
{
"id": tc["id"],
"name": tc["name"],
"arguments": parsed_args,
}
)
finish_reason = "tool_calls"
yield (
"message",
{
"content": full_content,
"tool_calls": tool_calls_list,
"finish_reason": finish_reason,
},
)
except Exception as e:
logger.warning("Azure OpenAI streaming returned an error: %s", str(e))
yield (
"message",
{
"content": None,
"tool_calls": None,
"finish_reason": "error",
},
)

View File

@ -1,5 +1,6 @@
"""Gemini Provider for Frigate AI."""
import json
import logging
from typing import Any, Optional
@ -273,3 +274,239 @@ class GeminiClient(GenAIClient):
"tool_calls": None,
"finish_reason": "error",
}
async def chat_with_tools_stream(
self,
messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto",
):
"""
Stream chat with tools; yields content deltas then final message.
Implements streaming function calling/tool usage for Gemini models.
"""
try:
# Convert messages to Gemini format
gemini_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
# Map roles to Gemini format
if role == "system":
# Gemini doesn't have system role, prepend to first user message
if gemini_messages and gemini_messages[0].role == "user":
gemini_messages[0].parts[
0
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}"
else:
gemini_messages.append(
types.Content(
role="user", parts=[types.Part.from_text(text=content)]
)
)
elif role == "assistant":
gemini_messages.append(
types.Content(
role="model", parts=[types.Part.from_text(text=content)]
)
)
elif role == "tool":
# Handle tool response
function_response = {
"name": msg.get("name", ""),
"response": content,
}
gemini_messages.append(
types.Content(
role="function",
parts=[
types.Part.from_function_response(function_response)
],
)
)
else: # user
gemini_messages.append(
types.Content(
role="user", parts=[types.Part.from_text(text=content)]
)
)
# Convert tools to Gemini format
gemini_tools = None
if tools:
gemini_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
gemini_tools.append(
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name=func.get("name", ""),
description=func.get("description", ""),
parameters=func.get("parameters", {}),
)
]
)
)
# Configure tool choice
tool_config = None
if tool_choice:
if tool_choice == "none":
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(mode="NONE")
)
elif tool_choice == "auto":
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(mode="AUTO")
)
elif tool_choice == "required":
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(mode="ANY")
)
# Build request config
config_params = {"candidate_count": 1}
if gemini_tools:
config_params["tools"] = gemini_tools
if tool_config:
config_params["tool_config"] = tool_config
# Merge runtime_options
if isinstance(self.genai_config.runtime_options, dict):
config_params.update(self.genai_config.runtime_options)
# Use streaming API
content_parts: list[str] = []
tool_calls_by_index: dict[int, dict[str, Any]] = {}
finish_reason = "stop"
response = self.provider.models.generate_content_stream(
model=self.genai_config.model,
contents=gemini_messages,
config=types.GenerateContentConfig(**config_params),
)
async for chunk in response:
if not chunk or not chunk.candidates:
continue
candidate = chunk.candidates[0]
# Check for finish reason
if hasattr(candidate, "finish_reason") and candidate.finish_reason:
from google.genai.types import FinishReason
if candidate.finish_reason == FinishReason.STOP:
finish_reason = "stop"
elif candidate.finish_reason == FinishReason.MAX_TOKENS:
finish_reason = "length"
elif candidate.finish_reason in [
FinishReason.SAFETY,
FinishReason.RECITATION,
]:
finish_reason = "error"
# Extract content and tool calls from chunk
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if part.text:
content_parts.append(part.text)
yield ("content_delta", part.text)
elif part.function_call:
# Handle function call
try:
arguments = (
dict(part.function_call.args)
if part.function_call.args
else {}
)
except Exception:
arguments = {}
# Store tool call
tool_call_id = part.function_call.name or ""
tool_call_name = part.function_call.name or ""
# Check if we already have this tool call
found_index = None
for idx, tc in tool_calls_by_index.items():
if tc["name"] == tool_call_name:
found_index = idx
break
if found_index is None:
found_index = len(tool_calls_by_index)
tool_calls_by_index[found_index] = {
"id": tool_call_id,
"name": tool_call_name,
"arguments": "",
}
# Accumulate arguments
if arguments:
tool_calls_by_index[found_index]["arguments"] += (
json.dumps(arguments)
if isinstance(arguments, dict)
else str(arguments)
)
# Build final message
full_content = "".join(content_parts).strip() or None
# Convert tool calls to list format
tool_calls_list = None
if tool_calls_by_index:
tool_calls_list = []
for tc in tool_calls_by_index.values():
try:
# Try to parse accumulated arguments as JSON
parsed_args = json.loads(tc["arguments"])
except (json.JSONDecodeError, Exception):
parsed_args = tc["arguments"]
tool_calls_list.append(
{
"id": tc["id"],
"name": tc["name"],
"arguments": parsed_args,
}
)
finish_reason = "tool_calls"
yield (
"message",
{
"content": full_content,
"tool_calls": tool_calls_list,
"finish_reason": finish_reason,
},
)
except errors.APIError as e:
logger.warning("Gemini API error during streaming: %s", str(e))
yield (
"message",
{
"content": None,
"tool_calls": None,
"finish_reason": "error",
},
)
except Exception as e:
logger.warning(
"Gemini returned an error during chat_with_tools_stream: %s", str(e)
)
yield (
"message",
{
"content": None,
"tool_calls": None,
"finish_reason": "error",
},
)

View File

@ -227,3 +227,142 @@ class OpenAIClient(GenAIClient):
"tool_calls": None,
"finish_reason": "error",
}
async def chat_with_tools_stream(
self,
messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto",
):
"""
Stream chat with tools; yields content deltas then final message.
Implements streaming function calling/tool usage for OpenAI models.
"""
try:
openai_tool_choice = None
if tool_choice:
if tool_choice == "none":
openai_tool_choice = "none"
elif tool_choice == "auto":
openai_tool_choice = "auto"
elif tool_choice == "required":
openai_tool_choice = "required"
request_params = {
"model": self.genai_config.model,
"messages": messages,
"timeout": self.timeout,
"stream": True,
}
if tools:
request_params["tools"] = tools
if openai_tool_choice is not None:
request_params["tool_choice"] = openai_tool_choice
if isinstance(self.genai_config.provider_options, dict):
excluded_options = {"context_size"}
provider_opts = {
k: v
for k, v in self.genai_config.provider_options.items()
if k not in excluded_options
}
request_params.update(provider_opts)
# Use streaming API
content_parts: list[str] = []
tool_calls_by_index: dict[int, dict[str, Any]] = {}
finish_reason = "stop"
stream = self.provider.chat.completions.create(**request_params)
for chunk in stream:
if not chunk or not chunk.choices:
continue
choice = chunk.choices[0]
delta = choice.delta
# Check for finish reason
if choice.finish_reason:
finish_reason = choice.finish_reason
# Extract content deltas
if delta.content:
content_parts.append(delta.content)
yield ("content_delta", delta.content)
# Extract tool calls
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
fn = tc.function
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": tc.id or "",
"name": fn.name if fn and fn.name else "",
"arguments": "",
}
t = tool_calls_by_index[idx]
if tc.id:
t["id"] = tc.id
if fn and fn.name:
t["name"] = fn.name
if fn and fn.arguments:
t["arguments"] += fn.arguments
# Build final message
full_content = "".join(content_parts).strip() or None
# Convert tool calls to list format
tool_calls_list = None
if tool_calls_by_index:
tool_calls_list = []
for tc in tool_calls_by_index.values():
try:
# Parse accumulated arguments as JSON
parsed_args = json.loads(tc["arguments"])
except (json.JSONDecodeError, Exception):
parsed_args = tc["arguments"]
tool_calls_list.append(
{
"id": tc["id"],
"name": tc["name"],
"arguments": parsed_args,
}
)
finish_reason = "tool_calls"
yield (
"message",
{
"content": full_content,
"tool_calls": tool_calls_list,
"finish_reason": finish_reason,
},
)
except TimeoutException as e:
logger.warning("OpenAI streaming request timed out: %s", str(e))
yield (
"message",
{
"content": None,
"tool_calls": None,
"finish_reason": "error",
},
)
except Exception as e:
logger.warning("OpenAI streaming returned an error: %s", str(e))
yield (
"message",
{
"content": None,
"tool_calls": None,
"finish_reason": "error",
},
)