Fix model complexity causing crash

This commit is contained in:
Nicolas Mowen 2025-10-29 08:11:53 -06:00
parent 4405070b50
commit baf6df5593

View File

@ -21,22 +21,26 @@ def is_arm64_platform() -> bool:
return machine in ("aarch64", "arm64", "armv8", "armv7l") return machine in ("aarch64", "arm64", "armv8", "armv7l")
def get_ort_session_options() -> ort.SessionOptions | None: def get_ort_session_options(
is_complex_model: bool = False,
) -> ort.SessionOptions | None:
"""Get ONNX Runtime session options with appropriate settings. """Get ONNX Runtime session options with appropriate settings.
On ARM/RKNN platforms, use basic optimizations to avoid graph fusion issues Args:
that can break certain models. On amd64, use default optimizations for better performance. is_complex_model: Whether the model needs basic optimization to avoid graph fusion issues.
"""
sess_options = None
if is_arm64_platform(): Returns:
SessionOptions with appropriate optimization level, or None for default settings.
"""
if is_complex_model:
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ( sess_options.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
) )
return sess_options return sess_options
return None
# Import OpenVINO only when needed to avoid circular dependencies # Import OpenVINO only when needed to avoid circular dependencies
try: try:
@ -103,6 +107,21 @@ class BaseModelRunner(ABC):
class ONNXModelRunner(BaseModelRunner): class ONNXModelRunner(BaseModelRunner):
"""Run ONNX models using ONNX Runtime.""" """Run ONNX models using ONNX Runtime."""
@staticmethod
def is_cpu_complex_model(model_type: str) -> bool:
"""Check if model needs basic optimization level to avoid graph fusion issues.
Some models (like Jina-CLIP) have issues with aggressive optimizations like
SimplifiedLayerNormFusion that create or expect nodes that don't exist.
"""
# Import here to avoid circular imports
from frigate.embeddings.types import EnrichmentModelTypeEnum
return model_type in [
EnrichmentModelTypeEnum.jina_v1.value,
EnrichmentModelTypeEnum.jina_v2.value,
]
@staticmethod @staticmethod
def is_migraphx_complex_model(model_type: str) -> bool: def is_migraphx_complex_model(model_type: str) -> bool:
# Import here to avoid circular imports # Import here to avoid circular imports
@ -496,7 +515,9 @@ def get_optimized_runner(
return ONNXModelRunner( return ONNXModelRunner(
ort.InferenceSession( ort.InferenceSession(
model_path, model_path,
sess_options=get_ort_session_options(), sess_options=get_ort_session_options(
ONNXModelRunner.is_cpu_complex_model(model_type)
),
providers=providers, providers=providers,
provider_options=options, provider_options=options,
) )