mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-09 12:45:25 +03:00
handle multiple trt outputs
This commit is contained in:
parent
13f9fe3b48
commit
adfaa04760
@ -296,9 +296,19 @@ class TensorRtDetector(DetectionApi):
|
|||||||
)
|
)
|
||||||
trt_outputs = self._do_inference()
|
trt_outputs = self._do_inference()
|
||||||
if self.model_type == ModelTypeEnum.yolov8:
|
if self.model_type == ModelTypeEnum.yolov8:
|
||||||
return yolov8_postprocess(
|
detections = []
|
||||||
self.input_shape[0], trt_outputs[0].reshape(self.output_shape[0])
|
for o in trt_outputs:
|
||||||
)
|
detections.append(
|
||||||
|
yolov8_postprocess(
|
||||||
|
self.input_shape[0], o.reshape(self.output_shape[0])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
detections = np.concatenate(detections)
|
||||||
|
# sort detections by confidence
|
||||||
|
detections = detections[detections[:, 1].argsort()[::-1]]
|
||||||
|
# trim to top 20
|
||||||
|
detections = detections[:20]
|
||||||
|
return np.resize(detections, (20, 6))
|
||||||
|
|
||||||
raw_detections = self._postprocess_yolo(trt_outputs, self.conf_th)
|
raw_detections = self._postprocess_yolo(trt_outputs, self.conf_th)
|
||||||
|
|
||||||
|
|||||||
@ -73,12 +73,8 @@ def yolov8_postprocess(
|
|||||||
boxes = np.stack((cx - w / 2, cy - h / 2, w, h), axis=1)
|
boxes = np.stack((cx - w / 2, cy - h / 2, w, h), axis=1)
|
||||||
indexes = cv2.dnn.NMSBoxes(boxes, confidences, score_threshold, nms_threshold)
|
indexes = cv2.dnn.NMSBoxes(boxes, confidences, score_threshold, nms_threshold)
|
||||||
detections = detections[indexes]
|
detections = detections[indexes]
|
||||||
# if still too many, trim the rest by confidence
|
|
||||||
if detections.shape[0] > box_count:
|
|
||||||
detections = detections[
|
|
||||||
np.argpartition(detections[:, 1], -box_count)[-box_count:]
|
|
||||||
]
|
|
||||||
detections = detections.copy()
|
|
||||||
# sort detections by confidence
|
# sort detections by confidence
|
||||||
detections = detections[detections[:, 1].argsort()[::-1]]
|
detections = detections[detections[:, 1].argsort()[::-1]]
|
||||||
|
# trim to box_count
|
||||||
|
detections = detections[:box_count]
|
||||||
return np.resize(detections, (box_count, 6))
|
return np.resize(detections, (box_count, 6))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user