From 1f8a8ffd3d58749804c0408131c0f2ecd4d2cbc6 Mon Sep 17 00:00:00 2001 From: Nate Meyer Date: Wed, 31 Aug 2022 00:48:40 -0400 Subject: [PATCH] Add configuration for model inputs Support transforming detection regions to RGB or BGR. Support specifying the input tensor shape. The tensor shape has a standard format ["BHWC"] when handed to the detector, but can be transformed in the detector to match the model shape using the model input_tensor config. --- frigate/app.py | 8 ++---- frigate/config.py | 12 +++++++++ frigate/detectors/cpu_tfl.py | 4 +-- frigate/detectors/detection_api.py | 2 +- frigate/detectors/edgetpu_tfl.py | 4 +-- frigate/object_detection.py | 30 ++++++++++----------- frigate/test/test_object_detector.py | 17 +++++++----- frigate/util.py | 10 +++++++ frigate/video.py | 39 +++++++++++++++++----------- process_clip.py | 3 +-- 10 files changed, 79 insertions(+), 50 deletions(-) diff --git a/frigate/app.py b/frigate/app.py index 3071c751b..958d4fe3c 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -173,8 +173,6 @@ class FrigateApp: self.mqtt_relay.start() def start_detectors(self) -> None: - model_path = self.config.model.path - model_shape = (self.config.model.height, self.config.model.width) for name in self.config.cameras.keys(): self.detection_out_events[name] = mp.Event() @@ -202,8 +200,7 @@ class FrigateApp: name, self.detection_queue, self.detection_out_events, - model_path, - model_shape, + self.config.model, detector.type, detector.device, detector.num_threads, @@ -238,7 +235,6 @@ class FrigateApp: logger.info(f"Output process started: {output_processor.pid}") def start_camera_processors(self) -> None: - model_shape = (self.config.model.height, self.config.model.width) for name, config in self.config.cameras.items(): camera_process = mp.Process( target=track_camera, @@ -246,7 +242,7 @@ class FrigateApp: args=( name, config, - model_shape, + self.config.model, self.config.model.merged_labelmap, self.detection_queue, self.detection_out_events[name], diff --git a/frigate/config.py b/frigate/config.py index 1370dcf88..005d0dfcf 100644 --- a/frigate/config.py +++ b/frigate/config.py @@ -687,6 +687,12 @@ class DatabaseConfig(FrigateBaseModel): ) +class PixelFormatEnum(str, Enum): + rgb = "rgb" + bgr = "bgr" + yuv = "yuv" + + class ModelConfig(FrigateBaseModel): path: Optional[str] = Field(title="Custom Object detection model path.") labelmap_path: Optional[str] = Field(title="Label map for custom object detector.") @@ -695,6 +701,12 @@ class ModelConfig(FrigateBaseModel): labelmap: Dict[int, str] = Field( default_factory=dict, title="Labelmap customization." ) + input_tensor: List[str] = Field( + default=["B", "H", "W", "C"], title="Model Input Tensor Shape" + ) + input_pixel_format: PixelFormatEnum = Field( + default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format" + ) _merged_labelmap: Optional[Dict[int, str]] = PrivateAttr() _colormap: Dict[int, Tuple[int, int, int]] = PrivateAttr() diff --git a/frigate/detectors/cpu_tfl.py b/frigate/detectors/cpu_tfl.py index b0c160783..ea1e4ddc2 100644 --- a/frigate/detectors/cpu_tfl.py +++ b/frigate/detectors/cpu_tfl.py @@ -8,9 +8,9 @@ logger = logging.getLogger(__name__) class CpuTfl(DetectionApi): - def __init__(self, det_device=None, model_path=None, num_threads=3): + def __init__(self, det_device=None, model_config=None, num_threads=3): self.interpreter = tflite.Interpreter( - model_path=model_path or "/cpu_model.tflite", num_threads=num_threads + model_path=model_config.path or "/cpu_model.tflite", num_threads=num_threads ) self.interpreter.allocate_tensors() diff --git a/frigate/detectors/detection_api.py b/frigate/detectors/detection_api.py index b8830334b..244195d46 100644 --- a/frigate/detectors/detection_api.py +++ b/frigate/detectors/detection_api.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) class DetectionApi(ABC): @abstractmethod - def __init__(self, det_device=None, model_path=None): + def __init__(self, det_device=None, model_config=None): pass @abstractmethod diff --git a/frigate/detectors/edgetpu_tfl.py b/frigate/detectors/edgetpu_tfl.py index f917f95e0..aa3abf70c 100644 --- a/frigate/detectors/edgetpu_tfl.py +++ b/frigate/detectors/edgetpu_tfl.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) class EdgeTpuTfl(DetectionApi): - def __init__(self, det_device=None, model_path=None): + def __init__(self, det_device=None, model_config=None): device_config = {"device": "usb"} if not det_device is None: device_config = {"device": det_device} @@ -21,7 +21,7 @@ class EdgeTpuTfl(DetectionApi): edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config) logger.info("TPU found") self.interpreter = tflite.Interpreter( - model_path=model_path or "/edgetpu_model.tflite", + model_path=model_config.path or "/edgetpu_model.tflite", experimental_delegates=[edge_tpu_delegate], ) except ValueError: diff --git a/frigate/object_detection.py b/frigate/object_detection.py index d8374644f..06944f64d 100644 --- a/frigate/object_detection.py +++ b/frigate/object_detection.py @@ -30,7 +30,7 @@ class LocalObjectDetector(ObjectDetector): self, det_type=DetectorTypeEnum.cpu, det_device=None, - model_path=None, + model_config=None, num_threads=3, labels=None, ): @@ -41,12 +41,14 @@ class LocalObjectDetector(ObjectDetector): self.labels = load_labels(labels) if det_type == DetectorTypeEnum.edgetpu: - self.detect_api = EdgeTpuTfl(det_device=det_device, model_path=model_path) + self.detect_api = EdgeTpuTfl( + det_device=det_device, model_config=model_config + ) else: logger.warning( "CPU detectors are not recommended and should only be used for testing or for trial purposes." ) - self.detect_api = CpuTfl(model_path=model_path, num_threads=num_threads) + self.detect_api = CpuTfl(model_config=model_config, num_threads=num_threads) def detect(self, tensor_input, threshold=0.4): detections = [] @@ -72,8 +74,7 @@ def run_detector( out_events: dict[str, mp.Event], avg_speed, start, - model_path, - model_shape, + model_config, det_type, det_device, num_threads, @@ -96,7 +97,7 @@ def run_detector( object_detector = LocalObjectDetector( det_type=det_type, det_device=det_device, - model_path=model_path, + model_config=model_config, num_threads=num_threads, ) @@ -112,7 +113,7 @@ def run_detector( except queue.Empty: continue input_frame = frame_manager.get( - connection_id, (1, model_shape[0], model_shape[1], 3) + connection_id, (1, model_config.height, model_config.width, 3) ) if input_frame is None: @@ -135,8 +136,7 @@ class ObjectDetectProcess: name, detection_queue, out_events, - model_path, - model_shape, + model_config, det_type=None, det_device=None, num_threads=3, @@ -147,8 +147,7 @@ class ObjectDetectProcess: self.avg_inference_speed = mp.Value("d", 0.01) self.detection_start = mp.Value("d", 0.0) self.detect_process = None - self.model_path = model_path - self.model_shape = model_shape + self.model_config = model_config self.det_type = det_type self.det_device = det_device self.num_threads = num_threads @@ -176,8 +175,7 @@ class ObjectDetectProcess: self.out_events, self.avg_inference_speed, self.detection_start, - self.model_path, - self.model_shape, + self.model_config, self.det_type, self.det_device, self.num_threads, @@ -188,7 +186,7 @@ class ObjectDetectProcess: class RemoteObjectDetector: - def __init__(self, name, labels, detection_queue, event, model_shape): + def __init__(self, name, labels, detection_queue, event, model_config): self.labels = labels self.name = name self.fps = EventsPerSecond() @@ -196,7 +194,9 @@ class RemoteObjectDetector: self.event = event self.shm = mp.shared_memory.SharedMemory(name=self.name, create=False) self.np_shm = np.ndarray( - (1, model_shape[0], model_shape[1], 3), dtype=np.uint8, buffer=self.shm.buf + (1, model_config.height, model_config.width, 3), + dtype=np.uint8, + buffer=self.shm.buf, ) self.out_shm = mp.shared_memory.SharedMemory( name=f"out-{self.name}", create=False diff --git a/frigate/test/test_object_detector.py b/frigate/test/test_object_detector.py index 0a3137273..a7ef17377 100644 --- a/frigate/test/test_object_detector.py +++ b/frigate/test/test_object_detector.py @@ -2,7 +2,7 @@ import unittest from unittest.mock import patch import numpy as np -from frigate.config import DetectorTypeEnum +from frigate.config import DetectorTypeEnum, ModelConfig import frigate.object_detection @@ -12,30 +12,33 @@ class TestLocalObjectDetector(unittest.TestCase): def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init( self, mock_cputfl, mock_edgetputfl ): + test_cfg = ModelConfig() + test_cfg.path = "/test/modelpath" test_obj = frigate.object_detection.LocalObjectDetector( - det_type=DetectorTypeEnum.cpu, model_path="/test/modelpath", num_threads=6 + det_type=DetectorTypeEnum.cpu, model_config=test_cfg, num_threads=6 ) assert test_obj is not None mock_edgetputfl.assert_not_called() - mock_cputfl.assert_called_once_with(model_path="/test/modelpath", num_threads=6) + mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6) @patch("frigate.object_detection.EdgeTpuTfl") @patch("frigate.object_detection.CpuTfl") def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init( self, mock_cputfl, mock_edgetputfl ): + test_cfg = ModelConfig() + test_cfg.path = "/test/modelpath" + test_obj = frigate.object_detection.LocalObjectDetector( det_type=DetectorTypeEnum.edgetpu, det_device="usb", - model_path="/test/modelpath", + model_config=test_cfg, ) assert test_obj is not None mock_cputfl.assert_not_called() - mock_edgetputfl.assert_called_once_with( - det_device="usb", model_path="/test/modelpath" - ) + mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg) @patch("frigate.object_detection.CpuTfl") def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result( diff --git a/frigate/util.py b/frigate/util.py index 49bda8620..53b7bd1e0 100755 --- a/frigate/util.py +++ b/frigate/util.py @@ -479,6 +479,16 @@ def yuv_region_2_rgb(frame, region): raise +def yuv_region_2_bgr(frame, region): + try: + yuv_cropped_frame = yuv_crop_and_resize(frame, region) + return cv2.cvtColor(yuv_cropped_frame, cv2.COLOR_YUV2BGR_I420) + except: + print(f"frame.shape: {frame.shape}") + print(f"region: {region}") + raise + + def intersection(box_a, box_b): return ( max(box_a[0], box_b[0]), diff --git a/frigate/video.py b/frigate/video.py index dd1ea18dd..06daa8ea2 100755 --- a/frigate/video.py +++ b/frigate/video.py @@ -14,7 +14,7 @@ import numpy as np import cv2 from setproctitle import setproctitle -from frigate.config import CameraConfig, DetectConfig +from frigate.config import CameraConfig, DetectConfig, PixelFormatEnum from frigate.object_detection import RemoteObjectDetector from frigate.log import LogPipe from frigate.motion import MotionDetector @@ -29,7 +29,9 @@ from frigate.util import ( intersection, intersection_over_union, listen, + yuv_crop_and_resize, yuv_region_2_rgb, + yuv_region_2_bgr, ) logger = logging.getLogger(__name__) @@ -89,13 +91,20 @@ def filtered(obj, objects_to_track, object_filters): return False -def create_tensor_input(frame, model_shape, region): - cropped_frame = yuv_region_2_rgb(frame, region) +def create_tensor_input(frame, model_config, region): + if model_config.input_pixel_format == PixelFormatEnum.rgb: + cropped_frame = yuv_region_2_rgb(frame, region) + elif model_config.input_pixel_format == PixelFormatEnum.bgr: + cropped_frame = yuv_region_2_bgr(frame, region) + else: + cropped_frame = yuv_crop_and_resize(frame, region) - # Resize to 300x300 if needed - if cropped_frame.shape != (model_shape[0], model_shape[1], 3): + # Resize if needed + if cropped_frame.shape != (model_config.height, model_config.width, 3): cropped_frame = cv2.resize( - cropped_frame, dsize=model_shape, interpolation=cv2.INTER_LINEAR + cropped_frame, + dsize=(model_config.height, model_config.width), + interpolation=cv2.INTER_LINEAR, ) # Expand dimensions since the model expects images to have shape: [1, height, width, 3] @@ -340,7 +349,7 @@ def capture_camera(name, config: CameraConfig, process_info): def track_camera( name, config: CameraConfig, - model_shape, + model_config, labelmap, detection_queue, result_connection, @@ -378,7 +387,7 @@ def track_camera( motion_contour_area, ) object_detector = RemoteObjectDetector( - name, labelmap, detection_queue, result_connection, model_shape + name, labelmap, detection_queue, result_connection, model_config ) object_tracker = ObjectTracker(config.detect) @@ -389,7 +398,7 @@ def track_camera( name, frame_queue, frame_shape, - model_shape, + model_config, config.detect, frame_manager, motion_detector, @@ -443,12 +452,12 @@ def detect( detect_config: DetectConfig, object_detector, frame, - model_shape, + model_config, region, objects_to_track, object_filters, ): - tensor_input = create_tensor_input(frame, model_shape, region) + tensor_input = create_tensor_input(frame, model_config, region) detections = [] region_detections = object_detector.detect(tensor_input) @@ -487,7 +496,7 @@ def process_frames( camera_name: str, frame_queue: mp.Queue, frame_shape, - model_shape, + model_config, detect_config: DetectConfig, frame_manager: FrameManager, motion_detector: MotionDetector, @@ -571,7 +580,7 @@ def process_frames( # combine motion boxes with known locations of existing objects combined_boxes = reduce_boxes(motion_boxes + tracked_object_boxes) - region_min_size = max(model_shape[0], model_shape[1]) + region_min_size = max(model_config.height, model_config.width) # compute regions regions = [ calculate_region( @@ -634,7 +643,7 @@ def process_frames( detect_config, object_detector, frame, - model_shape, + model_config, region, objects_to_track, object_filters, @@ -694,7 +703,7 @@ def process_frames( detect_config, object_detector, frame, - model_shape, + model_config, region, objects_to_track, object_filters, diff --git a/process_clip.py b/process_clip.py index e932241a0..d8dabbedd 100644 --- a/process_clip.py +++ b/process_clip.py @@ -117,13 +117,12 @@ class ProcessClip: detection_enabled = mp.Value("d", 1) motion_enabled = mp.Value("d", True) stop_event = mp.Event() - model_shape = (self.config.model.height, self.config.model.width) process_frames( self.camera_name, self.frame_queue, self.frame_shape, - model_shape, + self.config.model, self.camera_config.detect, self.frame_manager, motion_detector,