mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
Add configuration for model inputs
Support transforming detection regions to RGB or BGR. Support specifying the input tensor shape. The tensor shape has a standard format ["BHWC"] when handed to the detector, but can be transformed in the detector to match the model shape using the model input_tensor config.
This commit is contained in:
parent
d702c3d690
commit
1f8a8ffd3d
@ -173,8 +173,6 @@ class FrigateApp:
|
||||
self.mqtt_relay.start()
|
||||
|
||||
def start_detectors(self) -> None:
|
||||
model_path = self.config.model.path
|
||||
model_shape = (self.config.model.height, self.config.model.width)
|
||||
for name in self.config.cameras.keys():
|
||||
self.detection_out_events[name] = mp.Event()
|
||||
|
||||
@ -202,8 +200,7 @@ class FrigateApp:
|
||||
name,
|
||||
self.detection_queue,
|
||||
self.detection_out_events,
|
||||
model_path,
|
||||
model_shape,
|
||||
self.config.model,
|
||||
detector.type,
|
||||
detector.device,
|
||||
detector.num_threads,
|
||||
@ -238,7 +235,6 @@ class FrigateApp:
|
||||
logger.info(f"Output process started: {output_processor.pid}")
|
||||
|
||||
def start_camera_processors(self) -> None:
|
||||
model_shape = (self.config.model.height, self.config.model.width)
|
||||
for name, config in self.config.cameras.items():
|
||||
camera_process = mp.Process(
|
||||
target=track_camera,
|
||||
@ -246,7 +242,7 @@ class FrigateApp:
|
||||
args=(
|
||||
name,
|
||||
config,
|
||||
model_shape,
|
||||
self.config.model,
|
||||
self.config.model.merged_labelmap,
|
||||
self.detection_queue,
|
||||
self.detection_out_events[name],
|
||||
|
||||
@ -687,6 +687,12 @@ class DatabaseConfig(FrigateBaseModel):
|
||||
)
|
||||
|
||||
|
||||
class PixelFormatEnum(str, Enum):
|
||||
rgb = "rgb"
|
||||
bgr = "bgr"
|
||||
yuv = "yuv"
|
||||
|
||||
|
||||
class ModelConfig(FrigateBaseModel):
|
||||
path: Optional[str] = Field(title="Custom Object detection model path.")
|
||||
labelmap_path: Optional[str] = Field(title="Label map for custom object detector.")
|
||||
@ -695,6 +701,12 @@ class ModelConfig(FrigateBaseModel):
|
||||
labelmap: Dict[int, str] = Field(
|
||||
default_factory=dict, title="Labelmap customization."
|
||||
)
|
||||
input_tensor: List[str] = Field(
|
||||
default=["B", "H", "W", "C"], title="Model Input Tensor Shape"
|
||||
)
|
||||
input_pixel_format: PixelFormatEnum = Field(
|
||||
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
|
||||
)
|
||||
_merged_labelmap: Optional[Dict[int, str]] = PrivateAttr()
|
||||
_colormap: Dict[int, Tuple[int, int, int]] = PrivateAttr()
|
||||
|
||||
|
||||
@ -8,9 +8,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CpuTfl(DetectionApi):
|
||||
def __init__(self, det_device=None, model_path=None, num_threads=3):
|
||||
def __init__(self, det_device=None, model_config=None, num_threads=3):
|
||||
self.interpreter = tflite.Interpreter(
|
||||
model_path=model_path or "/cpu_model.tflite", num_threads=num_threads
|
||||
model_path=model_config.path or "/cpu_model.tflite", num_threads=num_threads
|
||||
)
|
||||
|
||||
self.interpreter.allocate_tensors()
|
||||
|
||||
@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DetectionApi(ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, det_device=None, model_path=None):
|
||||
def __init__(self, det_device=None, model_config=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EdgeTpuTfl(DetectionApi):
|
||||
def __init__(self, det_device=None, model_path=None):
|
||||
def __init__(self, det_device=None, model_config=None):
|
||||
device_config = {"device": "usb"}
|
||||
if not det_device is None:
|
||||
device_config = {"device": det_device}
|
||||
@ -21,7 +21,7 @@ class EdgeTpuTfl(DetectionApi):
|
||||
edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config)
|
||||
logger.info("TPU found")
|
||||
self.interpreter = tflite.Interpreter(
|
||||
model_path=model_path or "/edgetpu_model.tflite",
|
||||
model_path=model_config.path or "/edgetpu_model.tflite",
|
||||
experimental_delegates=[edge_tpu_delegate],
|
||||
)
|
||||
except ValueError:
|
||||
|
||||
@ -30,7 +30,7 @@ class LocalObjectDetector(ObjectDetector):
|
||||
self,
|
||||
det_type=DetectorTypeEnum.cpu,
|
||||
det_device=None,
|
||||
model_path=None,
|
||||
model_config=None,
|
||||
num_threads=3,
|
||||
labels=None,
|
||||
):
|
||||
@ -41,12 +41,14 @@ class LocalObjectDetector(ObjectDetector):
|
||||
self.labels = load_labels(labels)
|
||||
|
||||
if det_type == DetectorTypeEnum.edgetpu:
|
||||
self.detect_api = EdgeTpuTfl(det_device=det_device, model_path=model_path)
|
||||
self.detect_api = EdgeTpuTfl(
|
||||
det_device=det_device, model_config=model_config
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
|
||||
)
|
||||
self.detect_api = CpuTfl(model_path=model_path, num_threads=num_threads)
|
||||
self.detect_api = CpuTfl(model_config=model_config, num_threads=num_threads)
|
||||
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
detections = []
|
||||
@ -72,8 +74,7 @@ def run_detector(
|
||||
out_events: dict[str, mp.Event],
|
||||
avg_speed,
|
||||
start,
|
||||
model_path,
|
||||
model_shape,
|
||||
model_config,
|
||||
det_type,
|
||||
det_device,
|
||||
num_threads,
|
||||
@ -96,7 +97,7 @@ def run_detector(
|
||||
object_detector = LocalObjectDetector(
|
||||
det_type=det_type,
|
||||
det_device=det_device,
|
||||
model_path=model_path,
|
||||
model_config=model_config,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
@ -112,7 +113,7 @@ def run_detector(
|
||||
except queue.Empty:
|
||||
continue
|
||||
input_frame = frame_manager.get(
|
||||
connection_id, (1, model_shape[0], model_shape[1], 3)
|
||||
connection_id, (1, model_config.height, model_config.width, 3)
|
||||
)
|
||||
|
||||
if input_frame is None:
|
||||
@ -135,8 +136,7 @@ class ObjectDetectProcess:
|
||||
name,
|
||||
detection_queue,
|
||||
out_events,
|
||||
model_path,
|
||||
model_shape,
|
||||
model_config,
|
||||
det_type=None,
|
||||
det_device=None,
|
||||
num_threads=3,
|
||||
@ -147,8 +147,7 @@ class ObjectDetectProcess:
|
||||
self.avg_inference_speed = mp.Value("d", 0.01)
|
||||
self.detection_start = mp.Value("d", 0.0)
|
||||
self.detect_process = None
|
||||
self.model_path = model_path
|
||||
self.model_shape = model_shape
|
||||
self.model_config = model_config
|
||||
self.det_type = det_type
|
||||
self.det_device = det_device
|
||||
self.num_threads = num_threads
|
||||
@ -176,8 +175,7 @@ class ObjectDetectProcess:
|
||||
self.out_events,
|
||||
self.avg_inference_speed,
|
||||
self.detection_start,
|
||||
self.model_path,
|
||||
self.model_shape,
|
||||
self.model_config,
|
||||
self.det_type,
|
||||
self.det_device,
|
||||
self.num_threads,
|
||||
@ -188,7 +186,7 @@ class ObjectDetectProcess:
|
||||
|
||||
|
||||
class RemoteObjectDetector:
|
||||
def __init__(self, name, labels, detection_queue, event, model_shape):
|
||||
def __init__(self, name, labels, detection_queue, event, model_config):
|
||||
self.labels = labels
|
||||
self.name = name
|
||||
self.fps = EventsPerSecond()
|
||||
@ -196,7 +194,9 @@ class RemoteObjectDetector:
|
||||
self.event = event
|
||||
self.shm = mp.shared_memory.SharedMemory(name=self.name, create=False)
|
||||
self.np_shm = np.ndarray(
|
||||
(1, model_shape[0], model_shape[1], 3), dtype=np.uint8, buffer=self.shm.buf
|
||||
(1, model_config.height, model_config.width, 3),
|
||||
dtype=np.uint8,
|
||||
buffer=self.shm.buf,
|
||||
)
|
||||
self.out_shm = mp.shared_memory.SharedMemory(
|
||||
name=f"out-{self.name}", create=False
|
||||
|
||||
@ -2,7 +2,7 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
from frigate.config import DetectorTypeEnum
|
||||
from frigate.config import DetectorTypeEnum, ModelConfig
|
||||
import frigate.object_detection
|
||||
|
||||
|
||||
@ -12,30 +12,33 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init(
|
||||
self, mock_cputfl, mock_edgetputfl
|
||||
):
|
||||
test_cfg = ModelConfig()
|
||||
test_cfg.path = "/test/modelpath"
|
||||
test_obj = frigate.object_detection.LocalObjectDetector(
|
||||
det_type=DetectorTypeEnum.cpu, model_path="/test/modelpath", num_threads=6
|
||||
det_type=DetectorTypeEnum.cpu, model_config=test_cfg, num_threads=6
|
||||
)
|
||||
|
||||
assert test_obj is not None
|
||||
mock_edgetputfl.assert_not_called()
|
||||
mock_cputfl.assert_called_once_with(model_path="/test/modelpath", num_threads=6)
|
||||
mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6)
|
||||
|
||||
@patch("frigate.object_detection.EdgeTpuTfl")
|
||||
@patch("frigate.object_detection.CpuTfl")
|
||||
def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init(
|
||||
self, mock_cputfl, mock_edgetputfl
|
||||
):
|
||||
test_cfg = ModelConfig()
|
||||
test_cfg.path = "/test/modelpath"
|
||||
|
||||
test_obj = frigate.object_detection.LocalObjectDetector(
|
||||
det_type=DetectorTypeEnum.edgetpu,
|
||||
det_device="usb",
|
||||
model_path="/test/modelpath",
|
||||
model_config=test_cfg,
|
||||
)
|
||||
|
||||
assert test_obj is not None
|
||||
mock_cputfl.assert_not_called()
|
||||
mock_edgetputfl.assert_called_once_with(
|
||||
det_device="usb", model_path="/test/modelpath"
|
||||
)
|
||||
mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg)
|
||||
|
||||
@patch("frigate.object_detection.CpuTfl")
|
||||
def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result(
|
||||
|
||||
@ -479,6 +479,16 @@ def yuv_region_2_rgb(frame, region):
|
||||
raise
|
||||
|
||||
|
||||
def yuv_region_2_bgr(frame, region):
|
||||
try:
|
||||
yuv_cropped_frame = yuv_crop_and_resize(frame, region)
|
||||
return cv2.cvtColor(yuv_cropped_frame, cv2.COLOR_YUV2BGR_I420)
|
||||
except:
|
||||
print(f"frame.shape: {frame.shape}")
|
||||
print(f"region: {region}")
|
||||
raise
|
||||
|
||||
|
||||
def intersection(box_a, box_b):
|
||||
return (
|
||||
max(box_a[0], box_b[0]),
|
||||
|
||||
@ -14,7 +14,7 @@ import numpy as np
|
||||
import cv2
|
||||
from setproctitle import setproctitle
|
||||
|
||||
from frigate.config import CameraConfig, DetectConfig
|
||||
from frigate.config import CameraConfig, DetectConfig, PixelFormatEnum
|
||||
from frigate.object_detection import RemoteObjectDetector
|
||||
from frigate.log import LogPipe
|
||||
from frigate.motion import MotionDetector
|
||||
@ -29,7 +29,9 @@ from frigate.util import (
|
||||
intersection,
|
||||
intersection_over_union,
|
||||
listen,
|
||||
yuv_crop_and_resize,
|
||||
yuv_region_2_rgb,
|
||||
yuv_region_2_bgr,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -89,13 +91,20 @@ def filtered(obj, objects_to_track, object_filters):
|
||||
return False
|
||||
|
||||
|
||||
def create_tensor_input(frame, model_shape, region):
|
||||
def create_tensor_input(frame, model_config, region):
|
||||
if model_config.input_pixel_format == PixelFormatEnum.rgb:
|
||||
cropped_frame = yuv_region_2_rgb(frame, region)
|
||||
elif model_config.input_pixel_format == PixelFormatEnum.bgr:
|
||||
cropped_frame = yuv_region_2_bgr(frame, region)
|
||||
else:
|
||||
cropped_frame = yuv_crop_and_resize(frame, region)
|
||||
|
||||
# Resize to 300x300 if needed
|
||||
if cropped_frame.shape != (model_shape[0], model_shape[1], 3):
|
||||
# Resize if needed
|
||||
if cropped_frame.shape != (model_config.height, model_config.width, 3):
|
||||
cropped_frame = cv2.resize(
|
||||
cropped_frame, dsize=model_shape, interpolation=cv2.INTER_LINEAR
|
||||
cropped_frame,
|
||||
dsize=(model_config.height, model_config.width),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
# Expand dimensions since the model expects images to have shape: [1, height, width, 3]
|
||||
@ -340,7 +349,7 @@ def capture_camera(name, config: CameraConfig, process_info):
|
||||
def track_camera(
|
||||
name,
|
||||
config: CameraConfig,
|
||||
model_shape,
|
||||
model_config,
|
||||
labelmap,
|
||||
detection_queue,
|
||||
result_connection,
|
||||
@ -378,7 +387,7 @@ def track_camera(
|
||||
motion_contour_area,
|
||||
)
|
||||
object_detector = RemoteObjectDetector(
|
||||
name, labelmap, detection_queue, result_connection, model_shape
|
||||
name, labelmap, detection_queue, result_connection, model_config
|
||||
)
|
||||
|
||||
object_tracker = ObjectTracker(config.detect)
|
||||
@ -389,7 +398,7 @@ def track_camera(
|
||||
name,
|
||||
frame_queue,
|
||||
frame_shape,
|
||||
model_shape,
|
||||
model_config,
|
||||
config.detect,
|
||||
frame_manager,
|
||||
motion_detector,
|
||||
@ -443,12 +452,12 @@ def detect(
|
||||
detect_config: DetectConfig,
|
||||
object_detector,
|
||||
frame,
|
||||
model_shape,
|
||||
model_config,
|
||||
region,
|
||||
objects_to_track,
|
||||
object_filters,
|
||||
):
|
||||
tensor_input = create_tensor_input(frame, model_shape, region)
|
||||
tensor_input = create_tensor_input(frame, model_config, region)
|
||||
|
||||
detections = []
|
||||
region_detections = object_detector.detect(tensor_input)
|
||||
@ -487,7 +496,7 @@ def process_frames(
|
||||
camera_name: str,
|
||||
frame_queue: mp.Queue,
|
||||
frame_shape,
|
||||
model_shape,
|
||||
model_config,
|
||||
detect_config: DetectConfig,
|
||||
frame_manager: FrameManager,
|
||||
motion_detector: MotionDetector,
|
||||
@ -571,7 +580,7 @@ def process_frames(
|
||||
# combine motion boxes with known locations of existing objects
|
||||
combined_boxes = reduce_boxes(motion_boxes + tracked_object_boxes)
|
||||
|
||||
region_min_size = max(model_shape[0], model_shape[1])
|
||||
region_min_size = max(model_config.height, model_config.width)
|
||||
# compute regions
|
||||
regions = [
|
||||
calculate_region(
|
||||
@ -634,7 +643,7 @@ def process_frames(
|
||||
detect_config,
|
||||
object_detector,
|
||||
frame,
|
||||
model_shape,
|
||||
model_config,
|
||||
region,
|
||||
objects_to_track,
|
||||
object_filters,
|
||||
@ -694,7 +703,7 @@ def process_frames(
|
||||
detect_config,
|
||||
object_detector,
|
||||
frame,
|
||||
model_shape,
|
||||
model_config,
|
||||
region,
|
||||
objects_to_track,
|
||||
object_filters,
|
||||
|
||||
@ -117,13 +117,12 @@ class ProcessClip:
|
||||
detection_enabled = mp.Value("d", 1)
|
||||
motion_enabled = mp.Value("d", True)
|
||||
stop_event = mp.Event()
|
||||
model_shape = (self.config.model.height, self.config.model.width)
|
||||
|
||||
process_frames(
|
||||
self.camera_name,
|
||||
self.frame_queue,
|
||||
self.frame_shape,
|
||||
model_shape,
|
||||
self.config.model,
|
||||
self.camera_config.detect,
|
||||
self.frame_manager,
|
||||
motion_detector,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user