From dd6b0c9bcd1bcb3a5729e6483997bcdd11007111 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:06:13 -0500 Subject: [PATCH] ensure tokenizer and feature extractor are correctly loaded --- frigate/embeddings/functions/onnx.py | 32 +++++++++++++++++----------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index 454fe3faf..c82b60517 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -119,21 +119,27 @@ class GenericONNXEmbedding: ) def _load_tokenizer(self): - tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer") - return AutoTokenizer.from_pretrained( - self.model_name, - cache_dir=tokenizer_path, - trust_remote_code=True, - clean_up_tokenization_spaces=True, - ) + if not self.tokenizer: + tokenizer_path = os.path.join( + f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer" + ) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name, + cache_dir=tokenizer_path, + trust_remote_code=True, + clean_up_tokenization_spaces=True, + ) def _load_feature_extractor(self): - feature_extractor_path = os.path.join( - f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor" - ) - return AutoFeatureExtractor.from_pretrained( - self.model_name, trust_remote_code=True, cache_dir=feature_extractor_path - ) + if not self.feature_extractor: + feature_extractor_path = os.path.join( + f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor" + ) + self.feature_extractor = AutoFeatureExtractor.from_pretrained( + self.model_name, + trust_remote_code=True, + cache_dir=feature_extractor_path, + ) def _load_model(self, path: str, providers: List[str]): if os.path.exists(path):