diff --git a/frigate/detectors/detection_runners.py b/frigate/detectors/detection_runners.py index 0b4f319a8..9dadb16fa 100644 --- a/frigate/detectors/detection_runners.py +++ b/frigate/detectors/detection_runners.py @@ -21,21 +21,25 @@ def is_arm64_platform() -> bool: return machine in ("aarch64", "arm64", "armv8", "armv7l") -def get_ort_session_options() -> ort.SessionOptions | None: +def get_ort_session_options( + is_complex_model: bool = False, +) -> ort.SessionOptions | None: """Get ONNX Runtime session options with appropriate settings. - On ARM/RKNN platforms, use basic optimizations to avoid graph fusion issues - that can break certain models. On amd64, use default optimizations for better performance. - """ - sess_options = None + Args: + is_complex_model: Whether the model needs basic optimization to avoid graph fusion issues. - if is_arm64_platform(): + Returns: + SessionOptions with appropriate optimization level, or None for default settings. + """ + if is_complex_model: sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ( ort.GraphOptimizationLevel.ORT_ENABLE_BASIC ) + return sess_options - return sess_options + return None # Import OpenVINO only when needed to avoid circular dependencies @@ -103,6 +107,21 @@ class BaseModelRunner(ABC): class ONNXModelRunner(BaseModelRunner): """Run ONNX models using ONNX Runtime.""" + @staticmethod + def is_cpu_complex_model(model_type: str) -> bool: + """Check if model needs basic optimization level to avoid graph fusion issues. + + Some models (like Jina-CLIP) have issues with aggressive optimizations like + SimplifiedLayerNormFusion that create or expect nodes that don't exist. + """ + # Import here to avoid circular imports + from frigate.embeddings.types import EnrichmentModelTypeEnum + + return model_type in [ + EnrichmentModelTypeEnum.jina_v1.value, + EnrichmentModelTypeEnum.jina_v2.value, + ] + @staticmethod def is_migraphx_complex_model(model_type: str) -> bool: # Import here to avoid circular imports @@ -496,7 +515,9 @@ def get_optimized_runner( return ONNXModelRunner( ort.InferenceSession( model_path, - sess_options=get_ort_session_options(), + sess_options=get_ort_session_options( + ONNXModelRunner.is_cpu_complex_model(model_type) + ), providers=providers, provider_options=options, )