Send model type

This commit is contained in:
Nicolas Mowen 2025-09-18 13:48:01 -06:00
parent 20c300df95
commit 161ed46c55
7 changed files with 48 additions and 12 deletions

View File

@ -6,6 +6,8 @@ from abc import ABC, abstractmethod
from typing import Any from typing import Any
import numpy as np import numpy as np
from frigate.detectors.detector_types import ModelTypeEnum
from frigate.embeddings.types import EnrichmentModelTypeEnum
import onnxruntime as ort import onnxruntime as ort
from frigate.util.model import get_ort_providers from frigate.util.model import get_ort_providers
@ -101,6 +103,15 @@ class CudaGraphRunner(BaseModelRunner):
for more complex models like CLIP or PaddleOCR. for more complex models like CLIP or PaddleOCR.
""" """
@staticmethod
def is_complex_model(model_type: str) -> bool:
return model_type in [
ModelTypeEnum.yolonas.value,
EnrichmentModelTypeEnum.paddleocr.value,
EnrichmentModelTypeEnum.jina_v1.value,
EnrichmentModelTypeEnum.jina_v2.value,
]
def __init__(self, session: ort.InferenceSession, cuda_device_id: int): def __init__(self, session: ort.InferenceSession, cuda_device_id: int):
self._session = session self._session = session
self._cuda_device_id = cuda_device_id self._cuda_device_id = cuda_device_id
@ -156,10 +167,14 @@ class CudaGraphRunner(BaseModelRunner):
class OpenVINOModelRunner(BaseModelRunner): class OpenVINOModelRunner(BaseModelRunner):
"""OpenVINO model runner that handles inference efficiently.""" """OpenVINO model runner that handles inference efficiently."""
def __init__(self, model_path: str, device: str, complex_model: bool, **kwargs): @staticmethod
def is_complex_model(model_type: str) -> bool:
return model_type in [EnrichmentModelTypeEnum.paddleocr.value]
def __init__(self, model_path: str, device: str, model_type: str, **kwargs):
self.model_path = model_path self.model_path = model_path
self.device = device self.device = device
self.complex_model = complex_model self.complex_model = OpenVINOModelRunner.is_complex_model(model_type)
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
raise FileNotFoundError(f"OpenVINO model file {model_path} not found.") raise FileNotFoundError(f"OpenVINO model file {model_path} not found.")
@ -381,7 +396,7 @@ class RKNNModelRunner(BaseModelRunner):
def get_optimized_runner( def get_optimized_runner(
model_path: str, device: str | None, complex_model: bool = True, **kwargs model_path: str, device: str | None, model_type: str, **kwargs
) -> BaseModelRunner: ) -> BaseModelRunner:
"""Get an optimized runner for the hardware.""" """Get an optimized runner for the hardware."""
device = device or "AUTO" device = device or "AUTO"
@ -398,7 +413,7 @@ def get_optimized_runner(
# In other images we will get CUDA / ROCm which are preferred over OpenVINO # In other images we will get CUDA / ROCm which are preferred over OpenVINO
# There is currently no way to prioritize OpenVINO over CUDA / ROCm in these images # There is currently no way to prioritize OpenVINO over CUDA / ROCm in these images
if device != "CPU" and is_openvino_gpu_npu_available(): if device != "CPU" and is_openvino_gpu_npu_available():
return OpenVINOModelRunner(model_path, device, complex_model, **kwargs) return OpenVINOModelRunner(model_path, device, model_type, **kwargs)
ortSession = ort.InferenceSession( ortSession = ort.InferenceSession(
model_path, model_path,
@ -406,7 +421,10 @@ def get_optimized_runner(
provider_options=options, provider_options=options,
) )
if not complex_model and providers[0] == "CUDAExecutionProvider": if (
not CudaGraphRunner.is_complex_model(model_type)
and providers[0] == "CUDAExecutionProvider"
):
return CudaGraphRunner(ortSession, options[0]["device_id"]) return CudaGraphRunner(ortSession, options[0]["device_id"])
return ONNXModelRunner(ortSession) return ONNXModelRunner(ortSession)

View File

@ -18,11 +18,6 @@ from frigate.util.downloader import ModelDownloader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EmbeddingTypeEnum(str, Enum):
thumbnail = "thumbnail"
description = "description"
class BaseEmbedding(ABC): class BaseEmbedding(ABC):
"""Base embedding class.""" """Base embedding class."""

View File

@ -7,6 +7,7 @@ import numpy as np
from frigate.const import MODEL_CACHE_DIR from frigate.const import MODEL_CACHE_DIR
from frigate.detectors.detection_runners import get_optimized_runner from frigate.detectors.detection_runners import get_optimized_runner
from frigate.embeddings.types import EnrichmentModelTypeEnum
from frigate.log import redirect_output_to_logger from frigate.log import redirect_output_to_logger
from frigate.util.downloader import ModelDownloader from frigate.util.downloader import ModelDownloader
@ -151,7 +152,7 @@ class ArcfaceEmbedding(BaseEmbedding):
self.runner = get_optimized_runner( self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
device=self.config.device or "GPU", device=self.config.device or "GPU",
complex_model=False, model_type=EnrichmentModelTypeEnum.arcface.value,
) )
def _preprocess_inputs(self, raw_inputs): def _preprocess_inputs(self, raw_inputs):

View File

@ -7,6 +7,7 @@ import warnings
# importing this without pytorch or others causes a warning # importing this without pytorch or others causes a warning
# https://github.com/huggingface/transformers/issues/27214 # https://github.com/huggingface/transformers/issues/27214
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1 # suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
from frigate.embeddings.types import EnrichmentModelTypeEnum
from transformers import AutoFeatureExtractor, AutoTokenizer from transformers import AutoFeatureExtractor, AutoTokenizer
from transformers.utils.logging import disable_progress_bar from transformers.utils.logging import disable_progress_bar
@ -128,6 +129,7 @@ class JinaV1TextEmbedding(BaseEmbedding):
self.runner = get_optimized_runner( self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
self.device, self.device,
model_type=EnrichmentModelTypeEnum.jina_v1.value,
) )
def _preprocess_inputs(self, raw_inputs): def _preprocess_inputs(self, raw_inputs):
@ -206,6 +208,7 @@ class JinaV1ImageEmbedding(BaseEmbedding):
self.runner = get_optimized_runner( self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
self.device, self.device,
model_type=EnrichmentModelTypeEnum.jina_v1.value,
) )
def _preprocess_inputs(self, raw_inputs): def _preprocess_inputs(self, raw_inputs):

View File

@ -6,6 +6,7 @@ import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from frigate.embeddings.types import EnrichmentModelTypeEnum
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.utils.logging import disable_progress_bar, set_verbosity_error from transformers.utils.logging import disable_progress_bar, set_verbosity_error
@ -128,6 +129,7 @@ class JinaV2Embedding(BaseEmbedding):
self.runner = get_optimized_runner( self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
self.device, self.device,
model_type=EnrichmentModelTypeEnum.jina_v2.value,
) )
def _preprocess_image(self, image_data: bytes | Image.Image) -> np.ndarray: def _preprocess_image(self, image_data: bytes | Image.Image) -> np.ndarray:

View File

@ -8,6 +8,7 @@ import numpy as np
from frigate.comms.inter_process import InterProcessRequestor from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import MODEL_CACHE_DIR from frigate.const import MODEL_CACHE_DIR
from frigate.detectors.detection_runners import BaseModelRunner, get_optimized_runner from frigate.detectors.detection_runners import BaseModelRunner, get_optimized_runner
from frigate.embeddings.types import EnrichmentModelTypeEnum
from frigate.types import ModelStatusTypesEnum from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader from frigate.util.downloader import ModelDownloader
@ -79,6 +80,7 @@ class PaddleOCRDetection(BaseEmbedding):
self.runner = get_optimized_runner( self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
self.device, self.device,
model_type=EnrichmentModelTypeEnum.paddleocr.value,
) )
def _preprocess_inputs(self, raw_inputs): def _preprocess_inputs(self, raw_inputs):
@ -138,6 +140,7 @@ class PaddleOCRClassification(BaseEmbedding):
self.runner = get_optimized_runner( self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
self.device, self.device,
model_type=EnrichmentModelTypeEnum.paddleocr.value,
) )
def _preprocess_inputs(self, raw_inputs): def _preprocess_inputs(self, raw_inputs):
@ -198,6 +201,7 @@ class PaddleOCRRecognition(BaseEmbedding):
self.runner = get_optimized_runner( self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
self.device, self.device,
model_type=EnrichmentModelTypeEnum.paddleocr.value,
) )
def _preprocess_inputs(self, raw_inputs): def _preprocess_inputs(self, raw_inputs):
@ -258,7 +262,7 @@ class LicensePlateDetector(BaseEmbedding):
self.runner = get_optimized_runner( self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file),
self.device, self.device,
complex_model=False, model_type="yolov9",
) )
def _preprocess_inputs(self, raw_inputs): def _preprocess_inputs(self, raw_inputs):

View File

@ -0,0 +1,13 @@
from enum import Enum
class EmbeddingTypeEnum(str, Enum):
thumbnail = "thumbnail"
description = "description"
class EnrichmentModelTypeEnum(str, Enum):
arcface = "arcface"
facenet = "facenet"
jina_v1 = "jina_v1"
jina_v2 = "jina_v2"
paddleocr = "paddleocr"