manually download and cache feature extractor config

This commit is contained in:
Josh Hawkins 2024-10-09 15:43:53 -05:00
parent beae4b9423
commit c5d4e301d1
3 changed files with 26 additions and 32 deletions

View File

@ -91,7 +91,7 @@ class Embeddings:
"jinaai/jina-clip-v1-text_model_fp16.onnx", "jinaai/jina-clip-v1-text_model_fp16.onnx",
"jinaai/jina-clip-v1-tokenizer", "jinaai/jina-clip-v1-tokenizer",
"jinaai/jina-clip-v1-vision_model_fp16.onnx", "jinaai/jina-clip-v1-vision_model_fp16.onnx",
"jinaai/jina-clip-v1-feature_extractor", "jinaai/jina-clip-v1-preprocessor_config.json",
] ]
for model in models: for model in models:
@ -114,7 +114,7 @@ class Embeddings:
model_file="text_model_fp16.onnx", model_file="text_model_fp16.onnx",
tokenizer_file="tokenizer", tokenizer_file="tokenizer",
download_urls={ download_urls={
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx" "text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
}, },
embedding_function=jina_text_embedding_function, embedding_function=jina_text_embedding_function,
model_type="text", model_type="text",
@ -124,9 +124,9 @@ class Embeddings:
self.vision_embedding = GenericONNXEmbedding( self.vision_embedding = GenericONNXEmbedding(
model_name="jinaai/jina-clip-v1", model_name="jinaai/jina-clip-v1",
model_file="vision_model_fp16.onnx", model_file="vision_model_fp16.onnx",
tokenizer_file="feature_extractor",
download_urls={ download_urls={
"vision_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/vision_model_fp16.onnx" "vision_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/vision_model_fp16.onnx",
"preprocessor_config.json": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/preprocessor_config.json",
}, },
embedding_function=jina_vision_embedding_function, embedding_function=jina_vision_embedding_function,
model_type="vision", model_type="vision",

View File

@ -2,7 +2,7 @@ import logging
import os import os
import warnings import warnings
from io import BytesIO from io import BytesIO
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
@ -37,11 +37,11 @@ class GenericONNXEmbedding:
self, self,
model_name: str, model_name: str,
model_file: str, model_file: str,
tokenizer_file: str,
download_urls: Dict[str, str], download_urls: Dict[str, str],
embedding_function: Callable[[List[np.ndarray]], np.ndarray], embedding_function: Callable[[List[np.ndarray]], np.ndarray],
model_type: str, model_type: str,
preferred_providers: List[str] = ["CPUExecutionProvider"], preferred_providers: List[str] = ["CPUExecutionProvider"],
tokenizer_file: Optional[str] = None,
): ):
self.model_name = model_name self.model_name = model_name
self.model_file = model_file self.model_file = model_file
@ -59,7 +59,8 @@ class GenericONNXEmbedding:
self.downloader = ModelDownloader( self.downloader = ModelDownloader(
model_name=self.model_name, model_name=self.model_name,
download_path=self.download_path, download_path=self.download_path,
file_names=[self.model_file, self.tokenizer_file], file_names=list(self.download_urls.keys())
+ ([self.tokenizer_file] if self.tokenizer_file else []),
download_func=self._download_model, download_func=self._download_model,
) )
self.downloader.ensure_model_files() self.downloader.ensure_model_files()
@ -69,26 +70,22 @@ class GenericONNXEmbedding:
file_name = os.path.basename(path) file_name = os.path.basename(path)
if file_name in self.download_urls: if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path) ModelDownloader.download_from_url(self.download_urls[file_name], path)
elif file_name == self.tokenizer_file: elif file_name == self.tokenizer_file and self.model_type == "text":
if self.model_type == "text": if not os.path.exists(path + "/" + self.model_name):
if not os.path.exists(path + "/" + self.model_name): logger.info(f"Downloading {self.model_name} tokenizer")
logger.info(f"Downloading {self.model_name} tokenizer") tokenizer = AutoTokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained( self.model_name,
self.model_name, trust_remote_code=True,
trust_remote_code=True, cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer",
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer", clean_up_tokenization_spaces=True,
clean_up_tokenization_spaces=True, )
) tokenizer.save_pretrained(path)
tokenizer.save_pretrained(path) else:
else: if not os.path.exists(path + "/" + self.model_name):
if not os.path.exists(path + "/" + self.model_name): logger.info(f"Downloading {self.model_name} feature extractor")
logger.info(f"Downloading {self.model_name} feature extractor") self.feature_extractor = AutoFeatureExtractor.from_pretrained(
feature_extractor = AutoFeatureExtractor.from_pretrained( f"{MODEL_CACHE_DIR}/{self.model_name}",
self.model_name, )
trust_remote_code=True,
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor",
)
feature_extractor.save_pretrained(path)
self.downloader.requestor.send_data( self.downloader.requestor.send_data(
UPDATE_MODEL_STATE, UPDATE_MODEL_STATE,
@ -128,11 +125,8 @@ class GenericONNXEmbedding:
) )
def _load_feature_extractor(self): def _load_feature_extractor(self):
feature_extractor_path = os.path.join(
f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor"
)
return AutoFeatureExtractor.from_pretrained( return AutoFeatureExtractor.from_pretrained(
self.model_name, trust_remote_code=True, cache_dir=feature_extractor_path f"{MODEL_CACHE_DIR}/{self.model_name}",
) )
def _load_model(self, path: str, providers: List[str]): def _load_model(self, path: str, providers: List[str]):

View File

@ -194,7 +194,7 @@ export default function Explore() {
"jinaai/jina-clip-v1-vision_model_fp16.onnx", "jinaai/jina-clip-v1-vision_model_fp16.onnx",
); );
const { payload: visionFeatureExtractorState } = useModelState( const { payload: visionFeatureExtractorState } = useModelState(
"jinaai/jina-clip-v1-feature_extractor", "jinaai/jina-clip-v1-preprocessor_config.json",
); );
const allModelsLoaded = useMemo(() => { const allModelsLoaded = useMemo(() => {