diff --git a/frigate/util/model.py b/frigate/util/model.py index a4ff9bd75..119274b0b 100644 --- a/frigate/util/model.py +++ b/frigate/util/model.py @@ -187,7 +187,12 @@ def __post_process_multipart_yolo( def __post_process_nms_yolo(predictions: np.ndarray, width, height) -> np.ndarray: - predictions = np.squeeze(predictions).T + predictions = np.squeeze(predictions) + + # transpose the output so it has order (inferences, class_ids) + if predictions.shape[0] < predictions.shape[1]: + predictions = predictions.T + scores = np.max(predictions[:, 4:], axis=1) predictions = predictions[scores > 0.4, :] scores = scores[scores > 0.4]