mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-25 16:17:41 +03:00
Formatting
This commit is contained in:
parent
0a622ca065
commit
4ac77d1257
@ -159,12 +159,10 @@ class RKNNModelRunner:
|
|||||||
|
|
||||||
self.rknn = RKNNLite(verbose=False)
|
self.rknn = RKNNLite(verbose=False)
|
||||||
|
|
||||||
# Load the RKNN model
|
|
||||||
if self.rknn.load_rknn(self.model_path) != 0:
|
if self.rknn.load_rknn(self.model_path) != 0:
|
||||||
logger.error(f"Failed to load RKNN model: {self.model_path}")
|
logger.error(f"Failed to load RKNN model: {self.model_path}")
|
||||||
raise RuntimeError("Failed to load RKNN model")
|
raise RuntimeError("Failed to load RKNN model")
|
||||||
|
|
||||||
# Initialize runtime
|
|
||||||
if self.rknn.init_runtime() != 0:
|
if self.rknn.init_runtime() != 0:
|
||||||
logger.error("Failed to initialize RKNN runtime")
|
logger.error("Failed to initialize RKNN runtime")
|
||||||
raise RuntimeError("Failed to initialize RKNN runtime")
|
raise RuntimeError("Failed to initialize RKNN runtime")
|
||||||
@ -180,23 +178,18 @@ class RKNNModelRunner:
|
|||||||
|
|
||||||
def get_input_names(self) -> list[str]:
|
def get_input_names(self) -> list[str]:
|
||||||
"""Get input names for the model."""
|
"""Get input names for the model."""
|
||||||
# RKNN models typically have standard input names
|
|
||||||
# For CLIP models, these are usually "input_ids" and "pixel_values"
|
|
||||||
if self.model_type and "jina-clip" in self.model_type:
|
if self.model_type and "jina-clip" in self.model_type:
|
||||||
if "text" in self.model_type:
|
if "text" in self.model_type:
|
||||||
return ["input_ids"]
|
return ["input_ids"]
|
||||||
elif "vision" in self.model_type:
|
elif "vision" in self.model_type:
|
||||||
return ["pixel_values"]
|
return ["pixel_values"]
|
||||||
|
|
||||||
# Default fallback
|
|
||||||
return ["input"]
|
return ["input"]
|
||||||
|
|
||||||
def get_input_width(self) -> int:
|
def get_input_width(self) -> int:
|
||||||
"""Get the input width of the model."""
|
"""Get the input width of the model."""
|
||||||
# For CLIP vision models, this is typically 224 or 512
|
return 224 # CLIP V1 uses 224x224
|
||||||
if self.model_type and "jina-clip-v1-vision" in self.model_type:
|
|
||||||
return 224 # CLIP V1 uses 224x224
|
|
||||||
return -1
|
|
||||||
|
|
||||||
def run(self, inputs: dict[str, Any]) -> Any:
|
def run(self, inputs: dict[str, Any]) -> Any:
|
||||||
"""Run inference with the RKNN model."""
|
"""Run inference with the RKNN model."""
|
||||||
@ -204,7 +197,6 @@ class RKNNModelRunner:
|
|||||||
raise RuntimeError("RKNN model not loaded")
|
raise RuntimeError("RKNN model not loaded")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert inputs to the format expected by RKNN
|
|
||||||
rknn_inputs = []
|
rknn_inputs = []
|
||||||
input_names = self.get_input_names()
|
input_names = self.get_input_names()
|
||||||
|
|
||||||
@ -213,7 +205,6 @@ class RKNNModelRunner:
|
|||||||
rknn_inputs.append(inputs[name])
|
rknn_inputs.append(inputs[name])
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Input '{name}' not found in inputs")
|
logger.warning(f"Input '{name}' not found in inputs")
|
||||||
# Create a dummy input with appropriate shape
|
|
||||||
if name == "input_ids":
|
if name == "input_ids":
|
||||||
rknn_inputs.append(inputs.get("input_ids", [[0]]))
|
rknn_inputs.append(inputs.get("input_ids", [[0]]))
|
||||||
elif name == "pixel_values":
|
elif name == "pixel_values":
|
||||||
@ -221,7 +212,6 @@ class RKNNModelRunner:
|
|||||||
else:
|
else:
|
||||||
rknn_inputs.append([[0]])
|
rknn_inputs.append([[0]])
|
||||||
|
|
||||||
# Run inference
|
|
||||||
outputs = self.rknn.inference(inputs=rknn_inputs)
|
outputs = self.rknn.inference(inputs=rknn_inputs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user