diff --git a/frigate/util/model.py b/frigate/util/model.py index 119274b0b..6ef83aa51 100644 --- a/frigate/util/model.py +++ b/frigate/util/model.py @@ -230,6 +230,66 @@ def post_process_yolo(output: list[np.ndarray], width: int, height: int) -> np.n return __post_process_nms_yolo(output[0], width, height) +def post_process_yolox(predictions: np.ndarray, width: int, height: int) -> np.ndarray: + grids = [] + expanded_strides = [] + + # decode and orient predictions + strides = [8, 16, 32] + hsizes = [height // stride for stride in strides] + wsizes = [width // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + predictions[..., :2] = (predictions[..., :2] + grids) * expanded_strides + predictions[..., 2:4] = np.exp(predictions[..., 2:4]) * expanded_strides + + # process organized predictions + predictions = predictions[0] + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0 + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0 + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0 + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0 + + cls_inds = scores.argmax(1) + scores = scores[np.arange(len(cls_inds)), cls_inds] + + indices = cv2.dnn.NMSBoxes( + boxes_xyxy, scores, score_threshold=0.4, nms_threshold=0.4 + ) + + detections = np.zeros((20, 6), np.float32) + for i, (bbox, confidence, class_id) in enumerate( + zip(boxes_xyxy[indices], scores[indices], cls_inds[indices]) + ): + if i == 20: + break + + detections[i] = [ + class_id, + 0.75, + bbox[1] / height, + bbox[0] / width, + bbox[3] / height, + bbox[2] / width, + ] + + # print(f"raw det is {detections[i]}") + + return detections + + ### ONNX Utilities