Handle complex migraphx models

This commit is contained in:
Nicolas Mowen 2025-09-27 06:51:29 -06:00
parent 399d319bde
commit aaeab73505
2 changed files with 25 additions and 5 deletions

View File

@ -78,6 +78,21 @@ class BaseModelRunner(ABC):
class ONNXModelRunner(BaseModelRunner):
"""Run ONNX models using ONNX Runtime."""
@staticmethod
def is_migraphx_complex_model(model_type: str) -> bool:
# Import here to avoid circular imports
from frigate.detectors.detector_config import ModelTypeEnum
from frigate.embeddings.types import EnrichmentModelTypeEnum
return model_type in [
EnrichmentModelTypeEnum.paddleocr.value,
EnrichmentModelTypeEnum.jina_v1.value,
EnrichmentModelTypeEnum.jina_v2.value,
EnrichmentModelTypeEnum.facenet.value,
ModelTypeEnum.rfdetr.value,
ModelTypeEnum.dfine.value,
]
def __init__(self, ort: ort.InferenceSession):
self.ort = ort
@ -443,6 +458,15 @@ def get_optimized_runner(
options[0]["device_id"],
)
if providers[
0
] == "MIGraphXExecutionProvider" and ONNXModelRunner.is_migraphx_complex_model(
model_type
):
# Don't use MIGraphX for models that are not supported
providers.pop(0)
options.pop(0)
return ONNXModelRunner(
ort.InferenceSession(
model_path,

View File

@ -373,11 +373,7 @@ def get_ort_providers(
os.environ["ORT_MIGRAPHX_SAVE_COMPILED_PATH"] = compiled_model_path
providers.append(provider)
options.append(
{
"migraphx_fp16_enable": 0,
}
)
options.append({})
else:
providers.append(provider)
options.append({})