diff --git a/frigate/detectors/plugins/onnx.py b/frigate/detectors/plugins/onnx.py index c1771ef15..527de7e11 100644 --- a/frigate/detectors/plugins/onnx.py +++ b/frigate/detectors/plugins/onnx.py @@ -29,15 +29,15 @@ class CudaGraphRunner: This runner assumes a single tensor input and binds all model outputs. """ - def __init__(self, session, cuda_device_id: int): + def __init__(self, session: ort.InferenceSession, cuda_device_id: int): self._session = session self._cuda_device_id = cuda_device_id self._captured = False - self._io_binding = None - self._input_name = None - self._output_names = None - self._input_ortvalue = None - self._output_ortvalues = None + self._io_binding: ort.IOBinding | None = None + self._input_name: str | None = None + self._output_names: list[str] | None = None + self._input_ortvalue: ort.OrtValue | None = None + self._output_ortvalues: ort.OrtValue | None = None def run(self, input_name: str, tensor_input: np.ndarray): tensor_input = np.ascontiguousarray(tensor_input) @@ -109,15 +109,14 @@ class ONNXDetector(DetectionApi): if self.onnx_model_type == ModelTypeEnum.yolox: self.calculate_grids_strides() - # Internal CUDA Graphs state self._cuda_device_id = 0 - self._cg_runner = None + self._cg_runner: CudaGraphRunner | None = None try: if "CUDAExecutionProvider" in providers: cuda_idx = providers.index("CUDAExecutionProvider") self._cuda_device_id = options[cuda_idx].get("device_id", 0) - # If we enabled CUDA graphs above for supported models, set flag + if options[cuda_idx].get("enable_cuda_graph"): self._cg_runner = CudaGraphRunner(self.model, self._cuda_device_id) except Exception: @@ -140,7 +139,6 @@ class ONNXDetector(DetectionApi): model_input_name = self.model.get_inputs()[0].name - if self._cg_runner is not None: try: # Run using CUDA graphs if available