diff --git a/frigate/detectors/detection_runners.py b/frigate/detectors/detection_runners.py index 5e7eb9e32..d2ad07c02 100644 --- a/frigate/detectors/detection_runners.py +++ b/frigate/detectors/detection_runners.py @@ -6,6 +6,8 @@ from abc import ABC, abstractmethod from typing import Any import numpy as np +from frigate.detectors.detector_types import ModelTypeEnum +from frigate.embeddings.types import EnrichmentModelTypeEnum import onnxruntime as ort from frigate.util.model import get_ort_providers @@ -101,6 +103,15 @@ class CudaGraphRunner(BaseModelRunner): for more complex models like CLIP or PaddleOCR. """ + @staticmethod + def is_complex_model(model_type: str) -> bool: + return model_type in [ + ModelTypeEnum.yolonas.value, + EnrichmentModelTypeEnum.paddleocr.value, + EnrichmentModelTypeEnum.jina_v1.value, + EnrichmentModelTypeEnum.jina_v2.value, + ] + def __init__(self, session: ort.InferenceSession, cuda_device_id: int): self._session = session self._cuda_device_id = cuda_device_id @@ -156,10 +167,14 @@ class CudaGraphRunner(BaseModelRunner): class OpenVINOModelRunner(BaseModelRunner): """OpenVINO model runner that handles inference efficiently.""" - def __init__(self, model_path: str, device: str, complex_model: bool, **kwargs): + @staticmethod + def is_complex_model(model_type: str) -> bool: + return model_type in [EnrichmentModelTypeEnum.paddleocr.value] + + def __init__(self, model_path: str, device: str, model_type: str, **kwargs): self.model_path = model_path self.device = device - self.complex_model = complex_model + self.complex_model = OpenVINOModelRunner.is_complex_model(model_type) if not os.path.isfile(model_path): raise FileNotFoundError(f"OpenVINO model file {model_path} not found.") @@ -381,7 +396,7 @@ class RKNNModelRunner(BaseModelRunner): def get_optimized_runner( - model_path: str, device: str | None, complex_model: bool = True, **kwargs + model_path: str, device: str | None, model_type: str, **kwargs ) -> BaseModelRunner: """Get an optimized runner for the hardware.""" device = device or "AUTO" @@ -398,7 +413,7 @@ def get_optimized_runner( # In other images we will get CUDA / ROCm which are preferred over OpenVINO # There is currently no way to prioritize OpenVINO over CUDA / ROCm in these images if device != "CPU" and is_openvino_gpu_npu_available(): - return OpenVINOModelRunner(model_path, device, complex_model, **kwargs) + return OpenVINOModelRunner(model_path, device, model_type, **kwargs) ortSession = ort.InferenceSession( model_path, @@ -406,7 +421,10 @@ def get_optimized_runner( provider_options=options, ) - if not complex_model and providers[0] == "CUDAExecutionProvider": + if ( + not CudaGraphRunner.is_complex_model(model_type) + and providers[0] == "CUDAExecutionProvider" + ): return CudaGraphRunner(ortSession, options[0]["device_id"]) return ONNXModelRunner(ortSession) diff --git a/frigate/embeddings/onnx/base_embedding.py b/frigate/embeddings/onnx/base_embedding.py index fcadd2852..bd15c77a8 100644 --- a/frigate/embeddings/onnx/base_embedding.py +++ b/frigate/embeddings/onnx/base_embedding.py @@ -18,11 +18,6 @@ from frigate.util.downloader import ModelDownloader logger = logging.getLogger(__name__) -class EmbeddingTypeEnum(str, Enum): - thumbnail = "thumbnail" - description = "description" - - class BaseEmbedding(ABC): """Base embedding class.""" diff --git a/frigate/embeddings/onnx/face_embedding.py b/frigate/embeddings/onnx/face_embedding.py index 4e7e142fc..77f2dbdca 100644 --- a/frigate/embeddings/onnx/face_embedding.py +++ b/frigate/embeddings/onnx/face_embedding.py @@ -7,6 +7,7 @@ import numpy as np from frigate.const import MODEL_CACHE_DIR from frigate.detectors.detection_runners import get_optimized_runner +from frigate.embeddings.types import EnrichmentModelTypeEnum from frigate.log import redirect_output_to_logger from frigate.util.downloader import ModelDownloader @@ -151,7 +152,7 @@ class ArcfaceEmbedding(BaseEmbedding): self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), device=self.config.device or "GPU", - complex_model=False, + model_type=EnrichmentModelTypeEnum.arcface.value, ) def _preprocess_inputs(self, raw_inputs): diff --git a/frigate/embeddings/onnx/jina_v1_embedding.py b/frigate/embeddings/onnx/jina_v1_embedding.py index 169ee453d..80466511d 100644 --- a/frigate/embeddings/onnx/jina_v1_embedding.py +++ b/frigate/embeddings/onnx/jina_v1_embedding.py @@ -7,6 +7,7 @@ import warnings # importing this without pytorch or others causes a warning # https://github.com/huggingface/transformers/issues/27214 # suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1 +from frigate.embeddings.types import EnrichmentModelTypeEnum from transformers import AutoFeatureExtractor, AutoTokenizer from transformers.utils.logging import disable_progress_bar @@ -128,6 +129,7 @@ class JinaV1TextEmbedding(BaseEmbedding): self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, + model_type=EnrichmentModelTypeEnum.jina_v1.value, ) def _preprocess_inputs(self, raw_inputs): @@ -206,6 +208,7 @@ class JinaV1ImageEmbedding(BaseEmbedding): self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, + model_type=EnrichmentModelTypeEnum.jina_v1.value, ) def _preprocess_inputs(self, raw_inputs): diff --git a/frigate/embeddings/onnx/jina_v2_embedding.py b/frigate/embeddings/onnx/jina_v2_embedding.py index 94e608512..d8fd96c50 100644 --- a/frigate/embeddings/onnx/jina_v2_embedding.py +++ b/frigate/embeddings/onnx/jina_v2_embedding.py @@ -6,6 +6,7 @@ import os import numpy as np from PIL import Image +from frigate.embeddings.types import EnrichmentModelTypeEnum from transformers import AutoTokenizer from transformers.utils.logging import disable_progress_bar, set_verbosity_error @@ -128,6 +129,7 @@ class JinaV2Embedding(BaseEmbedding): self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, + model_type=EnrichmentModelTypeEnum.jina_v2.value, ) def _preprocess_image(self, image_data: bytes | Image.Image) -> np.ndarray: diff --git a/frigate/embeddings/onnx/lpr_embedding.py b/frigate/embeddings/onnx/lpr_embedding.py index 30b8d372c..d41531d19 100644 --- a/frigate/embeddings/onnx/lpr_embedding.py +++ b/frigate/embeddings/onnx/lpr_embedding.py @@ -8,6 +8,7 @@ import numpy as np from frigate.comms.inter_process import InterProcessRequestor from frigate.const import MODEL_CACHE_DIR from frigate.detectors.detection_runners import BaseModelRunner, get_optimized_runner +from frigate.embeddings.types import EnrichmentModelTypeEnum from frigate.types import ModelStatusTypesEnum from frigate.util.downloader import ModelDownloader @@ -79,6 +80,7 @@ class PaddleOCRDetection(BaseEmbedding): self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, + model_type=EnrichmentModelTypeEnum.paddleocr.value, ) def _preprocess_inputs(self, raw_inputs): @@ -138,6 +140,7 @@ class PaddleOCRClassification(BaseEmbedding): self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, + model_type=EnrichmentModelTypeEnum.paddleocr.value, ) def _preprocess_inputs(self, raw_inputs): @@ -198,6 +201,7 @@ class PaddleOCRRecognition(BaseEmbedding): self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, + model_type=EnrichmentModelTypeEnum.paddleocr.value, ) def _preprocess_inputs(self, raw_inputs): @@ -258,7 +262,7 @@ class LicensePlateDetector(BaseEmbedding): self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, - complex_model=False, + model_type="yolov9", ) def _preprocess_inputs(self, raw_inputs): diff --git a/frigate/embeddings/types.py b/frigate/embeddings/types.py new file mode 100644 index 000000000..7fd7e43fa --- /dev/null +++ b/frigate/embeddings/types.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class EmbeddingTypeEnum(str, Enum): + thumbnail = "thumbnail" + description = "description" + +class EnrichmentModelTypeEnum(str, Enum): + arcface = "arcface" + facenet = "facenet" + jina_v1 = "jina_v1" + jina_v2 = "jina_v2" + paddleocr = "paddleocr" \ No newline at end of file