Use enum for model types

This commit is contained in:
Nicolas Mowen 2024-10-21 16:11:56 -06:00
parent 843cce2349
commit 4f78aa5f8c
2 changed files with 13 additions and 6 deletions

View File

@ -95,7 +95,7 @@ class Embeddings:
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx", "text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
}, },
model_size=config.model_size, model_size=config.model_size,
model_type="text", model_type=ModelTypeEnum.text,
requestor=self.requestor, requestor=self.requestor,
device="CPU", device="CPU",
) )
@ -116,7 +116,7 @@ class Embeddings:
model_file=model_file, model_file=model_file,
download_urls=download_urls, download_urls=download_urls,
model_size=config.model_size, model_size=config.model_size,
model_type="vision", model_type=ModelTypeEnum.vision,
requestor=self.requestor, requestor=self.requestor,
device="GPU" if config.model_size == "large" else "CPU", device="GPU" if config.model_size == "large" else "CPU",
) )

View File

@ -1,6 +1,7 @@
import logging import logging
import os import os
import warnings import warnings
from enum import Enum
from io import BytesIO from io import BytesIO
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
@ -31,6 +32,12 @@ disable_progress_bar()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ModelTypeEnum(str, Enum):
face = "face"
vision = "vision"
text = "text"
class GenericONNXEmbedding: class GenericONNXEmbedding:
"""Generic embedding function for ONNX models (text and vision).""" """Generic embedding function for ONNX models (text and vision)."""
@ -88,7 +95,7 @@ class GenericONNXEmbedding:
file_name = os.path.basename(path) file_name = os.path.basename(path)
if file_name in self.download_urls: if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path) ModelDownloader.download_from_url(self.download_urls[file_name], path)
elif file_name == self.tokenizer_file and self.model_type == "text": elif file_name == self.tokenizer_file and self.model_type == ModelTypeEnum.text:
if not os.path.exists(path + "/" + self.model_name): if not os.path.exists(path + "/" + self.model_name):
logger.info(f"Downloading {self.model_name} tokenizer") logger.info(f"Downloading {self.model_name} tokenizer")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
@ -119,7 +126,7 @@ class GenericONNXEmbedding:
if self.runner is None: if self.runner is None:
if self.downloader: if self.downloader:
self.downloader.wait_for_download() self.downloader.wait_for_download()
if self.model_type == "text": if self.model_type == ModelTypeEnum.text:
self.tokenizer = self._load_tokenizer() self.tokenizer = self._load_tokenizer()
else: else:
self.feature_extractor = self._load_feature_extractor() self.feature_extractor = self._load_feature_extractor()
@ -144,7 +151,7 @@ class GenericONNXEmbedding:
) )
def _preprocess_inputs(self, raw_inputs: any) -> any: def _preprocess_inputs(self, raw_inputs: any) -> any:
if self.model_type == "text": if self.model_type == ModelTypeEnum.text:
max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs) max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs)
return [ return [
self.tokenizer( self.tokenizer(
@ -156,7 +163,7 @@ class GenericONNXEmbedding:
) )
for text in raw_inputs for text in raw_inputs
] ]
elif self.model_type == "image": elif self.model_type == ModelTypeEnum.vision:
processed_images = [self._process_image(img) for img in raw_inputs] processed_images = [self._process_image(img) for img in raw_inputs]
return [ return [
self.feature_extractor(images=image, return_tensors="np") self.feature_extractor(images=image, return_tensors="np")