ensure tokenizer and feature extractor are correctly loaded

This commit is contained in:
Josh Hawkins 2024-10-09 14:06:13 -05:00
parent 896dad5cc9
commit dd6b0c9bcd

View File

@ -119,21 +119,27 @@ class GenericONNXEmbedding:
)
def _load_tokenizer(self):
tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
return AutoTokenizer.from_pretrained(
self.model_name,
cache_dir=tokenizer_path,
trust_remote_code=True,
clean_up_tokenization_spaces=True,
)
if not self.tokenizer:
tokenizer_path = os.path.join(
f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer"
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
cache_dir=tokenizer_path,
trust_remote_code=True,
clean_up_tokenization_spaces=True,
)
def _load_feature_extractor(self):
feature_extractor_path = os.path.join(
f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor"
)
return AutoFeatureExtractor.from_pretrained(
self.model_name, trust_remote_code=True, cache_dir=feature_extractor_path
)
if not self.feature_extractor:
feature_extractor_path = os.path.join(
f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor"
)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
self.model_name,
trust_remote_code=True,
cache_dir=feature_extractor_path,
)
def _load_model(self, path: str, providers: List[str]):
if os.path.exists(path):