mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-12-06 21:44:13 +03:00
Update optimized YOLOv9 post-processing
This commit is contained in:
parent
213a1fbd00
commit
e496e75243
@ -962,7 +962,6 @@ model:
|
||||
# path: /config/yolov9.zip
|
||||
# The .zip file must contain:
|
||||
# ├── yolov9.dfp (a file ending with .dfp)
|
||||
# └── yolov9_post.onnx (optional; only if the model includes a cropped post-processing network)
|
||||
```
|
||||
|
||||
#### YOLOX
|
||||
|
||||
@ -178,13 +178,6 @@ class MemryXDetector(DetectionApi):
|
||||
logger.error(f"Failed to initialize MemryX model: {e}")
|
||||
raise
|
||||
|
||||
def load_yolo_constants(self):
|
||||
base = f"{self.cache_dir}/{self.model_folder}"
|
||||
# constants for yolov9 post-processing
|
||||
self.const_A = np.load(f"{base}/_model_22_Constant_9_output_0.npy")
|
||||
self.const_B = np.load(f"{base}/_model_22_Constant_10_output_0.npy")
|
||||
self.const_C = np.load(f"{base}/_model_22_Constant_12_output_0.npy")
|
||||
|
||||
def check_and_prepare_model(self):
|
||||
if not os.path.exists(self.cache_dir):
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
@ -236,7 +229,6 @@ class MemryXDetector(DetectionApi):
|
||||
|
||||
# Handle post model requirements by model type
|
||||
if self.memx_model_type in [
|
||||
ModelTypeEnum.yologeneric,
|
||||
ModelTypeEnum.yolonas,
|
||||
ModelTypeEnum.ssd,
|
||||
]:
|
||||
@ -245,7 +237,10 @@ class MemryXDetector(DetectionApi):
|
||||
f"No *_post.onnx file found in custom model zip for {self.memx_model_type.name}."
|
||||
)
|
||||
self.memx_post_model = post_candidates[0]
|
||||
elif self.memx_model_type == ModelTypeEnum.yolox:
|
||||
elif self.memx_model_type in [
|
||||
ModelTypeEnum.yolox,
|
||||
ModelTypeEnum.yologeneric,
|
||||
]:
|
||||
# Explicitly ignore any post model even if present
|
||||
self.memx_post_model = None
|
||||
else:
|
||||
@ -273,8 +268,6 @@ class MemryXDetector(DetectionApi):
|
||||
logger.info("Using cached models.")
|
||||
self.memx_model_path = dfp_path
|
||||
self.memx_post_model = post_path
|
||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||
self.load_yolo_constants()
|
||||
return
|
||||
|
||||
# ---------- CASE 3: download MemryX model (no cache) ----------
|
||||
@ -303,9 +296,6 @@ class MemryXDetector(DetectionApi):
|
||||
else None
|
||||
)
|
||||
|
||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||
self.load_yolo_constants()
|
||||
|
||||
finally:
|
||||
if os.path.exists(zip_path):
|
||||
try:
|
||||
@ -600,127 +590,232 @@ class MemryXDetector(DetectionApi):
|
||||
|
||||
self.output_queue.put(final_detections)
|
||||
|
||||
def onnx_reshape_with_allowzero(
|
||||
self, data: np.ndarray, shape: np.ndarray, allowzero: int = 0
|
||||
def _generate_anchors(self, sizes=[80, 40, 20]):
|
||||
"""Generate anchor points for YOLOv9 style processing"""
|
||||
yscales = []
|
||||
xscales = []
|
||||
for s in sizes:
|
||||
r = np.arange(s) + 0.5
|
||||
yscales.append(np.repeat(r, s))
|
||||
xscales.append(np.repeat(r[None, ...], s, axis=0).flatten())
|
||||
|
||||
yscales = np.concatenate(yscales)
|
||||
xscales = np.concatenate(xscales)
|
||||
anchors = np.stack([xscales, yscales], axis=1)
|
||||
return anchors
|
||||
|
||||
def _generate_scales(self, sizes=[80, 40, 20]):
|
||||
"""Generate scaling factors for each detection level"""
|
||||
factors = [8, 16, 32]
|
||||
s = np.concatenate([np.ones([int(s * s)]) * f for s, f in zip(sizes, factors)])
|
||||
return s[:, None]
|
||||
|
||||
@staticmethod
|
||||
def _softmax(x: np.ndarray, axis: int) -> np.ndarray:
|
||||
"""Efficient softmax implementation"""
|
||||
x = x - np.max(x, axis=axis, keepdims=True)
|
||||
np.exp(x, out=x)
|
||||
x /= np.sum(x, axis=axis, keepdims=True)
|
||||
return x
|
||||
|
||||
def dfl(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Distribution Focal Loss decoding - YOLOv9 style"""
|
||||
x = x.reshape(-1, 4, 16)
|
||||
weights = np.arange(16, dtype=np.float32)
|
||||
p = self._softmax(x, axis=2)
|
||||
p = p * weights[None, None, :]
|
||||
out = np.sum(p, axis=2, keepdims=False)
|
||||
return out
|
||||
|
||||
def dist2bbox(
|
||||
self, x: np.ndarray, anchors: np.ndarray, scales: np.ndarray
|
||||
) -> np.ndarray:
|
||||
shape = shape.astype(int)
|
||||
input_shape = data.shape
|
||||
output_shape = []
|
||||
"""Convert distances to bounding boxes - YOLOv9 style"""
|
||||
lt = x[:, :2]
|
||||
rb = x[:, 2:]
|
||||
|
||||
for i, dim in enumerate(shape):
|
||||
if dim == 0 and allowzero == 0:
|
||||
output_shape.append(input_shape[i]) # Copy dimension from input
|
||||
else:
|
||||
output_shape.append(dim)
|
||||
x1y1 = anchors - lt
|
||||
x2y2 = anchors + rb
|
||||
|
||||
# Now let NumPy infer any -1 if needed
|
||||
reshaped = np.reshape(data, output_shape)
|
||||
wh = x2y2 - x1y1
|
||||
c_xy = (x1y1 + x2y2) / 2
|
||||
|
||||
return reshaped
|
||||
out = np.concatenate([c_xy, wh], axis=1)
|
||||
out = out * scales
|
||||
return out
|
||||
|
||||
def post_process_yolo_optimized(self, outputs):
|
||||
"""
|
||||
Custom YOLOv9 post-processing optimized for MemryX ONNX outputs.
|
||||
Implements DFL decoding, confidence filtering, and NMS in pure NumPy.
|
||||
"""
|
||||
# YOLOv9 outputs: 6 outputs (lbox, lcls, mbox, mcls, sbox, scls)
|
||||
conv_out1, conv_out2, conv_out3, conv_out4, conv_out5, conv_out6 = outputs
|
||||
|
||||
# Determine grid sizes based on input resolution
|
||||
# YOLOv9 uses 3 detection heads with strides [8, 16, 32]
|
||||
# Grid sizes = input_size / stride
|
||||
sizes = [
|
||||
self.memx_model_height
|
||||
// 8, # Large objects (e.g., 80 for 640x640, 40 for 320x320)
|
||||
self.memx_model_height
|
||||
// 16, # Medium objects (e.g., 40 for 640x640, 20 for 320x320)
|
||||
self.memx_model_height
|
||||
// 32, # Small objects (e.g., 20 for 640x640, 10 for 320x320)
|
||||
]
|
||||
|
||||
# Generate anchors and scales if not already done
|
||||
if not hasattr(self, "anchors"):
|
||||
self.anchors = self._generate_anchors(sizes)
|
||||
self.scales = self._generate_scales(sizes)
|
||||
|
||||
# Process outputs in YOLOv9 format: reshape and moveaxis for ONNX format
|
||||
lbox = np.moveaxis(conv_out1, 1, -1) # Large boxes
|
||||
lcls = np.moveaxis(conv_out2, 1, -1) # Large classes
|
||||
mbox = np.moveaxis(conv_out3, 1, -1) # Medium boxes
|
||||
mcls = np.moveaxis(conv_out4, 1, -1) # Medium classes
|
||||
sbox = np.moveaxis(conv_out5, 1, -1) # Small boxes
|
||||
scls = np.moveaxis(conv_out6, 1, -1) # Small classes
|
||||
|
||||
# Determine number of classes dynamically from the class output shape
|
||||
# lcls shape should be (batch, height, width, num_classes)
|
||||
num_classes = lcls.shape[-1]
|
||||
|
||||
# Validate that all class outputs have the same number of classes
|
||||
if not (mcls.shape[-1] == num_classes and scls.shape[-1] == num_classes):
|
||||
raise ValueError(
|
||||
f"Class output shapes mismatch: lcls={lcls.shape}, mcls={mcls.shape}, scls={scls.shape}"
|
||||
)
|
||||
|
||||
# Concatenate boxes and classes
|
||||
boxes = np.concatenate(
|
||||
[
|
||||
lbox.reshape(-1, 64), # 64 is for 4 bbox coords * 16 DFL bins
|
||||
mbox.reshape(-1, 64),
|
||||
sbox.reshape(-1, 64),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
classes = np.concatenate(
|
||||
[
|
||||
lcls.reshape(-1, num_classes),
|
||||
mcls.reshape(-1, num_classes),
|
||||
scls.reshape(-1, num_classes),
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
# Apply sigmoid to classes
|
||||
classes = self.sigmoid(classes)
|
||||
|
||||
# Apply DFL to box predictions
|
||||
boxes = self.dfl(boxes)
|
||||
|
||||
# YOLOv9 postprocessing with confidence filtering and NMS
|
||||
confidence_thres = 0.4
|
||||
iou_thres = 0.6
|
||||
|
||||
# Find the class with the highest score for each detection
|
||||
max_scores = np.max(classes, axis=1) # Maximum class score for each detection
|
||||
class_ids = np.argmax(classes, axis=1) # Index of the best class
|
||||
|
||||
# Filter out detections with scores below the confidence threshold
|
||||
valid_indices = np.where(max_scores >= confidence_thres)[0]
|
||||
if len(valid_indices) == 0:
|
||||
# Return empty detections array
|
||||
final_detections = np.zeros((20, 6), np.float32)
|
||||
return final_detections
|
||||
|
||||
# Select only valid detections
|
||||
valid_boxes = boxes[valid_indices]
|
||||
valid_class_ids = class_ids[valid_indices]
|
||||
valid_scores = max_scores[valid_indices]
|
||||
|
||||
# Convert distances to actual bounding boxes using anchors and scales
|
||||
valid_boxes = self.dist2bbox(
|
||||
valid_boxes, self.anchors[valid_indices], self.scales[valid_indices]
|
||||
)
|
||||
|
||||
# Convert bounding box coordinates from (x_center, y_center, w, h) to (x_min, y_min, x_max, y_max)
|
||||
x_center, y_center, width, height = (
|
||||
valid_boxes[:, 0],
|
||||
valid_boxes[:, 1],
|
||||
valid_boxes[:, 2],
|
||||
valid_boxes[:, 3],
|
||||
)
|
||||
x_min = x_center - width / 2
|
||||
y_min = y_center - height / 2
|
||||
x_max = x_center + width / 2
|
||||
y_max = y_center + height / 2
|
||||
|
||||
# Convert to format expected by cv2.dnn.NMSBoxes: [x, y, width, height]
|
||||
boxes_for_nms = []
|
||||
scores_for_nms = []
|
||||
|
||||
for i in range(len(valid_indices)):
|
||||
# Ensure coordinates are within bounds and positive
|
||||
x_min_clipped = max(0, x_min[i])
|
||||
y_min_clipped = max(0, y_min[i])
|
||||
x_max_clipped = min(self.memx_model_width, x_max[i])
|
||||
y_max_clipped = min(self.memx_model_height, y_max[i])
|
||||
|
||||
width_clipped = x_max_clipped - x_min_clipped
|
||||
height_clipped = y_max_clipped - y_min_clipped
|
||||
|
||||
if width_clipped > 0 and height_clipped > 0:
|
||||
boxes_for_nms.append(
|
||||
[x_min_clipped, y_min_clipped, width_clipped, height_clipped]
|
||||
)
|
||||
scores_for_nms.append(float(valid_scores[i]))
|
||||
|
||||
final_detections = np.zeros((20, 6), np.float32)
|
||||
|
||||
if len(boxes_for_nms) == 0:
|
||||
return final_detections
|
||||
|
||||
# Apply NMS using OpenCV
|
||||
indices = cv2.dnn.NMSBoxes(
|
||||
boxes_for_nms, scores_for_nms, confidence_thres, iou_thres
|
||||
)
|
||||
|
||||
if len(indices) > 0:
|
||||
# Flatten indices if they are returned as a list of arrays
|
||||
if isinstance(indices[0], list) or isinstance(indices[0], np.ndarray):
|
||||
indices = [i[0] for i in indices]
|
||||
|
||||
# Limit to top 20 detections
|
||||
indices = indices[:20]
|
||||
|
||||
# Convert to Frigate format: [class_id, confidence, y_min, x_min, y_max, x_max] (normalized)
|
||||
for i, idx in enumerate(indices):
|
||||
class_id = valid_class_ids[idx]
|
||||
confidence = valid_scores[idx]
|
||||
|
||||
# Get the box coordinates
|
||||
box = boxes_for_nms[idx]
|
||||
x_min_norm = box[0] / self.memx_model_width
|
||||
y_min_norm = box[1] / self.memx_model_height
|
||||
x_max_norm = (box[0] + box[2]) / self.memx_model_width
|
||||
y_max_norm = (box[1] + box[3]) / self.memx_model_height
|
||||
|
||||
final_detections[i] = [
|
||||
class_id,
|
||||
confidence,
|
||||
y_min_norm, # Frigate expects y_min first
|
||||
x_min_norm,
|
||||
y_max_norm,
|
||||
x_max_norm,
|
||||
]
|
||||
|
||||
return final_detections
|
||||
|
||||
def process_output(self, *outputs):
|
||||
"""Output callback function -- receives frames from the MX3 and triggers post-processing"""
|
||||
if self.memx_model_type == ModelTypeEnum.yologeneric:
|
||||
if not self.memx_post_model:
|
||||
conv_out1 = outputs[0]
|
||||
conv_out2 = outputs[1]
|
||||
conv_out3 = outputs[2]
|
||||
conv_out4 = outputs[3]
|
||||
conv_out5 = outputs[4]
|
||||
conv_out6 = outputs[5]
|
||||
# Use complete YOLOv9-style postprocessing (includes NMS)
|
||||
final_detections = self.post_process_yolo_optimized(outputs)
|
||||
|
||||
concat_1 = self.onnx_concat([conv_out1, conv_out2], axis=1)
|
||||
concat_2 = self.onnx_concat([conv_out3, conv_out4], axis=1)
|
||||
concat_3 = self.onnx_concat([conv_out5, conv_out6], axis=1)
|
||||
|
||||
shape = np.array([1, 144, -1], dtype=np.int64)
|
||||
|
||||
reshaped_1 = self.onnx_reshape_with_allowzero(
|
||||
concat_1, shape, allowzero=0
|
||||
)
|
||||
reshaped_2 = self.onnx_reshape_with_allowzero(
|
||||
concat_2, shape, allowzero=0
|
||||
)
|
||||
reshaped_3 = self.onnx_reshape_with_allowzero(
|
||||
concat_3, shape, allowzero=0
|
||||
)
|
||||
|
||||
concat_4 = self.onnx_concat([reshaped_1, reshaped_2, reshaped_3], 2)
|
||||
|
||||
axis = 1
|
||||
split_sizes = [64, 80]
|
||||
|
||||
# Calculate indices at which to split
|
||||
indices = np.cumsum(split_sizes)[
|
||||
:-1
|
||||
] # [64] — split before the second chunk
|
||||
|
||||
# Perform split along axis 1
|
||||
split_0, split_1 = np.split(concat_4, indices, axis=axis)
|
||||
|
||||
num_boxes = 2100 if self.memx_model_height == 320 else 8400
|
||||
shape1 = np.array([1, 4, 16, num_boxes])
|
||||
reshape_4 = self.onnx_reshape_with_allowzero(
|
||||
split_0, shape1, allowzero=0
|
||||
)
|
||||
|
||||
transpose_1 = reshape_4.transpose(0, 2, 1, 3)
|
||||
|
||||
axis = 1 # As per ONNX softmax node
|
||||
|
||||
# Subtract max for numerical stability
|
||||
x_max = np.max(transpose_1, axis=axis, keepdims=True)
|
||||
x_exp = np.exp(transpose_1 - x_max)
|
||||
x_sum = np.sum(x_exp, axis=axis, keepdims=True)
|
||||
softmax_output = x_exp / x_sum
|
||||
|
||||
# Weight W from the ONNX initializer (1, 16, 1, 1) with values 0 to 15
|
||||
W = np.arange(16, dtype=np.float32).reshape(
|
||||
1, 16, 1, 1
|
||||
) # (1, 16, 1, 1)
|
||||
|
||||
# Apply 1x1 convolution: this is a weighted sum over channels
|
||||
conv_output = np.sum(
|
||||
softmax_output * W, axis=1, keepdims=True
|
||||
) # shape: (1, 1, 4, 8400)
|
||||
|
||||
shape2 = np.array([1, 4, num_boxes])
|
||||
reshape_5 = self.onnx_reshape_with_allowzero(
|
||||
conv_output, shape2, allowzero=0
|
||||
)
|
||||
|
||||
# ONNX Slice — get first 2 channels: [0:2] along axis 1
|
||||
slice_output1 = reshape_5[:, 0:2, :] # Result: (1, 2, 8400)
|
||||
|
||||
# Slice channels 2 to 4 → axis = 1
|
||||
slice_output2 = reshape_5[:, 2:4, :]
|
||||
|
||||
# Perform Subtraction
|
||||
sub_output = self.const_A - slice_output1 # Equivalent to ONNX Sub
|
||||
|
||||
# Perform the ONNX-style Add
|
||||
add_output = self.const_B + slice_output2
|
||||
|
||||
sub1 = add_output - sub_output
|
||||
|
||||
add1 = sub_output + add_output
|
||||
|
||||
div_output = add1 / 2.0
|
||||
|
||||
concat_5 = self.onnx_concat([div_output, sub1], axis=1)
|
||||
|
||||
# Expand B to (1, 1, 8400) so it can broadcast across axis=1 (4 channels)
|
||||
const_C_expanded = self.const_C[:, np.newaxis, :] # Shape: (1, 1, 8400)
|
||||
|
||||
# Perform ONNX-style element-wise multiplication
|
||||
mul_output = concat_5 * const_C_expanded # Result: (1, 4, 8400)
|
||||
|
||||
sigmoid_output = self.sigmoid(split_1)
|
||||
outputs = self.onnx_concat([mul_output, sigmoid_output], axis=1)
|
||||
|
||||
final_detections = post_process_yolo(
|
||||
outputs, self.memx_model_width, self.memx_model_height
|
||||
)
|
||||
self.output_queue.put(final_detections)
|
||||
|
||||
elif self.memx_model_type == ModelTypeEnum.yolonas:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user