Add configuration for model inputs

Support transforming detection regions to RGB or BGR.
Support specifying the input tensor shape.  The tensor shape has a standard format ["BHWC"] when handed to the detector, but can be transformed in the detector to match the model shape using the model  input_tensor config.
This commit is contained in:
Nate Meyer 2022-08-31 00:48:40 -04:00
parent d702c3d690
commit 1f8a8ffd3d
10 changed files with 79 additions and 50 deletions

View File

@ -173,8 +173,6 @@ class FrigateApp:
self.mqtt_relay.start() self.mqtt_relay.start()
def start_detectors(self) -> None: def start_detectors(self) -> None:
model_path = self.config.model.path
model_shape = (self.config.model.height, self.config.model.width)
for name in self.config.cameras.keys(): for name in self.config.cameras.keys():
self.detection_out_events[name] = mp.Event() self.detection_out_events[name] = mp.Event()
@ -202,8 +200,7 @@ class FrigateApp:
name, name,
self.detection_queue, self.detection_queue,
self.detection_out_events, self.detection_out_events,
model_path, self.config.model,
model_shape,
detector.type, detector.type,
detector.device, detector.device,
detector.num_threads, detector.num_threads,
@ -238,7 +235,6 @@ class FrigateApp:
logger.info(f"Output process started: {output_processor.pid}") logger.info(f"Output process started: {output_processor.pid}")
def start_camera_processors(self) -> None: def start_camera_processors(self) -> None:
model_shape = (self.config.model.height, self.config.model.width)
for name, config in self.config.cameras.items(): for name, config in self.config.cameras.items():
camera_process = mp.Process( camera_process = mp.Process(
target=track_camera, target=track_camera,
@ -246,7 +242,7 @@ class FrigateApp:
args=( args=(
name, name,
config, config,
model_shape, self.config.model,
self.config.model.merged_labelmap, self.config.model.merged_labelmap,
self.detection_queue, self.detection_queue,
self.detection_out_events[name], self.detection_out_events[name],

View File

@ -687,6 +687,12 @@ class DatabaseConfig(FrigateBaseModel):
) )
class PixelFormatEnum(str, Enum):
rgb = "rgb"
bgr = "bgr"
yuv = "yuv"
class ModelConfig(FrigateBaseModel): class ModelConfig(FrigateBaseModel):
path: Optional[str] = Field(title="Custom Object detection model path.") path: Optional[str] = Field(title="Custom Object detection model path.")
labelmap_path: Optional[str] = Field(title="Label map for custom object detector.") labelmap_path: Optional[str] = Field(title="Label map for custom object detector.")
@ -695,6 +701,12 @@ class ModelConfig(FrigateBaseModel):
labelmap: Dict[int, str] = Field( labelmap: Dict[int, str] = Field(
default_factory=dict, title="Labelmap customization." default_factory=dict, title="Labelmap customization."
) )
input_tensor: List[str] = Field(
default=["B", "H", "W", "C"], title="Model Input Tensor Shape"
)
input_pixel_format: PixelFormatEnum = Field(
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
)
_merged_labelmap: Optional[Dict[int, str]] = PrivateAttr() _merged_labelmap: Optional[Dict[int, str]] = PrivateAttr()
_colormap: Dict[int, Tuple[int, int, int]] = PrivateAttr() _colormap: Dict[int, Tuple[int, int, int]] = PrivateAttr()

View File

@ -8,9 +8,9 @@ logger = logging.getLogger(__name__)
class CpuTfl(DetectionApi): class CpuTfl(DetectionApi):
def __init__(self, det_device=None, model_path=None, num_threads=3): def __init__(self, det_device=None, model_config=None, num_threads=3):
self.interpreter = tflite.Interpreter( self.interpreter = tflite.Interpreter(
model_path=model_path or "/cpu_model.tflite", num_threads=num_threads model_path=model_config.path or "/cpu_model.tflite", num_threads=num_threads
) )
self.interpreter.allocate_tensors() self.interpreter.allocate_tensors()

View File

@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
class DetectionApi(ABC): class DetectionApi(ABC):
@abstractmethod @abstractmethod
def __init__(self, det_device=None, model_path=None): def __init__(self, det_device=None, model_config=None):
pass pass
@abstractmethod @abstractmethod

View File

@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
class EdgeTpuTfl(DetectionApi): class EdgeTpuTfl(DetectionApi):
def __init__(self, det_device=None, model_path=None): def __init__(self, det_device=None, model_config=None):
device_config = {"device": "usb"} device_config = {"device": "usb"}
if not det_device is None: if not det_device is None:
device_config = {"device": det_device} device_config = {"device": det_device}
@ -21,7 +21,7 @@ class EdgeTpuTfl(DetectionApi):
edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config) edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config)
logger.info("TPU found") logger.info("TPU found")
self.interpreter = tflite.Interpreter( self.interpreter = tflite.Interpreter(
model_path=model_path or "/edgetpu_model.tflite", model_path=model_config.path or "/edgetpu_model.tflite",
experimental_delegates=[edge_tpu_delegate], experimental_delegates=[edge_tpu_delegate],
) )
except ValueError: except ValueError:

View File

@ -30,7 +30,7 @@ class LocalObjectDetector(ObjectDetector):
self, self,
det_type=DetectorTypeEnum.cpu, det_type=DetectorTypeEnum.cpu,
det_device=None, det_device=None,
model_path=None, model_config=None,
num_threads=3, num_threads=3,
labels=None, labels=None,
): ):
@ -41,12 +41,14 @@ class LocalObjectDetector(ObjectDetector):
self.labels = load_labels(labels) self.labels = load_labels(labels)
if det_type == DetectorTypeEnum.edgetpu: if det_type == DetectorTypeEnum.edgetpu:
self.detect_api = EdgeTpuTfl(det_device=det_device, model_path=model_path) self.detect_api = EdgeTpuTfl(
det_device=det_device, model_config=model_config
)
else: else:
logger.warning( logger.warning(
"CPU detectors are not recommended and should only be used for testing or for trial purposes." "CPU detectors are not recommended and should only be used for testing or for trial purposes."
) )
self.detect_api = CpuTfl(model_path=model_path, num_threads=num_threads) self.detect_api = CpuTfl(model_config=model_config, num_threads=num_threads)
def detect(self, tensor_input, threshold=0.4): def detect(self, tensor_input, threshold=0.4):
detections = [] detections = []
@ -72,8 +74,7 @@ def run_detector(
out_events: dict[str, mp.Event], out_events: dict[str, mp.Event],
avg_speed, avg_speed,
start, start,
model_path, model_config,
model_shape,
det_type, det_type,
det_device, det_device,
num_threads, num_threads,
@ -96,7 +97,7 @@ def run_detector(
object_detector = LocalObjectDetector( object_detector = LocalObjectDetector(
det_type=det_type, det_type=det_type,
det_device=det_device, det_device=det_device,
model_path=model_path, model_config=model_config,
num_threads=num_threads, num_threads=num_threads,
) )
@ -112,7 +113,7 @@ def run_detector(
except queue.Empty: except queue.Empty:
continue continue
input_frame = frame_manager.get( input_frame = frame_manager.get(
connection_id, (1, model_shape[0], model_shape[1], 3) connection_id, (1, model_config.height, model_config.width, 3)
) )
if input_frame is None: if input_frame is None:
@ -135,8 +136,7 @@ class ObjectDetectProcess:
name, name,
detection_queue, detection_queue,
out_events, out_events,
model_path, model_config,
model_shape,
det_type=None, det_type=None,
det_device=None, det_device=None,
num_threads=3, num_threads=3,
@ -147,8 +147,7 @@ class ObjectDetectProcess:
self.avg_inference_speed = mp.Value("d", 0.01) self.avg_inference_speed = mp.Value("d", 0.01)
self.detection_start = mp.Value("d", 0.0) self.detection_start = mp.Value("d", 0.0)
self.detect_process = None self.detect_process = None
self.model_path = model_path self.model_config = model_config
self.model_shape = model_shape
self.det_type = det_type self.det_type = det_type
self.det_device = det_device self.det_device = det_device
self.num_threads = num_threads self.num_threads = num_threads
@ -176,8 +175,7 @@ class ObjectDetectProcess:
self.out_events, self.out_events,
self.avg_inference_speed, self.avg_inference_speed,
self.detection_start, self.detection_start,
self.model_path, self.model_config,
self.model_shape,
self.det_type, self.det_type,
self.det_device, self.det_device,
self.num_threads, self.num_threads,
@ -188,7 +186,7 @@ class ObjectDetectProcess:
class RemoteObjectDetector: class RemoteObjectDetector:
def __init__(self, name, labels, detection_queue, event, model_shape): def __init__(self, name, labels, detection_queue, event, model_config):
self.labels = labels self.labels = labels
self.name = name self.name = name
self.fps = EventsPerSecond() self.fps = EventsPerSecond()
@ -196,7 +194,9 @@ class RemoteObjectDetector:
self.event = event self.event = event
self.shm = mp.shared_memory.SharedMemory(name=self.name, create=False) self.shm = mp.shared_memory.SharedMemory(name=self.name, create=False)
self.np_shm = np.ndarray( self.np_shm = np.ndarray(
(1, model_shape[0], model_shape[1], 3), dtype=np.uint8, buffer=self.shm.buf (1, model_config.height, model_config.width, 3),
dtype=np.uint8,
buffer=self.shm.buf,
) )
self.out_shm = mp.shared_memory.SharedMemory( self.out_shm = mp.shared_memory.SharedMemory(
name=f"out-{self.name}", create=False name=f"out-{self.name}", create=False

View File

@ -2,7 +2,7 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
from frigate.config import DetectorTypeEnum from frigate.config import DetectorTypeEnum, ModelConfig
import frigate.object_detection import frigate.object_detection
@ -12,30 +12,33 @@ class TestLocalObjectDetector(unittest.TestCase):
def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init( def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init(
self, mock_cputfl, mock_edgetputfl self, mock_cputfl, mock_edgetputfl
): ):
test_cfg = ModelConfig()
test_cfg.path = "/test/modelpath"
test_obj = frigate.object_detection.LocalObjectDetector( test_obj = frigate.object_detection.LocalObjectDetector(
det_type=DetectorTypeEnum.cpu, model_path="/test/modelpath", num_threads=6 det_type=DetectorTypeEnum.cpu, model_config=test_cfg, num_threads=6
) )
assert test_obj is not None assert test_obj is not None
mock_edgetputfl.assert_not_called() mock_edgetputfl.assert_not_called()
mock_cputfl.assert_called_once_with(model_path="/test/modelpath", num_threads=6) mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6)
@patch("frigate.object_detection.EdgeTpuTfl") @patch("frigate.object_detection.EdgeTpuTfl")
@patch("frigate.object_detection.CpuTfl") @patch("frigate.object_detection.CpuTfl")
def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init( def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init(
self, mock_cputfl, mock_edgetputfl self, mock_cputfl, mock_edgetputfl
): ):
test_cfg = ModelConfig()
test_cfg.path = "/test/modelpath"
test_obj = frigate.object_detection.LocalObjectDetector( test_obj = frigate.object_detection.LocalObjectDetector(
det_type=DetectorTypeEnum.edgetpu, det_type=DetectorTypeEnum.edgetpu,
det_device="usb", det_device="usb",
model_path="/test/modelpath", model_config=test_cfg,
) )
assert test_obj is not None assert test_obj is not None
mock_cputfl.assert_not_called() mock_cputfl.assert_not_called()
mock_edgetputfl.assert_called_once_with( mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg)
det_device="usb", model_path="/test/modelpath"
)
@patch("frigate.object_detection.CpuTfl") @patch("frigate.object_detection.CpuTfl")
def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result( def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result(

View File

@ -479,6 +479,16 @@ def yuv_region_2_rgb(frame, region):
raise raise
def yuv_region_2_bgr(frame, region):
try:
yuv_cropped_frame = yuv_crop_and_resize(frame, region)
return cv2.cvtColor(yuv_cropped_frame, cv2.COLOR_YUV2BGR_I420)
except:
print(f"frame.shape: {frame.shape}")
print(f"region: {region}")
raise
def intersection(box_a, box_b): def intersection(box_a, box_b):
return ( return (
max(box_a[0], box_b[0]), max(box_a[0], box_b[0]),

View File

@ -14,7 +14,7 @@ import numpy as np
import cv2 import cv2
from setproctitle import setproctitle from setproctitle import setproctitle
from frigate.config import CameraConfig, DetectConfig from frigate.config import CameraConfig, DetectConfig, PixelFormatEnum
from frigate.object_detection import RemoteObjectDetector from frigate.object_detection import RemoteObjectDetector
from frigate.log import LogPipe from frigate.log import LogPipe
from frigate.motion import MotionDetector from frigate.motion import MotionDetector
@ -29,7 +29,9 @@ from frigate.util import (
intersection, intersection,
intersection_over_union, intersection_over_union,
listen, listen,
yuv_crop_and_resize,
yuv_region_2_rgb, yuv_region_2_rgb,
yuv_region_2_bgr,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -89,13 +91,20 @@ def filtered(obj, objects_to_track, object_filters):
return False return False
def create_tensor_input(frame, model_shape, region): def create_tensor_input(frame, model_config, region):
if model_config.input_pixel_format == PixelFormatEnum.rgb:
cropped_frame = yuv_region_2_rgb(frame, region) cropped_frame = yuv_region_2_rgb(frame, region)
elif model_config.input_pixel_format == PixelFormatEnum.bgr:
cropped_frame = yuv_region_2_bgr(frame, region)
else:
cropped_frame = yuv_crop_and_resize(frame, region)
# Resize to 300x300 if needed # Resize if needed
if cropped_frame.shape != (model_shape[0], model_shape[1], 3): if cropped_frame.shape != (model_config.height, model_config.width, 3):
cropped_frame = cv2.resize( cropped_frame = cv2.resize(
cropped_frame, dsize=model_shape, interpolation=cv2.INTER_LINEAR cropped_frame,
dsize=(model_config.height, model_config.width),
interpolation=cv2.INTER_LINEAR,
) )
# Expand dimensions since the model expects images to have shape: [1, height, width, 3] # Expand dimensions since the model expects images to have shape: [1, height, width, 3]
@ -340,7 +349,7 @@ def capture_camera(name, config: CameraConfig, process_info):
def track_camera( def track_camera(
name, name,
config: CameraConfig, config: CameraConfig,
model_shape, model_config,
labelmap, labelmap,
detection_queue, detection_queue,
result_connection, result_connection,
@ -378,7 +387,7 @@ def track_camera(
motion_contour_area, motion_contour_area,
) )
object_detector = RemoteObjectDetector( object_detector = RemoteObjectDetector(
name, labelmap, detection_queue, result_connection, model_shape name, labelmap, detection_queue, result_connection, model_config
) )
object_tracker = ObjectTracker(config.detect) object_tracker = ObjectTracker(config.detect)
@ -389,7 +398,7 @@ def track_camera(
name, name,
frame_queue, frame_queue,
frame_shape, frame_shape,
model_shape, model_config,
config.detect, config.detect,
frame_manager, frame_manager,
motion_detector, motion_detector,
@ -443,12 +452,12 @@ def detect(
detect_config: DetectConfig, detect_config: DetectConfig,
object_detector, object_detector,
frame, frame,
model_shape, model_config,
region, region,
objects_to_track, objects_to_track,
object_filters, object_filters,
): ):
tensor_input = create_tensor_input(frame, model_shape, region) tensor_input = create_tensor_input(frame, model_config, region)
detections = [] detections = []
region_detections = object_detector.detect(tensor_input) region_detections = object_detector.detect(tensor_input)
@ -487,7 +496,7 @@ def process_frames(
camera_name: str, camera_name: str,
frame_queue: mp.Queue, frame_queue: mp.Queue,
frame_shape, frame_shape,
model_shape, model_config,
detect_config: DetectConfig, detect_config: DetectConfig,
frame_manager: FrameManager, frame_manager: FrameManager,
motion_detector: MotionDetector, motion_detector: MotionDetector,
@ -571,7 +580,7 @@ def process_frames(
# combine motion boxes with known locations of existing objects # combine motion boxes with known locations of existing objects
combined_boxes = reduce_boxes(motion_boxes + tracked_object_boxes) combined_boxes = reduce_boxes(motion_boxes + tracked_object_boxes)
region_min_size = max(model_shape[0], model_shape[1]) region_min_size = max(model_config.height, model_config.width)
# compute regions # compute regions
regions = [ regions = [
calculate_region( calculate_region(
@ -634,7 +643,7 @@ def process_frames(
detect_config, detect_config,
object_detector, object_detector,
frame, frame,
model_shape, model_config,
region, region,
objects_to_track, objects_to_track,
object_filters, object_filters,
@ -694,7 +703,7 @@ def process_frames(
detect_config, detect_config,
object_detector, object_detector,
frame, frame,
model_shape, model_config,
region, region,
objects_to_track, objects_to_track,
object_filters, object_filters,

View File

@ -117,13 +117,12 @@ class ProcessClip:
detection_enabled = mp.Value("d", 1) detection_enabled = mp.Value("d", 1)
motion_enabled = mp.Value("d", True) motion_enabled = mp.Value("d", True)
stop_event = mp.Event() stop_event = mp.Event()
model_shape = (self.config.model.height, self.config.model.width)
process_frames( process_frames(
self.camera_name, self.camera_name,
self.frame_queue, self.frame_queue,
self.frame_shape, self.frame_shape,
model_shape, self.config.model,
self.camera_config.detect, self.camera_config.detect,
self.frame_manager, self.frame_manager,
motion_detector, motion_detector,