Use enum for model types

This commit is contained in:
Nicolas Mowen 2024-10-21 16:11:56 -06:00
parent 843cce2349
commit 4f78aa5f8c
2 changed files with 13 additions and 6 deletions

View File

@ -95,7 +95,7 @@ class Embeddings:
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
},
model_size=config.model_size,
model_type="text",
model_type=ModelTypeEnum.text,
requestor=self.requestor,
device="CPU",
)
@ -116,7 +116,7 @@ class Embeddings:
model_file=model_file,
download_urls=download_urls,
model_size=config.model_size,
model_type="vision",
model_type=ModelTypeEnum.vision,
requestor=self.requestor,
device="GPU" if config.model_size == "large" else "CPU",
)

View File

@ -1,6 +1,7 @@
import logging
import os
import warnings
from enum import Enum
from io import BytesIO
from typing import Dict, List, Optional, Union
@ -31,6 +32,12 @@ disable_progress_bar()
logger = logging.getLogger(__name__)
class ModelTypeEnum(str, Enum):
face = "face"
vision = "vision"
text = "text"
class GenericONNXEmbedding:
"""Generic embedding function for ONNX models (text and vision)."""
@ -88,7 +95,7 @@ class GenericONNXEmbedding:
file_name = os.path.basename(path)
if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path)
elif file_name == self.tokenizer_file and self.model_type == "text":
elif file_name == self.tokenizer_file and self.model_type == ModelTypeEnum.text:
if not os.path.exists(path + "/" + self.model_name):
logger.info(f"Downloading {self.model_name} tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
@ -119,7 +126,7 @@ class GenericONNXEmbedding:
if self.runner is None:
if self.downloader:
self.downloader.wait_for_download()
if self.model_type == "text":
if self.model_type == ModelTypeEnum.text:
self.tokenizer = self._load_tokenizer()
else:
self.feature_extractor = self._load_feature_extractor()
@ -144,7 +151,7 @@ class GenericONNXEmbedding:
)
def _preprocess_inputs(self, raw_inputs: any) -> any:
if self.model_type == "text":
if self.model_type == ModelTypeEnum.text:
max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs)
return [
self.tokenizer(
@ -156,7 +163,7 @@ class GenericONNXEmbedding:
)
for text in raw_inputs
]
elif self.model_type == "image":
elif self.model_type == ModelTypeEnum.vision:
processed_images = [self._process_image(img) for img in raw_inputs]
return [
self.feature_extractor(images=image, return_tensors="np")