Correctly handle cuda input

This commit is contained in:
Nicolas Mowen 2025-09-13 20:34:56 -06:00
parent 43412f6390
commit 226b370cad
2 changed files with 5 additions and 2 deletions

View File

@ -87,7 +87,10 @@ class CudaGraphRunner(BaseModelRunner):
"""Get the input width of the model.""" """Get the input width of the model."""
return self._session.get_inputs()[0].shape[3] return self._session.get_inputs()[0].shape[3]
def run(self, input_name: str, tensor_input: np.ndarray): def run(self, input: dict[str, Any]):
# Extract the single tensor input (assuming one input)
input_name = list(input.keys())[0]
tensor_input = input[input_name]
tensor_input = np.ascontiguousarray(tensor_input) tensor_input = np.ascontiguousarray(tensor_input)
if not self._captured: if not self._captured:

View File

@ -95,7 +95,7 @@ class ONNXDetector(DetectionApi):
if self._cg_runner is not None: if self._cg_runner is not None:
try: try:
# Run using CUDA graphs if available # Run using CUDA graphs if available
tensor_output = self._cg_runner.run(model_input_name, tensor_input) tensor_output = self._cg_runner.run({model_input_name: tensor_input})
except Exception as e: except Exception as e:
logger.warning(f"CUDA Graphs failed, falling back to regular run: {e}") logger.warning(f"CUDA Graphs failed, falling back to regular run: {e}")
self._cg_runner = None self._cg_runner = None