Refactor preprocessing of images

This commit is contained in:
Nicolas Mowen 2024-10-21 15:25:49 -06:00
parent b1285a16c1
commit dacdc1e0fe

View File

@ -143,6 +143,28 @@ class GenericONNXEmbedding:
f"{MODEL_CACHE_DIR}/{self.model_name}", f"{MODEL_CACHE_DIR}/{self.model_name}",
) )
def _preprocess_inputs(self, raw_inputs: any) -> any:
if self.model_type == "text":
max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs)
return [
self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="np",
)
for text in raw_inputs
]
elif self.model_type == "image":
processed_images = [self._process_image(img) for img in raw_inputs]
return [
self.feature_extractor(images=image, return_tensors="np")
for image in processed_images
]
else:
raise ValueError(f"Unable to preprocess inputs for {self.model_type}")
def _process_image(self, image): def _process_image(self, image):
if isinstance(image, str): if isinstance(image, str):
if image.startswith("http"): if image.startswith("http"):
@ -163,25 +185,7 @@ class GenericONNXEmbedding:
) )
return [] return []
if self.model_type == "text": processed_inputs = self._preprocess_inputs(inputs)
max_length = max(len(self.tokenizer.encode(text)) for text in inputs)
processed_inputs = [
self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors="np",
)
for text in inputs
]
else:
processed_images = [self._process_image(img) for img in inputs]
processed_inputs = [
self.feature_extractor(images=image, return_tensors="np")
for image in processed_images
]
input_names = self.runner.get_input_names() input_names = self.runner.get_input_names()
onnx_inputs = {name: [] for name in input_names} onnx_inputs = {name: [] for name in input_names}
input: dict[str, any] input: dict[str, any]