Correctly enable cuda graphs

This commit is contained in:
Nicolas Mowen 2025-09-18 22:23:29 -06:00
parent 61d3b370b1
commit ae3427158c

View File

@ -420,16 +420,22 @@ def get_optimized_runner(
if device != "CPU" and is_openvino_gpu_npu_available():
return OpenVINOModelRunner(model_path, device, model_type, **kwargs)
ortSession = ort.InferenceSession(
model_path,
providers=providers,
provider_options=options,
)
if (
not CudaGraphRunner.is_complex_model(model_type)
and providers[0] == "CUDAExecutionProvider"
):
return CudaGraphRunner(ortSession, options[0]["device_id"])
options[0] = {
**options[0],
"enable_cuda_graph": True,
}
return CudaGraphRunner(ort.InferenceSession(
model_path,
providers=providers,
provider_options=options,
), options[0]["device_id"])
return ONNXModelRunner(ortSession)
return ONNXModelRunner(ort.InferenceSession(
model_path,
providers=providers,
provider_options=options,
))