From 46df2a67345f9c63244698581b2426f33dd9e61d Mon Sep 17 00:00:00 2001 From: Nate Meyer Date: Wed, 31 Aug 2022 01:58:23 -0400 Subject: [PATCH] Add input tensor transpose to LocalObjectDetector --- frigate/object_detection.py | 19 ++++++++++++++ frigate/test/test_object_detector.py | 39 +++++++++++++++++++++++++--- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/frigate/object_detection.py b/frigate/object_detection.py index 06944f64d..7dcc3088f 100644 --- a/frigate/object_detection.py +++ b/frigate/object_detection.py @@ -25,6 +25,18 @@ class ObjectDetector(ABC): pass +def tensor_transform(desired_shape): + # Currently this function only supports BHWC permutations + if desired_shape == ["B", "H", "W", "C"]: + return None + else: + transform = [0] * 4 + transform[desired_shape.index("H")] = 1 + transform[desired_shape.index("W")] = 2 + transform[desired_shape.index("C")] = 3 + return tuple(transform) + + class LocalObjectDetector(ObjectDetector): def __init__( self, @@ -40,6 +52,11 @@ class LocalObjectDetector(ObjectDetector): else: self.labels = load_labels(labels) + if model_config: + self.input_transform = tensor_transform(model_config.input_tensor) + else: + self.input_transform = None + if det_type == DetectorTypeEnum.edgetpu: self.detect_api = EdgeTpuTfl( det_device=det_device, model_config=model_config @@ -65,6 +82,8 @@ class LocalObjectDetector(ObjectDetector): return detections def detect_raw(self, tensor_input): + if self.input_transform: + tensor_input = np.transpose(tensor_input, self.input_transform) return self.detect_api.detect_raw(tensor_input=tensor_input) diff --git a/frigate/test/test_object_detector.py b/frigate/test/test_object_detector.py index a7ef17377..4e85ebda7 100644 --- a/frigate/test/test_object_detector.py +++ b/frigate/test/test_object_detector.py @@ -58,12 +58,39 @@ class TestLocalObjectDetector(unittest.TestCase): mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA) assert test_result is mock_det_api.detect_raw.return_value + @patch("frigate.object_detection.CpuTfl") + def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor( + self, mock_cputfl + ): + TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8) + TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32]) + + test_cfg = ModelConfig() + test_cfg.input_tensor = ["B", "C", "H", "W"] + + test_obj_detect = frigate.object_detection.LocalObjectDetector( + det_device=DetectorTypeEnum.cpu, model_config=test_cfg + ) + + mock_det_api = mock_cputfl.return_value + mock_det_api.detect_raw.return_value = TEST_DETECT_RESULT + + test_result = test_obj_detect.detect_raw(TEST_DATA) + + mock_det_api.detect_raw.assert_called_once() + assert ( + mock_det_api.detect_raw.call_args.kwargs["tensor_input"].shape + == np.zeros((1, 3, 32, 32)).shape + ) + + assert test_result is mock_det_api.detect_raw.return_value + @patch("frigate.object_detection.CpuTfl") @patch("frigate.object_detection.load_labels") def test_detect_given_tensor_input_should_return_lfiltered_detections( self, mock_load_labels, mock_cputfl ): - TEST_DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8) TEST_DETECT_RAW = [ [2, 0.9, 5, 4, 3, 2], [1, 0.5, 8, 7, 6, 5], @@ -83,7 +110,9 @@ class TestLocalObjectDetector(unittest.TestCase): ] test_obj_detect = frigate.object_detection.LocalObjectDetector( - det_device=DetectorTypeEnum.cpu, labels=TEST_LABEL_FILE + det_device=DetectorTypeEnum.cpu, + model_config=ModelConfig(), + labels=TEST_LABEL_FILE, ) mock_load_labels.assert_called_once_with(TEST_LABEL_FILE) @@ -93,5 +122,9 @@ class TestLocalObjectDetector(unittest.TestCase): test_result = test_obj_detect.detect(tensor_input=TEST_DATA, threshold=0.5) - mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA) + mock_det_api.detect_raw.assert_called_once() + assert ( + mock_det_api.detect_raw.call_args.kwargs["tensor_input"].shape + == np.zeros((1, 32, 32, 3)).shape + ) assert test_result == TEST_DETECT_RESULT