disable download progress bar

This commit is contained in:
Josh Hawkins 2024-10-09 13:29:13 -05:00
parent e275c205f9
commit 0eabfa2763

View File

@ -13,6 +13,7 @@ from PIL import Image
# https://github.com/huggingface/transformers/issues/27214 # https://github.com/huggingface/transformers/issues/27214
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1 # suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
from transformers import AutoFeatureExtractor, AutoTokenizer from transformers import AutoFeatureExtractor, AutoTokenizer
from transformers.utils.logging import disable_progress_bar
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum from frigate.types import ModelStatusTypesEnum
@ -24,7 +25,8 @@ warnings.filterwarnings(
message="The class CLIPFeatureExtractor is deprecated", message="The class CLIPFeatureExtractor is deprecated",
) )
# disables the progress bar for downloading tokenizers and feature extractors
disable_progress_bar()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,11 +71,9 @@ class GenericONNXEmbedding:
ModelDownloader.download_from_url(self.download_urls[file_name], path) ModelDownloader.download_from_url(self.download_urls[file_name], path)
elif file_name == self.tokenizer_file: elif file_name == self.tokenizer_file:
logger.info(path + "/" + self.model_name) logger.info(path + "/" + self.model_name)
if not os.path.exists(path + "/" + self.model_name):
logger.info(
f"Downloading {self.model_name} tokenizer/feature extractor"
)
if self.model_type == "text": if self.model_type == "text":
if not os.path.exists(path + "/" + self.model_name):
logger.info(f"Downloading {self.model_name} tokenizer")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
self.model_name, self.model_name,
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer", cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer",
@ -81,6 +81,8 @@ class GenericONNXEmbedding:
) )
tokenizer.save_pretrained(path) tokenizer.save_pretrained(path)
else: else:
if not os.path.exists(path + "/" + self.model_name):
logger.info(f"Downloading {self.model_name} feature extractor")
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
self.model_name, self.model_name,
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor", cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor",