This commit is contained in:
Josh Hawkins 2024-10-09 14:08:41 -05:00
parent dd6b0c9bcd
commit beae4b9423

View File

@ -119,27 +119,21 @@ class GenericONNXEmbedding:
) )
def _load_tokenizer(self): def _load_tokenizer(self):
if not self.tokenizer: tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
tokenizer_path = os.path.join( return AutoTokenizer.from_pretrained(
f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer" self.model_name,
) cache_dir=tokenizer_path,
self.tokenizer = AutoTokenizer.from_pretrained( trust_remote_code=True,
self.model_name, clean_up_tokenization_spaces=True,
cache_dir=tokenizer_path, )
trust_remote_code=True,
clean_up_tokenization_spaces=True,
)
def _load_feature_extractor(self): def _load_feature_extractor(self):
if not self.feature_extractor: feature_extractor_path = os.path.join(
feature_extractor_path = os.path.join( f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor"
f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor" )
) return AutoFeatureExtractor.from_pretrained(
self.feature_extractor = AutoFeatureExtractor.from_pretrained( self.model_name, trust_remote_code=True, cache_dir=feature_extractor_path
self.model_name, )
trust_remote_code=True,
cache_dir=feature_extractor_path,
)
def _load_model(self, path: str, providers: List[str]): def _load_model(self, path: str, providers: List[str]):
if os.path.exists(path): if os.path.exists(path):