Don't run download

This commit is contained in:
Nicolas Mowen 2024-10-10 13:09:31 -06:00
parent 6287253563
commit 8f5e05fc71

View File

@ -59,15 +59,19 @@ class GenericONNXEmbedding:
self.feature_extractor = None self.feature_extractor = None
self.session = None self.session = None
print("starting model download") if not all(os.path.exists(os.path.join(self.download_path, n)) for n in self.download_urls.keys()):
self.downloader = ModelDownloader( print("starting model download")
self.downloader = ModelDownloader(
model_name=self.model_name, model_name=self.model_name,
download_path=self.download_path, download_path=self.download_path,
file_names=list(self.download_urls.keys()) file_names=list(self.download_urls.keys())
+ ([self.tokenizer_file] if self.tokenizer_file else []), + ([self.tokenizer_file] if self.tokenizer_file else []),
download_func=self._download_model, download_func=self._download_model,
) )
self.downloader.ensure_model_files() self.downloader.ensure_model_files()
else:
self.downloader = None
print("models are already downloaded")
def _download_model(self, path: str): def _download_model(self, path: str):
try: try:
@ -104,7 +108,8 @@ class GenericONNXEmbedding:
def _load_model_and_tokenizer(self): def _load_model_and_tokenizer(self):
if self.session is None: if self.session is None:
self.downloader.wait_for_download() if self.downloader:
self.downloader.wait_for_download()
if self.model_type == "text": if self.model_type == "text":
self.tokenizer = self._load_tokenizer() self.tokenizer = self._load_tokenizer()
else: else: