This commit is contained in:
Josh Hawkins 2024-10-09 14:08:41 -05:00
parent dd6b0c9bcd
commit beae4b9423

View File

@ -119,11 +119,8 @@ class GenericONNXEmbedding:
) )
def _load_tokenizer(self): def _load_tokenizer(self):
if not self.tokenizer: tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
tokenizer_path = os.path.join( return AutoTokenizer.from_pretrained(
f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer"
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name, self.model_name,
cache_dir=tokenizer_path, cache_dir=tokenizer_path,
trust_remote_code=True, trust_remote_code=True,
@ -131,14 +128,11 @@ class GenericONNXEmbedding:
) )
def _load_feature_extractor(self): def _load_feature_extractor(self):
if not self.feature_extractor:
feature_extractor_path = os.path.join( feature_extractor_path = os.path.join(
f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor" f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor"
) )
self.feature_extractor = AutoFeatureExtractor.from_pretrained( return AutoFeatureExtractor.from_pretrained(
self.model_name, self.model_name, trust_remote_code=True, cache_dir=feature_extractor_path
trust_remote_code=True,
cache_dir=feature_extractor_path,
) )
def _load_model(self, path: str, providers: List[str]): def _load_model(self, path: str, providers: List[str]):