mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-19 01:17:06 +03:00
Move post processing to separate utility
This commit is contained in:
parent
77b5182e7b
commit
4685ffb41b
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
@ -10,7 +9,7 @@ from frigate.detectors.detector_config import (
|
|||||||
BaseDetectorConfig,
|
BaseDetectorConfig,
|
||||||
ModelTypeEnum,
|
ModelTypeEnum,
|
||||||
)
|
)
|
||||||
from frigate.util.model import get_ort_providers
|
from frigate.util.model import get_ort_providers, post_process_yolov9
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -81,39 +80,8 @@ class ONNXDetector(DetectionApi):
|
|||||||
]
|
]
|
||||||
return detections
|
return detections
|
||||||
elif self.onnx_model_type == ModelTypeEnum.yolov9:
|
elif self.onnx_model_type == ModelTypeEnum.yolov9:
|
||||||
# see https://github.com/MultimediaTechLab/YOLO/blob/main/yolo/utils/bounding_box_utils.py#L338
|
|
||||||
predictions: np.ndarray = tensor_output[0]
|
predictions: np.ndarray = tensor_output[0]
|
||||||
predictions = np.squeeze(predictions).T
|
return post_process_yolov9(predictions, self.w, self.h)
|
||||||
scores = np.max(predictions[:, 4:], axis=1)
|
|
||||||
predictions = predictions[scores > 0.4, :]
|
|
||||||
scores = scores[scores > 0.4]
|
|
||||||
class_ids = np.argmax(predictions[:, 4:], axis=1)
|
|
||||||
|
|
||||||
# Rescale box
|
|
||||||
boxes = predictions[:, :4]
|
|
||||||
|
|
||||||
input_shape = np.array([self.w, self.h, self.w, self.h])
|
|
||||||
boxes = np.divide(boxes, input_shape, dtype=np.float32)
|
|
||||||
indices = cv2.dnn.NMSBoxes(
|
|
||||||
boxes, scores, score_threshold=0.4, nms_threshold=0.4
|
|
||||||
)
|
|
||||||
detections = np.zeros((20, 6), np.float32)
|
|
||||||
for i, (bbox, confidence, class_id) in enumerate(
|
|
||||||
zip(boxes[indices], scores[indices], class_ids[indices])
|
|
||||||
):
|
|
||||||
if i == 20:
|
|
||||||
break
|
|
||||||
|
|
||||||
detections[i] = [
|
|
||||||
class_id,
|
|
||||||
confidence,
|
|
||||||
bbox[1] - bbox[3] / 2,
|
|
||||||
bbox[0] - bbox[2] / 2,
|
|
||||||
bbox[1] + bbox[3] / 2,
|
|
||||||
bbox[0] + bbox[2] / 2,
|
|
||||||
]
|
|
||||||
|
|
||||||
return detections
|
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"{self.onnx_model_type} is currently not supported for rocm. See the docs for more info on supported models."
|
f"{self.onnx_model_type} is currently not supported for rocm. See the docs for more info on supported models."
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -14,6 +16,43 @@ except ImportError:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
### Post Processing
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_yolov9(predictions: np.ndarray, width, height) -> np.ndarray:
|
||||||
|
predictions = np.squeeze(predictions).T
|
||||||
|
scores = np.max(predictions[:, 4:], axis=1)
|
||||||
|
predictions = predictions[scores > 0.4, :]
|
||||||
|
scores = scores[scores > 0.4]
|
||||||
|
class_ids = np.argmax(predictions[:, 4:], axis=1)
|
||||||
|
|
||||||
|
# Rescale box
|
||||||
|
boxes = predictions[:, :4]
|
||||||
|
|
||||||
|
input_shape = np.array([width, height, width, height])
|
||||||
|
boxes = np.divide(boxes, input_shape, dtype=np.float32)
|
||||||
|
indices = cv2.dnn.NMSBoxes(boxes, scores, score_threshold=0.4, nms_threshold=0.4)
|
||||||
|
detections = np.zeros((20, 6), np.float32)
|
||||||
|
for i, (bbox, confidence, class_id) in enumerate(
|
||||||
|
zip(boxes[indices], scores[indices], class_ids[indices])
|
||||||
|
):
|
||||||
|
if i == 20:
|
||||||
|
break
|
||||||
|
|
||||||
|
detections[i] = [
|
||||||
|
class_id,
|
||||||
|
confidence,
|
||||||
|
bbox[1] - bbox[3] / 2,
|
||||||
|
bbox[0] - bbox[2] / 2,
|
||||||
|
bbox[1] + bbox[3] / 2,
|
||||||
|
bbox[0] + bbox[2] / 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
return detections
|
||||||
|
|
||||||
|
|
||||||
|
### ONNX Utilities
|
||||||
|
|
||||||
|
|
||||||
def get_ort_providers(
|
def get_ort_providers(
|
||||||
force_cpu: bool = False, device: str = "AUTO", requires_fp16: bool = False
|
force_cpu: bool = False, device: str = "AUTO", requires_fp16: bool = False
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user