manually download and cache feature extractor config

This commit is contained in:
Josh Hawkins 2024-10-09 15:43:53 -05:00
parent beae4b9423
commit c5d4e301d1
3 changed files with 26 additions and 32 deletions

View File

@ -91,7 +91,7 @@ class Embeddings:
"jinaai/jina-clip-v1-text_model_fp16.onnx",
"jinaai/jina-clip-v1-tokenizer",
"jinaai/jina-clip-v1-vision_model_fp16.onnx",
"jinaai/jina-clip-v1-feature_extractor",
"jinaai/jina-clip-v1-preprocessor_config.json",
]
for model in models:
@ -114,7 +114,7 @@ class Embeddings:
model_file="text_model_fp16.onnx",
tokenizer_file="tokenizer",
download_urls={
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx"
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
},
embedding_function=jina_text_embedding_function,
model_type="text",
@ -124,9 +124,9 @@ class Embeddings:
self.vision_embedding = GenericONNXEmbedding(
model_name="jinaai/jina-clip-v1",
model_file="vision_model_fp16.onnx",
tokenizer_file="feature_extractor",
download_urls={
"vision_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/vision_model_fp16.onnx"
"vision_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/vision_model_fp16.onnx",
"preprocessor_config.json": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/preprocessor_config.json",
},
embedding_function=jina_vision_embedding_function,
model_type="vision",

View File

@ -2,7 +2,7 @@ import logging
import os
import warnings
from io import BytesIO
from typing import Callable, Dict, List, Union
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import onnxruntime as ort
@ -37,11 +37,11 @@ class GenericONNXEmbedding:
self,
model_name: str,
model_file: str,
tokenizer_file: str,
download_urls: Dict[str, str],
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
model_type: str,
preferred_providers: List[str] = ["CPUExecutionProvider"],
tokenizer_file: Optional[str] = None,
):
self.model_name = model_name
self.model_file = model_file
@ -59,7 +59,8 @@ class GenericONNXEmbedding:
self.downloader = ModelDownloader(
model_name=self.model_name,
download_path=self.download_path,
file_names=[self.model_file, self.tokenizer_file],
file_names=list(self.download_urls.keys())
+ ([self.tokenizer_file] if self.tokenizer_file else []),
download_func=self._download_model,
)
self.downloader.ensure_model_files()
@ -69,26 +70,22 @@ class GenericONNXEmbedding:
file_name = os.path.basename(path)
if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path)
elif file_name == self.tokenizer_file:
if self.model_type == "text":
if not os.path.exists(path + "/" + self.model_name):
logger.info(f"Downloading {self.model_name} tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer",
clean_up_tokenization_spaces=True,
)
tokenizer.save_pretrained(path)
else:
if not os.path.exists(path + "/" + self.model_name):
logger.info(f"Downloading {self.model_name} feature extractor")
feature_extractor = AutoFeatureExtractor.from_pretrained(
self.model_name,
trust_remote_code=True,
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor",
)
feature_extractor.save_pretrained(path)
elif file_name == self.tokenizer_file and self.model_type == "text":
if not os.path.exists(path + "/" + self.model_name):
logger.info(f"Downloading {self.model_name} tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True,
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer",
clean_up_tokenization_spaces=True,
)
tokenizer.save_pretrained(path)
else:
if not os.path.exists(path + "/" + self.model_name):
logger.info(f"Downloading {self.model_name} feature extractor")
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
f"{MODEL_CACHE_DIR}/{self.model_name}",
)
self.downloader.requestor.send_data(
UPDATE_MODEL_STATE,
@ -128,11 +125,8 @@ class GenericONNXEmbedding:
)
def _load_feature_extractor(self):
feature_extractor_path = os.path.join(
f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor"
)
return AutoFeatureExtractor.from_pretrained(
self.model_name, trust_remote_code=True, cache_dir=feature_extractor_path
f"{MODEL_CACHE_DIR}/{self.model_name}",
)
def _load_model(self, path: str, providers: List[str]):

View File

@ -194,7 +194,7 @@ export default function Explore() {
"jinaai/jina-clip-v1-vision_model_fp16.onnx",
);
const { payload: visionFeatureExtractorState } = useModelState(
"jinaai/jina-clip-v1-feature_extractor",
"jinaai/jina-clip-v1-preprocessor_config.json",
);
const allModelsLoaded = useMemo(() => {