diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index 574822d59..c669bcf73 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -143,6 +143,28 @@ class GenericONNXEmbedding: f"{MODEL_CACHE_DIR}/{self.model_name}", ) + def _preprocess_inputs(self, raw_inputs: any) -> any: + if self.model_type == "text": + max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs) + return [ + self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=max_length, + return_tensors="np", + ) + for text in raw_inputs + ] + elif self.model_type == "image": + processed_images = [self._process_image(img) for img in raw_inputs] + return [ + self.feature_extractor(images=image, return_tensors="np") + for image in processed_images + ] + else: + raise ValueError(f"Unable to preprocess inputs for {self.model_type}") + def _process_image(self, image): if isinstance(image, str): if image.startswith("http"): @@ -163,25 +185,7 @@ class GenericONNXEmbedding: ) return [] - if self.model_type == "text": - max_length = max(len(self.tokenizer.encode(text)) for text in inputs) - processed_inputs = [ - self.tokenizer( - text, - padding="max_length", - truncation=True, - max_length=max_length, - return_tensors="np", - ) - for text in inputs - ] - else: - processed_images = [self._process_image(img) for img in inputs] - processed_inputs = [ - self.feature_extractor(images=image, return_tensors="np") - for image in processed_images - ] - + processed_inputs = self._preprocess_inputs(inputs) input_names = self.runner.get_input_names() onnx_inputs = {name: [] for name in input_names} input: dict[str, any]