Checkpoint

This commit is contained in:
Nicolas Mowen 2025-04-17 16:03:55 -06:00
parent e45737754a
commit 722fe21287

View File

@ -230,6 +230,66 @@ def post_process_yolo(output: list[np.ndarray], width: int, height: int) -> np.n
return __post_process_nms_yolo(output[0], width, height)
def post_process_yolox(predictions: np.ndarray, width: int, height: int) -> np.ndarray:
grids = []
expanded_strides = []
# decode and orient predictions
strides = [8, 16, 32]
hsizes = [height // stride for stride in strides]
wsizes = [width // stride for stride in strides]
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))
grids = np.concatenate(grids, 1)
expanded_strides = np.concatenate(expanded_strides, 1)
predictions[..., :2] = (predictions[..., :2] + grids) * expanded_strides
predictions[..., 2:4] = np.exp(predictions[..., 2:4]) * expanded_strides
# process organized predictions
predictions = predictions[0]
boxes = predictions[:, :4]
scores = predictions[:, 4:5] * predictions[:, 5:]
boxes_xyxy = np.ones_like(boxes)
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0
cls_inds = scores.argmax(1)
scores = scores[np.arange(len(cls_inds)), cls_inds]
indices = cv2.dnn.NMSBoxes(
boxes_xyxy, scores, score_threshold=0.4, nms_threshold=0.4
)
detections = np.zeros((20, 6), np.float32)
for i, (bbox, confidence, class_id) in enumerate(
zip(boxes_xyxy[indices], scores[indices], cls_inds[indices])
):
if i == 20:
break
detections[i] = [
class_id,
0.75,
bbox[1] / height,
bbox[0] / width,
bbox[3] / height,
bbox[2] / width,
]
# print(f"raw det is {detections[i]}")
return detections
### ONNX Utilities