Cleanup comments and typing

This commit is contained in:
Nicolas Mowen 2025-09-11 09:11:26 -06:00
parent 6e5c1f49d9
commit f306d761f0

View File

@ -29,15 +29,15 @@ class CudaGraphRunner:
This runner assumes a single tensor input and binds all model outputs.
"""
def __init__(self, session, cuda_device_id: int):
def __init__(self, session: ort.InferenceSession, cuda_device_id: int):
self._session = session
self._cuda_device_id = cuda_device_id
self._captured = False
self._io_binding = None
self._input_name = None
self._output_names = None
self._input_ortvalue = None
self._output_ortvalues = None
self._io_binding: ort.IOBinding | None = None
self._input_name: str | None = None
self._output_names: list[str] | None = None
self._input_ortvalue: ort.OrtValue | None = None
self._output_ortvalues: ort.OrtValue | None = None
def run(self, input_name: str, tensor_input: np.ndarray):
tensor_input = np.ascontiguousarray(tensor_input)
@ -109,15 +109,14 @@ class ONNXDetector(DetectionApi):
if self.onnx_model_type == ModelTypeEnum.yolox:
self.calculate_grids_strides()
# Internal CUDA Graphs state
self._cuda_device_id = 0
self._cg_runner = None
self._cg_runner: CudaGraphRunner | None = None
try:
if "CUDAExecutionProvider" in providers:
cuda_idx = providers.index("CUDAExecutionProvider")
self._cuda_device_id = options[cuda_idx].get("device_id", 0)
# If we enabled CUDA graphs above for supported models, set flag
if options[cuda_idx].get("enable_cuda_graph"):
self._cg_runner = CudaGraphRunner(self.model, self._cuda_device_id)
except Exception:
@ -140,7 +139,6 @@ class ONNXDetector(DetectionApi):
model_input_name = self.model.get_inputs()[0].name
if self._cg_runner is not None:
try:
# Run using CUDA graphs if available