Enable and fix inner genai

This commit is contained in:
Nicolas Mowen 2026-03-25 08:59:15 -06:00
parent 4c41b47aad
commit 42560de302
6 changed files with 64 additions and 55 deletions

View File

@ -3,7 +3,7 @@
import base64 import base64
import json import json
import logging import logging
from typing import Any, Optional from typing import Any, AsyncGenerator, Optional
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from openai import AzureOpenAI from openai import AzureOpenAI
@ -20,13 +20,13 @@ class OpenAIClient(GenAIClient):
provider: AzureOpenAI provider: AzureOpenAI
def _init_provider(self): def _init_provider(self) -> AzureOpenAI | None:
"""Initialize the client.""" """Initialize the client."""
try: try:
parsed_url = urlparse(self.genai_config.base_url) parsed_url = urlparse(self.genai_config.base_url)
query_params = parse_qs(parsed_url.query) query_params = parse_qs(parsed_url.query) # type: ignore[type-var]
api_version = query_params.get("api-version", [None])[0] api_version = query_params.get("api-version", [None])[0]
azure_endpoint = f"{parsed_url.scheme}://{parsed_url.netloc}/" azure_endpoint = f"{parsed_url.scheme}://{parsed_url.netloc}/" # type: ignore[str-bytes-safe]
if not api_version: if not api_version:
logger.warning("Azure OpenAI url is missing API version.") logger.warning("Azure OpenAI url is missing API version.")
@ -36,7 +36,7 @@ class OpenAIClient(GenAIClient):
logger.warning("Error parsing Azure OpenAI url: %s", str(e)) logger.warning("Error parsing Azure OpenAI url: %s", str(e))
return None return None
return AzureOpenAI( return AzureOpenAI( # type: ignore[call-overload,no-any-return]
api_key=self.genai_config.api_key, api_key=self.genai_config.api_key,
api_version=api_version, api_version=api_version,
azure_endpoint=azure_endpoint, azure_endpoint=azure_endpoint,
@ -79,7 +79,7 @@ class OpenAIClient(GenAIClient):
logger.warning("Azure OpenAI returned an error: %s", str(e)) logger.warning("Azure OpenAI returned an error: %s", str(e))
return None return None
if len(result.choices) > 0: if len(result.choices) > 0:
return result.choices[0].message.content.strip() return str(result.choices[0].message.content.strip())
return None return None
def get_context_size(self) -> int: def get_context_size(self) -> int:
@ -113,7 +113,7 @@ class OpenAIClient(GenAIClient):
if openai_tool_choice is not None: if openai_tool_choice is not None:
request_params["tool_choice"] = openai_tool_choice request_params["tool_choice"] = openai_tool_choice
result = self.provider.chat.completions.create(**request_params) result = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
if ( if (
result is None result is None
@ -181,7 +181,7 @@ class OpenAIClient(GenAIClient):
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto", tool_choice: Optional[str] = "auto",
): ) -> AsyncGenerator[tuple[str, Any], None]:
""" """
Stream chat with tools; yields content deltas then final message. Stream chat with tools; yields content deltas then final message.
@ -214,7 +214,7 @@ class OpenAIClient(GenAIClient):
tool_calls_by_index: dict[int, dict[str, Any]] = {} tool_calls_by_index: dict[int, dict[str, Any]] = {}
finish_reason = "stop" finish_reason = "stop"
stream = self.provider.chat.completions.create(**request_params) stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
for chunk in stream: for chunk in stream:
if not chunk or not chunk.choices: if not chunk or not chunk.choices:

View File

@ -2,7 +2,7 @@
import json import json
import logging import logging
from typing import Any, Optional from typing import Any, AsyncGenerator, Optional
from google import genai from google import genai
from google.genai import errors, types from google.genai import errors, types
@ -19,10 +19,10 @@ class GeminiClient(GenAIClient):
provider: genai.Client provider: genai.Client
def _init_provider(self): def _init_provider(self) -> genai.Client:
"""Initialize the client.""" """Initialize the client."""
# Merge provider_options into HttpOptions # Merge provider_options into HttpOptions
http_options_dict = { http_options_dict: dict[str, Any] = {
"timeout": int(self.timeout * 1000), # requires milliseconds "timeout": int(self.timeout * 1000), # requires milliseconds
"retry_options": types.HttpRetryOptions( "retry_options": types.HttpRetryOptions(
attempts=3, attempts=3,
@ -54,7 +54,7 @@ class GeminiClient(GenAIClient):
] + [prompt] ] + [prompt]
try: try:
# Merge runtime_options into generation_config if provided # Merge runtime_options into generation_config if provided
generation_config_dict = {"candidate_count": 1} generation_config_dict: dict[str, Any] = {"candidate_count": 1}
generation_config_dict.update(self.genai_config.runtime_options) generation_config_dict.update(self.genai_config.runtime_options)
if response_format and response_format.get("type") == "json_schema": if response_format and response_format.get("type") == "json_schema":
@ -65,7 +65,7 @@ class GeminiClient(GenAIClient):
response = self.provider.models.generate_content( response = self.provider.models.generate_content(
model=self.genai_config.model, model=self.genai_config.model,
contents=contents, contents=contents, # type: ignore[arg-type]
config=types.GenerateContentConfig( config=types.GenerateContentConfig(
**generation_config_dict, **generation_config_dict,
), ),
@ -78,7 +78,7 @@ class GeminiClient(GenAIClient):
return None return None
try: try:
description = response.text.strip() description = response.text.strip() # type: ignore[union-attr]
except (ValueError, AttributeError): except (ValueError, AttributeError):
# No description was generated # No description was generated
return None return None
@ -102,7 +102,7 @@ class GeminiClient(GenAIClient):
""" """
try: try:
# Convert messages to Gemini format # Convert messages to Gemini format
gemini_messages = [] gemini_messages: list[types.Content] = []
for msg in messages: for msg in messages:
role = msg.get("role", "user") role = msg.get("role", "user")
content = msg.get("content", "") content = msg.get("content", "")
@ -111,9 +111,9 @@ class GeminiClient(GenAIClient):
if role == "system": if role == "system":
# Gemini doesn't have system role, prepend to first user message # Gemini doesn't have system role, prepend to first user message
if gemini_messages and gemini_messages[0].role == "user": if gemini_messages and gemini_messages[0].role == "user":
gemini_messages[0].parts[ gemini_messages[0].parts[ # type: ignore[index]
0 0
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}" ].text = f"{content}\n\n{gemini_messages[0].parts[0].text}" # type: ignore[index]
else: else:
gemini_messages.append( gemini_messages.append(
types.Content( types.Content(
@ -136,7 +136,7 @@ class GeminiClient(GenAIClient):
types.Content( types.Content(
role="function", role="function",
parts=[ parts=[
types.Part.from_function_response(function_response) types.Part.from_function_response(function_response) # type: ignore[misc,call-arg,arg-type]
], ],
) )
) )
@ -171,19 +171,19 @@ class GeminiClient(GenAIClient):
if tool_choice: if tool_choice:
if tool_choice == "none": if tool_choice == "none":
tool_config = types.ToolConfig( tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(mode="NONE") function_calling_config=types.FunctionCallingConfig(mode="NONE") # type: ignore[arg-type]
) )
elif tool_choice == "auto": elif tool_choice == "auto":
tool_config = types.ToolConfig( tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(mode="AUTO") function_calling_config=types.FunctionCallingConfig(mode="AUTO") # type: ignore[arg-type]
) )
elif tool_choice == "required": elif tool_choice == "required":
tool_config = types.ToolConfig( tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(mode="ANY") function_calling_config=types.FunctionCallingConfig(mode="ANY") # type: ignore[arg-type]
) )
# Build request config # Build request config
config_params = {"candidate_count": 1} config_params: dict[str, Any] = {"candidate_count": 1}
if gemini_tools: if gemini_tools:
config_params["tools"] = gemini_tools config_params["tools"] = gemini_tools
@ -197,7 +197,7 @@ class GeminiClient(GenAIClient):
response = self.provider.models.generate_content( response = self.provider.models.generate_content(
model=self.genai_config.model, model=self.genai_config.model,
contents=gemini_messages, contents=gemini_messages, # type: ignore[arg-type]
config=types.GenerateContentConfig(**config_params), config=types.GenerateContentConfig(**config_params),
) )
@ -291,7 +291,7 @@ class GeminiClient(GenAIClient):
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto", tool_choice: Optional[str] = "auto",
): ) -> AsyncGenerator[tuple[str, Any], None]:
""" """
Stream chat with tools; yields content deltas then final message. Stream chat with tools; yields content deltas then final message.
@ -299,7 +299,7 @@ class GeminiClient(GenAIClient):
""" """
try: try:
# Convert messages to Gemini format # Convert messages to Gemini format
gemini_messages = [] gemini_messages: list[types.Content] = []
for msg in messages: for msg in messages:
role = msg.get("role", "user") role = msg.get("role", "user")
content = msg.get("content", "") content = msg.get("content", "")
@ -308,9 +308,9 @@ class GeminiClient(GenAIClient):
if role == "system": if role == "system":
# Gemini doesn't have system role, prepend to first user message # Gemini doesn't have system role, prepend to first user message
if gemini_messages and gemini_messages[0].role == "user": if gemini_messages and gemini_messages[0].role == "user":
gemini_messages[0].parts[ gemini_messages[0].parts[ # type: ignore[index]
0 0
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}" ].text = f"{content}\n\n{gemini_messages[0].parts[0].text}" # type: ignore[index]
else: else:
gemini_messages.append( gemini_messages.append(
types.Content( types.Content(
@ -333,7 +333,7 @@ class GeminiClient(GenAIClient):
types.Content( types.Content(
role="function", role="function",
parts=[ parts=[
types.Part.from_function_response(function_response) types.Part.from_function_response(function_response) # type: ignore[misc,call-arg,arg-type]
], ],
) )
) )
@ -368,19 +368,19 @@ class GeminiClient(GenAIClient):
if tool_choice: if tool_choice:
if tool_choice == "none": if tool_choice == "none":
tool_config = types.ToolConfig( tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(mode="NONE") function_calling_config=types.FunctionCallingConfig(mode="NONE") # type: ignore[arg-type]
) )
elif tool_choice == "auto": elif tool_choice == "auto":
tool_config = types.ToolConfig( tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(mode="AUTO") function_calling_config=types.FunctionCallingConfig(mode="AUTO") # type: ignore[arg-type]
) )
elif tool_choice == "required": elif tool_choice == "required":
tool_config = types.ToolConfig( tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(mode="ANY") function_calling_config=types.FunctionCallingConfig(mode="ANY") # type: ignore[arg-type]
) )
# Build request config # Build request config
config_params = {"candidate_count": 1} config_params: dict[str, Any] = {"candidate_count": 1}
if gemini_tools: if gemini_tools:
config_params["tools"] = gemini_tools config_params["tools"] = gemini_tools
@ -399,7 +399,7 @@ class GeminiClient(GenAIClient):
stream = await self.provider.aio.models.generate_content_stream( stream = await self.provider.aio.models.generate_content_stream(
model=self.genai_config.model, model=self.genai_config.model,
contents=gemini_messages, contents=gemini_messages, # type: ignore[arg-type]
config=types.GenerateContentConfig(**config_params), config=types.GenerateContentConfig(**config_params),
) )

View File

@ -4,7 +4,7 @@ import base64
import io import io
import json import json
import logging import logging
from typing import Any, Optional from typing import Any, AsyncGenerator, Optional
import httpx import httpx
import numpy as np import numpy as np
@ -23,7 +23,7 @@ def _to_jpeg(img_bytes: bytes) -> bytes | None:
try: try:
img = Image.open(io.BytesIO(img_bytes)) img = Image.open(io.BytesIO(img_bytes))
if img.mode != "RGB": if img.mode != "RGB":
img = img.convert("RGB") img = img.convert("RGB") # type: ignore[assignment]
buf = io.BytesIO() buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85) img.save(buf, format="JPEG", quality=85)
return buf.getvalue() return buf.getvalue()
@ -36,10 +36,10 @@ def _to_jpeg(img_bytes: bytes) -> bytes | None:
class LlamaCppClient(GenAIClient): class LlamaCppClient(GenAIClient):
"""Generative AI client for Frigate using llama.cpp server.""" """Generative AI client for Frigate using llama.cpp server."""
provider: str # base_url provider: str | None # base_url
provider_options: dict[str, Any] provider_options: dict[str, Any]
def _init_provider(self): def _init_provider(self) -> str | None:
"""Initialize the client.""" """Initialize the client."""
self.provider_options = { self.provider_options = {
**self.genai_config.provider_options, **self.genai_config.provider_options,
@ -75,7 +75,7 @@ class LlamaCppClient(GenAIClient):
content.append( content.append(
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": { # type: ignore[dict-item]
"url": f"data:image/jpeg;base64,{encoded_image}", "url": f"data:image/jpeg;base64,{encoded_image}",
}, },
} }
@ -111,7 +111,7 @@ class LlamaCppClient(GenAIClient):
): ):
choice = result["choices"][0] choice = result["choices"][0]
if "message" in choice and "content" in choice["message"]: if "message" in choice and "content" in choice["message"]:
return choice["message"]["content"].strip() return str(choice["message"]["content"].strip())
return None return None
except Exception as e: except Exception as e:
logger.warning("llama.cpp returned an error: %s", str(e)) logger.warning("llama.cpp returned an error: %s", str(e))
@ -229,7 +229,7 @@ class LlamaCppClient(GenAIClient):
content.append( content.append(
{ {
"prompt_string": "<__media__>\n", "prompt_string": "<__media__>\n",
"multimodal_data": [encoded], "multimodal_data": [encoded], # type: ignore[dict-item]
} }
) )
@ -367,7 +367,7 @@ class LlamaCppClient(GenAIClient):
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto", tool_choice: Optional[str] = "auto",
): ) -> AsyncGenerator[tuple[str, Any], None]:
"""Stream chat with tools via OpenAI-compatible streaming API.""" """Stream chat with tools via OpenAI-compatible streaming API."""
if self.provider is None: if self.provider is None:
logger.warning( logger.warning(

View File

@ -2,7 +2,7 @@
import json import json
import logging import logging
from typing import Any, Optional from typing import Any, AsyncGenerator, Optional
from httpx import RemoteProtocolError, TimeoutException from httpx import RemoteProtocolError, TimeoutException
from ollama import AsyncClient as OllamaAsyncClient from ollama import AsyncClient as OllamaAsyncClient
@ -28,10 +28,10 @@ class OllamaClient(GenAIClient):
}, },
} }
provider: ApiClient provider: ApiClient | None
provider_options: dict[str, Any] provider_options: dict[str, Any]
def _init_provider(self): def _init_provider(self) -> ApiClient | None:
"""Initialize the client.""" """Initialize the client."""
self.provider_options = { self.provider_options = {
**self.LOCAL_OPTIMIZED_OPTIONS, **self.LOCAL_OPTIMIZED_OPTIONS,
@ -73,7 +73,7 @@ class OllamaClient(GenAIClient):
"exclusiveMinimum", "exclusiveMinimum",
"exclusiveMaximum", "exclusiveMaximum",
} }
result = {} result: dict[str, Any] = {}
for key, value in schema.items(): for key, value in schema.items():
if not _is_properties and key in STRIP_KEYS: if not _is_properties and key in STRIP_KEYS:
continue continue
@ -122,7 +122,7 @@ class OllamaClient(GenAIClient):
logger.debug( logger.debug(
f"Ollama tokens used: eval_count={result.get('eval_count')}, prompt_eval_count={result.get('prompt_eval_count')}" f"Ollama tokens used: eval_count={result.get('eval_count')}, prompt_eval_count={result.get('prompt_eval_count')}"
) )
return result["response"].strip() return str(result["response"]).strip()
except ( except (
TimeoutException, TimeoutException,
ResponseError, ResponseError,
@ -263,7 +263,7 @@ class OllamaClient(GenAIClient):
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto", tool_choice: Optional[str] = "auto",
): ) -> AsyncGenerator[tuple[str, Any], None]:
"""Stream chat with tools; yields content deltas then final message. """Stream chat with tools; yields content deltas then final message.
When tools are provided, Ollama streaming does not include tool_calls When tools are provided, Ollama streaming does not include tool_calls

View File

@ -3,7 +3,7 @@
import base64 import base64
import json import json
import logging import logging
from typing import Any, Optional from typing import Any, AsyncGenerator, Optional
from httpx import TimeoutException from httpx import TimeoutException
from openai import OpenAI from openai import OpenAI
@ -21,7 +21,7 @@ class OpenAIClient(GenAIClient):
provider: OpenAI provider: OpenAI
context_size: Optional[int] = None context_size: Optional[int] = None
def _init_provider(self): def _init_provider(self) -> OpenAI:
"""Initialize the client.""" """Initialize the client."""
# Extract context_size from provider_options as it's not a valid OpenAI client parameter # Extract context_size from provider_options as it's not a valid OpenAI client parameter
# It will be used in get_context_size() instead # It will be used in get_context_size() instead
@ -81,7 +81,7 @@ class OpenAIClient(GenAIClient):
and hasattr(result, "choices") and hasattr(result, "choices")
and len(result.choices) > 0 and len(result.choices) > 0
): ):
return result.choices[0].message.content.strip() return str(result.choices[0].message.content.strip())
return None return None
except (TimeoutException, Exception) as e: except (TimeoutException, Exception) as e:
logger.warning("OpenAI returned an error: %s", str(e)) logger.warning("OpenAI returned an error: %s", str(e))
@ -171,7 +171,7 @@ class OpenAIClient(GenAIClient):
} }
request_params.update(provider_opts) request_params.update(provider_opts)
result = self.provider.chat.completions.create(**request_params) result = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
if ( if (
result is None result is None
@ -245,7 +245,7 @@ class OpenAIClient(GenAIClient):
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
tools: Optional[list[dict[str, Any]]] = None, tools: Optional[list[dict[str, Any]]] = None,
tool_choice: Optional[str] = "auto", tool_choice: Optional[str] = "auto",
): ) -> AsyncGenerator[tuple[str, Any], None]:
""" """
Stream chat with tools; yields content deltas then final message. Stream chat with tools; yields content deltas then final message.
@ -287,7 +287,7 @@ class OpenAIClient(GenAIClient):
tool_calls_by_index: dict[int, dict[str, Any]] = {} tool_calls_by_index: dict[int, dict[str, Any]] = {}
finish_reason = "stop" finish_reason = "stop"
stream = self.provider.chat.completions.create(**request_params) stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
for chunk in stream: for chunk in stream:
if not chunk or not chunk.choices: if not chunk or not chunk.choices:

View File

@ -41,7 +41,7 @@ ignore_errors = false
[mypy-frigate.events] [mypy-frigate.events]
ignore_errors = false ignore_errors = false
[mypy-frigate.genai] [mypy-frigate.genai.*]
ignore_errors = false ignore_errors = false
[mypy-frigate.jobs] [mypy-frigate.jobs]
@ -50,6 +50,15 @@ ignore_errors = false
[mypy-frigate.motion] [mypy-frigate.motion]
ignore_errors = false ignore_errors = false
[mypy-frigate.object_detection]
ignore_errors = false
[mypy-frigate.output]
ignore_errors = false
[mypy-frigate.ptz]
ignore_errors = false
[mypy-frigate.log] [mypy-frigate.log]
ignore_errors = false ignore_errors = false