mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-16 03:52:09 +03:00
Handle complex migraphx models
This commit is contained in:
parent
399d319bde
commit
aaeab73505
@ -78,6 +78,21 @@ class BaseModelRunner(ABC):
|
||||
class ONNXModelRunner(BaseModelRunner):
|
||||
"""Run ONNX models using ONNX Runtime."""
|
||||
|
||||
@staticmethod
|
||||
def is_migraphx_complex_model(model_type: str) -> bool:
|
||||
# Import here to avoid circular imports
|
||||
from frigate.detectors.detector_config import ModelTypeEnum
|
||||
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||
|
||||
return model_type in [
|
||||
EnrichmentModelTypeEnum.paddleocr.value,
|
||||
EnrichmentModelTypeEnum.jina_v1.value,
|
||||
EnrichmentModelTypeEnum.jina_v2.value,
|
||||
EnrichmentModelTypeEnum.facenet.value,
|
||||
ModelTypeEnum.rfdetr.value,
|
||||
ModelTypeEnum.dfine.value,
|
||||
]
|
||||
|
||||
def __init__(self, ort: ort.InferenceSession):
|
||||
self.ort = ort
|
||||
|
||||
@ -443,6 +458,15 @@ def get_optimized_runner(
|
||||
options[0]["device_id"],
|
||||
)
|
||||
|
||||
if providers[
|
||||
0
|
||||
] == "MIGraphXExecutionProvider" and ONNXModelRunner.is_migraphx_complex_model(
|
||||
model_type
|
||||
):
|
||||
# Don't use MIGraphX for models that are not supported
|
||||
providers.pop(0)
|
||||
options.pop(0)
|
||||
|
||||
return ONNXModelRunner(
|
||||
ort.InferenceSession(
|
||||
model_path,
|
||||
|
||||
@ -373,11 +373,7 @@ def get_ort_providers(
|
||||
os.environ["ORT_MIGRAPHX_SAVE_COMPILED_PATH"] = compiled_model_path
|
||||
|
||||
providers.append(provider)
|
||||
options.append(
|
||||
{
|
||||
"migraphx_fp16_enable": 0,
|
||||
}
|
||||
)
|
||||
options.append({})
|
||||
else:
|
||||
providers.append(provider)
|
||||
options.append({})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user