diff --git a/frigate/embeddings/onnx/runner.py b/frigate/embeddings/onnx/runner.py index 34db4c7fc..6eea44c86 100644 --- a/frigate/embeddings/onnx/runner.py +++ b/frigate/embeddings/onnx/runner.py @@ -159,12 +159,10 @@ class RKNNModelRunner: self.rknn = RKNNLite(verbose=False) - # Load the RKNN model if self.rknn.load_rknn(self.model_path) != 0: logger.error(f"Failed to load RKNN model: {self.model_path}") raise RuntimeError("Failed to load RKNN model") - # Initialize runtime if self.rknn.init_runtime() != 0: logger.error("Failed to initialize RKNN runtime") raise RuntimeError("Failed to initialize RKNN runtime") @@ -180,23 +178,18 @@ class RKNNModelRunner: def get_input_names(self) -> list[str]: """Get input names for the model.""" - # RKNN models typically have standard input names - # For CLIP models, these are usually "input_ids" and "pixel_values" + if self.model_type and "jina-clip" in self.model_type: if "text" in self.model_type: return ["input_ids"] elif "vision" in self.model_type: return ["pixel_values"] - # Default fallback return ["input"] def get_input_width(self) -> int: """Get the input width of the model.""" - # For CLIP vision models, this is typically 224 or 512 - if self.model_type and "jina-clip-v1-vision" in self.model_type: - return 224 # CLIP V1 uses 224x224 - return -1 + return 224 # CLIP V1 uses 224x224 def run(self, inputs: dict[str, Any]) -> Any: """Run inference with the RKNN model.""" @@ -204,7 +197,6 @@ class RKNNModelRunner: raise RuntimeError("RKNN model not loaded") try: - # Convert inputs to the format expected by RKNN rknn_inputs = [] input_names = self.get_input_names() @@ -213,7 +205,6 @@ class RKNNModelRunner: rknn_inputs.append(inputs[name]) else: logger.warning(f"Input '{name}' not found in inputs") - # Create a dummy input with appropriate shape if name == "input_ids": rknn_inputs.append(inputs.get("input_ids", [[0]])) elif name == "pixel_values": @@ -221,7 +212,6 @@ class RKNNModelRunner: else: rknn_inputs.append([[0]]) - # Run inference outputs = self.rknn.inference(inputs=rknn_inputs) return outputs