Fix wrong function name in new _get_binding_dtype function and only return input check (not assertion) in new _binding_is_input function

This commit is contained in:
Rémi Bédard-Couture 2024-05-07 22:38:07 -04:00
parent 6c0abe4833
commit 485f307574

View File

@ -91,10 +91,9 @@ class TensorRtDetector(DetectionApi):
def _binding_is_input(self, binding):
if TRT_VERSION < 10:
assert self.engine.binding_is_input(binding)
return self.engine.binding_is_input(binding)
else:
assert binding == "input"
return True
return binding == "input"
def _get_binding_dims(self, binding):
if TRT_VERSION < 10:
@ -106,7 +105,7 @@ class TensorRtDetector(DetectionApi):
if TRT_VERSION < 10:
return self.engine.get_binding_dtype(binding)
else:
return self.engine.get_tensor_shape(binding)
return self.engine.get_tensor_dtype(binding)
def _execute(self):
if TRT_VERSION < 10: