Add input tensor transpose to LocalObjectDetector

This commit is contained in:
Nate Meyer 2022-08-31 01:58:23 -04:00
parent c09aee85e6
commit 46df2a6734
2 changed files with 55 additions and 3 deletions

View File

@ -25,6 +25,18 @@ class ObjectDetector(ABC):
pass
def tensor_transform(desired_shape):
# Currently this function only supports BHWC permutations
if desired_shape == ["B", "H", "W", "C"]:
return None
else:
transform = [0] * 4
transform[desired_shape.index("H")] = 1
transform[desired_shape.index("W")] = 2
transform[desired_shape.index("C")] = 3
return tuple(transform)
class LocalObjectDetector(ObjectDetector):
def __init__(
self,
@ -40,6 +52,11 @@ class LocalObjectDetector(ObjectDetector):
else:
self.labels = load_labels(labels)
if model_config:
self.input_transform = tensor_transform(model_config.input_tensor)
else:
self.input_transform = None
if det_type == DetectorTypeEnum.edgetpu:
self.detect_api = EdgeTpuTfl(
det_device=det_device, model_config=model_config
@ -65,6 +82,8 @@ class LocalObjectDetector(ObjectDetector):
return detections
def detect_raw(self, tensor_input):
if self.input_transform:
tensor_input = np.transpose(tensor_input, self.input_transform)
return self.detect_api.detect_raw(tensor_input=tensor_input)

View File

@ -58,12 +58,39 @@ class TestLocalObjectDetector(unittest.TestCase):
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
assert test_result is mock_det_api.detect_raw.return_value
@patch("frigate.object_detection.CpuTfl")
def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor(
self, mock_cputfl
):
TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8)
TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32])
test_cfg = ModelConfig()
test_cfg.input_tensor = ["B", "C", "H", "W"]
test_obj_detect = frigate.object_detection.LocalObjectDetector(
det_device=DetectorTypeEnum.cpu, model_config=test_cfg
)
mock_det_api = mock_cputfl.return_value
mock_det_api.detect_raw.return_value = TEST_DETECT_RESULT
test_result = test_obj_detect.detect_raw(TEST_DATA)
mock_det_api.detect_raw.assert_called_once()
assert (
mock_det_api.detect_raw.call_args.kwargs["tensor_input"].shape
== np.zeros((1, 3, 32, 32)).shape
)
assert test_result is mock_det_api.detect_raw.return_value
@patch("frigate.object_detection.CpuTfl")
@patch("frigate.object_detection.load_labels")
def test_detect_given_tensor_input_should_return_lfiltered_detections(
self, mock_load_labels, mock_cputfl
):
TEST_DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8)
TEST_DETECT_RAW = [
[2, 0.9, 5, 4, 3, 2],
[1, 0.5, 8, 7, 6, 5],
@ -83,7 +110,9 @@ class TestLocalObjectDetector(unittest.TestCase):
]
test_obj_detect = frigate.object_detection.LocalObjectDetector(
det_device=DetectorTypeEnum.cpu, labels=TEST_LABEL_FILE
det_device=DetectorTypeEnum.cpu,
model_config=ModelConfig(),
labels=TEST_LABEL_FILE,
)
mock_load_labels.assert_called_once_with(TEST_LABEL_FILE)
@ -93,5 +122,9 @@ class TestLocalObjectDetector(unittest.TestCase):
test_result = test_obj_detect.detect(tensor_input=TEST_DATA, threshold=0.5)
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
mock_det_api.detect_raw.assert_called_once()
assert (
mock_det_api.detect_raw.call_args.kwargs["tensor_input"].shape
== np.zeros((1, 32, 32, 3)).shape
)
assert test_result == TEST_DETECT_RESULT