Use optimized runner

This commit is contained in:
Nicolas Mowen 2025-09-13 20:11:06 -06:00
parent b2321cf1e8
commit 9a3fca6b2b
4 changed files with 17 additions and 17 deletions

View File

@ -6,12 +6,12 @@ import os
import numpy as np
from frigate.const import MODEL_CACHE_DIR
from frigate.detectors.base_runner import get_optimized_runner
from frigate.log import redirect_output_to_logger
from frigate.util.downloader import ModelDownloader
from ...config import FaceRecognitionConfig
from .base_embedding import BaseEmbedding
from .runner import ONNXModelRunner
try:
from tflite_runtime.interpreter import Interpreter
@ -148,7 +148,7 @@ class ArcfaceEmbedding(BaseEmbedding):
if self.downloader:
self.downloader.wait_for_download()
self.runner = ONNXModelRunner(
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
device=self.config.device or "GPU",
)

View File

@ -7,6 +7,7 @@ import warnings
# importing this without pytorch or others causes a warning
# https://github.com/huggingface/transformers/issues/27214
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
from frigate.detectors.base_runner import BaseModelRunner, get_optimized_runner
from transformers import AutoFeatureExtractor, AutoTokenizer
from transformers.utils.logging import disable_progress_bar
@ -16,7 +17,6 @@ from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
from .base_embedding import BaseEmbedding
from .runner import ONNXModelRunner
warnings.filterwarnings(
"ignore",
@ -125,7 +125,7 @@ class JinaV1TextEmbedding(BaseEmbedding):
clean_up_tokenization_spaces=True,
)
self.runner = ONNXModelRunner(
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
)
@ -170,7 +170,7 @@ class JinaV1ImageEmbedding(BaseEmbedding):
self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.feature_extractor = None
self.runner: ONNXModelRunner | None = None
self.runner: BaseModelRunner | None = None
files_names = list(self.download_urls.keys())
if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
@ -203,7 +203,7 @@ class JinaV1ImageEmbedding(BaseEmbedding):
f"{MODEL_CACHE_DIR}/{self.model_name}",
)
self.runner = ONNXModelRunner(
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
)

View File

@ -6,6 +6,7 @@ import os
import numpy as np
from PIL import Image
from frigate.detectors.base_runner import get_optimized_runner
from transformers import AutoTokenizer
from transformers.utils.logging import disable_progress_bar, set_verbosity_error
@ -15,7 +16,6 @@ from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
from .base_embedding import BaseEmbedding
from .runner import ONNXModelRunner
# disables the progress bar and download logging for downloading tokenizers and image processors
disable_progress_bar()
@ -125,7 +125,7 @@ class JinaV2Embedding(BaseEmbedding):
clean_up_tokenization_spaces=True,
)
self.runner = ONNXModelRunner(
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
)

View File

@ -7,11 +7,11 @@ import numpy as np
from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import MODEL_CACHE_DIR
from frigate.detectors.base_runner import BaseModelRunner, get_optimized_runner
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
from .base_embedding import BaseEmbedding
from .runner import ONNXModelRunner
warnings.filterwarnings(
"ignore",
@ -47,7 +47,7 @@ class PaddleOCRDetection(BaseEmbedding):
self.model_size = model_size
self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.runner: ONNXModelRunner | None = None
self.runner: BaseModelRunner | None = None
files_names = list(self.download_urls.keys())
if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
@ -76,7 +76,7 @@ class PaddleOCRDetection(BaseEmbedding):
if self.downloader:
self.downloader.wait_for_download()
self.runner = ONNXModelRunner(
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
self.model_size,
@ -107,7 +107,7 @@ class PaddleOCRClassification(BaseEmbedding):
self.model_size = model_size
self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.runner: ONNXModelRunner | None = None
self.runner: BaseModelRunner | None = None
files_names = list(self.download_urls.keys())
if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
@ -136,7 +136,7 @@ class PaddleOCRClassification(BaseEmbedding):
if self.downloader:
self.downloader.wait_for_download()
self.runner = ONNXModelRunner(
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
self.model_size,
@ -168,7 +168,7 @@ class PaddleOCRRecognition(BaseEmbedding):
self.model_size = model_size
self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.runner: ONNXModelRunner | None = None
self.runner: BaseModelRunner | None = None
files_names = list(self.download_urls.keys())
if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
@ -197,7 +197,7 @@ class PaddleOCRRecognition(BaseEmbedding):
if self.downloader:
self.downloader.wait_for_download()
self.runner = ONNXModelRunner(
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
self.model_size,
@ -229,7 +229,7 @@ class LicensePlateDetector(BaseEmbedding):
self.model_size = model_size
self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.runner: ONNXModelRunner | None = None
self.runner: BaseModelRunner | None = None
files_names = list(self.download_urls.keys())
if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
@ -258,7 +258,7 @@ class LicensePlateDetector(BaseEmbedding):
if self.downloader:
self.downloader.wait_for_download()
self.runner = ONNXModelRunner(
self.runner = get_optimized_runner(
os.path.join(self.download_path, self.model_file),
self.device,
self.model_size,