Refactor to reduce code duplication

This commit is contained in:
Rémi Bédard-Couture 2024-05-07 21:15:57 -04:00
parent 408f295416
commit 6c0abe4833

View File

@ -89,38 +89,45 @@ class TensorRtDetector(DetectionApi):
with open(model_path, "rb") as f, trt.Runtime(self.trt_logger) as runtime:
return runtime.deserialize_cuda_engine(f.read())
def _binding_is_input(self, binding):
if TRT_VERSION < 10:
assert self.engine.binding_is_input(binding)
else:
assert binding == "input"
return True
def _get_binding_dims(self, binding):
if TRT_VERSION < 10:
return self.engine.get_binding_shape(binding)
else:
return self.engine.get_tensor_shape(binding)
def _get_binding_dtype(self, binding):
if TRT_VERSION < 10:
return self.engine.get_binding_dtype(binding)
else:
return self.engine.get_tensor_shape(binding)
def _execute(self):
if TRT_VERSION < 10:
return self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream)
else:
return self.context.execute_v2(self.bindings)
def _get_input_shape(self):
"""Get input shape of the TensorRT YOLO engine."""
binding = self.engine[0]
if TRT_VERSION < 10:
assert self.engine.binding_is_input(binding)
binding_dims = self.engine.get_binding_shape(binding)
assert self._binding_is_input(binding)
binding_dims = self._get_binding_dims(binding)
if len(binding_dims) == 4:
return (
tuple(binding_dims[2:]),
trt.nptype(self.engine.get_binding_dtype(binding)),
trt.nptype(self._get_binding_dtype(binding)),
)
elif len(binding_dims) == 3:
return (
tuple(binding_dims[1:]),
trt.nptype(self.engine.get_binding_dtype(binding)),
)
else:
raise ValueError(
"bad dims of binding %s: %s" % (binding, str(binding_dims))
)
else:
assert binding == "input"
binding_dims = self.engine.get_tensor_shape("input")
if len(binding_dims) == 4:
return (
tuple(binding_dims[2:]),
trt.nptype(self.engine.get_tensor_dtype(binding)),
)
elif len(binding_dims) == 3:
return (
tuple(binding_dims[1:]),
trt.nptype(self.engine.get_tensor_dtype(binding)),
trt.nptype(self._get_binding_dtype(binding)),
)
else:
raise ValueError(
@ -134,8 +141,7 @@ class TensorRtDetector(DetectionApi):
bindings = []
output_idx = 0
for binding in self.engine:
if TRT_VERSION < 10:
binding_dims = self.engine.get_binding_shape(binding)
binding_dims = self._get_binding_dims(binding)
if len(binding_dims) == 4:
# explicit batch case (TensorRT 7+)
size = trt.volume(binding_dims)
@ -146,54 +152,21 @@ class TensorRtDetector(DetectionApi):
raise ValueError(
"bad dims of binding %s: %s" % (binding, str(binding_dims))
)
nbytes = size * self.engine.get_binding_dtype(binding).itemsize
nbytes = size * self._get_binding_dtype(binding).itemsize
# Allocate host and device buffers
err, host_mem = cuda.cuMemHostAlloc(
nbytes, Flags=cuda.CU_MEMHOSTALLOC_DEVICEMAP
)
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAllocHost returned {err}"
logger.debug(
f"Allocated Tensor Binding {binding} Memory {nbytes} Bytes ({size} * {self.engine.get_binding_dtype(binding)})"
f"Allocated Tensor Binding {binding} Memory {nbytes} Bytes ({size} * {self._get_binding_dtype(binding)})"
)
err, device_mem = cuda.cuMemAlloc(nbytes)
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAlloc returned {err}"
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if self.engine.binding_is_input(binding):
logger.debug(f"Input has Shape {binding_dims}")
inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
else:
# each grid has 3 anchors, each anchor generates a detection
# output of 7 float32 values
assert size % 7 == 0, f"output size was {size}"
logger.debug(f"Output has Shape {binding_dims}")
outputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
output_idx += 1
else:
binding_dims = self.engine.get_tensor_shape(binding)
if len(binding_dims) == 4:
# explicit batch case (TensorRT 7+)
size = trt.volume(binding_dims)
else:
raise ValueError(
"bad dims of binding %s: %s" % (binding, str(binding_dims))
)
nbytes = size * self.engine.get_tensor_dtype(binding).itemsize
# Allocate host and device buffers
err, host_mem = cuda.cuMemHostAlloc(
nbytes, Flags=cuda.CU_MEMHOSTALLOC_DEVICEMAP
)
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAllocHost returned {err}"
logger.debug(
f"Allocated Tensor Binding {binding} Memory {nbytes} Bytes ({size} * {self.engine.get_tensor_dtype(binding)})"
)
err, device_mem = cuda.cuMemAlloc(nbytes)
assert err is cuda.CUresult.CUDA_SUCCESS, f"cuMemAlloc returned {err}"
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if binding == "input":
if self._binding_is_input(binding):
logger.debug(f"Input has Shape {binding_dims}")
inputs.append(HostDeviceMem(host_mem, device_mem, nbytes, size))
else:
@ -223,15 +196,7 @@ class TensorRtDetector(DetectionApi):
]
# Run inference.
if TRT_VERSION < 10:
if not self.context.execute_async_v2(
bindings=self.bindings, stream_handle=self.stream
):
logger.warn("Execute returned false")
else:
if not self.context.execute_v2(
self.bindings
):
if not self._execute():
logger.warn("Execute returned false")
# Transfer predictions back from the GPU.