Refactor onnx embeddings to handle multiple inputs by default

This commit is contained in:
Nicolas Mowen 2024-10-13 10:26:27 -06:00
parent 0fc7999780
commit a28175cd79

View File

@ -2,7 +2,7 @@ import logging
import os
import warnings
from io import BytesIO
from typing import Callable, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union
import numpy as np
import requests
@ -39,7 +39,6 @@ class GenericONNXEmbedding:
model_name: str,
model_file: str,
download_urls: Dict[str, str],
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
model_size: str,
model_type: str,
requestor: InterProcessRequestor,
@ -51,7 +50,6 @@ class GenericONNXEmbedding:
self.tokenizer_file = tokenizer_file
self.requestor = requestor
self.download_urls = download_urls
self.embedding_function = embedding_function
self.model_type = model_type # 'text' or 'vision'
self.model_size = model_size
self.device = device
@ -157,7 +155,6 @@ class GenericONNXEmbedding:
self, inputs: Union[List[str], List[Image.Image], List[str]]
) -> List[np.ndarray]:
self._load_model_and_tokenizer()
if self.runner is None or (
self.tokenizer is None and self.feature_extractor is None
):
@ -167,23 +164,27 @@ class GenericONNXEmbedding:
return []
if self.model_type == "text":
processed_inputs = self.tokenizer(
inputs, padding=True, truncation=True, return_tensors="np"
)
processed_inputs = [
self.tokenizer(text, padding=True, truncation=True, return_tensors="np")
for text in inputs
]
else:
processed_images = [self._process_image(img) for img in inputs]
processed_inputs = self.feature_extractor(
images=processed_images, return_tensors="np"
)
processed_inputs = [
self.feature_extractor(images=image, return_tensors="np")
for image in processed_images
]
input_names = self.runner.get_input_names()
onnx_inputs = {
name: processed_inputs[name]
for name in input_names
if name in processed_inputs
}
onnx_inputs = {name: [] for name in input_names}
input: dict[str, any]
for input in processed_inputs:
for key, value in input.items():
if key in input_names:
onnx_inputs[key].append(value[0])
outputs = self.runner.run(onnx_inputs)
embeddings = self.embedding_function(outputs)
for key in onnx_inputs.keys():
onnx_inputs[key] = np.array(onnx_inputs[key])
embeddings = self.runner.run(onnx_inputs)[0]
return [embedding for embedding in embeddings]