fix paths

This commit is contained in:
Josh Hawkins 2024-10-09 13:20:28 -05:00
parent 29fba8bdb4
commit e275c205f9
3 changed files with 25 additions and 18 deletions

View File

@ -88,10 +88,10 @@ class Embeddings:
self._create_tables() self._create_tables()
models = [ models = [
"all-jina-clip-v1-text_model_fp16.onnx", "jinaai/jina-clip-v1-text_model_fp16.onnx",
"all-jina-clip-v1-tokenizer", "jinaai/jina-clip-v1-tokenizer",
"all-jina-clip-v1-vision_model_fp16.onnx", "jinaai/jina-clip-v1-vision_model_fp16.onnx",
"all-jina-clip-v1-preprocessor_config.json", "jinaai/jina-clip-v1-feature_extractor.json",
] ]
for model in models: for model in models:
@ -110,7 +110,7 @@ class Embeddings:
return outputs[0] return outputs[0]
self.text_embedding = GenericONNXEmbedding( self.text_embedding = GenericONNXEmbedding(
model_name="all-jina-clip-v1", model_name="jinaai/jina-clip-v1",
model_file="text_model_fp16.onnx", model_file="text_model_fp16.onnx",
tokenizer_file="tokenizer", tokenizer_file="tokenizer",
download_urls={ download_urls={
@ -122,9 +122,9 @@ class Embeddings:
) )
self.vision_embedding = GenericONNXEmbedding( self.vision_embedding = GenericONNXEmbedding(
model_name="all-jina-clip-v1", model_name="jinaai/jina-clip-v1",
model_file="vision_model_fp16.onnx", model_file="vision_model_fp16.onnx",
tokenizer_file="preprocessor_config.json", tokenizer_file="feature_extractor",
download_urls={ download_urls={
"vision_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/vision_model_fp16.onnx" "vision_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/vision_model_fp16.onnx"
}, },

View File

@ -68,17 +68,22 @@ class GenericONNXEmbedding:
if file_name in self.download_urls: if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path) ModelDownloader.download_from_url(self.download_urls[file_name], path)
elif file_name == self.tokenizer_file: elif file_name == self.tokenizer_file:
logger.info( logger.info(path + "/" + self.model_name)
f"Downloading {self.model_name} tokenizer/feature extractor" if not os.path.exists(path + "/" + self.model_name):
) logger.info(
f"Downloading {self.model_name} tokenizer/feature extractor"
)
if self.model_type == "text": if self.model_type == "text":
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
self.model_name, clean_up_tokenization_spaces=True self.model_name,
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer",
clean_up_tokenization_spaces=True,
) )
tokenizer.save_pretrained(path) tokenizer.save_pretrained(path)
else: else:
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
self.model_name self.model_name,
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor",
) )
feature_extractor.save_pretrained(path) feature_extractor.save_pretrained(path)
@ -111,13 +116,15 @@ class GenericONNXEmbedding:
) )
def _load_tokenizer(self): def _load_tokenizer(self):
tokenizer_path = os.path.join(self.download_path, self.tokenizer_file) tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
return AutoTokenizer.from_pretrained( return AutoTokenizer.from_pretrained(
tokenizer_path, clean_up_tokenization_spaces=True tokenizer_path, clean_up_tokenization_spaces=True
) )
def _load_feature_extractor(self): def _load_feature_extractor(self):
feature_extractor_path = os.path.join(self.download_path, self.tokenizer_file) feature_extractor_path = os.path.join(
f"{MODEL_CACHE_DIR}/{self.model_name}/feature_extractor"
)
return AutoFeatureExtractor.from_pretrained(feature_extractor_path) return AutoFeatureExtractor.from_pretrained(feature_extractor_path)
def _load_model(self, path: str, providers: List[str]): def _load_model(self, path: str, providers: List[str]):

View File

@ -185,16 +185,16 @@ export default function Explore() {
// model states // model states
const { payload: textModelState } = useModelState( const { payload: textModelState } = useModelState(
"all-jina-clip-v1-text_model_fp16.onnx", "jinaai/jina-clip-v1-text_model_fp16.onnx",
); );
const { payload: textTokenizerState } = useModelState( const { payload: textTokenizerState } = useModelState(
"all-jina-clip-v1-tokenizer", "jinaai/jina-clip-v1-tokenizer",
); );
const { payload: visionModelState } = useModelState( const { payload: visionModelState } = useModelState(
"all-jina-clip-v1-vision_model_fp16.onnx", "jinaai/jina-clip-v1-vision_model_fp16.onnx",
); );
const { payload: visionFeatureExtractorState } = useModelState( const { payload: visionFeatureExtractorState } = useModelState(
"all-jina-clip-v1-preprocessor_config.json", "jinaai/jina-clip-v1-feature_extractor",
); );
const allModelsLoaded = useMemo(() => { const allModelsLoaded = useMemo(() => {