diff --git a/frigate/genai/plugins/azure-openai.py b/frigate/genai/plugins/azure-openai.py index 6330e539df..4cd4084ae3 100644 --- a/frigate/genai/plugins/azure-openai.py +++ b/frigate/genai/plugins/azure-openai.py @@ -1,28 +1,39 @@ -"""Azure OpenAI Provider for Frigate AI.""" +"""Azure OpenAI Provider for Frigate AI. + +Azure OpenAI exposes the same chat completions API as OpenAI once the +client is constructed, so this provider inherits all transport, streaming, +reasoning, and tool-calling logic from :class:`OpenAIClient` and only +overrides what is genuinely Azure-specific: + +- Client construction: parses ``api-version`` out of the configured + ``base_url`` query string and instantiates :class:`openai.AzureOpenAI` + with ``azure_endpoint`` instead of ``base_url``. +- Context size: Azure does not expose a per-model ``max_model_len`` field + reliably, so we keep the historical 128K default rather than the + model-name heuristic used by OpenAI. +""" -import base64 -import json import logging -from typing import Any, AsyncGenerator, Optional +from typing import Optional from urllib.parse import parse_qs, urlparse from openai import AzureOpenAI from frigate.config import GenAIProviderEnum -from frigate.genai import GenAIClient, register_genai_provider -from frigate.genai.plugins.openai import _stats_from_openai_usage +from frigate.genai import register_genai_provider +from frigate.genai.plugins.openai import OpenAIClient logger = logging.getLogger(__name__) @register_genai_provider(GenAIProviderEnum.azure_openai) -class OpenAIClient(GenAIClient): +class AzureOpenAIClient(OpenAIClient): """Generative AI client for Frigate using Azure OpenAI.""" - provider: AzureOpenAI + provider: AzureOpenAI # type: ignore[assignment] - def _init_provider(self) -> AzureOpenAI | None: - """Initialize the client.""" + def _init_provider(self) -> Optional[AzureOpenAI]: + """Initialize the AzureOpenAI client from the configured base_url.""" try: parsed_url = urlparse(self.genai_config.base_url or "") query_params = parse_qs(parsed_url.query) @@ -32,7 +43,6 @@ class OpenAIClient(GenAIClient): if not api_version: logger.warning("Azure OpenAI url is missing API version.") return None - except Exception as e: logger.warning("Error parsing Azure OpenAI url: %s", str(e)) return None @@ -43,275 +53,6 @@ class OpenAIClient(GenAIClient): azure_endpoint=azure_endpoint, ) - def _send( - self, - prompt: str, - images: list[bytes], - response_format: Optional[dict] = None, - ) -> Optional[str]: - """Submit a request to Azure OpenAI.""" - encoded_images = [base64.b64encode(image).decode("utf-8") for image in images] - try: - request_params = { - "model": self.genai_config.model, - "messages": [ - { - "role": "user", - "content": [{"type": "text", "text": prompt}] - + [ - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image}", - "detail": "low", - }, - } - for image in encoded_images - ], - }, - ], - "timeout": self.timeout, - **self.genai_config.runtime_options, - } - if response_format: - request_params["response_format"] = response_format - result = self.provider.chat.completions.create(**request_params) - except Exception as e: - logger.warning("Azure OpenAI returned an error: %s", str(e)) - return None - if len(result.choices) > 0: - return str(result.choices[0].message.content.strip()) - return None - - def list_models(self) -> list[str]: - """Return available model IDs from Azure OpenAI.""" - try: - return sorted(m.id for m in self.provider.models.list().data) - except Exception as e: - logger.warning("Failed to list Azure OpenAI models: %s", e) - return [] - def get_context_size(self) -> int: - """Get the context window size for Azure OpenAI.""" + """Azure does not reliably surface per-model context size; use 128K.""" return 128000 - - def chat_with_tools( - self, - messages: list[dict[str, Any]], - tools: Optional[list[dict[str, Any]]] = None, - tool_choice: Optional[str] = "auto", - ) -> dict[str, Any]: - try: - openai_tool_choice = None - if tool_choice: - if tool_choice == "none": - openai_tool_choice = "none" - elif tool_choice == "auto": - openai_tool_choice = "auto" - elif tool_choice == "required": - openai_tool_choice = "required" - - request_params = { - "model": self.genai_config.model, - "messages": messages, - "timeout": self.timeout, - **self.genai_config.runtime_options, - } - - if tools: - request_params["tools"] = tools - if openai_tool_choice is not None: - request_params["tool_choice"] = openai_tool_choice - - result = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload] - - if ( - result is None - or not hasattr(result, "choices") - or len(result.choices) == 0 - ): - return { - "content": None, - "tool_calls": None, - "finish_reason": "error", - } - - choice = result.choices[0] - message = choice.message - - content = message.content.strip() if message.content else None - - tool_calls = None - if message.tool_calls: - tool_calls = [] - for tool_call in message.tool_calls: - try: - arguments = json.loads(tool_call.function.arguments) - except (json.JSONDecodeError, AttributeError) as e: - logger.warning( - f"Failed to parse tool call arguments: {e}, " - f"tool: {tool_call.function.name if hasattr(tool_call.function, 'name') else 'unknown'}" - ) - arguments = {} - - tool_calls.append( - { - "id": tool_call.id if hasattr(tool_call, "id") else "", - "name": tool_call.function.name - if hasattr(tool_call.function, "name") - else "", - "arguments": arguments, - } - ) - - finish_reason = "error" - if hasattr(choice, "finish_reason") and choice.finish_reason: - finish_reason = choice.finish_reason - elif tool_calls: - finish_reason = "tool_calls" - elif content: - finish_reason = "stop" - - return { - "content": content, - "tool_calls": tool_calls, - "finish_reason": finish_reason, - } - - except Exception as e: - logger.warning("Azure OpenAI returned an error: %s", str(e)) - return { - "content": None, - "tool_calls": None, - "finish_reason": "error", - } - - async def chat_with_tools_stream( - self, - messages: list[dict[str, Any]], - tools: Optional[list[dict[str, Any]]] = None, - tool_choice: Optional[str] = "auto", - ) -> AsyncGenerator[tuple[str, Any], None]: - """ - Stream chat with tools; yields content deltas then final message. - - Implements streaming function calling/tool usage for Azure OpenAI models. - """ - try: - openai_tool_choice = None - if tool_choice: - if tool_choice == "none": - openai_tool_choice = "none" - elif tool_choice == "auto": - openai_tool_choice = "auto" - elif tool_choice == "required": - openai_tool_choice = "required" - - request_params = { - "model": self.genai_config.model, - "messages": messages, - "timeout": self.timeout, - "stream": True, - "stream_options": {"include_usage": True}, - **self.genai_config.runtime_options, - } - - if tools: - request_params["tools"] = tools - if openai_tool_choice is not None: - request_params["tool_choice"] = openai_tool_choice - - # Use streaming API - content_parts: list[str] = [] - tool_calls_by_index: dict[int, dict[str, Any]] = {} - finish_reason = "stop" - usage_stats: Optional[dict[str, Any]] = None - - stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload] - - for chunk in stream: - chunk_usage = getattr(chunk, "usage", None) - if chunk_usage is not None: - usage_stats = _stats_from_openai_usage(chunk_usage) - - if not chunk or not chunk.choices: - continue - - choice = chunk.choices[0] - delta = choice.delta - - # Check for finish reason - if choice.finish_reason: - finish_reason = choice.finish_reason - - # Extract content deltas - if delta.content: - content_parts.append(delta.content) - yield ("content_delta", delta.content) - - # Extract tool calls - if delta.tool_calls: - for tc in delta.tool_calls: - idx = tc.index - fn = tc.function - - if idx not in tool_calls_by_index: - tool_calls_by_index[idx] = { - "id": tc.id or "", - "name": fn.name if fn and fn.name else "", - "arguments": "", - } - - t = tool_calls_by_index[idx] - if tc.id: - t["id"] = tc.id - if fn and fn.name: - t["name"] = fn.name - if fn and fn.arguments: - t["arguments"] += fn.arguments - - # Build final message - full_content = "".join(content_parts).strip() or None - - # Convert tool calls to list format - tool_calls_list = None - if tool_calls_by_index: - tool_calls_list = [] - for tc in tool_calls_by_index.values(): - try: - # Parse accumulated arguments as JSON - parsed_args = json.loads(tc["arguments"]) - except (json.JSONDecodeError, Exception): - parsed_args = tc["arguments"] - - tool_calls_list.append( - { - "id": tc["id"], - "name": tc["name"], - "arguments": parsed_args, - } - ) - finish_reason = "tool_calls" - - if usage_stats is not None: - yield ("stats", usage_stats) - - yield ( - "message", - { - "content": full_content, - "tool_calls": tool_calls_list, - "finish_reason": finish_reason, - }, - ) - - except Exception as e: - logger.warning("Azure OpenAI streaming returned an error: %s", str(e)) - yield ( - "message", - { - "content": None, - "tool_calls": None, - "finish_reason": "error", - }, - )