mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 07:35:27 +03:00
Refactor onnx embeddings to handle multiple inputs by default
This commit is contained in:
parent
0fc7999780
commit
a28175cd79
@ -2,7 +2,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@ -39,7 +39,6 @@ class GenericONNXEmbedding:
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
model_file: str,
|
model_file: str,
|
||||||
download_urls: Dict[str, str],
|
download_urls: Dict[str, str],
|
||||||
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
|
|
||||||
model_size: str,
|
model_size: str,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
requestor: InterProcessRequestor,
|
requestor: InterProcessRequestor,
|
||||||
@ -51,7 +50,6 @@ class GenericONNXEmbedding:
|
|||||||
self.tokenizer_file = tokenizer_file
|
self.tokenizer_file = tokenizer_file
|
||||||
self.requestor = requestor
|
self.requestor = requestor
|
||||||
self.download_urls = download_urls
|
self.download_urls = download_urls
|
||||||
self.embedding_function = embedding_function
|
|
||||||
self.model_type = model_type # 'text' or 'vision'
|
self.model_type = model_type # 'text' or 'vision'
|
||||||
self.model_size = model_size
|
self.model_size = model_size
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -157,7 +155,6 @@ class GenericONNXEmbedding:
|
|||||||
self, inputs: Union[List[str], List[Image.Image], List[str]]
|
self, inputs: Union[List[str], List[Image.Image], List[str]]
|
||||||
) -> List[np.ndarray]:
|
) -> List[np.ndarray]:
|
||||||
self._load_model_and_tokenizer()
|
self._load_model_and_tokenizer()
|
||||||
|
|
||||||
if self.runner is None or (
|
if self.runner is None or (
|
||||||
self.tokenizer is None and self.feature_extractor is None
|
self.tokenizer is None and self.feature_extractor is None
|
||||||
):
|
):
|
||||||
@ -167,23 +164,27 @@ class GenericONNXEmbedding:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
if self.model_type == "text":
|
if self.model_type == "text":
|
||||||
processed_inputs = self.tokenizer(
|
processed_inputs = [
|
||||||
inputs, padding=True, truncation=True, return_tensors="np"
|
self.tokenizer(text, padding=True, truncation=True, return_tensors="np")
|
||||||
)
|
for text in inputs
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
processed_images = [self._process_image(img) for img in inputs]
|
processed_images = [self._process_image(img) for img in inputs]
|
||||||
processed_inputs = self.feature_extractor(
|
processed_inputs = [
|
||||||
images=processed_images, return_tensors="np"
|
self.feature_extractor(images=image, return_tensors="np")
|
||||||
)
|
for image in processed_images
|
||||||
|
]
|
||||||
|
|
||||||
input_names = self.runner.get_input_names()
|
input_names = self.runner.get_input_names()
|
||||||
onnx_inputs = {
|
onnx_inputs = {name: [] for name in input_names}
|
||||||
name: processed_inputs[name]
|
input: dict[str, any]
|
||||||
for name in input_names
|
for input in processed_inputs:
|
||||||
if name in processed_inputs
|
for key, value in input.items():
|
||||||
}
|
if key in input_names:
|
||||||
|
onnx_inputs[key].append(value[0])
|
||||||
|
|
||||||
outputs = self.runner.run(onnx_inputs)
|
for key in onnx_inputs.keys():
|
||||||
embeddings = self.embedding_function(outputs)
|
onnx_inputs[key] = np.array(onnx_inputs[key])
|
||||||
|
|
||||||
|
embeddings = self.runner.run(onnx_inputs)[0]
|
||||||
return [embedding for embedding in embeddings]
|
return [embedding for embedding in embeddings]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user