ensure tokenizer and feature extractor are correctly loaded

This commit is contained in:
Josh Hawkins 2024-10-09 14:06:13 -05:00
parent 896dad5cc9
commit dd6b0c9bcd

View File

@ -119,21 +119,27 @@ class GenericONNXEmbedding:
) )
def _load_tokenizer(self): def _load_tokenizer(self):
tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer") if not self.tokenizer:
return AutoTokenizer.from_pretrained( tokenizer_path = os.path.join(
self.model_name, f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer"
cache_dir=tokenizer_path, )
trust_remote_code=True, self.tokenizer = AutoTokenizer.from_pretrained(
clean_up_tokenization_spaces=True, self.model_name,
) cache_dir=tokenizer_path,
trust_remote_code=True,
clean_up_tokenization_spaces=True,
)
def _load_feature_extractor(self): def _load_feature_extractor(self):
feature_extractor_path = os.path.join( if not self.feature_extractor:
f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor" feature_extractor_path = os.path.join(
) f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor"
return AutoFeatureExtractor.from_pretrained( )
self.model_name, trust_remote_code=True, cache_dir=feature_extractor_path self.feature_extractor = AutoFeatureExtractor.from_pretrained(
) self.model_name,
trust_remote_code=True,
cache_dir=feature_extractor_path,
)
def _load_model(self, path: str, providers: List[str]): def _load_model(self, path: str, providers: List[str]):
if os.path.exists(path): if os.path.exists(path):