mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-05 14:47:40 +03:00
Enable and fix inner genai
This commit is contained in:
parent
4c41b47aad
commit
42560de302
@ -3,7 +3,7 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from openai import AzureOpenAI
|
||||
@ -20,13 +20,13 @@ class OpenAIClient(GenAIClient):
|
||||
|
||||
provider: AzureOpenAI
|
||||
|
||||
def _init_provider(self):
|
||||
def _init_provider(self) -> AzureOpenAI | None:
|
||||
"""Initialize the client."""
|
||||
try:
|
||||
parsed_url = urlparse(self.genai_config.base_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
query_params = parse_qs(parsed_url.query) # type: ignore[type-var]
|
||||
api_version = query_params.get("api-version", [None])[0]
|
||||
azure_endpoint = f"{parsed_url.scheme}://{parsed_url.netloc}/"
|
||||
azure_endpoint = f"{parsed_url.scheme}://{parsed_url.netloc}/" # type: ignore[str-bytes-safe]
|
||||
|
||||
if not api_version:
|
||||
logger.warning("Azure OpenAI url is missing API version.")
|
||||
@ -36,7 +36,7 @@ class OpenAIClient(GenAIClient):
|
||||
logger.warning("Error parsing Azure OpenAI url: %s", str(e))
|
||||
return None
|
||||
|
||||
return AzureOpenAI(
|
||||
return AzureOpenAI( # type: ignore[call-overload,no-any-return]
|
||||
api_key=self.genai_config.api_key,
|
||||
api_version=api_version,
|
||||
azure_endpoint=azure_endpoint,
|
||||
@ -79,7 +79,7 @@ class OpenAIClient(GenAIClient):
|
||||
logger.warning("Azure OpenAI returned an error: %s", str(e))
|
||||
return None
|
||||
if len(result.choices) > 0:
|
||||
return result.choices[0].message.content.strip()
|
||||
return str(result.choices[0].message.content.strip())
|
||||
return None
|
||||
|
||||
def get_context_size(self) -> int:
|
||||
@ -113,7 +113,7 @@ class OpenAIClient(GenAIClient):
|
||||
if openai_tool_choice is not None:
|
||||
request_params["tool_choice"] = openai_tool_choice
|
||||
|
||||
result = self.provider.chat.completions.create(**request_params)
|
||||
result = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
|
||||
|
||||
if (
|
||||
result is None
|
||||
@ -181,7 +181,7 @@ class OpenAIClient(GenAIClient):
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = "auto",
|
||||
):
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""
|
||||
Stream chat with tools; yields content deltas then final message.
|
||||
|
||||
@ -214,7 +214,7 @@ class OpenAIClient(GenAIClient):
|
||||
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
stream = self.provider.chat.completions.create(**request_params)
|
||||
stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk or not chunk.choices:
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
from google import genai
|
||||
from google.genai import errors, types
|
||||
@ -19,10 +19,10 @@ class GeminiClient(GenAIClient):
|
||||
|
||||
provider: genai.Client
|
||||
|
||||
def _init_provider(self):
|
||||
def _init_provider(self) -> genai.Client:
|
||||
"""Initialize the client."""
|
||||
# Merge provider_options into HttpOptions
|
||||
http_options_dict = {
|
||||
http_options_dict: dict[str, Any] = {
|
||||
"timeout": int(self.timeout * 1000), # requires milliseconds
|
||||
"retry_options": types.HttpRetryOptions(
|
||||
attempts=3,
|
||||
@ -54,7 +54,7 @@ class GeminiClient(GenAIClient):
|
||||
] + [prompt]
|
||||
try:
|
||||
# Merge runtime_options into generation_config if provided
|
||||
generation_config_dict = {"candidate_count": 1}
|
||||
generation_config_dict: dict[str, Any] = {"candidate_count": 1}
|
||||
generation_config_dict.update(self.genai_config.runtime_options)
|
||||
|
||||
if response_format and response_format.get("type") == "json_schema":
|
||||
@ -65,7 +65,7 @@ class GeminiClient(GenAIClient):
|
||||
|
||||
response = self.provider.models.generate_content(
|
||||
model=self.genai_config.model,
|
||||
contents=contents,
|
||||
contents=contents, # type: ignore[arg-type]
|
||||
config=types.GenerateContentConfig(
|
||||
**generation_config_dict,
|
||||
),
|
||||
@ -78,7 +78,7 @@ class GeminiClient(GenAIClient):
|
||||
return None
|
||||
|
||||
try:
|
||||
description = response.text.strip()
|
||||
description = response.text.strip() # type: ignore[union-attr]
|
||||
except (ValueError, AttributeError):
|
||||
# No description was generated
|
||||
return None
|
||||
@ -102,7 +102,7 @@ class GeminiClient(GenAIClient):
|
||||
"""
|
||||
try:
|
||||
# Convert messages to Gemini format
|
||||
gemini_messages = []
|
||||
gemini_messages: list[types.Content] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
@ -111,9 +111,9 @@ class GeminiClient(GenAIClient):
|
||||
if role == "system":
|
||||
# Gemini doesn't have system role, prepend to first user message
|
||||
if gemini_messages and gemini_messages[0].role == "user":
|
||||
gemini_messages[0].parts[
|
||||
gemini_messages[0].parts[ # type: ignore[index]
|
||||
0
|
||||
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}"
|
||||
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}" # type: ignore[index]
|
||||
else:
|
||||
gemini_messages.append(
|
||||
types.Content(
|
||||
@ -136,7 +136,7 @@ class GeminiClient(GenAIClient):
|
||||
types.Content(
|
||||
role="function",
|
||||
parts=[
|
||||
types.Part.from_function_response(function_response)
|
||||
types.Part.from_function_response(function_response) # type: ignore[misc,call-arg,arg-type]
|
||||
],
|
||||
)
|
||||
)
|
||||
@ -171,19 +171,19 @@ class GeminiClient(GenAIClient):
|
||||
if tool_choice:
|
||||
if tool_choice == "none":
|
||||
tool_config = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(mode="NONE")
|
||||
function_calling_config=types.FunctionCallingConfig(mode="NONE") # type: ignore[arg-type]
|
||||
)
|
||||
elif tool_choice == "auto":
|
||||
tool_config = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(mode="AUTO")
|
||||
function_calling_config=types.FunctionCallingConfig(mode="AUTO") # type: ignore[arg-type]
|
||||
)
|
||||
elif tool_choice == "required":
|
||||
tool_config = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(mode="ANY")
|
||||
function_calling_config=types.FunctionCallingConfig(mode="ANY") # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Build request config
|
||||
config_params = {"candidate_count": 1}
|
||||
config_params: dict[str, Any] = {"candidate_count": 1}
|
||||
|
||||
if gemini_tools:
|
||||
config_params["tools"] = gemini_tools
|
||||
@ -197,7 +197,7 @@ class GeminiClient(GenAIClient):
|
||||
|
||||
response = self.provider.models.generate_content(
|
||||
model=self.genai_config.model,
|
||||
contents=gemini_messages,
|
||||
contents=gemini_messages, # type: ignore[arg-type]
|
||||
config=types.GenerateContentConfig(**config_params),
|
||||
)
|
||||
|
||||
@ -291,7 +291,7 @@ class GeminiClient(GenAIClient):
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = "auto",
|
||||
):
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""
|
||||
Stream chat with tools; yields content deltas then final message.
|
||||
|
||||
@ -299,7 +299,7 @@ class GeminiClient(GenAIClient):
|
||||
"""
|
||||
try:
|
||||
# Convert messages to Gemini format
|
||||
gemini_messages = []
|
||||
gemini_messages: list[types.Content] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
@ -308,9 +308,9 @@ class GeminiClient(GenAIClient):
|
||||
if role == "system":
|
||||
# Gemini doesn't have system role, prepend to first user message
|
||||
if gemini_messages and gemini_messages[0].role == "user":
|
||||
gemini_messages[0].parts[
|
||||
gemini_messages[0].parts[ # type: ignore[index]
|
||||
0
|
||||
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}"
|
||||
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}" # type: ignore[index]
|
||||
else:
|
||||
gemini_messages.append(
|
||||
types.Content(
|
||||
@ -333,7 +333,7 @@ class GeminiClient(GenAIClient):
|
||||
types.Content(
|
||||
role="function",
|
||||
parts=[
|
||||
types.Part.from_function_response(function_response)
|
||||
types.Part.from_function_response(function_response) # type: ignore[misc,call-arg,arg-type]
|
||||
],
|
||||
)
|
||||
)
|
||||
@ -368,19 +368,19 @@ class GeminiClient(GenAIClient):
|
||||
if tool_choice:
|
||||
if tool_choice == "none":
|
||||
tool_config = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(mode="NONE")
|
||||
function_calling_config=types.FunctionCallingConfig(mode="NONE") # type: ignore[arg-type]
|
||||
)
|
||||
elif tool_choice == "auto":
|
||||
tool_config = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(mode="AUTO")
|
||||
function_calling_config=types.FunctionCallingConfig(mode="AUTO") # type: ignore[arg-type]
|
||||
)
|
||||
elif tool_choice == "required":
|
||||
tool_config = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(mode="ANY")
|
||||
function_calling_config=types.FunctionCallingConfig(mode="ANY") # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Build request config
|
||||
config_params = {"candidate_count": 1}
|
||||
config_params: dict[str, Any] = {"candidate_count": 1}
|
||||
|
||||
if gemini_tools:
|
||||
config_params["tools"] = gemini_tools
|
||||
@ -399,7 +399,7 @@ class GeminiClient(GenAIClient):
|
||||
|
||||
stream = await self.provider.aio.models.generate_content_stream(
|
||||
model=self.genai_config.model,
|
||||
contents=gemini_messages,
|
||||
contents=gemini_messages, # type: ignore[arg-type]
|
||||
config=types.GenerateContentConfig(**config_params),
|
||||
)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
@ -23,7 +23,7 @@ def _to_jpeg(img_bytes: bytes) -> bytes | None:
|
||||
try:
|
||||
img = Image.open(io.BytesIO(img_bytes))
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
img = img.convert("RGB") # type: ignore[assignment]
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=85)
|
||||
return buf.getvalue()
|
||||
@ -36,10 +36,10 @@ def _to_jpeg(img_bytes: bytes) -> bytes | None:
|
||||
class LlamaCppClient(GenAIClient):
|
||||
"""Generative AI client for Frigate using llama.cpp server."""
|
||||
|
||||
provider: str # base_url
|
||||
provider: str | None # base_url
|
||||
provider_options: dict[str, Any]
|
||||
|
||||
def _init_provider(self):
|
||||
def _init_provider(self) -> str | None:
|
||||
"""Initialize the client."""
|
||||
self.provider_options = {
|
||||
**self.genai_config.provider_options,
|
||||
@ -75,7 +75,7 @@ class LlamaCppClient(GenAIClient):
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"image_url": { # type: ignore[dict-item]
|
||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
||||
},
|
||||
}
|
||||
@ -111,7 +111,7 @@ class LlamaCppClient(GenAIClient):
|
||||
):
|
||||
choice = result["choices"][0]
|
||||
if "message" in choice and "content" in choice["message"]:
|
||||
return choice["message"]["content"].strip()
|
||||
return str(choice["message"]["content"].strip())
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("llama.cpp returned an error: %s", str(e))
|
||||
@ -229,7 +229,7 @@ class LlamaCppClient(GenAIClient):
|
||||
content.append(
|
||||
{
|
||||
"prompt_string": "<__media__>\n",
|
||||
"multimodal_data": [encoded],
|
||||
"multimodal_data": [encoded], # type: ignore[dict-item]
|
||||
}
|
||||
)
|
||||
|
||||
@ -367,7 +367,7 @@ class LlamaCppClient(GenAIClient):
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = "auto",
|
||||
):
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Stream chat with tools via OpenAI-compatible streaming API."""
|
||||
if self.provider is None:
|
||||
logger.warning(
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
from httpx import RemoteProtocolError, TimeoutException
|
||||
from ollama import AsyncClient as OllamaAsyncClient
|
||||
@ -28,10 +28,10 @@ class OllamaClient(GenAIClient):
|
||||
},
|
||||
}
|
||||
|
||||
provider: ApiClient
|
||||
provider: ApiClient | None
|
||||
provider_options: dict[str, Any]
|
||||
|
||||
def _init_provider(self):
|
||||
def _init_provider(self) -> ApiClient | None:
|
||||
"""Initialize the client."""
|
||||
self.provider_options = {
|
||||
**self.LOCAL_OPTIMIZED_OPTIONS,
|
||||
@ -73,7 +73,7 @@ class OllamaClient(GenAIClient):
|
||||
"exclusiveMinimum",
|
||||
"exclusiveMaximum",
|
||||
}
|
||||
result = {}
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in schema.items():
|
||||
if not _is_properties and key in STRIP_KEYS:
|
||||
continue
|
||||
@ -122,7 +122,7 @@ class OllamaClient(GenAIClient):
|
||||
logger.debug(
|
||||
f"Ollama tokens used: eval_count={result.get('eval_count')}, prompt_eval_count={result.get('prompt_eval_count')}"
|
||||
)
|
||||
return result["response"].strip()
|
||||
return str(result["response"]).strip()
|
||||
except (
|
||||
TimeoutException,
|
||||
ResponseError,
|
||||
@ -263,7 +263,7 @@ class OllamaClient(GenAIClient):
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = "auto",
|
||||
):
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Stream chat with tools; yields content deltas then final message.
|
||||
|
||||
When tools are provided, Ollama streaming does not include tool_calls
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
from httpx import TimeoutException
|
||||
from openai import OpenAI
|
||||
@ -21,7 +21,7 @@ class OpenAIClient(GenAIClient):
|
||||
provider: OpenAI
|
||||
context_size: Optional[int] = None
|
||||
|
||||
def _init_provider(self):
|
||||
def _init_provider(self) -> OpenAI:
|
||||
"""Initialize the client."""
|
||||
# Extract context_size from provider_options as it's not a valid OpenAI client parameter
|
||||
# It will be used in get_context_size() instead
|
||||
@ -81,7 +81,7 @@ class OpenAIClient(GenAIClient):
|
||||
and hasattr(result, "choices")
|
||||
and len(result.choices) > 0
|
||||
):
|
||||
return result.choices[0].message.content.strip()
|
||||
return str(result.choices[0].message.content.strip())
|
||||
return None
|
||||
except (TimeoutException, Exception) as e:
|
||||
logger.warning("OpenAI returned an error: %s", str(e))
|
||||
@ -171,7 +171,7 @@ class OpenAIClient(GenAIClient):
|
||||
}
|
||||
request_params.update(provider_opts)
|
||||
|
||||
result = self.provider.chat.completions.create(**request_params)
|
||||
result = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
|
||||
|
||||
if (
|
||||
result is None
|
||||
@ -245,7 +245,7 @@ class OpenAIClient(GenAIClient):
|
||||
messages: list[dict[str, Any]],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
tool_choice: Optional[str] = "auto",
|
||||
):
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""
|
||||
Stream chat with tools; yields content deltas then final message.
|
||||
|
||||
@ -287,7 +287,7 @@ class OpenAIClient(GenAIClient):
|
||||
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
stream = self.provider.chat.completions.create(**request_params)
|
||||
stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk or not chunk.choices:
|
||||
|
||||
@ -41,7 +41,7 @@ ignore_errors = false
|
||||
[mypy-frigate.events]
|
||||
ignore_errors = false
|
||||
|
||||
[mypy-frigate.genai]
|
||||
[mypy-frigate.genai.*]
|
||||
ignore_errors = false
|
||||
|
||||
[mypy-frigate.jobs]
|
||||
@ -50,6 +50,15 @@ ignore_errors = false
|
||||
[mypy-frigate.motion]
|
||||
ignore_errors = false
|
||||
|
||||
[mypy-frigate.object_detection]
|
||||
ignore_errors = false
|
||||
|
||||
[mypy-frigate.output]
|
||||
ignore_errors = false
|
||||
|
||||
[mypy-frigate.ptz]
|
||||
ignore_errors = false
|
||||
|
||||
[mypy-frigate.log]
|
||||
ignore_errors = false
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user