mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
Add input tensor transpose to LocalObjectDetector
This commit is contained in:
parent
c09aee85e6
commit
46df2a6734
@ -25,6 +25,18 @@ class ObjectDetector(ABC):
|
||||
pass
|
||||
|
||||
|
||||
def tensor_transform(desired_shape):
|
||||
# Currently this function only supports BHWC permutations
|
||||
if desired_shape == ["B", "H", "W", "C"]:
|
||||
return None
|
||||
else:
|
||||
transform = [0] * 4
|
||||
transform[desired_shape.index("H")] = 1
|
||||
transform[desired_shape.index("W")] = 2
|
||||
transform[desired_shape.index("C")] = 3
|
||||
return tuple(transform)
|
||||
|
||||
|
||||
class LocalObjectDetector(ObjectDetector):
|
||||
def __init__(
|
||||
self,
|
||||
@ -40,6 +52,11 @@ class LocalObjectDetector(ObjectDetector):
|
||||
else:
|
||||
self.labels = load_labels(labels)
|
||||
|
||||
if model_config:
|
||||
self.input_transform = tensor_transform(model_config.input_tensor)
|
||||
else:
|
||||
self.input_transform = None
|
||||
|
||||
if det_type == DetectorTypeEnum.edgetpu:
|
||||
self.detect_api = EdgeTpuTfl(
|
||||
det_device=det_device, model_config=model_config
|
||||
@ -65,6 +82,8 @@ class LocalObjectDetector(ObjectDetector):
|
||||
return detections
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
if self.input_transform:
|
||||
tensor_input = np.transpose(tensor_input, self.input_transform)
|
||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||
|
||||
|
||||
|
||||
@ -58,12 +58,39 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
|
||||
assert test_result is mock_det_api.detect_raw.return_value
|
||||
|
||||
@patch("frigate.object_detection.CpuTfl")
|
||||
def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor(
|
||||
self, mock_cputfl
|
||||
):
|
||||
TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8)
|
||||
TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32])
|
||||
|
||||
test_cfg = ModelConfig()
|
||||
test_cfg.input_tensor = ["B", "C", "H", "W"]
|
||||
|
||||
test_obj_detect = frigate.object_detection.LocalObjectDetector(
|
||||
det_device=DetectorTypeEnum.cpu, model_config=test_cfg
|
||||
)
|
||||
|
||||
mock_det_api = mock_cputfl.return_value
|
||||
mock_det_api.detect_raw.return_value = TEST_DETECT_RESULT
|
||||
|
||||
test_result = test_obj_detect.detect_raw(TEST_DATA)
|
||||
|
||||
mock_det_api.detect_raw.assert_called_once()
|
||||
assert (
|
||||
mock_det_api.detect_raw.call_args.kwargs["tensor_input"].shape
|
||||
== np.zeros((1, 3, 32, 32)).shape
|
||||
)
|
||||
|
||||
assert test_result is mock_det_api.detect_raw.return_value
|
||||
|
||||
@patch("frigate.object_detection.CpuTfl")
|
||||
@patch("frigate.object_detection.load_labels")
|
||||
def test_detect_given_tensor_input_should_return_lfiltered_detections(
|
||||
self, mock_load_labels, mock_cputfl
|
||||
):
|
||||
TEST_DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8)
|
||||
TEST_DETECT_RAW = [
|
||||
[2, 0.9, 5, 4, 3, 2],
|
||||
[1, 0.5, 8, 7, 6, 5],
|
||||
@ -83,7 +110,9 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
]
|
||||
|
||||
test_obj_detect = frigate.object_detection.LocalObjectDetector(
|
||||
det_device=DetectorTypeEnum.cpu, labels=TEST_LABEL_FILE
|
||||
det_device=DetectorTypeEnum.cpu,
|
||||
model_config=ModelConfig(),
|
||||
labels=TEST_LABEL_FILE,
|
||||
)
|
||||
|
||||
mock_load_labels.assert_called_once_with(TEST_LABEL_FILE)
|
||||
@ -93,5 +122,9 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
|
||||
test_result = test_obj_detect.detect(tensor_input=TEST_DATA, threshold=0.5)
|
||||
|
||||
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
|
||||
mock_det_api.detect_raw.assert_called_once()
|
||||
assert (
|
||||
mock_det_api.detect_raw.call_args.kwargs["tensor_input"].shape
|
||||
== np.zeros((1, 32, 32, 3)).shape
|
||||
)
|
||||
assert test_result == TEST_DETECT_RESULT
|
||||
|
||||
Loading…
Reference in New Issue
Block a user