diff --git a/frigate/embeddings/onnx/face_embedding.py b/frigate/embeddings/onnx/face_embedding.py index 10d5627d9..da3f468d4 100644 --- a/frigate/embeddings/onnx/face_embedding.py +++ b/frigate/embeddings/onnx/face_embedding.py @@ -6,12 +6,12 @@ import os import numpy as np from frigate.const import MODEL_CACHE_DIR +from frigate.detectors.base_runner import get_optimized_runner from frigate.log import redirect_output_to_logger from frigate.util.downloader import ModelDownloader from ...config import FaceRecognitionConfig from .base_embedding import BaseEmbedding -from .runner import ONNXModelRunner try: from tflite_runtime.interpreter import Interpreter @@ -148,7 +148,7 @@ class ArcfaceEmbedding(BaseEmbedding): if self.downloader: self.downloader.wait_for_download() - self.runner = ONNXModelRunner( + self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), device=self.config.device or "GPU", ) diff --git a/frigate/embeddings/onnx/jina_v1_embedding.py b/frigate/embeddings/onnx/jina_v1_embedding.py index d327fa8ba..51c075aa3 100644 --- a/frigate/embeddings/onnx/jina_v1_embedding.py +++ b/frigate/embeddings/onnx/jina_v1_embedding.py @@ -7,6 +7,7 @@ import warnings # importing this without pytorch or others causes a warning # https://github.com/huggingface/transformers/issues/27214 # suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1 +from frigate.detectors.base_runner import BaseModelRunner, get_optimized_runner from transformers import AutoFeatureExtractor, AutoTokenizer from transformers.utils.logging import disable_progress_bar @@ -16,7 +17,6 @@ from frigate.types import ModelStatusTypesEnum from frigate.util.downloader import ModelDownloader from .base_embedding import BaseEmbedding -from .runner import ONNXModelRunner warnings.filterwarnings( "ignore", @@ -125,7 +125,7 @@ class JinaV1TextEmbedding(BaseEmbedding): clean_up_tokenization_spaces=True, ) - self.runner = ONNXModelRunner( + self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, ) @@ -170,7 +170,7 @@ class JinaV1ImageEmbedding(BaseEmbedding): self.device = device self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) self.feature_extractor = None - self.runner: ONNXModelRunner | None = None + self.runner: BaseModelRunner | None = None files_names = list(self.download_urls.keys()) if not all( os.path.exists(os.path.join(self.download_path, n)) for n in files_names @@ -203,7 +203,7 @@ class JinaV1ImageEmbedding(BaseEmbedding): f"{MODEL_CACHE_DIR}/{self.model_name}", ) - self.runner = ONNXModelRunner( + self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, ) diff --git a/frigate/embeddings/onnx/jina_v2_embedding.py b/frigate/embeddings/onnx/jina_v2_embedding.py index 50b503d76..6077b2faa 100644 --- a/frigate/embeddings/onnx/jina_v2_embedding.py +++ b/frigate/embeddings/onnx/jina_v2_embedding.py @@ -6,6 +6,7 @@ import os import numpy as np from PIL import Image +from frigate.detectors.base_runner import get_optimized_runner from transformers import AutoTokenizer from transformers.utils.logging import disable_progress_bar, set_verbosity_error @@ -15,7 +16,6 @@ from frigate.types import ModelStatusTypesEnum from frigate.util.downloader import ModelDownloader from .base_embedding import BaseEmbedding -from .runner import ONNXModelRunner # disables the progress bar and download logging for downloading tokenizers and image processors disable_progress_bar() @@ -125,7 +125,7 @@ class JinaV2Embedding(BaseEmbedding): clean_up_tokenization_spaces=True, ) - self.runner = ONNXModelRunner( + self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, ) diff --git a/frigate/embeddings/onnx/lpr_embedding.py b/frigate/embeddings/onnx/lpr_embedding.py index 1b5b9acd0..88b0ae07b 100644 --- a/frigate/embeddings/onnx/lpr_embedding.py +++ b/frigate/embeddings/onnx/lpr_embedding.py @@ -7,11 +7,11 @@ import numpy as np from frigate.comms.inter_process import InterProcessRequestor from frigate.const import MODEL_CACHE_DIR +from frigate.detectors.base_runner import BaseModelRunner, get_optimized_runner from frigate.types import ModelStatusTypesEnum from frigate.util.downloader import ModelDownloader from .base_embedding import BaseEmbedding -from .runner import ONNXModelRunner warnings.filterwarnings( "ignore", @@ -47,7 +47,7 @@ class PaddleOCRDetection(BaseEmbedding): self.model_size = model_size self.device = device self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) - self.runner: ONNXModelRunner | None = None + self.runner: BaseModelRunner | None = None files_names = list(self.download_urls.keys()) if not all( os.path.exists(os.path.join(self.download_path, n)) for n in files_names @@ -76,7 +76,7 @@ class PaddleOCRDetection(BaseEmbedding): if self.downloader: self.downloader.wait_for_download() - self.runner = ONNXModelRunner( + self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, self.model_size, @@ -107,7 +107,7 @@ class PaddleOCRClassification(BaseEmbedding): self.model_size = model_size self.device = device self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) - self.runner: ONNXModelRunner | None = None + self.runner: BaseModelRunner | None = None files_names = list(self.download_urls.keys()) if not all( os.path.exists(os.path.join(self.download_path, n)) for n in files_names @@ -136,7 +136,7 @@ class PaddleOCRClassification(BaseEmbedding): if self.downloader: self.downloader.wait_for_download() - self.runner = ONNXModelRunner( + self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, self.model_size, @@ -168,7 +168,7 @@ class PaddleOCRRecognition(BaseEmbedding): self.model_size = model_size self.device = device self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) - self.runner: ONNXModelRunner | None = None + self.runner: BaseModelRunner | None = None files_names = list(self.download_urls.keys()) if not all( os.path.exists(os.path.join(self.download_path, n)) for n in files_names @@ -197,7 +197,7 @@ class PaddleOCRRecognition(BaseEmbedding): if self.downloader: self.downloader.wait_for_download() - self.runner = ONNXModelRunner( + self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, self.model_size, @@ -229,7 +229,7 @@ class LicensePlateDetector(BaseEmbedding): self.model_size = model_size self.device = device self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) - self.runner: ONNXModelRunner | None = None + self.runner: BaseModelRunner | None = None files_names = list(self.download_urls.keys()) if not all( os.path.exists(os.path.join(self.download_path, n)) for n in files_names @@ -258,7 +258,7 @@ class LicensePlateDetector(BaseEmbedding): if self.downloader: self.downloader.wait_for_download() - self.runner = ONNXModelRunner( + self.runner = get_optimized_runner( os.path.join(self.download_path, self.model_file), self.device, self.model_size,