mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-21 06:21:53 +03:00
Handle complex migraphx models
This commit is contained in:
parent
399d319bde
commit
aaeab73505
@ -78,6 +78,21 @@ class BaseModelRunner(ABC):
|
|||||||
class ONNXModelRunner(BaseModelRunner):
|
class ONNXModelRunner(BaseModelRunner):
|
||||||
"""Run ONNX models using ONNX Runtime."""
|
"""Run ONNX models using ONNX Runtime."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_migraphx_complex_model(model_type: str) -> bool:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from frigate.detectors.detector_config import ModelTypeEnum
|
||||||
|
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||||
|
|
||||||
|
return model_type in [
|
||||||
|
EnrichmentModelTypeEnum.paddleocr.value,
|
||||||
|
EnrichmentModelTypeEnum.jina_v1.value,
|
||||||
|
EnrichmentModelTypeEnum.jina_v2.value,
|
||||||
|
EnrichmentModelTypeEnum.facenet.value,
|
||||||
|
ModelTypeEnum.rfdetr.value,
|
||||||
|
ModelTypeEnum.dfine.value,
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(self, ort: ort.InferenceSession):
|
def __init__(self, ort: ort.InferenceSession):
|
||||||
self.ort = ort
|
self.ort = ort
|
||||||
|
|
||||||
@ -443,6 +458,15 @@ def get_optimized_runner(
|
|||||||
options[0]["device_id"],
|
options[0]["device_id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if providers[
|
||||||
|
0
|
||||||
|
] == "MIGraphXExecutionProvider" and ONNXModelRunner.is_migraphx_complex_model(
|
||||||
|
model_type
|
||||||
|
):
|
||||||
|
# Don't use MIGraphX for models that are not supported
|
||||||
|
providers.pop(0)
|
||||||
|
options.pop(0)
|
||||||
|
|
||||||
return ONNXModelRunner(
|
return ONNXModelRunner(
|
||||||
ort.InferenceSession(
|
ort.InferenceSession(
|
||||||
model_path,
|
model_path,
|
||||||
|
|||||||
@ -373,11 +373,7 @@ def get_ort_providers(
|
|||||||
os.environ["ORT_MIGRAPHX_SAVE_COMPILED_PATH"] = compiled_model_path
|
os.environ["ORT_MIGRAPHX_SAVE_COMPILED_PATH"] = compiled_model_path
|
||||||
|
|
||||||
providers.append(provider)
|
providers.append(provider)
|
||||||
options.append(
|
options.append({})
|
||||||
{
|
|
||||||
"migraphx_fp16_enable": 0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
providers.append(provider)
|
providers.append(provider)
|
||||||
options.append({})
|
options.append({})
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user