mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-26 18:18:22 +03:00
Increase mypy coverage and fixes (#22632)
This commit is contained in:
parent
04a2f42d11
commit
80c4ce2b5d
@ -5,7 +5,7 @@ import importlib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Any, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from playhouse.shortcuts import model_to_dict
|
from playhouse.shortcuts import model_to_dict
|
||||||
@ -31,10 +31,10 @@ __all__ = [
|
|||||||
PROVIDERS = {}
|
PROVIDERS = {}
|
||||||
|
|
||||||
|
|
||||||
def register_genai_provider(key: GenAIProviderEnum):
|
def register_genai_provider(key: GenAIProviderEnum) -> Callable:
|
||||||
"""Register a GenAI provider."""
|
"""Register a GenAI provider."""
|
||||||
|
|
||||||
def decorator(cls):
|
def decorator(cls: type) -> type:
|
||||||
PROVIDERS[key] = cls
|
PROVIDERS[key] = cls
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@ -297,7 +297,7 @@ Guidelines:
|
|||||||
"""Generate a description for the frame."""
|
"""Generate a description for the frame."""
|
||||||
try:
|
try:
|
||||||
prompt = camera_config.objects.genai.object_prompts.get(
|
prompt = camera_config.objects.genai.object_prompts.get(
|
||||||
event.label,
|
str(event.label),
|
||||||
camera_config.objects.genai.prompt,
|
camera_config.objects.genai.prompt,
|
||||||
).format(**model_to_dict(event))
|
).format(**model_to_dict(event))
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
@ -307,7 +307,7 @@ Guidelines:
|
|||||||
logger.debug(f"Sending images to genai provider with prompt: {prompt}")
|
logger.debug(f"Sending images to genai provider with prompt: {prompt}")
|
||||||
return self._send(prompt, thumbnails)
|
return self._send(prompt, thumbnails)
|
||||||
|
|
||||||
def _init_provider(self):
|
def _init_provider(self) -> Any:
|
||||||
"""Initialize the client."""
|
"""Initialize the client."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -402,7 +402,7 @@ Guidelines:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_providers():
|
def load_providers() -> None:
|
||||||
package_dir = os.path.dirname(__file__)
|
package_dir = os.path.dirname(__file__)
|
||||||
for filename in os.listdir(package_dir):
|
for filename in os.listdir(package_dir):
|
||||||
if filename.endswith(".py") and filename != "__init__.py":
|
if filename.endswith(".py") and filename != "__init__.py":
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, AsyncGenerator, Optional
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
from openai import AzureOpenAI
|
from openai import AzureOpenAI
|
||||||
@ -20,10 +20,10 @@ class OpenAIClient(GenAIClient):
|
|||||||
|
|
||||||
provider: AzureOpenAI
|
provider: AzureOpenAI
|
||||||
|
|
||||||
def _init_provider(self):
|
def _init_provider(self) -> AzureOpenAI | None:
|
||||||
"""Initialize the client."""
|
"""Initialize the client."""
|
||||||
try:
|
try:
|
||||||
parsed_url = urlparse(self.genai_config.base_url)
|
parsed_url = urlparse(self.genai_config.base_url or "")
|
||||||
query_params = parse_qs(parsed_url.query)
|
query_params = parse_qs(parsed_url.query)
|
||||||
api_version = query_params.get("api-version", [None])[0]
|
api_version = query_params.get("api-version", [None])[0]
|
||||||
azure_endpoint = f"{parsed_url.scheme}://{parsed_url.netloc}/"
|
azure_endpoint = f"{parsed_url.scheme}://{parsed_url.netloc}/"
|
||||||
@ -79,7 +79,7 @@ class OpenAIClient(GenAIClient):
|
|||||||
logger.warning("Azure OpenAI returned an error: %s", str(e))
|
logger.warning("Azure OpenAI returned an error: %s", str(e))
|
||||||
return None
|
return None
|
||||||
if len(result.choices) > 0:
|
if len(result.choices) > 0:
|
||||||
return result.choices[0].message.content.strip()
|
return str(result.choices[0].message.content.strip())
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_context_size(self) -> int:
|
def get_context_size(self) -> int:
|
||||||
@ -113,7 +113,7 @@ class OpenAIClient(GenAIClient):
|
|||||||
if openai_tool_choice is not None:
|
if openai_tool_choice is not None:
|
||||||
request_params["tool_choice"] = openai_tool_choice
|
request_params["tool_choice"] = openai_tool_choice
|
||||||
|
|
||||||
result = self.provider.chat.completions.create(**request_params)
|
result = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
result is None
|
result is None
|
||||||
@ -181,7 +181,7 @@ class OpenAIClient(GenAIClient):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: Optional[list[dict[str, Any]]] = None,
|
tools: Optional[list[dict[str, Any]]] = None,
|
||||||
tool_choice: Optional[str] = "auto",
|
tool_choice: Optional[str] = "auto",
|
||||||
):
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
"""
|
"""
|
||||||
Stream chat with tools; yields content deltas then final message.
|
Stream chat with tools; yields content deltas then final message.
|
||||||
|
|
||||||
@ -214,7 +214,7 @@ class OpenAIClient(GenAIClient):
|
|||||||
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
|
|
||||||
stream = self.provider.chat.completions.create(**request_params)
|
stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
|
||||||
|
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
if not chunk or not chunk.choices:
|
if not chunk or not chunk.choices:
|
||||||
|
|||||||
@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, AsyncGenerator, Optional
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import errors, types
|
from google.genai import errors, types
|
||||||
|
from google.genai.types import FunctionCallingConfigMode
|
||||||
|
|
||||||
from frigate.config import GenAIProviderEnum
|
from frigate.config import GenAIProviderEnum
|
||||||
from frigate.genai import GenAIClient, register_genai_provider
|
from frigate.genai import GenAIClient, register_genai_provider
|
||||||
@ -19,10 +20,10 @@ class GeminiClient(GenAIClient):
|
|||||||
|
|
||||||
provider: genai.Client
|
provider: genai.Client
|
||||||
|
|
||||||
def _init_provider(self):
|
def _init_provider(self) -> genai.Client:
|
||||||
"""Initialize the client."""
|
"""Initialize the client."""
|
||||||
# Merge provider_options into HttpOptions
|
# Merge provider_options into HttpOptions
|
||||||
http_options_dict = {
|
http_options_dict: dict[str, Any] = {
|
||||||
"timeout": int(self.timeout * 1000), # requires milliseconds
|
"timeout": int(self.timeout * 1000), # requires milliseconds
|
||||||
"retry_options": types.HttpRetryOptions(
|
"retry_options": types.HttpRetryOptions(
|
||||||
attempts=3,
|
attempts=3,
|
||||||
@ -54,7 +55,7 @@ class GeminiClient(GenAIClient):
|
|||||||
] + [prompt]
|
] + [prompt]
|
||||||
try:
|
try:
|
||||||
# Merge runtime_options into generation_config if provided
|
# Merge runtime_options into generation_config if provided
|
||||||
generation_config_dict = {"candidate_count": 1}
|
generation_config_dict: dict[str, Any] = {"candidate_count": 1}
|
||||||
generation_config_dict.update(self.genai_config.runtime_options)
|
generation_config_dict.update(self.genai_config.runtime_options)
|
||||||
|
|
||||||
if response_format and response_format.get("type") == "json_schema":
|
if response_format and response_format.get("type") == "json_schema":
|
||||||
@ -65,7 +66,7 @@ class GeminiClient(GenAIClient):
|
|||||||
|
|
||||||
response = self.provider.models.generate_content(
|
response = self.provider.models.generate_content(
|
||||||
model=self.genai_config.model,
|
model=self.genai_config.model,
|
||||||
contents=contents,
|
contents=contents, # type: ignore[arg-type]
|
||||||
config=types.GenerateContentConfig(
|
config=types.GenerateContentConfig(
|
||||||
**generation_config_dict,
|
**generation_config_dict,
|
||||||
),
|
),
|
||||||
@ -78,6 +79,8 @@ class GeminiClient(GenAIClient):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if response.text is None:
|
||||||
|
return None
|
||||||
description = response.text.strip()
|
description = response.text.strip()
|
||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
# No description was generated
|
# No description was generated
|
||||||
@ -102,7 +105,7 @@ class GeminiClient(GenAIClient):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Convert messages to Gemini format
|
# Convert messages to Gemini format
|
||||||
gemini_messages = []
|
gemini_messages: list[types.Content] = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
role = msg.get("role", "user")
|
role = msg.get("role", "user")
|
||||||
content = msg.get("content", "")
|
content = msg.get("content", "")
|
||||||
@ -110,7 +113,11 @@ class GeminiClient(GenAIClient):
|
|||||||
# Map roles to Gemini format
|
# Map roles to Gemini format
|
||||||
if role == "system":
|
if role == "system":
|
||||||
# Gemini doesn't have system role, prepend to first user message
|
# Gemini doesn't have system role, prepend to first user message
|
||||||
if gemini_messages and gemini_messages[0].role == "user":
|
if (
|
||||||
|
gemini_messages
|
||||||
|
and gemini_messages[0].role == "user"
|
||||||
|
and gemini_messages[0].parts
|
||||||
|
):
|
||||||
gemini_messages[0].parts[
|
gemini_messages[0].parts[
|
||||||
0
|
0
|
||||||
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}"
|
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}"
|
||||||
@ -136,7 +143,7 @@ class GeminiClient(GenAIClient):
|
|||||||
types.Content(
|
types.Content(
|
||||||
role="function",
|
role="function",
|
||||||
parts=[
|
parts=[
|
||||||
types.Part.from_function_response(function_response)
|
types.Part.from_function_response(function_response) # type: ignore[misc,call-arg,arg-type]
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -171,19 +178,25 @@ class GeminiClient(GenAIClient):
|
|||||||
if tool_choice:
|
if tool_choice:
|
||||||
if tool_choice == "none":
|
if tool_choice == "none":
|
||||||
tool_config = types.ToolConfig(
|
tool_config = types.ToolConfig(
|
||||||
function_calling_config=types.FunctionCallingConfig(mode="NONE")
|
function_calling_config=types.FunctionCallingConfig(
|
||||||
|
mode=FunctionCallingConfigMode.NONE
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elif tool_choice == "auto":
|
elif tool_choice == "auto":
|
||||||
tool_config = types.ToolConfig(
|
tool_config = types.ToolConfig(
|
||||||
function_calling_config=types.FunctionCallingConfig(mode="AUTO")
|
function_calling_config=types.FunctionCallingConfig(
|
||||||
|
mode=FunctionCallingConfigMode.AUTO
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elif tool_choice == "required":
|
elif tool_choice == "required":
|
||||||
tool_config = types.ToolConfig(
|
tool_config = types.ToolConfig(
|
||||||
function_calling_config=types.FunctionCallingConfig(mode="ANY")
|
function_calling_config=types.FunctionCallingConfig(
|
||||||
|
mode=FunctionCallingConfigMode.ANY
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build request config
|
# Build request config
|
||||||
config_params = {"candidate_count": 1}
|
config_params: dict[str, Any] = {"candidate_count": 1}
|
||||||
|
|
||||||
if gemini_tools:
|
if gemini_tools:
|
||||||
config_params["tools"] = gemini_tools
|
config_params["tools"] = gemini_tools
|
||||||
@ -197,7 +210,7 @@ class GeminiClient(GenAIClient):
|
|||||||
|
|
||||||
response = self.provider.models.generate_content(
|
response = self.provider.models.generate_content(
|
||||||
model=self.genai_config.model,
|
model=self.genai_config.model,
|
||||||
contents=gemini_messages,
|
contents=gemini_messages, # type: ignore[arg-type]
|
||||||
config=types.GenerateContentConfig(**config_params),
|
config=types.GenerateContentConfig(**config_params),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -291,7 +304,7 @@ class GeminiClient(GenAIClient):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: Optional[list[dict[str, Any]]] = None,
|
tools: Optional[list[dict[str, Any]]] = None,
|
||||||
tool_choice: Optional[str] = "auto",
|
tool_choice: Optional[str] = "auto",
|
||||||
):
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
"""
|
"""
|
||||||
Stream chat with tools; yields content deltas then final message.
|
Stream chat with tools; yields content deltas then final message.
|
||||||
|
|
||||||
@ -299,7 +312,7 @@ class GeminiClient(GenAIClient):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Convert messages to Gemini format
|
# Convert messages to Gemini format
|
||||||
gemini_messages = []
|
gemini_messages: list[types.Content] = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
role = msg.get("role", "user")
|
role = msg.get("role", "user")
|
||||||
content = msg.get("content", "")
|
content = msg.get("content", "")
|
||||||
@ -307,7 +320,11 @@ class GeminiClient(GenAIClient):
|
|||||||
# Map roles to Gemini format
|
# Map roles to Gemini format
|
||||||
if role == "system":
|
if role == "system":
|
||||||
# Gemini doesn't have system role, prepend to first user message
|
# Gemini doesn't have system role, prepend to first user message
|
||||||
if gemini_messages and gemini_messages[0].role == "user":
|
if (
|
||||||
|
gemini_messages
|
||||||
|
and gemini_messages[0].role == "user"
|
||||||
|
and gemini_messages[0].parts
|
||||||
|
):
|
||||||
gemini_messages[0].parts[
|
gemini_messages[0].parts[
|
||||||
0
|
0
|
||||||
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}"
|
].text = f"{content}\n\n{gemini_messages[0].parts[0].text}"
|
||||||
@ -333,7 +350,7 @@ class GeminiClient(GenAIClient):
|
|||||||
types.Content(
|
types.Content(
|
||||||
role="function",
|
role="function",
|
||||||
parts=[
|
parts=[
|
||||||
types.Part.from_function_response(function_response)
|
types.Part.from_function_response(function_response) # type: ignore[misc,call-arg,arg-type]
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -368,19 +385,25 @@ class GeminiClient(GenAIClient):
|
|||||||
if tool_choice:
|
if tool_choice:
|
||||||
if tool_choice == "none":
|
if tool_choice == "none":
|
||||||
tool_config = types.ToolConfig(
|
tool_config = types.ToolConfig(
|
||||||
function_calling_config=types.FunctionCallingConfig(mode="NONE")
|
function_calling_config=types.FunctionCallingConfig(
|
||||||
|
mode=FunctionCallingConfigMode.NONE
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elif tool_choice == "auto":
|
elif tool_choice == "auto":
|
||||||
tool_config = types.ToolConfig(
|
tool_config = types.ToolConfig(
|
||||||
function_calling_config=types.FunctionCallingConfig(mode="AUTO")
|
function_calling_config=types.FunctionCallingConfig(
|
||||||
|
mode=FunctionCallingConfigMode.AUTO
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elif tool_choice == "required":
|
elif tool_choice == "required":
|
||||||
tool_config = types.ToolConfig(
|
tool_config = types.ToolConfig(
|
||||||
function_calling_config=types.FunctionCallingConfig(mode="ANY")
|
function_calling_config=types.FunctionCallingConfig(
|
||||||
|
mode=FunctionCallingConfigMode.ANY
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build request config
|
# Build request config
|
||||||
config_params = {"candidate_count": 1}
|
config_params: dict[str, Any] = {"candidate_count": 1}
|
||||||
|
|
||||||
if gemini_tools:
|
if gemini_tools:
|
||||||
config_params["tools"] = gemini_tools
|
config_params["tools"] = gemini_tools
|
||||||
@ -399,7 +422,7 @@ class GeminiClient(GenAIClient):
|
|||||||
|
|
||||||
stream = await self.provider.aio.models.generate_content_stream(
|
stream = await self.provider.aio.models.generate_content_stream(
|
||||||
model=self.genai_config.model,
|
model=self.genai_config.model,
|
||||||
contents=gemini_messages,
|
contents=gemini_messages, # type: ignore[arg-type]
|
||||||
config=types.GenerateContentConfig(**config_params),
|
config=types.GenerateContentConfig(**config_params),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import base64
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, AsyncGenerator, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -23,7 +23,7 @@ def _to_jpeg(img_bytes: bytes) -> bytes | None:
|
|||||||
try:
|
try:
|
||||||
img = Image.open(io.BytesIO(img_bytes))
|
img = Image.open(io.BytesIO(img_bytes))
|
||||||
if img.mode != "RGB":
|
if img.mode != "RGB":
|
||||||
img = img.convert("RGB")
|
img = img.convert("RGB") # type: ignore[assignment]
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
img.save(buf, format="JPEG", quality=85)
|
img.save(buf, format="JPEG", quality=85)
|
||||||
return buf.getvalue()
|
return buf.getvalue()
|
||||||
@ -36,10 +36,10 @@ def _to_jpeg(img_bytes: bytes) -> bytes | None:
|
|||||||
class LlamaCppClient(GenAIClient):
|
class LlamaCppClient(GenAIClient):
|
||||||
"""Generative AI client for Frigate using llama.cpp server."""
|
"""Generative AI client for Frigate using llama.cpp server."""
|
||||||
|
|
||||||
provider: str # base_url
|
provider: str | None # base_url
|
||||||
provider_options: dict[str, Any]
|
provider_options: dict[str, Any]
|
||||||
|
|
||||||
def _init_provider(self):
|
def _init_provider(self) -> str | None:
|
||||||
"""Initialize the client."""
|
"""Initialize the client."""
|
||||||
self.provider_options = {
|
self.provider_options = {
|
||||||
**self.genai_config.provider_options,
|
**self.genai_config.provider_options,
|
||||||
@ -75,7 +75,7 @@ class LlamaCppClient(GenAIClient):
|
|||||||
content.append(
|
content.append(
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": { # type: ignore[dict-item]
|
||||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
"url": f"data:image/jpeg;base64,{encoded_image}",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -111,7 +111,7 @@ class LlamaCppClient(GenAIClient):
|
|||||||
):
|
):
|
||||||
choice = result["choices"][0]
|
choice = result["choices"][0]
|
||||||
if "message" in choice and "content" in choice["message"]:
|
if "message" in choice and "content" in choice["message"]:
|
||||||
return choice["message"]["content"].strip()
|
return str(choice["message"]["content"].strip())
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("llama.cpp returned an error: %s", str(e))
|
logger.warning("llama.cpp returned an error: %s", str(e))
|
||||||
@ -229,7 +229,7 @@ class LlamaCppClient(GenAIClient):
|
|||||||
content.append(
|
content.append(
|
||||||
{
|
{
|
||||||
"prompt_string": "<__media__>\n",
|
"prompt_string": "<__media__>\n",
|
||||||
"multimodal_data": [encoded],
|
"multimodal_data": [encoded], # type: ignore[dict-item]
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -367,7 +367,7 @@ class LlamaCppClient(GenAIClient):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: Optional[list[dict[str, Any]]] = None,
|
tools: Optional[list[dict[str, Any]]] = None,
|
||||||
tool_choice: Optional[str] = "auto",
|
tool_choice: Optional[str] = "auto",
|
||||||
):
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
"""Stream chat with tools via OpenAI-compatible streaming API."""
|
"""Stream chat with tools via OpenAI-compatible streaming API."""
|
||||||
if self.provider is None:
|
if self.provider is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, AsyncGenerator, Optional
|
||||||
|
|
||||||
from httpx import RemoteProtocolError, TimeoutException
|
from httpx import RemoteProtocolError, TimeoutException
|
||||||
from ollama import AsyncClient as OllamaAsyncClient
|
from ollama import AsyncClient as OllamaAsyncClient
|
||||||
@ -28,10 +28,10 @@ class OllamaClient(GenAIClient):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
provider: ApiClient
|
provider: ApiClient | None
|
||||||
provider_options: dict[str, Any]
|
provider_options: dict[str, Any]
|
||||||
|
|
||||||
def _init_provider(self):
|
def _init_provider(self) -> ApiClient | None:
|
||||||
"""Initialize the client."""
|
"""Initialize the client."""
|
||||||
self.provider_options = {
|
self.provider_options = {
|
||||||
**self.LOCAL_OPTIMIZED_OPTIONS,
|
**self.LOCAL_OPTIMIZED_OPTIONS,
|
||||||
@ -73,7 +73,7 @@ class OllamaClient(GenAIClient):
|
|||||||
"exclusiveMinimum",
|
"exclusiveMinimum",
|
||||||
"exclusiveMaximum",
|
"exclusiveMaximum",
|
||||||
}
|
}
|
||||||
result = {}
|
result: dict[str, Any] = {}
|
||||||
for key, value in schema.items():
|
for key, value in schema.items():
|
||||||
if not _is_properties and key in STRIP_KEYS:
|
if not _is_properties and key in STRIP_KEYS:
|
||||||
continue
|
continue
|
||||||
@ -122,7 +122,7 @@ class OllamaClient(GenAIClient):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Ollama tokens used: eval_count={result.get('eval_count')}, prompt_eval_count={result.get('prompt_eval_count')}"
|
f"Ollama tokens used: eval_count={result.get('eval_count')}, prompt_eval_count={result.get('prompt_eval_count')}"
|
||||||
)
|
)
|
||||||
return result["response"].strip()
|
return str(result["response"]).strip()
|
||||||
except (
|
except (
|
||||||
TimeoutException,
|
TimeoutException,
|
||||||
ResponseError,
|
ResponseError,
|
||||||
@ -263,7 +263,7 @@ class OllamaClient(GenAIClient):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: Optional[list[dict[str, Any]]] = None,
|
tools: Optional[list[dict[str, Any]]] = None,
|
||||||
tool_choice: Optional[str] = "auto",
|
tool_choice: Optional[str] = "auto",
|
||||||
):
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
"""Stream chat with tools; yields content deltas then final message.
|
"""Stream chat with tools; yields content deltas then final message.
|
||||||
|
|
||||||
When tools are provided, Ollama streaming does not include tool_calls
|
When tools are provided, Ollama streaming does not include tool_calls
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, AsyncGenerator, Optional
|
||||||
|
|
||||||
from httpx import TimeoutException
|
from httpx import TimeoutException
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@ -21,7 +21,7 @@ class OpenAIClient(GenAIClient):
|
|||||||
provider: OpenAI
|
provider: OpenAI
|
||||||
context_size: Optional[int] = None
|
context_size: Optional[int] = None
|
||||||
|
|
||||||
def _init_provider(self):
|
def _init_provider(self) -> OpenAI:
|
||||||
"""Initialize the client."""
|
"""Initialize the client."""
|
||||||
# Extract context_size from provider_options as it's not a valid OpenAI client parameter
|
# Extract context_size from provider_options as it's not a valid OpenAI client parameter
|
||||||
# It will be used in get_context_size() instead
|
# It will be used in get_context_size() instead
|
||||||
@ -81,7 +81,7 @@ class OpenAIClient(GenAIClient):
|
|||||||
and hasattr(result, "choices")
|
and hasattr(result, "choices")
|
||||||
and len(result.choices) > 0
|
and len(result.choices) > 0
|
||||||
):
|
):
|
||||||
return result.choices[0].message.content.strip()
|
return str(result.choices[0].message.content.strip())
|
||||||
return None
|
return None
|
||||||
except (TimeoutException, Exception) as e:
|
except (TimeoutException, Exception) as e:
|
||||||
logger.warning("OpenAI returned an error: %s", str(e))
|
logger.warning("OpenAI returned an error: %s", str(e))
|
||||||
@ -171,7 +171,7 @@ class OpenAIClient(GenAIClient):
|
|||||||
}
|
}
|
||||||
request_params.update(provider_opts)
|
request_params.update(provider_opts)
|
||||||
|
|
||||||
result = self.provider.chat.completions.create(**request_params)
|
result = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
result is None
|
result is None
|
||||||
@ -245,7 +245,7 @@ class OpenAIClient(GenAIClient):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: Optional[list[dict[str, Any]]] = None,
|
tools: Optional[list[dict[str, Any]]] = None,
|
||||||
tool_choice: Optional[str] = "auto",
|
tool_choice: Optional[str] = "auto",
|
||||||
):
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
"""
|
"""
|
||||||
Stream chat with tools; yields content deltas then final message.
|
Stream chat with tools; yields content deltas then final message.
|
||||||
|
|
||||||
@ -287,7 +287,7 @@ class OpenAIClient(GenAIClient):
|
|||||||
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
tool_calls_by_index: dict[int, dict[str, Any]] = {}
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
|
|
||||||
stream = self.provider.chat.completions.create(**request_params)
|
stream = self.provider.chat.completions.create(**request_params) # type: ignore[call-overload]
|
||||||
|
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
if not chunk or not chunk.choices:
|
if not chunk or not chunk.choices:
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional, cast
|
||||||
|
|
||||||
from frigate.comms.inter_process import InterProcessRequestor
|
from frigate.comms.inter_process import InterProcessRequestor
|
||||||
from frigate.const import CONFIG_DIR, UPDATE_JOB_STATE
|
from frigate.const import CONFIG_DIR, UPDATE_JOB_STATE
|
||||||
@ -122,7 +122,7 @@ def start_media_sync_job(
|
|||||||
if job_is_running("media_sync"):
|
if job_is_running("media_sync"):
|
||||||
current = get_current_job("media_sync")
|
current = get_current_job("media_sync")
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Media sync job {current.id} is already running. Rejecting new request."
|
f"Media sync job {current.id if current else 'unknown'} is already running. Rejecting new request."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -146,9 +146,9 @@ def start_media_sync_job(
|
|||||||
|
|
||||||
def get_current_media_sync_job() -> Optional[MediaSyncJob]:
|
def get_current_media_sync_job() -> Optional[MediaSyncJob]:
|
||||||
"""Get the current running/queued media sync job, if any."""
|
"""Get the current running/queued media sync job, if any."""
|
||||||
return get_current_job("media_sync")
|
return cast(Optional[MediaSyncJob], get_current_job("media_sync"))
|
||||||
|
|
||||||
|
|
||||||
def get_media_sync_job_by_id(job_id: str) -> Optional[MediaSyncJob]:
|
def get_media_sync_job_by_id(job_id: str) -> Optional[MediaSyncJob]:
|
||||||
"""Get media sync job by ID. Currently only tracks the current job."""
|
"""Get media sync job by ID. Currently only tracks the current job."""
|
||||||
return get_job_by_id("media_sync", job_id)
|
return cast(Optional[MediaSyncJob], get_job_by_id("media_sync", job_id))
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import threading
|
|||||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -96,7 +96,7 @@ def create_polygon_mask(
|
|||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
mask = np.zeros((frame_height, frame_width), dtype=np.uint8)
|
mask = np.zeros((frame_height, frame_width), dtype=np.uint8)
|
||||||
cv2.fillPoly(mask, [motion_points], 255)
|
cv2.fillPoly(mask, [motion_points], (255,))
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
@ -116,7 +116,7 @@ def compute_roi_bbox_normalized(
|
|||||||
|
|
||||||
|
|
||||||
def heatmap_overlaps_roi(
|
def heatmap_overlaps_roi(
|
||||||
heatmap: dict[str, int], roi_bbox: tuple[float, float, float, float]
|
heatmap: object, roi_bbox: tuple[float, float, float, float]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if a sparse motion heatmap has any overlap with the ROI bounding box.
|
"""Check if a sparse motion heatmap has any overlap with the ROI bounding box.
|
||||||
|
|
||||||
@ -155,9 +155,9 @@ def segment_passes_activity_gate(recording: Recordings) -> bool:
|
|||||||
Returns True if any of motion, objects, or regions is non-zero/non-null.
|
Returns True if any of motion, objects, or regions is non-zero/non-null.
|
||||||
Returns True if all are null (old segments without data).
|
Returns True if all are null (old segments without data).
|
||||||
"""
|
"""
|
||||||
motion = recording.motion
|
motion: Any = recording.motion
|
||||||
objects = recording.objects
|
objects: Any = recording.objects
|
||||||
regions = recording.regions
|
regions: Any = recording.regions
|
||||||
|
|
||||||
# Old segments without metadata - pass through (conservative)
|
# Old segments without metadata - pass through (conservative)
|
||||||
if motion is None and objects is None and regions is None:
|
if motion is None and objects is None and regions is None:
|
||||||
@ -278,6 +278,9 @@ class MotionSearchRunner(threading.Thread):
|
|||||||
frame_width = camera_config.detect.width
|
frame_width = camera_config.detect.width
|
||||||
frame_height = camera_config.detect.height
|
frame_height = camera_config.detect.height
|
||||||
|
|
||||||
|
if frame_width is None or frame_height is None:
|
||||||
|
raise ValueError(f"Camera {camera_name} detect dimensions not configured")
|
||||||
|
|
||||||
# Create polygon mask
|
# Create polygon mask
|
||||||
polygon_mask = create_polygon_mask(
|
polygon_mask = create_polygon_mask(
|
||||||
self.job.polygon_points, frame_width, frame_height
|
self.job.polygon_points, frame_width, frame_height
|
||||||
@ -415,11 +418,13 @@ class MotionSearchRunner(threading.Thread):
|
|||||||
if self._should_stop():
|
if self._should_stop():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
rec_start: float = recording.start_time # type: ignore[assignment]
|
||||||
|
rec_end: float = recording.end_time # type: ignore[assignment]
|
||||||
future = executor.submit(
|
future = executor.submit(
|
||||||
self._process_recording_for_motion,
|
self._process_recording_for_motion,
|
||||||
recording.path,
|
str(recording.path),
|
||||||
recording.start_time,
|
rec_start,
|
||||||
recording.end_time,
|
rec_end,
|
||||||
self.job.start_time_range,
|
self.job.start_time_range,
|
||||||
self.job.end_time_range,
|
self.job.end_time_range,
|
||||||
polygon_mask,
|
polygon_mask,
|
||||||
@ -524,10 +529,12 @@ class MotionSearchRunner(threading.Thread):
|
|||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
rec_start: float = recording.start_time # type: ignore[assignment]
|
||||||
|
rec_end: float = recording.end_time # type: ignore[assignment]
|
||||||
results, frames = self._process_recording_for_motion(
|
results, frames = self._process_recording_for_motion(
|
||||||
recording.path,
|
str(recording.path),
|
||||||
recording.start_time,
|
rec_start,
|
||||||
recording.end_time,
|
rec_end,
|
||||||
self.job.start_time_range,
|
self.job.start_time_range,
|
||||||
self.job.end_time_range,
|
self.job.end_time_range,
|
||||||
polygon_mask,
|
polygon_mask,
|
||||||
@ -672,7 +679,9 @@ class MotionSearchRunner(threading.Thread):
|
|||||||
# Handle frame dimension changes
|
# Handle frame dimension changes
|
||||||
if gray.shape != polygon_mask.shape:
|
if gray.shape != polygon_mask.shape:
|
||||||
resized_mask = cv2.resize(
|
resized_mask = cv2.resize(
|
||||||
polygon_mask, (gray.shape[1], gray.shape[0]), cv2.INTER_NEAREST
|
polygon_mask,
|
||||||
|
(gray.shape[1], gray.shape[0]),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
)
|
)
|
||||||
current_bbox = cv2.boundingRect(resized_mask)
|
current_bbox = cv2.boundingRect(resized_mask)
|
||||||
else:
|
else:
|
||||||
@ -698,7 +707,7 @@ class MotionSearchRunner(threading.Thread):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prev_frame_gray is not None:
|
if prev_frame_gray is not None:
|
||||||
diff = cv2.absdiff(prev_frame_gray, masked_gray)
|
diff = cv2.absdiff(prev_frame_gray, masked_gray) # type: ignore[unreachable]
|
||||||
diff_blurred = cv2.GaussianBlur(diff, (3, 3), 0)
|
diff_blurred = cv2.GaussianBlur(diff, (3, 3), 0)
|
||||||
_, thresh = cv2.threshold(
|
_, thresh = cv2.threshold(
|
||||||
diff_blurred, threshold, 255, cv2.THRESH_BINARY
|
diff_blurred, threshold, 255, cv2.THRESH_BINARY
|
||||||
@ -825,7 +834,7 @@ def get_motion_search_job(job_id: str) -> Optional[MotionSearchJob]:
|
|||||||
if job_entry:
|
if job_entry:
|
||||||
return job_entry[0]
|
return job_entry[0]
|
||||||
# Check completed jobs via manager
|
# Check completed jobs via manager
|
||||||
return get_job_by_id("motion_search", job_id)
|
return cast(Optional[MotionSearchJob], get_job_by_id("motion_search", job_id))
|
||||||
|
|
||||||
|
|
||||||
def cancel_motion_search_job(job_id: str) -> bool:
|
def cancel_motion_search_job(job_id: str) -> bool:
|
||||||
|
|||||||
@ -54,9 +54,9 @@ class VLMWatchRunner(threading.Thread):
|
|||||||
job: VLMWatchJob,
|
job: VLMWatchJob,
|
||||||
config: FrigateConfig,
|
config: FrigateConfig,
|
||||||
cancel_event: threading.Event,
|
cancel_event: threading.Event,
|
||||||
frame_processor,
|
frame_processor: Any,
|
||||||
genai_manager,
|
genai_manager: Any,
|
||||||
dispatcher,
|
dispatcher: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(daemon=True, name=f"vlm_watch_{job.id}")
|
super().__init__(daemon=True, name=f"vlm_watch_{job.id}")
|
||||||
self.job = job
|
self.job = job
|
||||||
@ -226,9 +226,12 @@ class VLMWatchRunner(threading.Thread):
|
|||||||
remaining = deadline - time.time()
|
remaining = deadline - time.time()
|
||||||
if remaining <= 0:
|
if remaining <= 0:
|
||||||
break
|
break
|
||||||
topic, payload = self.detection_subscriber.check_for_update(
|
result = self.detection_subscriber.check_for_update(
|
||||||
timeout=min(1.0, remaining)
|
timeout=min(1.0, remaining)
|
||||||
)
|
)
|
||||||
|
if result is None:
|
||||||
|
continue
|
||||||
|
topic, payload = result
|
||||||
if topic is None or payload is None:
|
if topic is None or payload is None:
|
||||||
continue
|
continue
|
||||||
# payload = (camera, frame_name, frame_time, tracked_objects, motion_boxes, regions)
|
# payload = (camera, frame_name, frame_time, tracked_objects, motion_boxes, regions)
|
||||||
@ -328,9 +331,9 @@ def start_vlm_watch_job(
|
|||||||
condition: str,
|
condition: str,
|
||||||
max_duration_minutes: int,
|
max_duration_minutes: int,
|
||||||
config: FrigateConfig,
|
config: FrigateConfig,
|
||||||
frame_processor,
|
frame_processor: Any,
|
||||||
genai_manager,
|
genai_manager: Any,
|
||||||
dispatcher,
|
dispatcher: Any,
|
||||||
labels: list[str] | None = None,
|
labels: list[str] | None = None,
|
||||||
zones: list[str] | None = None,
|
zones: list[str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|||||||
@ -13,10 +13,10 @@ class MotionDetector(ABC):
|
|||||||
frame_shape: Tuple[int, int, int],
|
frame_shape: Tuple[int, int, int],
|
||||||
config: MotionConfig,
|
config: MotionConfig,
|
||||||
fps: int,
|
fps: int,
|
||||||
improve_contrast,
|
improve_contrast: bool,
|
||||||
threshold,
|
threshold: int,
|
||||||
contour_area,
|
contour_area: int | None,
|
||||||
):
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -25,7 +25,7 @@ class MotionDetector(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_calibrating(self):
|
def is_calibrating(self) -> bool:
|
||||||
"""Return if motion is recalibrating."""
|
"""Return if motion is recalibrating."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -35,6 +35,6 @@ class MotionDetector(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def stop(self):
|
def stop(self) -> None:
|
||||||
"""Stop any ongoing work and processes."""
|
"""Stop any ongoing work and processes."""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -41,6 +41,24 @@ ignore_errors = false
|
|||||||
[mypy-frigate.events]
|
[mypy-frigate.events]
|
||||||
ignore_errors = false
|
ignore_errors = false
|
||||||
|
|
||||||
|
[mypy-frigate.genai.*]
|
||||||
|
ignore_errors = false
|
||||||
|
|
||||||
|
[mypy-frigate.jobs.*]
|
||||||
|
ignore_errors = false
|
||||||
|
|
||||||
|
[mypy-frigate.motion]
|
||||||
|
ignore_errors = false
|
||||||
|
|
||||||
|
[mypy-frigate.object_detection]
|
||||||
|
ignore_errors = false
|
||||||
|
|
||||||
|
[mypy-frigate.output]
|
||||||
|
ignore_errors = false
|
||||||
|
|
||||||
|
[mypy-frigate.ptz]
|
||||||
|
ignore_errors = false
|
||||||
|
|
||||||
[mypy-frigate.log]
|
[mypy-frigate.log]
|
||||||
ignore_errors = false
|
ignore_errors = false
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user