From 2283b325751de65922a467ba8788bac7865b2355 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Wed, 20 Aug 2025 16:40:28 -0600 Subject: [PATCH] Handle model type inference --- frigate/embeddings/onnx/runner.py | 2 +- frigate/util/rknn_converter.py | 28 +++++++++++++++++----------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/frigate/embeddings/onnx/runner.py b/frigate/embeddings/onnx/runner.py index 6eea44c86..0340ba1fe 100644 --- a/frigate/embeddings/onnx/runner.py +++ b/frigate/embeddings/onnx/runner.py @@ -30,7 +30,7 @@ class ONNXModelRunner: self.type = "ort" try: - if is_rknn_compatible(model_path): + if device != "CPU" and is_rknn_compatible(model_path): # Try to auto-convert to RKNN format rknn_path = auto_convert_model(model_path) if rknn_path: diff --git a/frigate/util/rknn_converter.py b/frigate/util/rknn_converter.py index 25801f920..5b4c61c18 100644 --- a/frigate/util/rknn_converter.py +++ b/frigate/util/rknn_converter.py @@ -38,6 +38,16 @@ MODEL_TYPE_CONFIGS = { }, } +def get_rknn_model_type(model_path: str) -> str | None: + if all(keyword in model_path for keyword in ["jina-clip-v1", "vision"]): + return "jina-clip-v1-vision" + + model_name = os.path.basename(model_path).lower() + + if any(keyword in model_name for keyword in ["yolo", "yolox", "yolonas"]): + return model_name + + return None def is_rknn_compatible(model_path: str, model_type: str | None = None) -> bool: """ @@ -57,17 +67,12 @@ def is_rknn_compatible(model_path: str, model_type: str | None = None) -> bool: if soc is None: return False + if not model_type: + model_type = get_rknn_model_type(model_path) + if model_type and model_type in MODEL_TYPE_CONFIGS: return True - model_name = os.path.basename(model_path).lower() - - if any(keyword in model_name for keyword in ["jina", "clip", "vision"]): - return True - - if any(keyword in model_name for keyword in ["yolo", "yolox", "yolonas"]): - return True - return False @@ -108,9 +113,7 @@ def ensure_torch_dependencies() -> bool: def ensure_rknn_toolkit() -> bool: """Ensure RKNN toolkit is available.""" try: - import rknn # type: ignore # noqa: F401 from rknn.api import RKNN # type: ignore # noqa: F401 - logger.debug("RKNN toolkit is already available") return True except ImportError: @@ -379,7 +382,7 @@ def wait_for_conversion_completion( def auto_convert_model( - model_path: str, model_type: str, quantization: bool = False + model_path: str, model_type: str | None = None, quantization: bool = False ) -> Optional[str]: """ Automatically convert a model to RKNN format if needed. @@ -418,6 +421,9 @@ def auto_convert_model( logger.info(f"Converting {model_path} to RKNN format...") rknn_path.parent.mkdir(parents=True, exist_ok=True) + if not model_type: + model_type = get_rknn_model_type(base_path) + if convert_onnx_to_rknn( str(base_path), str(rknn_path), model_type, quantization ):