mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
Refactor preprocessing of images
This commit is contained in:
parent
b1285a16c1
commit
dacdc1e0fe
@ -143,6 +143,28 @@ class GenericONNXEmbedding:
|
|||||||
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _preprocess_inputs(self, raw_inputs: any) -> any:
|
||||||
|
if self.model_type == "text":
|
||||||
|
max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs)
|
||||||
|
return [
|
||||||
|
self.tokenizer(
|
||||||
|
text,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
for text in raw_inputs
|
||||||
|
]
|
||||||
|
elif self.model_type == "image":
|
||||||
|
processed_images = [self._process_image(img) for img in raw_inputs]
|
||||||
|
return [
|
||||||
|
self.feature_extractor(images=image, return_tensors="np")
|
||||||
|
for image in processed_images
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unable to preprocess inputs for {self.model_type}")
|
||||||
|
|
||||||
def _process_image(self, image):
|
def _process_image(self, image):
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
if image.startswith("http"):
|
if image.startswith("http"):
|
||||||
@ -163,25 +185,7 @@ class GenericONNXEmbedding:
|
|||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if self.model_type == "text":
|
processed_inputs = self._preprocess_inputs(inputs)
|
||||||
max_length = max(len(self.tokenizer.encode(text)) for text in inputs)
|
|
||||||
processed_inputs = [
|
|
||||||
self.tokenizer(
|
|
||||||
text,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_length,
|
|
||||||
return_tensors="np",
|
|
||||||
)
|
|
||||||
for text in inputs
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
processed_images = [self._process_image(img) for img in inputs]
|
|
||||||
processed_inputs = [
|
|
||||||
self.feature_extractor(images=image, return_tensors="np")
|
|
||||||
for image in processed_images
|
|
||||||
]
|
|
||||||
|
|
||||||
input_names = self.runner.get_input_names()
|
input_names = self.runner.get_input_names()
|
||||||
onnx_inputs = {name: [] for name in input_names}
|
onnx_inputs = {name: [] for name in input_names}
|
||||||
input: dict[str, any]
|
input: dict[str, any]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user