From 41102d82d87521c6b458fdc7663a46caa9a7cd20 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Mon, 28 Oct 2024 09:48:28 -0600 Subject: [PATCH] Add input type for dtype --- frigate/detectors/detector_config.py | 8 ++++++++ frigate/detectors/plugins/onnx.py | 2 +- frigate/object_detection.py | 17 ++++++++++++++--- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/frigate/detectors/detector_config.py b/frigate/detectors/detector_config.py index c40ef65bf..45875e2e6 100644 --- a/frigate/detectors/detector_config.py +++ b/frigate/detectors/detector_config.py @@ -27,6 +27,11 @@ class InputTensorEnum(str, Enum): nhwc = "nhwc" +class InputDTypeEnum(str, Enum): + float = "float" + int = "int" + + class ModelTypeEnum(str, Enum): ssd = "ssd" yolox = "yolox" @@ -53,6 +58,9 @@ class ModelConfig(BaseModel): input_pixel_format: PixelFormatEnum = Field( default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format" ) + input_dtype: InputDTypeEnum = Field( + default=InputDTypeEnum.int, title="Model Input D Type" + ) model_type: ModelTypeEnum = Field( default=ModelTypeEnum.ssd, title="Object Detection Model Type" ) diff --git a/frigate/detectors/plugins/onnx.py b/frigate/detectors/plugins/onnx.py index 3e58df72a..7004f28fa 100644 --- a/frigate/detectors/plugins/onnx.py +++ b/frigate/detectors/plugins/onnx.py @@ -54,7 +54,7 @@ class ONNXDetector(DetectionApi): logger.info(f"ONNX: {path} loaded") - def detect_raw(self, tensor_input): + def detect_raw(self, tensor_input: np.ndarray): model_input_name = self.model.get_inputs()[0].name tensor_output = self.model.run(None, {model_input_name: tensor_input}) diff --git a/frigate/object_detection.py b/frigate/object_detection.py index eaa3b4e04..0af32034e 100644 --- a/frigate/object_detection.py +++ b/frigate/object_detection.py @@ -12,7 +12,11 @@ from setproctitle import setproctitle import frigate.util as util from frigate.detectors import create_detector -from frigate.detectors.detector_config import BaseDetectorConfig, InputTensorEnum +from frigate.detectors.detector_config import ( + BaseDetectorConfig, + InputDTypeEnum, + InputTensorEnum, +) from frigate.detectors.plugins.rocm import DETECTOR_KEY as ROCM_DETECTOR_KEY from frigate.util.builtin import EventsPerSecond, load_labels from frigate.util.image import SharedMemoryFrameManager @@ -55,12 +59,15 @@ class LocalObjectDetector(ObjectDetector): self.input_transform = tensor_transform( detector_config.model.input_tensor ) + + self.dtype = detector_config.model.input_dtype else: self.input_transform = None + self.dtype = InputDTypeEnum.int self.detect_api = create_detector(detector_config) - def detect(self, tensor_input, threshold=0.4): + def detect(self, tensor_input: np.ndarray, threshold=0.4): detections = [] raw_detections = self.detect_raw(tensor_input) @@ -77,9 +84,13 @@ class LocalObjectDetector(ObjectDetector): self.fps.update() return detections - def detect_raw(self, tensor_input): + def detect_raw(self, tensor_input: np.ndarray): if self.input_transform: tensor_input = np.transpose(tensor_input, self.input_transform) + + if self.dtype == InputDTypeEnum.float: + tensor_input = tensor_input.astype(np.float32) + return self.detect_api.detect_raw(tensor_input=tensor_input)