handle multiple trt outputs

This commit is contained in:
alexyao2015 2024-02-21 01:59:19 -06:00
parent 13f9fe3b48
commit adfaa04760
2 changed files with 15 additions and 9 deletions

View File

@ -296,9 +296,19 @@ class TensorRtDetector(DetectionApi):
) )
trt_outputs = self._do_inference() trt_outputs = self._do_inference()
if self.model_type == ModelTypeEnum.yolov8: if self.model_type == ModelTypeEnum.yolov8:
return yolov8_postprocess( detections = []
self.input_shape[0], trt_outputs[0].reshape(self.output_shape[0]) for o in trt_outputs:
) detections.append(
yolov8_postprocess(
self.input_shape[0], o.reshape(self.output_shape[0])
),
)
detections = np.concatenate(detections)
# sort detections by confidence
detections = detections[detections[:, 1].argsort()[::-1]]
# trim to top 20
detections = detections[:20]
return np.resize(detections, (20, 6))
raw_detections = self._postprocess_yolo(trt_outputs, self.conf_th) raw_detections = self._postprocess_yolo(trt_outputs, self.conf_th)

View File

@ -73,12 +73,8 @@ def yolov8_postprocess(
boxes = np.stack((cx - w / 2, cy - h / 2, w, h), axis=1) boxes = np.stack((cx - w / 2, cy - h / 2, w, h), axis=1)
indexes = cv2.dnn.NMSBoxes(boxes, confidences, score_threshold, nms_threshold) indexes = cv2.dnn.NMSBoxes(boxes, confidences, score_threshold, nms_threshold)
detections = detections[indexes] detections = detections[indexes]
# if still too many, trim the rest by confidence
if detections.shape[0] > box_count:
detections = detections[
np.argpartition(detections[:, 1], -box_count)[-box_count:]
]
detections = detections.copy()
# sort detections by confidence # sort detections by confidence
detections = detections[detections[:, 1].argsort()[::-1]] detections = detections[detections[:, 1].argsort()[::-1]]
# trim to box_count
detections = detections[:box_count]
return np.resize(detections, (box_count, 6)) return np.resize(detections, (box_count, 6))