mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-18 17:14:26 +03:00
Implement post processing for yolov9
This commit is contained in:
parent
a831fc2bef
commit
75ef26fa11
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
@ -54,6 +55,15 @@ class ONNXDetector(DetectionApi):
|
|||||||
|
|
||||||
logger.info(f"ONNX: {path} loaded")
|
logger.info(f"ONNX: {path} loaded")
|
||||||
|
|
||||||
|
def xywh2xyxy(self, x):
|
||||||
|
# Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
|
||||||
|
y = np.copy(x)
|
||||||
|
y[..., 0] = x[..., 0] - x[..., 2] / 2
|
||||||
|
y[..., 1] = x[..., 1] - x[..., 3] / 2
|
||||||
|
y[..., 2] = x[..., 0] + x[..., 2] / 2
|
||||||
|
y[..., 3] = x[..., 1] + x[..., 3] / 2
|
||||||
|
return y
|
||||||
|
|
||||||
def detect_raw(self, tensor_input: np.ndarray):
|
def detect_raw(self, tensor_input: np.ndarray):
|
||||||
model_input_name = self.model.get_inputs()[0].name
|
model_input_name = self.model.get_inputs()[0].name
|
||||||
tensor_output = self.model.run(None, {model_input_name: tensor_input})
|
tensor_output = self.model.run(None, {model_input_name: tensor_input})
|
||||||
@ -81,8 +91,40 @@ class ONNXDetector(DetectionApi):
|
|||||||
return detections
|
return detections
|
||||||
elif self.onnx_model_type == ModelTypeEnum.yolov9:
|
elif self.onnx_model_type == ModelTypeEnum.yolov9:
|
||||||
# see https://github.com/MultimediaTechLab/YOLO/blob/main/yolo/utils/bounding_box_utils.py#L338
|
# see https://github.com/MultimediaTechLab/YOLO/blob/main/yolo/utils/bounding_box_utils.py#L338
|
||||||
logger.info(f"the output shape is {tensor_output[0][0].shape} which has {tensor_output[0][0][0]}")
|
predictions: np.ndarray = tensor_output[0]
|
||||||
return np.zeros((20, 6), np.float32)
|
predictions = np.squeeze(predictions).T
|
||||||
|
scores = np.max(predictions[:, 4:], axis=1)
|
||||||
|
predictions = predictions[scores > 0.4, :]
|
||||||
|
scores = scores[scores > 0.4]
|
||||||
|
class_ids = np.argmax(predictions[:, 4:], axis=1)
|
||||||
|
|
||||||
|
# Rescale box
|
||||||
|
boxes = predictions[:, :4]
|
||||||
|
|
||||||
|
input_shape = np.array([self.w, self.h, self.w, self.h])
|
||||||
|
boxes = np.divide(boxes, input_shape, dtype=np.float32)
|
||||||
|
boxes *= np.array([self.w, self.h, self.w, self.h])
|
||||||
|
boxes = boxes.astype(np.int32)
|
||||||
|
indices = cv2.dnn.NMSBoxes(
|
||||||
|
boxes, scores, score_threshold=0.4, nms_threshold=0.4
|
||||||
|
)
|
||||||
|
detections = np.zeros((20, 6), np.float32)
|
||||||
|
for i, (bbox, confidence, class_id) in enumerate(
|
||||||
|
zip(self.xywh2xyxy(boxes[indices]), scores[indices], class_ids[indices])
|
||||||
|
):
|
||||||
|
if i == 20:
|
||||||
|
break
|
||||||
|
|
||||||
|
detections[i] = [
|
||||||
|
class_id,
|
||||||
|
confidence,
|
||||||
|
bbox[0],
|
||||||
|
bbox[1],
|
||||||
|
bbox[2],
|
||||||
|
bbox[3],
|
||||||
|
]
|
||||||
|
|
||||||
|
return detections
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"{self.onnx_model_type} is currently not supported for rocm. See the docs for more info on supported models."
|
f"{self.onnx_model_type} is currently not supported for rocm. See the docs for more info on supported models."
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user