Formatting

This commit is contained in:
Nicolas Mowen 2025-08-20 16:15:18 -06:00
parent 0a622ca065
commit 4ac77d1257

View File

@ -159,12 +159,10 @@ class RKNNModelRunner:
self.rknn = RKNNLite(verbose=False) self.rknn = RKNNLite(verbose=False)
# Load the RKNN model
if self.rknn.load_rknn(self.model_path) != 0: if self.rknn.load_rknn(self.model_path) != 0:
logger.error(f"Failed to load RKNN model: {self.model_path}") logger.error(f"Failed to load RKNN model: {self.model_path}")
raise RuntimeError("Failed to load RKNN model") raise RuntimeError("Failed to load RKNN model")
# Initialize runtime
if self.rknn.init_runtime() != 0: if self.rknn.init_runtime() != 0:
logger.error("Failed to initialize RKNN runtime") logger.error("Failed to initialize RKNN runtime")
raise RuntimeError("Failed to initialize RKNN runtime") raise RuntimeError("Failed to initialize RKNN runtime")
@ -180,23 +178,18 @@ class RKNNModelRunner:
def get_input_names(self) -> list[str]: def get_input_names(self) -> list[str]:
"""Get input names for the model.""" """Get input names for the model."""
# RKNN models typically have standard input names
# For CLIP models, these are usually "input_ids" and "pixel_values"
if self.model_type and "jina-clip" in self.model_type: if self.model_type and "jina-clip" in self.model_type:
if "text" in self.model_type: if "text" in self.model_type:
return ["input_ids"] return ["input_ids"]
elif "vision" in self.model_type: elif "vision" in self.model_type:
return ["pixel_values"] return ["pixel_values"]
# Default fallback
return ["input"] return ["input"]
def get_input_width(self) -> int: def get_input_width(self) -> int:
"""Get the input width of the model.""" """Get the input width of the model."""
# For CLIP vision models, this is typically 224 or 512 return 224 # CLIP V1 uses 224x224
if self.model_type and "jina-clip-v1-vision" in self.model_type:
return 224 # CLIP V1 uses 224x224
return -1
def run(self, inputs: dict[str, Any]) -> Any: def run(self, inputs: dict[str, Any]) -> Any:
"""Run inference with the RKNN model.""" """Run inference with the RKNN model."""
@ -204,7 +197,6 @@ class RKNNModelRunner:
raise RuntimeError("RKNN model not loaded") raise RuntimeError("RKNN model not loaded")
try: try:
# Convert inputs to the format expected by RKNN
rknn_inputs = [] rknn_inputs = []
input_names = self.get_input_names() input_names = self.get_input_names()
@ -213,7 +205,6 @@ class RKNNModelRunner:
rknn_inputs.append(inputs[name]) rknn_inputs.append(inputs[name])
else: else:
logger.warning(f"Input '{name}' not found in inputs") logger.warning(f"Input '{name}' not found in inputs")
# Create a dummy input with appropriate shape
if name == "input_ids": if name == "input_ids":
rknn_inputs.append(inputs.get("input_ids", [[0]])) rknn_inputs.append(inputs.get("input_ids", [[0]]))
elif name == "pixel_values": elif name == "pixel_values":
@ -221,7 +212,6 @@ class RKNNModelRunner:
else: else:
rknn_inputs.append([[0]]) rknn_inputs.append([[0]])
# Run inference
outputs = self.rknn.inference(inputs=rknn_inputs) outputs = self.rknn.inference(inputs=rknn_inputs)
return outputs return outputs