Don't always transpose

This commit is contained in:
Nicolas Mowen 2025-04-17 14:10:06 -06:00
parent 1abd3c68ec
commit e45737754a

View File

@ -187,7 +187,12 @@ def __post_process_multipart_yolo(
def __post_process_nms_yolo(predictions: np.ndarray, width, height) -> np.ndarray:
predictions = np.squeeze(predictions).T
predictions = np.squeeze(predictions)
# transpose the output so it has order (inferences, class_ids)
if predictions.shape[0] < predictions.shape[1]:
predictions = predictions.T
scores = np.max(predictions[:, 4:], axis=1)
predictions = predictions[scores > 0.4, :]
scores = scores[scores > 0.4]