mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
Add input tensor transpose to LocalObjectDetector
This commit is contained in:
parent
c09aee85e6
commit
46df2a6734
@ -25,6 +25,18 @@ class ObjectDetector(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_transform(desired_shape):
|
||||||
|
# Currently this function only supports BHWC permutations
|
||||||
|
if desired_shape == ["B", "H", "W", "C"]:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
transform = [0] * 4
|
||||||
|
transform[desired_shape.index("H")] = 1
|
||||||
|
transform[desired_shape.index("W")] = 2
|
||||||
|
transform[desired_shape.index("C")] = 3
|
||||||
|
return tuple(transform)
|
||||||
|
|
||||||
|
|
||||||
class LocalObjectDetector(ObjectDetector):
|
class LocalObjectDetector(ObjectDetector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -40,6 +52,11 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
else:
|
else:
|
||||||
self.labels = load_labels(labels)
|
self.labels = load_labels(labels)
|
||||||
|
|
||||||
|
if model_config:
|
||||||
|
self.input_transform = tensor_transform(model_config.input_tensor)
|
||||||
|
else:
|
||||||
|
self.input_transform = None
|
||||||
|
|
||||||
if det_type == DetectorTypeEnum.edgetpu:
|
if det_type == DetectorTypeEnum.edgetpu:
|
||||||
self.detect_api = EdgeTpuTfl(
|
self.detect_api = EdgeTpuTfl(
|
||||||
det_device=det_device, model_config=model_config
|
det_device=det_device, model_config=model_config
|
||||||
@ -65,6 +82,8 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
return detections
|
return detections
|
||||||
|
|
||||||
def detect_raw(self, tensor_input):
|
def detect_raw(self, tensor_input):
|
||||||
|
if self.input_transform:
|
||||||
|
tensor_input = np.transpose(tensor_input, self.input_transform)
|
||||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -58,12 +58,39 @@ class TestLocalObjectDetector(unittest.TestCase):
|
|||||||
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
|
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
|
||||||
assert test_result is mock_det_api.detect_raw.return_value
|
assert test_result is mock_det_api.detect_raw.return_value
|
||||||
|
|
||||||
|
@patch("frigate.object_detection.CpuTfl")
|
||||||
|
def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor(
|
||||||
|
self, mock_cputfl
|
||||||
|
):
|
||||||
|
TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8)
|
||||||
|
TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32])
|
||||||
|
|
||||||
|
test_cfg = ModelConfig()
|
||||||
|
test_cfg.input_tensor = ["B", "C", "H", "W"]
|
||||||
|
|
||||||
|
test_obj_detect = frigate.object_detection.LocalObjectDetector(
|
||||||
|
det_device=DetectorTypeEnum.cpu, model_config=test_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_det_api = mock_cputfl.return_value
|
||||||
|
mock_det_api.detect_raw.return_value = TEST_DETECT_RESULT
|
||||||
|
|
||||||
|
test_result = test_obj_detect.detect_raw(TEST_DATA)
|
||||||
|
|
||||||
|
mock_det_api.detect_raw.assert_called_once()
|
||||||
|
assert (
|
||||||
|
mock_det_api.detect_raw.call_args.kwargs["tensor_input"].shape
|
||||||
|
== np.zeros((1, 3, 32, 32)).shape
|
||||||
|
)
|
||||||
|
|
||||||
|
assert test_result is mock_det_api.detect_raw.return_value
|
||||||
|
|
||||||
@patch("frigate.object_detection.CpuTfl")
|
@patch("frigate.object_detection.CpuTfl")
|
||||||
@patch("frigate.object_detection.load_labels")
|
@patch("frigate.object_detection.load_labels")
|
||||||
def test_detect_given_tensor_input_should_return_lfiltered_detections(
|
def test_detect_given_tensor_input_should_return_lfiltered_detections(
|
||||||
self, mock_load_labels, mock_cputfl
|
self, mock_load_labels, mock_cputfl
|
||||||
):
|
):
|
||||||
TEST_DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8)
|
||||||
TEST_DETECT_RAW = [
|
TEST_DETECT_RAW = [
|
||||||
[2, 0.9, 5, 4, 3, 2],
|
[2, 0.9, 5, 4, 3, 2],
|
||||||
[1, 0.5, 8, 7, 6, 5],
|
[1, 0.5, 8, 7, 6, 5],
|
||||||
@ -83,7 +110,9 @@ class TestLocalObjectDetector(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
test_obj_detect = frigate.object_detection.LocalObjectDetector(
|
test_obj_detect = frigate.object_detection.LocalObjectDetector(
|
||||||
det_device=DetectorTypeEnum.cpu, labels=TEST_LABEL_FILE
|
det_device=DetectorTypeEnum.cpu,
|
||||||
|
model_config=ModelConfig(),
|
||||||
|
labels=TEST_LABEL_FILE,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_load_labels.assert_called_once_with(TEST_LABEL_FILE)
|
mock_load_labels.assert_called_once_with(TEST_LABEL_FILE)
|
||||||
@ -93,5 +122,9 @@ class TestLocalObjectDetector(unittest.TestCase):
|
|||||||
|
|
||||||
test_result = test_obj_detect.detect(tensor_input=TEST_DATA, threshold=0.5)
|
test_result = test_obj_detect.detect(tensor_input=TEST_DATA, threshold=0.5)
|
||||||
|
|
||||||
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
|
mock_det_api.detect_raw.assert_called_once()
|
||||||
|
assert (
|
||||||
|
mock_det_api.detect_raw.call_args.kwargs["tensor_input"].shape
|
||||||
|
== np.zeros((1, 32, 32, 3)).shape
|
||||||
|
)
|
||||||
assert test_result == TEST_DETECT_RESULT
|
assert test_result == TEST_DETECT_RESULT
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user