Fix ROCm input name

This commit is contained in:
Nicolas Mowen 2024-09-26 11:19:31 -06:00
parent a5595189ed
commit fb07319831

View File

@ -125,8 +125,9 @@ class ROCmDetector(DetectionApi):
def detect_raw(self, tensor_input): def detect_raw(self, tensor_input):
model_input_name = self.model.get_parameter_names()[0] model_input_name = self.model.get_parameter_names()[0]
model_input_name = self.model.get_inputs()[0].name model_input_shape = tuple(
model_input_shape = self.model.get_inputs()[0].shape self.model.get_parameter_shapes()[model_input_name].lens()
)
tensor_input = cv2.dnn.blobFromImage( tensor_input = cv2.dnn.blobFromImage(
tensor_input[0], tensor_input[0],