From adfaa0476041ee95d0356ddc814fdfc1d770c171 Mon Sep 17 00:00:00 2001 From: alexyao2015 <33379584+alexyao2015@users.noreply.github.com> Date: Wed, 21 Feb 2024 01:59:19 -0600 Subject: [PATCH] handle multiple trt outputs --- frigate/detectors/plugins/tensorrt.py | 16 +++++++++++++--- frigate/detectors/util.py | 8 ++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/frigate/detectors/plugins/tensorrt.py b/frigate/detectors/plugins/tensorrt.py index c0967e724..50f6e036c 100644 --- a/frigate/detectors/plugins/tensorrt.py +++ b/frigate/detectors/plugins/tensorrt.py @@ -296,9 +296,19 @@ class TensorRtDetector(DetectionApi): ) trt_outputs = self._do_inference() if self.model_type == ModelTypeEnum.yolov8: - return yolov8_postprocess( - self.input_shape[0], trt_outputs[0].reshape(self.output_shape[0]) - ) + detections = [] + for o in trt_outputs: + detections.append( + yolov8_postprocess( + self.input_shape[0], o.reshape(self.output_shape[0]) + ), + ) + detections = np.concatenate(detections) + # sort detections by confidence + detections = detections[detections[:, 1].argsort()[::-1]] + # trim to top 20 + detections = detections[:20] + return np.resize(detections, (20, 6)) raw_detections = self._postprocess_yolo(trt_outputs, self.conf_th) diff --git a/frigate/detectors/util.py b/frigate/detectors/util.py index 5240f4a0b..82c764a2c 100644 --- a/frigate/detectors/util.py +++ b/frigate/detectors/util.py @@ -73,12 +73,8 @@ def yolov8_postprocess( boxes = np.stack((cx - w / 2, cy - h / 2, w, h), axis=1) indexes = cv2.dnn.NMSBoxes(boxes, confidences, score_threshold, nms_threshold) detections = detections[indexes] - # if still too many, trim the rest by confidence - if detections.shape[0] > box_count: - detections = detections[ - np.argpartition(detections[:, 1], -box_count)[-box_count:] - ] - detections = detections.copy() # sort detections by confidence detections = detections[detections[:, 1].argsort()[::-1]] + # trim to box_count + detections = detections[:box_count] return np.resize(detections, (box_count, 6))