mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-16 20:12:11 +03:00
Create RKNN model runner and and use for jina v1 clip
This commit is contained in:
parent
2236ecf23f
commit
0a622ca065
@ -8,6 +8,7 @@ import onnxruntime as ort
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR
|
||||
from frigate.util.model import get_ort_providers
|
||||
from frigate.util.rknn_converter import is_rknn_compatible, auto_convert_model
|
||||
|
||||
try:
|
||||
import openvino as ov
|
||||
@ -25,7 +26,33 @@ class ONNXModelRunner:
|
||||
self.model_path = model_path
|
||||
self.ort: ort.InferenceSession = None
|
||||
self.ov: ov.Core = None
|
||||
providers, options = get_ort_providers(device == "CPU", device, requires_fp16)
|
||||
self.rknn = None
|
||||
self.type = "ort"
|
||||
|
||||
try:
|
||||
if is_rknn_compatible(model_path):
|
||||
# Try to auto-convert to RKNN format
|
||||
rknn_path = auto_convert_model(model_path)
|
||||
if rknn_path:
|
||||
try:
|
||||
self.rknn = RKNNModelRunner(rknn_path, device)
|
||||
self.type = "rknn"
|
||||
logger.info(f"Using RKNN model: {rknn_path}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to load RKNN model, falling back to ONNX: {e}"
|
||||
)
|
||||
self.rknn = None
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Fall back to standard ONNX providers
|
||||
providers, options = get_ort_providers(
|
||||
device == "CPU",
|
||||
device,
|
||||
requires_fp16,
|
||||
)
|
||||
self.interpreter = None
|
||||
|
||||
if "OpenVINOExecutionProvider" in providers:
|
||||
@ -55,7 +82,9 @@ class ONNXModelRunner:
|
||||
)
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
if self.type == "ov":
|
||||
if self.type == "rknn":
|
||||
return self.rknn.get_input_names()
|
||||
elif self.type == "ov":
|
||||
input_names = []
|
||||
|
||||
for input in self.interpreter.inputs:
|
||||
@ -67,7 +96,9 @@ class ONNXModelRunner:
|
||||
|
||||
def get_input_width(self):
|
||||
"""Get the input width of the model regardless of backend."""
|
||||
if self.type == "ort":
|
||||
if self.type == "rknn":
|
||||
return self.rknn.get_input_width()
|
||||
elif self.type == "ort":
|
||||
return self.ort.get_inputs()[0].shape[3]
|
||||
elif self.type == "ov":
|
||||
input_info = self.interpreter.inputs
|
||||
@ -91,7 +122,9 @@ class ONNXModelRunner:
|
||||
return -1
|
||||
|
||||
def run(self, input: dict[str, Any]) -> Any:
|
||||
if self.type == "ov":
|
||||
if self.type == "rknn":
|
||||
return self.rknn.run(input)
|
||||
elif self.type == "ov":
|
||||
infer_request = self.interpreter.create_infer_request()
|
||||
|
||||
try:
|
||||
@ -107,3 +140,99 @@ class ONNXModelRunner:
|
||||
return outputs
|
||||
elif self.type == "ort":
|
||||
return self.ort.run(None, input)
|
||||
|
||||
|
||||
class RKNNModelRunner:
|
||||
"""Run RKNN models for embeddings."""
|
||||
|
||||
def __init__(self, model_path: str, device: str = "AUTO", model_type: str = None):
|
||||
self.model_path = model_path
|
||||
self.device = device
|
||||
self.model_type = model_type
|
||||
self.rknn = None
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""Load the RKNN model."""
|
||||
try:
|
||||
from rknnlite.api import RKNNLite
|
||||
|
||||
self.rknn = RKNNLite(verbose=False)
|
||||
|
||||
# Load the RKNN model
|
||||
if self.rknn.load_rknn(self.model_path) != 0:
|
||||
logger.error(f"Failed to load RKNN model: {self.model_path}")
|
||||
raise RuntimeError("Failed to load RKNN model")
|
||||
|
||||
# Initialize runtime
|
||||
if self.rknn.init_runtime() != 0:
|
||||
logger.error("Failed to initialize RKNN runtime")
|
||||
raise RuntimeError("Failed to initialize RKNN runtime")
|
||||
|
||||
logger.info(f"Successfully loaded RKNN model: {self.model_path}")
|
||||
|
||||
except ImportError:
|
||||
logger.error("RKNN Lite not available")
|
||||
raise ImportError("RKNN Lite not available")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading RKNN model: {e}")
|
||||
raise
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
"""Get input names for the model."""
|
||||
# RKNN models typically have standard input names
|
||||
# For CLIP models, these are usually "input_ids" and "pixel_values"
|
||||
if self.model_type and "jina-clip" in self.model_type:
|
||||
if "text" in self.model_type:
|
||||
return ["input_ids"]
|
||||
elif "vision" in self.model_type:
|
||||
return ["pixel_values"]
|
||||
|
||||
# Default fallback
|
||||
return ["input"]
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
"""Get the input width of the model."""
|
||||
# For CLIP vision models, this is typically 224 or 512
|
||||
if self.model_type and "jina-clip-v1-vision" in self.model_type:
|
||||
return 224 # CLIP V1 uses 224x224
|
||||
return -1
|
||||
|
||||
def run(self, inputs: dict[str, Any]) -> Any:
|
||||
"""Run inference with the RKNN model."""
|
||||
if not self.rknn:
|
||||
raise RuntimeError("RKNN model not loaded")
|
||||
|
||||
try:
|
||||
# Convert inputs to the format expected by RKNN
|
||||
rknn_inputs = []
|
||||
input_names = self.get_input_names()
|
||||
|
||||
for name in input_names:
|
||||
if name in inputs:
|
||||
rknn_inputs.append(inputs[name])
|
||||
else:
|
||||
logger.warning(f"Input '{name}' not found in inputs")
|
||||
# Create a dummy input with appropriate shape
|
||||
if name == "input_ids":
|
||||
rknn_inputs.append(inputs.get("input_ids", [[0]]))
|
||||
elif name == "pixel_values":
|
||||
rknn_inputs.append(inputs.get("pixel_values", [[[[0]]]]))
|
||||
else:
|
||||
rknn_inputs.append([[0]])
|
||||
|
||||
# Run inference
|
||||
outputs = self.rknn.inference(inputs=rknn_inputs)
|
||||
return outputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during RKNN inference: {e}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup when the runner is destroyed."""
|
||||
if self.rknn:
|
||||
try:
|
||||
self.rknn.release()
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -27,9 +27,50 @@ MODEL_TYPE_CONFIGS = {
|
||||
"std_values": [[255, 255, 255]],
|
||||
"target_platform": None, # Will be set dynamically
|
||||
},
|
||||
"jina-clip-v1-vision": {
|
||||
"mean_values": [
|
||||
[0.48145466, 0.4578275, 0.40821073]
|
||||
], # CLIP standard normalization
|
||||
"std_values": [
|
||||
[0.26862954, 0.26130258, 0.27577711]
|
||||
], # CLIP standard normalization
|
||||
"target_platform": None, # Will be set dynamically
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def is_rknn_compatible(model_path: str, model_type: str | None = None) -> bool:
|
||||
"""
|
||||
Check if a model is compatible with RKNN conversion.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model file
|
||||
model_type: Type of the model (if known)
|
||||
|
||||
Returns:
|
||||
True if the model is RKNN-compatible, False otherwise
|
||||
"""
|
||||
if not ensure_rknn_toolkit():
|
||||
return False
|
||||
|
||||
soc = get_soc_type()
|
||||
if soc is None:
|
||||
return False
|
||||
|
||||
if model_type and model_type in MODEL_TYPE_CONFIGS:
|
||||
return True
|
||||
|
||||
model_name = os.path.basename(model_path).lower()
|
||||
|
||||
if any(keyword in model_name for keyword in ["jina", "clip", "vision"]):
|
||||
return True
|
||||
|
||||
if any(keyword in model_name for keyword in ["yolo", "yolox", "yolonas"]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def ensure_torch_dependencies() -> bool:
|
||||
"""Dynamically install torch dependencies if not available."""
|
||||
try:
|
||||
@ -109,11 +150,11 @@ def convert_onnx_to_rknn(
|
||||
True if conversion successful, False otherwise
|
||||
"""
|
||||
if not ensure_torch_dependencies():
|
||||
logger.error("PyTorch dependencies not available")
|
||||
logger.debug("PyTorch dependencies not available")
|
||||
return False
|
||||
|
||||
if not ensure_rknn_toolkit():
|
||||
logger.error("RKNN toolkit not available")
|
||||
logger.debug("RKNN toolkit not available")
|
||||
return False
|
||||
|
||||
# Get SoC type if not provided
|
||||
@ -125,7 +166,7 @@ def convert_onnx_to_rknn(
|
||||
|
||||
# Get model config for the specified type
|
||||
if model_type not in MODEL_TYPE_CONFIGS:
|
||||
logger.error(f"Unsupported model type: {model_type}")
|
||||
logger.debug(f"Unsupported model type: {model_type}")
|
||||
return False
|
||||
|
||||
config = MODEL_TYPE_CONFIGS[model_type].copy()
|
||||
@ -265,7 +306,7 @@ def is_lock_stale(lock_file_path: Path, max_age: int = 600) -> bool:
|
||||
|
||||
|
||||
def wait_for_conversion_completion(
|
||||
rknn_path: Path, lock_file_path: Path, timeout: int = 300
|
||||
model_type: str, rknn_path: Path, lock_file_path: Path, timeout: int = 300
|
||||
) -> bool:
|
||||
"""
|
||||
Wait for another process to complete the conversion.
|
||||
@ -320,7 +361,7 @@ def wait_for_conversion_completion(
|
||||
|
||||
if onnx_path.exists():
|
||||
if convert_onnx_to_rknn(
|
||||
str(onnx_path), str(rknn_path), "yolo-generic", False
|
||||
str(onnx_path), str(rknn_path), model_type, False
|
||||
):
|
||||
return str(rknn_path)
|
||||
|
||||
@ -392,7 +433,7 @@ def auto_convert_model(
|
||||
f"Another process is converting {model_path}, waiting for completion..."
|
||||
)
|
||||
|
||||
if wait_for_conversion_completion(rknn_path, lock_file_path):
|
||||
if wait_for_conversion_completion(model_type, rknn_path, lock_file_path):
|
||||
return str(rknn_path)
|
||||
else:
|
||||
logger.error(f"Timeout waiting for conversion of {model_path}")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user