use tensor api

This commit is contained in:
alexyao2015 2024-02-21 01:36:54 -06:00
parent 4ea2e16ecd
commit 13f9fe3b48

View File

@ -106,8 +106,8 @@ class TensorRtDetector(DetectionApi):
for i in range(self.engine.num_bindings): for i in range(self.engine.num_bindings):
name = self.engine.get_tensor_name(i) name = self.engine.get_tensor_name(i)
shape = ( shape = (
tuple(self.engine.get_binding_shape(name)), tuple(self.engine.get_tensor_shape(name)),
trt.nptype(self.engine.get_binding_dtype(name)), trt.nptype(self.engine.get_tensor_dtype(name)),
) )
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
input_shape = shape input_shape = shape