Make azure a subclass of openai

This commit is contained in:
Nicolas Mowen 2026-05-19 10:20:20 -06:00
parent 3ca97ed3d0
commit d7683fd797

View File

@ -1,28 +1,39 @@
"""Azure OpenAI Provider for Frigate AI."""
"""Azure OpenAI Provider for Frigate AI.
Azure OpenAI exposes the same chat completions API as OpenAI once the
client is constructed, so this provider inherits all transport, streaming,
reasoning, and tool-calling logic from :class:`OpenAIClient` and only
overrides what is genuinely Azure-specific:
- Client construction: parses ``api-version`` out of the configured
``base_url`` query string and instantiates :class:`openai.AzureOpenAI`
with ``azure_endpoint`` instead of ``base_url``.
- Context size: Azure does not expose a per-model ``max_model_len`` field
reliably, so we keep the historical 128K default rather than the
model-name heuristic used by OpenAI.
"""
import base64
import json
import logging
from typing import Any, AsyncGenerator, Optional
from typing import Optional
from urllib.parse import parse_qs, urlparse
from openai import AzureOpenAI
from frigate.config import GenAIProviderEnum
from frigate.genai import GenAIClient, register_genai_provider
from frigate.genai.plugins.openai import _stats_from_openai_usage
from frigate.genai import register_genai_provider
from frigate.genai.plugins.openai import OpenAIClient
logger = logging.getLogger(__name__)
@register_genai_provider(GenAIProviderEnum.azure_openai)
class OpenAIClient(GenAIClient):
class AzureOpenAIClient(OpenAIClient):
"""Generative AI client for Frigate using Azure OpenAI."""
provider: AzureOpenAI
provider: AzureOpenAI # type: ignore[assignment]
def _init_provider(self) -> AzureOpenAI | None:
"""Initialize the client."""
def _init_provider(self) -> Optional[AzureOpenAI]:
"""Initialize the AzureOpenAI client from the configured base_url."""
try:
parsed_url = urlparse(self.genai_config.base_url or "")
query_params = parse_qs(parsed_url.query)
@ -32,7 +43,6 @@ class OpenAIClient(GenAIClient):
if not api_version:
logger.warning("Azure OpenAI url is missing API version.")
return None
except Exception as e:
logger.warning("Error parsing Azure OpenAI url: %s", str(e))
return None
@ -43,275 +53,6 @@ class OpenAIClient(GenAIClient):
azure_endpoint=azure_endpoint,
)
def _send(
self,
prompt: str,
images: list[bytes],
response_format: Optional[dict] = None,
) -> Optional[str]:
"""Submit a request to Azure OpenAI."""
encoded_images = [base64.b64encode(image).decode("utf-8") for image in images]
try:
request_params = {
"model": self.genai_config.model,
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
+ [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image}",
"detail": "low",
},
}
for image in encoded_images
],
},
],
"timeout": self.timeout,
**self.genai_config.runtime_options,
}
if response_format:
request_params["response_format"] = response_format
result = self.provider.chat.completions.create(**request_params)
except Exception as e:
logger.warning("Azure OpenAI returned an error: %s", str(e))
return None
if len(result.choices) > 0:
return str(result.choices[0].message.content.strip())
return None
def list_models(self) -> list[str]:
"""Return available model IDs from Azure OpenAI."""
try:
return sorted(m.id for m in self.provider.models.list().data)
except Exception as e:
logger.warning("Failed to list Azure OpenAI models: %s", e)
return []
def get_context_size(self) -> int:
"""Get the context window size for Azure OpenAI."""
"""Azure does not reliably surface per-model context size; use 128K."""
return 128000
def chat_with_tools(
self,
messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto",
) -> dict[str, Any]:
try:
openai_tool_choice = None
if tool_choice:
if tool_choice == "none":
openai_tool_choice = "none"
elif tool_choice == "auto":
openai_tool_choice = "auto"
elif tool_choice == "required":
openai_tool_choice = "required"
request_params = {
"model": self.genai_config.model,
"messages": messages,
"timeout": self.timeout,
**self.genai_config.runtime_options,
}
if tools:
request_params["tools"] = tools
if openai_tool_choice is not None:
request_params["tool_choice"] = openai_tool_choice
result = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
if (
result is None
or not hasattr(result, "choices")
or len(result.choices) == 0
):
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
choice = result.choices[0]
message = choice.message
content = message.content.strip() if message.content else None
tool_calls = None
if message.tool_calls:
tool_calls = []
for tool_call in message.tool_calls:
try:
arguments = json.loads(tool_call.function.arguments)
except (json.JSONDecodeError, AttributeError) as e:
logger.warning(
f"Failed to parse tool call arguments: {e}, "
f"tool: {tool_call.function.name if hasattr(tool_call.function, 'name') else 'unknown'}"
)
arguments = {}
tool_calls.append(
{
"id": tool_call.id if hasattr(tool_call, "id") else "",
"name": tool_call.function.name
if hasattr(tool_call.function, "name")
else "",
"arguments": arguments,
}
)
finish_reason = "error"
if hasattr(choice, "finish_reason") and choice.finish_reason:
finish_reason = choice.finish_reason
elif tool_calls:
finish_reason = "tool_calls"
elif content:
finish_reason = "stop"
return {
"content": content,
"tool_calls": tool_calls,
"finish_reason": finish_reason,
}
except Exception as e:
logger.warning("Azure OpenAI returned an error: %s", str(e))
return {
"content": None,
"tool_calls": None,
"finish_reason": "error",
}
async def chat_with_tools_stream(
self,
messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto",
) -> AsyncGenerator[tuple[str, Any], None]:
"""
Stream chat with tools; yields content deltas then final message.
Implements streaming function calling/tool usage for Azure OpenAI models.
"""
try:
openai_tool_choice = None
if tool_choice:
if tool_choice == "none":
openai_tool_choice = "none"
elif tool_choice == "auto":
openai_tool_choice = "auto"
elif tool_choice == "required":
openai_tool_choice = "required"
request_params = {
"model": self.genai_config.model,
"messages": messages,
"timeout": self.timeout,
"stream": True,
"stream_options": {"include_usage": True},
**self.genai_config.runtime_options,
}
if tools:
request_params["tools"] = tools
if openai_tool_choice is not None:
request_params["tool_choice"] = openai_tool_choice
# Use streaming API
content_parts: list[str] = []
tool_calls_by_index: dict[int, dict[str, Any]] = {}
finish_reason = "stop"
usage_stats: Optional[dict[str, Any]] = None
stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
for chunk in stream:
chunk_usage = getattr(chunk, "usage", None)
if chunk_usage is not None:
usage_stats = _stats_from_openai_usage(chunk_usage)
if not chunk or not chunk.choices:
continue
choice = chunk.choices[0]
delta = choice.delta
# Check for finish reason
if choice.finish_reason:
finish_reason = choice.finish_reason
# Extract content deltas
if delta.content:
content_parts.append(delta.content)
yield ("content_delta", delta.content)
# Extract tool calls
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
fn = tc.function
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": tc.id or "",
"name": fn.name if fn and fn.name else "",
"arguments": "",
}
t = tool_calls_by_index[idx]
if tc.id:
t["id"] = tc.id
if fn and fn.name:
t["name"] = fn.name
if fn and fn.arguments:
t["arguments"] += fn.arguments
# Build final message
full_content = "".join(content_parts).strip() or None
# Convert tool calls to list format
tool_calls_list = None
if tool_calls_by_index:
tool_calls_list = []
for tc in tool_calls_by_index.values():
try:
# Parse accumulated arguments as JSON
parsed_args = json.loads(tc["arguments"])
except (json.JSONDecodeError, Exception):
parsed_args = tc["arguments"]
tool_calls_list.append(
{
"id": tc["id"],
"name": tc["name"],
"arguments": parsed_args,
}
)
finish_reason = "tool_calls"
if usage_stats is not None:
yield ("stats", usage_stats)
yield (
"message",
{
"content": full_content,
"tool_calls": tool_calls_list,
"finish_reason": finish_reason,
},
)
except Exception as e:
logger.warning("Azure OpenAI streaming returned an error: %s", str(e))
yield (
"message",
{
"content": None,
"tool_calls": None,
"finish_reason": "error",
},
)