mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-16 03:52:09 +03:00
Cleanup onnx detector
This commit is contained in:
parent
c05e260ae9
commit
85bff61776
@ -1,18 +1,16 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detection_runners import CudaGraphRunner
|
||||
from frigate.detectors.detection_runners import get_optimized_runner
|
||||
from frigate.detectors.detector_config import (
|
||||
BaseDetectorConfig,
|
||||
ModelTypeEnum,
|
||||
)
|
||||
from frigate.util.model import (
|
||||
get_ort_providers,
|
||||
post_process_dfine,
|
||||
post_process_rfdetr,
|
||||
post_process_yolo,
|
||||
@ -38,80 +36,35 @@ class ONNXDetector(DetectionApi):
|
||||
path = detector_config.model.path
|
||||
logger.info(f"ONNX: loading {detector_config.model.path}")
|
||||
|
||||
providers, options = get_ort_providers(
|
||||
detector_config.device == "CPU", detector_config.device
|
||||
)
|
||||
|
||||
# Enable CUDA Graphs only for supported models when using CUDA EP
|
||||
if "CUDAExecutionProvider" in providers:
|
||||
cuda_idx = providers.index("CUDAExecutionProvider")
|
||||
# mutate only this call's provider options
|
||||
options[cuda_idx] = {
|
||||
**options[cuda_idx],
|
||||
"enable_cuda_graph": True,
|
||||
}
|
||||
|
||||
sess_options = None
|
||||
|
||||
if providers[0] == "ROCMExecutionProvider":
|
||||
# avoid AMD GPU kernel crashes
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = (
|
||||
ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
)
|
||||
|
||||
self.model = ort.InferenceSession(
|
||||
path, providers=providers, provider_options=options
|
||||
self.runner = get_optimized_runner(
|
||||
path,
|
||||
detector_config.device,
|
||||
complex_model=False,
|
||||
)
|
||||
|
||||
self.onnx_model_type = detector_config.model.model_type
|
||||
self.onnx_model_px = detector_config.model.input_pixel_format
|
||||
self.onnx_model_shape = detector_config.model.input_tensor
|
||||
path = detector_config.model.path
|
||||
|
||||
if self.onnx_model_type == ModelTypeEnum.yolox:
|
||||
self.calculate_grids_strides()
|
||||
|
||||
self._cuda_device_id = 0
|
||||
self._cg_runner: CudaGraphRunner | None = None
|
||||
|
||||
try:
|
||||
if "CUDAExecutionProvider" in providers:
|
||||
self._cuda_device_id = options[cuda_idx].get("device_id", 0)
|
||||
|
||||
if options[cuda_idx].get("enable_cuda_graph"):
|
||||
self._cg_runner = CudaGraphRunner(self.model, self._cuda_device_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"ONNX: {path} loaded")
|
||||
|
||||
def detect_raw(self, tensor_input: np.ndarray):
|
||||
if self.onnx_model_type == ModelTypeEnum.dfine:
|
||||
tensor_output = self.model.run(
|
||||
None,
|
||||
tensor_output = self.runner.run(
|
||||
{
|
||||
"images": tensor_input,
|
||||
"orig_target_sizes": np.array(
|
||||
[[self.height, self.width]], dtype=np.int64
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
return post_process_dfine(tensor_output, self.width, self.height)
|
||||
|
||||
model_input_name = self.model.get_inputs()[0].name
|
||||
|
||||
if self._cg_runner is not None:
|
||||
try:
|
||||
# Run using CUDA graphs if available
|
||||
tensor_output = self._cg_runner.run({model_input_name: tensor_input})
|
||||
except Exception as e:
|
||||
logger.warning(f"CUDA Graphs failed, falling back to regular run: {e}")
|
||||
self._cg_runner = None
|
||||
tensor_output = self.model.run(None, {model_input_name: tensor_input})
|
||||
else:
|
||||
# Use regular run if CUDA graphs are not available
|
||||
tensor_output = self.model.run(None, {model_input_name: tensor_input})
|
||||
model_input_name = self.runner.get_input_names()[0]
|
||||
tensor_output = self.runner.run({model_input_name: tensor_input})
|
||||
|
||||
if self.onnx_model_type == ModelTypeEnum.rfdetr:
|
||||
return post_process_rfdetr(tensor_output)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user