Formatting

This commit is contained in:
Nicolas Mowen 2025-08-20 16:15:18 -06:00
parent 0a622ca065
commit 4ac77d1257

View File

@ -159,12 +159,10 @@ class RKNNModelRunner:
self.rknn = RKNNLite(verbose=False)
# Load the RKNN model
if self.rknn.load_rknn(self.model_path) != 0:
logger.error(f"Failed to load RKNN model: {self.model_path}")
raise RuntimeError("Failed to load RKNN model")
# Initialize runtime
if self.rknn.init_runtime() != 0:
logger.error("Failed to initialize RKNN runtime")
raise RuntimeError("Failed to initialize RKNN runtime")
@ -180,23 +178,18 @@ class RKNNModelRunner:
def get_input_names(self) -> list[str]:
"""Get input names for the model."""
# RKNN models typically have standard input names
# For CLIP models, these are usually "input_ids" and "pixel_values"
if self.model_type and "jina-clip" in self.model_type:
if "text" in self.model_type:
return ["input_ids"]
elif "vision" in self.model_type:
return ["pixel_values"]
# Default fallback
return ["input"]
def get_input_width(self) -> int:
"""Get the input width of the model."""
# For CLIP vision models, this is typically 224 or 512
if self.model_type and "jina-clip-v1-vision" in self.model_type:
return 224 # CLIP V1 uses 224x224
return -1
return 224 # CLIP V1 uses 224x224
def run(self, inputs: dict[str, Any]) -> Any:
"""Run inference with the RKNN model."""
@ -204,7 +197,6 @@ class RKNNModelRunner:
raise RuntimeError("RKNN model not loaded")
try:
# Convert inputs to the format expected by RKNN
rknn_inputs = []
input_names = self.get_input_names()
@ -213,7 +205,6 @@ class RKNNModelRunner:
rknn_inputs.append(inputs[name])
else:
logger.warning(f"Input '{name}' not found in inputs")
# Create a dummy input with appropriate shape
if name == "input_ids":
rknn_inputs.append(inputs.get("input_ids", [[0]]))
elif name == "pixel_values":
@ -221,7 +212,6 @@ class RKNNModelRunner:
else:
rknn_inputs.append([[0]])
# Run inference
outputs = self.rknn.inference(inputs=rknn_inputs)
return outputs