Support yolox models

This commit is contained in:
Nicolas Mowen 2025-04-17 17:12:07 -06:00
parent cc807f49a0
commit 97a78af7f9
2 changed files with 35 additions and 31 deletions

View File

@ -14,6 +14,7 @@ from frigate.util.model import (
post_process_dfine, post_process_dfine,
post_process_rfdetr, post_process_rfdetr,
post_process_yolo, post_process_yolo,
post_process_yolox,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,6 +59,25 @@ class ONNXDetector(DetectionApi):
self.onnx_model_shape = detector_config.model.input_tensor self.onnx_model_shape = detector_config.model.input_tensor
path = detector_config.model.path path = detector_config.model.path
if self.onnx_model_type == ModelTypeEnum.yolox:
grids = []
expanded_strides = []
# decode and orient predictions
strides = [8, 16, 32]
hsizes = [self.h // stride for stride in strides]
wsizes = [self.w // stride for stride in strides]
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))
self.grids = np.concatenate(grids, 1)
self.expanded_strides = np.concatenate(expanded_strides, 1)
logger.info(f"ONNX: {path} loaded") logger.info(f"ONNX: {path} loaded")
def detect_raw(self, tensor_input: np.ndarray): def detect_raw(self, tensor_input: np.ndarray):
@ -99,6 +119,10 @@ class ONNXDetector(DetectionApi):
return detections return detections
elif self.onnx_model_type == ModelTypeEnum.yologeneric: elif self.onnx_model_type == ModelTypeEnum.yologeneric:
return post_process_yolo(tensor_output, self.w, self.h) return post_process_yolo(tensor_output, self.w, self.h)
elif self.onnx_model_type == ModelTypeEnum.yolox:
return post_process_yolox(
tensor_output[0], self.w, self.h, self.grids, self.expanded_strides
)
else: else:
raise Exception( raise Exception(
f"{self.onnx_model_type} is currently not supported for onnx. See the docs for more info on supported models." f"{self.onnx_model_type} is currently not supported for onnx. See the docs for more info on supported models."

View File

@ -230,24 +230,13 @@ def post_process_yolo(output: list[np.ndarray], width: int, height: int) -> np.n
return __post_process_nms_yolo(output[0], width, height) return __post_process_nms_yolo(output[0], width, height)
def post_process_yolox(predictions: np.ndarray, width: int, height: int) -> np.ndarray: def post_process_yolox(
grids = [] predictions: np.ndarray,
expanded_strides = [] width: int,
height: int,
# decode and orient predictions grids: np.ndarray,
strides = [8, 16, 32] expanded_strides: np.ndarray,
hsizes = [height // stride for stride in strides] ) -> np.ndarray:
wsizes = [width // stride for stride in strides]
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape, 1), stride))
grids = np.concatenate(grids, 1)
expanded_strides = np.concatenate(expanded_strides, 1)
predictions[..., :2] = (predictions[..., :2] + grids) * expanded_strides predictions[..., :2] = (predictions[..., :2] + grids) * expanded_strides
predictions[..., 2:4] = np.exp(predictions[..., 2:4]) * expanded_strides predictions[..., 2:4] = np.exp(predictions[..., 2:4]) * expanded_strides
@ -269,15 +258,6 @@ def post_process_yolox(predictions: np.ndarray, width: int, height: int) -> np.n
boxes_xyxy, scores, score_threshold=0.4, nms_threshold=0.4 boxes_xyxy, scores, score_threshold=0.4, nms_threshold=0.4
) )
final_boxes = boxes_xyxy[indices]
final_scores = scores[indices]
final_cls_inds = cls_inds[indices]
print(f"frig boxes: {final_boxes}")
print(f"frig cls: {final_cls_inds}")
print(f"frig scores: {final_scores}")
detections = np.zeros((20, 6), np.float32) detections = np.zeros((20, 6), np.float32)
for i, (bbox, confidence, class_id) in enumerate( for i, (bbox, confidence, class_id) in enumerate(
zip(boxes_xyxy[indices], scores[indices], cls_inds[indices]) zip(boxes_xyxy[indices], scores[indices], cls_inds[indices])
@ -288,10 +268,10 @@ def post_process_yolox(predictions: np.ndarray, width: int, height: int) -> np.n
detections[i] = [ detections[i] = [
class_id, class_id,
confidence, confidence,
bbox[1], bbox[1] / height,
bbox[0], bbox[0] / width,
bbox[3], bbox[3] / height,
bbox[2], bbox[2] / width,
] ]
print(f"got {detections[i]}") print(f"got {detections[i]}")