mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-18 17:14:26 +03:00
Move post processing to separate utility
This commit is contained in:
parent
77b5182e7b
commit
4685ffb41b
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from typing_extensions import Literal
|
||||
@ -10,7 +9,7 @@ from frigate.detectors.detector_config import (
|
||||
BaseDetectorConfig,
|
||||
ModelTypeEnum,
|
||||
)
|
||||
from frigate.util.model import get_ort_providers
|
||||
from frigate.util.model import get_ort_providers, post_process_yolov9
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -81,39 +80,8 @@ class ONNXDetector(DetectionApi):
|
||||
]
|
||||
return detections
|
||||
elif self.onnx_model_type == ModelTypeEnum.yolov9:
|
||||
# see https://github.com/MultimediaTechLab/YOLO/blob/main/yolo/utils/bounding_box_utils.py#L338
|
||||
predictions: np.ndarray = tensor_output[0]
|
||||
predictions = np.squeeze(predictions).T
|
||||
scores = np.max(predictions[:, 4:], axis=1)
|
||||
predictions = predictions[scores > 0.4, :]
|
||||
scores = scores[scores > 0.4]
|
||||
class_ids = np.argmax(predictions[:, 4:], axis=1)
|
||||
|
||||
# Rescale box
|
||||
boxes = predictions[:, :4]
|
||||
|
||||
input_shape = np.array([self.w, self.h, self.w, self.h])
|
||||
boxes = np.divide(boxes, input_shape, dtype=np.float32)
|
||||
indices = cv2.dnn.NMSBoxes(
|
||||
boxes, scores, score_threshold=0.4, nms_threshold=0.4
|
||||
)
|
||||
detections = np.zeros((20, 6), np.float32)
|
||||
for i, (bbox, confidence, class_id) in enumerate(
|
||||
zip(boxes[indices], scores[indices], class_ids[indices])
|
||||
):
|
||||
if i == 20:
|
||||
break
|
||||
|
||||
detections[i] = [
|
||||
class_id,
|
||||
confidence,
|
||||
bbox[1] - bbox[3] / 2,
|
||||
bbox[0] - bbox[2] / 2,
|
||||
bbox[1] + bbox[3] / 2,
|
||||
bbox[0] + bbox[2] / 2,
|
||||
]
|
||||
|
||||
return detections
|
||||
return post_process_yolov9(predictions, self.w, self.h)
|
||||
else:
|
||||
raise Exception(
|
||||
f"{self.onnx_model_type} is currently not supported for rocm. See the docs for more info on supported models."
|
||||
|
||||
@ -4,6 +4,8 @@ import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
try:
|
||||
@ -14,6 +16,43 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
### Post Processing
|
||||
|
||||
|
||||
def post_process_yolov9(predictions: np.ndarray, width, height) -> np.ndarray:
|
||||
predictions = np.squeeze(predictions).T
|
||||
scores = np.max(predictions[:, 4:], axis=1)
|
||||
predictions = predictions[scores > 0.4, :]
|
||||
scores = scores[scores > 0.4]
|
||||
class_ids = np.argmax(predictions[:, 4:], axis=1)
|
||||
|
||||
# Rescale box
|
||||
boxes = predictions[:, :4]
|
||||
|
||||
input_shape = np.array([width, height, width, height])
|
||||
boxes = np.divide(boxes, input_shape, dtype=np.float32)
|
||||
indices = cv2.dnn.NMSBoxes(boxes, scores, score_threshold=0.4, nms_threshold=0.4)
|
||||
detections = np.zeros((20, 6), np.float32)
|
||||
for i, (bbox, confidence, class_id) in enumerate(
|
||||
zip(boxes[indices], scores[indices], class_ids[indices])
|
||||
):
|
||||
if i == 20:
|
||||
break
|
||||
|
||||
detections[i] = [
|
||||
class_id,
|
||||
confidence,
|
||||
bbox[1] - bbox[3] / 2,
|
||||
bbox[0] - bbox[2] / 2,
|
||||
bbox[1] + bbox[3] / 2,
|
||||
bbox[0] + bbox[2] / 2,
|
||||
]
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
### ONNX Utilities
|
||||
|
||||
|
||||
def get_ort_providers(
|
||||
force_cpu: bool = False, device: str = "AUTO", requires_fp16: bool = False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user