use tensor api

This commit is contained in:
alexyao2015 2024-02-21 01:36:54 -06:00
parent 4ea2e16ecd
commit 13f9fe3b48

View File

@ -106,8 +106,8 @@ class TensorRtDetector(DetectionApi):
for i in range(self.engine.num_bindings):
name = self.engine.get_tensor_name(i)
shape = (
tuple(self.engine.get_binding_shape(name)),
trt.nptype(self.engine.get_binding_dtype(name)),
tuple(self.engine.get_tensor_shape(name)),
trt.nptype(self.engine.get_tensor_dtype(name)),
)
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
input_shape = shape