mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-09 04:35:25 +03:00
handle multiple trt outputs
This commit is contained in:
parent
13f9fe3b48
commit
adfaa04760
@ -296,9 +296,19 @@ class TensorRtDetector(DetectionApi):
|
||||
)
|
||||
trt_outputs = self._do_inference()
|
||||
if self.model_type == ModelTypeEnum.yolov8:
|
||||
return yolov8_postprocess(
|
||||
self.input_shape[0], trt_outputs[0].reshape(self.output_shape[0])
|
||||
)
|
||||
detections = []
|
||||
for o in trt_outputs:
|
||||
detections.append(
|
||||
yolov8_postprocess(
|
||||
self.input_shape[0], o.reshape(self.output_shape[0])
|
||||
),
|
||||
)
|
||||
detections = np.concatenate(detections)
|
||||
# sort detections by confidence
|
||||
detections = detections[detections[:, 1].argsort()[::-1]]
|
||||
# trim to top 20
|
||||
detections = detections[:20]
|
||||
return np.resize(detections, (20, 6))
|
||||
|
||||
raw_detections = self._postprocess_yolo(trt_outputs, self.conf_th)
|
||||
|
||||
|
||||
@ -73,12 +73,8 @@ def yolov8_postprocess(
|
||||
boxes = np.stack((cx - w / 2, cy - h / 2, w, h), axis=1)
|
||||
indexes = cv2.dnn.NMSBoxes(boxes, confidences, score_threshold, nms_threshold)
|
||||
detections = detections[indexes]
|
||||
# if still too many, trim the rest by confidence
|
||||
if detections.shape[0] > box_count:
|
||||
detections = detections[
|
||||
np.argpartition(detections[:, 1], -box_count)[-box_count:]
|
||||
]
|
||||
detections = detections.copy()
|
||||
# sort detections by confidence
|
||||
detections = detections[detections[:, 1].argsort()[::-1]]
|
||||
# trim to box_count
|
||||
detections = detections[:box_count]
|
||||
return np.resize(detections, (box_count, 6))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user