fix(cuda): gracefully fallback when CUDA graph capture is unsupported

When enable_cuda_graph is True and the model contains operations that
cannot be fully partitioned to CUDA (e.g. Memcpy nodes from reshape
or concat ops), ONNX Runtime fails session creation fatally. This
prevents models like YOLOv8 from running on GPU with device: cuda.

Wrap the session creation in try/except and fall back to CUDA execution
without graph capture. Both paths still use CudaGraphRunner since it
works as a regular IOBinding runner even without graph capture.
This commit is contained in:
Nabheet S. Sandhu 2026-06-07 21:42:13 -06:00
parent ad968efd3e
commit ef59366c2c

View File

@ -601,16 +601,31 @@ def get_optimized_runner(
CudaGraphRunner.is_model_supported(model_type)
and providers[0] == "CUDAExecutionProvider"
):
options[0] = {
# Try to enable CUDA graph capture for maximum performance.
# If the model has ops that can't be fully partitioned to CUDA
# (e.g. Memcpy nodes), fall back gracefully without graph capture.
graph_options = {
**options[0],
"enable_cuda_graph": True,
}
return CudaGraphRunner(
ort.InferenceSession(
try:
session = ort.InferenceSession(
model_path,
providers=providers,
provider_options=[graph_options] + options[1:],
)
except Exception:
logger.warning(
"CUDA graph capture not supported for this model, "
"falling back to CUDA execution without graph capture"
)
session = ort.InferenceSession(
model_path,
providers=providers,
provider_options=options,
),
)
return CudaGraphRunner(
session,
options[0]["device_id"],
)