Send model type

This commit is contained in:
Nicolas Mowen 2025-09-18 13:48:01 -06:00
parent 20c300df95
commit 161ed46c55
7 changed files with 48 additions and 12 deletions

View File

@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
from typing import Any
import numpy as np
from frigate.detectors.detector_types import ModelTypeEnum
from frigate.embeddings.types import EnrichmentModelTypeEnum
import onnxruntime as ort
from frigate.util.model import get_ort_providers
@ -101,6 +103,15 @@ class CudaGraphRunner(BaseModelRunner):
for more complex models like CLIP or PaddleOCR.
"""
@staticmethod
def is_complex_model(model_type: str) -> bool:
return model_type in [
ModelTypeEnum.yolonas.value,
EnrichmentModelTypeEnum.paddleocr.value,
EnrichmentModelTypeEnum.jina_v1.value,
EnrichmentModelTypeEnum.jina_v2.value,
]
def __init__(self, session: ort.InferenceSession, cuda_device_id: int):
self._session = session
self._cuda_device_id = cuda_device_id
@ -156,10 +167,14 @@ class CudaGraphRunner(BaseModelRunner):
class OpenVINOModelRunner(BaseModelRunner):
"""OpenVINO model runner that handles inference efficiently."""
def __init__(self, model_path: str, device: str, complex_model: bool, **kwargs):
@staticmethod
def is_complex_model(model_type: str) -> bool:
return model_type in [EnrichmentModelTypeEnum.paddleocr.value]
def __init__(self, model_path: str, device: str, model_type: str, **kwargs):
self.model_path = model_path
self.device = device
self.complex_model = complex_model
self.complex_model = OpenVINOModelRunner.is_complex_model(model_type)
if not os.path.isfile(model_path):
raise FileNotFoundError(f"OpenVINO model file {model_path} not found.")
@ -381,7 +396,7 @@ class RKNNModelRunner(BaseModelRunner):
def get_optimized_runner(
model_path: str, device: str | None, complex_model: bool = True, **kwargs
model_path: str, device: str | None, model_type: str, **kwargs
) -> BaseModelRunner:
"""Get an optimized runner for the hardware."""
device = device or "AUTO"
@ -398,7 +413,7 @@ def get_optimized_runner(
# In other images we will get CUDA / ROCm which are preferred over OpenVINO
# There is currently no way to prioritize OpenVINO over CUDA / ROCm in these images
if device != "CPU" and is_openvino_gpu_npu_available():
return OpenVINOModelRunner(model_path, device, complex_model, **kwargs)
return OpenVINOModelRunner(model_path, device, model_type, **kwargs)
ortSession = ort.InferenceSession(
model_path,
@ -406,7 +421,10 @@ def get_optimized_runner(
provider_options=options,
)
if not complex_model and providers[0] == "CUDAExecutionProvider":
if (
not CudaGraphRunner.is_complex_model(model_type)
and providers[0] == "CUDAExecutionProvider"
):
return CudaGraphRunner(ortSession, options[0]["device_id"])
return ONNXModelRunner(ortSession)

View File

@ -18,11 +18,6 @@ from frigate.util.downloader import ModelDownloader
logger = logging.getLogger(__name__)
class EmbeddingTypeEnum(str, Enum):
thumbnail = "thumbnail"
description = "description"
class BaseEmbedding(ABC):
"""Base embedding class."""

View File

@ -7,6 +7,7 @@ import numpy as np
from frigate.const import MODEL_CACHE_DIR
from frigate.detectors.detection_runners import get_optimized_runner
from frigate.embeddings.types import EnrichmentModelTypeEnum
from frigate.log import redirect_output_to_logger
from frigate.util.downloader import ModelDownloader
@ -151,7 +152,7 @@ class ArcfaceEmbedding(BaseEmbedding):
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
device=self.config.device or "GPU",
complex_model=False,
model_type=EnrichmentModelTypeEnum.arcface.value,
)
def _preprocess_inputs(self, raw_inputs):

View File

@ -7,6 +7,7 @@ import warnings
# importing this without pytorch or others causes a warning
# https://github.com/huggingface/transformers/issues/27214
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
from frigate.embeddings.types import EnrichmentModelTypeEnum
from transformers import AutoFeatureExtractor, AutoTokenizer
from transformers.utils.logging import disable_progress_bar
@ -128,6 +129,7 @@ class JinaV1TextEmbedding(BaseEmbedding):
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
model_type=EnrichmentModelTypeEnum.jina_v1.value,
)
def _preprocess_inputs(self, raw_inputs):
@ -206,6 +208,7 @@ class JinaV1ImageEmbedding(BaseEmbedding):
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
model_type=EnrichmentModelTypeEnum.jina_v1.value,
)
def _preprocess_inputs(self, raw_inputs):

View File

@ -6,6 +6,7 @@ import os
import numpy as np
from PIL import Image
from frigate.embeddings.types import EnrichmentModelTypeEnum
from transformers import AutoTokenizer
from transformers.utils.logging import disable_progress_bar, set_verbosity_error
@ -128,6 +129,7 @@ class JinaV2Embedding(BaseEmbedding):
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
model_type=EnrichmentModelTypeEnum.jina_v2.value,
)
def _preprocess_image(self, image_data: bytes | Image.Image) -> np.ndarray:

View File

@ -8,6 +8,7 @@ import numpy as np
from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import MODEL_CACHE_DIR
from frigate.detectors.detection_runners import BaseModelRunner, get_optimized_runner
from frigate.embeddings.types import EnrichmentModelTypeEnum
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
@ -79,6 +80,7 @@ class PaddleOCRDetection(BaseEmbedding):
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
model_type=EnrichmentModelTypeEnum.paddleocr.value,
)
def _preprocess_inputs(self, raw_inputs):
@ -138,6 +140,7 @@ class PaddleOCRClassification(BaseEmbedding):
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
model_type=EnrichmentModelTypeEnum.paddleocr.value,
)
def _preprocess_inputs(self, raw_inputs):
@ -198,6 +201,7 @@ class PaddleOCRRecognition(BaseEmbedding):
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
model_type=EnrichmentModelTypeEnum.paddleocr.value,
)
def _preprocess_inputs(self, raw_inputs):
@ -258,7 +262,7 @@ class LicensePlateDetector(BaseEmbedding):
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
complex_model=False,
model_type="yolov9",
)
def _preprocess_inputs(self, raw_inputs):

View File

@ -0,0 +1,13 @@
from enum import Enum
class EmbeddingTypeEnum(str, Enum):
thumbnail = "thumbnail"
description = "description"
class EnrichmentModelTypeEnum(str, Enum):
arcface = "arcface"
facenet = "facenet"
jina_v1 = "jina_v1"
jina_v2 = "jina_v2"
paddleocr = "paddleocr"