Improve batched yolo NMS

This commit is contained in:
Nicolas Mowen 2025-04-18 07:21:59 -06:00
parent 6b2246cf08
commit 321eed5b27

View File

@ -148,27 +148,17 @@ def __post_process_multipart_yolo(
bw = ((dw * 2.0) ** 2) * anchor_w bw = ((dw * 2.0) ** 2) * anchor_w
bh = ((dh * 2.0) ** 2) * anchor_h bh = ((dh * 2.0) ** 2) * anchor_h
x1 = max(0, bx - bw / 2) / width x1 = max(0, bx - bw / 2)
y1 = max(0, by - bh / 2) / height y1 = max(0, by - bh / 2)
x2 = min(width, bx + bw / 2) / width x2 = min(width, bx + bw / 2)
y2 = min(height, by + bh / 2) / height y2 = min(height, by + bh / 2)
all_boxes.append([x1, y1, x2, y2]) all_boxes.append([x1, y1, x2, y2])
all_scores.append(conf) all_scores.append(conf)
all_class_ids.append(class_id) all_class_ids.append(class_id)
formatted_boxes = [
[
int(x1 * width),
int(y1 * height),
int((x2 - x1) * width),
int((y2 - y1) * height),
]
for x1, y1, x2, y2 in all_boxes
]
indices = cv2.dnn.NMSBoxes( indices = cv2.dnn.NMSBoxes(
bboxes=formatted_boxes, bboxes=all_boxes,
scores=all_scores, scores=all_scores,
score_threshold=0.4, score_threshold=0.4,
nms_threshold=0.4, nms_threshold=0.4,
@ -181,7 +171,14 @@ def __post_process_multipart_yolo(
class_id = all_class_ids[idx] class_id = all_class_ids[idx]
conf = all_scores[idx] conf = all_scores[idx]
x1, y1, x2, y2 = all_boxes[idx] x1, y1, x2, y2 = all_boxes[idx]
results[i] = [class_id, conf, y1, x1, y2, x2] results[i] = [
class_id,
conf,
y1 / height,
x1 / width,
y2 / height,
x2 / width,
]
return np.array(results, dtype=np.float32) return np.array(results, dtype=np.float32)