diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index 396145380..ee6420201 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -18,6 +18,7 @@ from transformers.utils.logging import disable_progress_bar from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE from frigate.types import ModelStatusTypesEnum from frigate.util.downloader import ModelDownloader +from frigate.util.model import get_ort_providers warnings.filterwarnings( "ignore", @@ -40,8 +41,8 @@ class GenericONNXEmbedding: download_urls: Dict[str, str], embedding_function: Callable[[List[np.ndarray]], np.ndarray], model_type: str, - preferred_providers: List[str] = ["CPUExecutionProvider"], tokenizer_file: Optional[str] = None, + force_cpu: bool = False, ): self.model_name = model_name self.model_file = model_file @@ -49,7 +50,7 @@ class GenericONNXEmbedding: self.download_urls = download_urls self.embedding_function = embedding_function self.model_type = model_type # 'text' or 'vision' - self.preferred_providers = preferred_providers + self.providers, self.provider_options = get_ort_providers(force_cpu=force_cpu) self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) self.tokenizer = None @@ -105,8 +106,7 @@ class GenericONNXEmbedding: else: self.feature_extractor = self._load_feature_extractor() self.session = self._load_model( - os.path.join(self.download_path, self.model_file), - self.preferred_providers, + os.path.join(self.download_path, self.model_file) ) def _load_tokenizer(self): @@ -123,7 +123,7 @@ class GenericONNXEmbedding: f"{MODEL_CACHE_DIR}/{self.model_name}", ) - def _load_model(self, path: str, providers: List[str]): + def _load_model(self, path: str): if os.path.exists(path): return ort.InferenceSession(path, providers=providers) else: