Rework to create util for onnx initialization

This commit is contained in:
Nicolas Mowen 2024-09-26 19:28:28 -06:00
parent c0bd3b362c
commit d4cc2d777a
2 changed files with 42 additions and 31 deletions

View File

@ -1,5 +1,4 @@
import logging import logging
import os
import numpy as np import numpy as np
from pydantic import Field from pydantic import Field
@ -10,6 +9,7 @@ from frigate.detectors.detector_config import (
BaseDetectorConfig, BaseDetectorConfig,
ModelTypeEnum, ModelTypeEnum,
) )
from frigate.util.model import get_ort_providers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,37 +38,9 @@ class ONNXDetector(DetectionApi):
path = detector_config.model.path path = detector_config.model.path
logger.info(f"ONNX: loading {detector_config.model.path}") logger.info(f"ONNX: loading {detector_config.model.path}")
providers = ( providers, options = get_ort_providers(
["CPUExecutionProvider"] detector_config.device == "CPU", detector_config.device
if detector_config.device == "CPU"
else ort.get_available_providers()
) )
options = []
for provider in providers:
if provider == "TensorrtExecutionProvider":
os.makedirs(
"/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True
)
options.append(
{
"trt_timing_cache_enable": True,
"trt_engine_cache_enable": True,
"trt_timing_cache_path": "/config/model_cache/tensorrt/ort",
"trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines",
}
)
elif provider == "OpenVINOExecutionProvider":
os.makedirs("/config/model_cache/openvino/ort", exist_ok=True)
options.append(
{
"cache_dir": "/config/model_cache/openvino/ort",
"device_type": detector_config.device,
}
)
else:
options.append({})
self.model = ort.InferenceSession( self.model = ort.InferenceSession(
path, providers=providers, provider_options=options path, providers=providers, provider_options=options
) )

39
frigate/util/model.py Normal file
View File

@ -0,0 +1,39 @@
"""Model Utils"""
import os
import onnxruntime as ort
def get_ort_providers(
force_cpu: bool = False, openvino_device: str = "AUTO"
) -> tuple[list[str], list[dict[str, any]]]:
if force_cpu:
return (["CPUExecutionProvider"], [{}])
providers = ort.get_available_providers()
options = []
for provider in providers:
if provider == "TensorrtExecutionProvider":
os.makedirs("/config/model_cache/tensorrt/ort/trt-engines", exist_ok=True)
options.append(
{
"trt_timing_cache_enable": True,
"trt_engine_cache_enable": True,
"trt_timing_cache_path": "/config/model_cache/tensorrt/ort",
"trt_engine_cache_path": "/config/model_cache/tensorrt/ort/trt-engines",
}
)
elif provider == "OpenVINOExecutionProvider":
os.makedirs("/config/model_cache/openvino/ort", exist_ok=True)
options.append(
{
"cache_dir": "/config/model_cache/openvino/ort",
"device_type": openvino_device,
}
)
else:
options.append({})
return (providers, options)