mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-16 20:12:11 +03:00
Properly provide input to RKNN
This commit is contained in:
parent
2283b32575
commit
b32e8681a4
@ -4,6 +4,7 @@ import logging
|
||||
import os.path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
@ -121,7 +122,7 @@ class ONNXModelRunner:
|
||||
return -1
|
||||
return -1
|
||||
|
||||
def run(self, input: dict[str, Any]) -> Any:
|
||||
def run(self, input: dict[str, Any]) -> Any | None:
|
||||
if self.type == "rknn":
|
||||
return self.rknn.run(input)
|
||||
elif self.type == "ov":
|
||||
@ -178,18 +179,27 @@ class RKNNModelRunner:
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
"""Get input names for the model."""
|
||||
|
||||
if self.model_type and "jina-clip" in self.model_type:
|
||||
if "text" in self.model_type:
|
||||
return ["input_ids"]
|
||||
elif "vision" in self.model_type:
|
||||
return ["pixel_values"]
|
||||
|
||||
return ["input"]
|
||||
# For CLIP models, we need to determine the model type from the path
|
||||
model_name = os.path.basename(self.model_path).lower()
|
||||
|
||||
if "vision" in model_name:
|
||||
return ["pixel_values"]
|
||||
else:
|
||||
# Default fallback - try to infer from model type
|
||||
if self.model_type and "jina-clip" in self.model_type:
|
||||
if "vision" in self.model_type:
|
||||
return ["pixel_values"]
|
||||
|
||||
# Generic fallback
|
||||
return ["input"]
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
"""Get the input width of the model."""
|
||||
return 224 # CLIP V1 uses 224x224
|
||||
# For CLIP vision models, this is typically 224
|
||||
model_name = os.path.basename(self.model_path).lower()
|
||||
if "vision" in model_name:
|
||||
return 224 # CLIP V1 uses 224x224
|
||||
return -1
|
||||
|
||||
def run(self, inputs: dict[str, Any]) -> Any:
|
||||
"""Run inference with the RKNN model."""
|
||||
@ -197,20 +207,41 @@ class RKNNModelRunner:
|
||||
raise RuntimeError("RKNN model not loaded")
|
||||
|
||||
try:
|
||||
rknn_inputs = []
|
||||
input_names = self.get_input_names()
|
||||
|
||||
rknn_inputs = []
|
||||
|
||||
for name in input_names:
|
||||
if name in inputs:
|
||||
rknn_inputs.append(inputs[name])
|
||||
else:
|
||||
logger.warning(f"Input '{name}' not found in inputs")
|
||||
if name == "input_ids":
|
||||
rknn_inputs.append(inputs.get("input_ids", [[0]]))
|
||||
elif name == "pixel_values":
|
||||
rknn_inputs.append(inputs.get("pixel_values", [[[[0]]]]))
|
||||
if name == "pixel_values":
|
||||
# RKNN expects NHWC format, but ONNX typically provides NCHW
|
||||
# Transpose from [batch, channels, height, width] to [batch, height, width, channels]
|
||||
pixel_data = inputs[name]
|
||||
if len(pixel_data.shape) == 4 and pixel_data.shape[1] == 3:
|
||||
# Transpose from NCHW to NHWC
|
||||
pixel_data = np.transpose(pixel_data, (0, 2, 3, 1))
|
||||
rknn_inputs.append(pixel_data)
|
||||
else:
|
||||
rknn_inputs.append([[0]])
|
||||
rknn_inputs.append(inputs[name])
|
||||
else:
|
||||
logger.warning(f"Input '{name}' not found in inputs, using default")
|
||||
|
||||
if name == "pixel_values":
|
||||
batch_size = 1
|
||||
if inputs:
|
||||
for val in inputs.values():
|
||||
if hasattr(val, 'shape') and len(val.shape) > 0:
|
||||
batch_size = val.shape[0]
|
||||
break
|
||||
# Create default in NHWC format as expected by RKNN
|
||||
rknn_inputs.append(np.zeros((batch_size, 224, 224, 3), dtype=np.float32))
|
||||
else:
|
||||
batch_size = 1
|
||||
if inputs:
|
||||
for val in inputs.values():
|
||||
if hasattr(val, 'shape') and len(val.shape) > 0:
|
||||
batch_size = val.shape[0]
|
||||
break
|
||||
rknn_inputs.append(np.zeros((batch_size, 1), dtype=np.float32))
|
||||
|
||||
outputs = self.rknn.inference(inputs=rknn_inputs)
|
||||
return outputs
|
||||
|
||||
Loading…
Reference in New Issue
Block a user