add: multi batch setting

This commit is contained in:
mjq2020 2025-04-06 19:44:23 +01:00
parent ff9e4460b2
commit 814f2a2935
4 changed files with 121 additions and 59 deletions

View File

@ -4,6 +4,7 @@ import multiprocessing as mp
import os import os
import secrets import secrets
import shutil import shutil
import threading
from multiprocessing import Queue from multiprocessing import Queue
from multiprocessing.synchronize import Event as MpEvent from multiprocessing.synchronize import Event as MpEvent
from typing import Optional from typing import Optional
@ -101,7 +102,9 @@ class FrigateApp:
self.processes: dict[str, int] = {} self.processes: dict[str, int] = {}
self.embeddings: Optional[EmbeddingsContext] = None self.embeddings: Optional[EmbeddingsContext] = None
self.region_grids: dict[str, list[list[dict[str, int]]]] = {} self.region_grids: dict[str, list[list[dict[str, int]]]] = {}
self.frame_manager = SharedMemoryFrameManager() self.frame_manager = SharedMemoryFrameManager(
frame_shape=(config.model.height, config.model.width, 3)
)
self.config = config self.config = config
def ensure_dirs(self) -> None: def ensure_dirs(self) -> None:
@ -359,9 +362,12 @@ class FrigateApp:
try: try:
largest_frame = max( largest_frame = max(
[ [
det.model.height * det.model.width * 3 (
if det.model is not None det.model.height * det.model.width * 3 * det.model.max_batch
else 320 + 8
if det.model is not None
else 320
)
for det in self.config.detectors.values() for det in self.config.detectors.values()
] ]
) )
@ -375,7 +381,9 @@ class FrigateApp:
try: try:
shm_out = UntrackedSharedMemory( shm_out = UntrackedSharedMemory(
name=f"out-{name}", create=True, size=20 * 6 * 4 name=f"out-{name}",
create=True,
size=20 * 6 * 4 * self.config.model.max_batch + 8,
) )
except FileExistsError: except FileExistsError:
shm_out = UntrackedSharedMemory(name=f"out-{name}") shm_out = UntrackedSharedMemory(name=f"out-{name}")

View File

@ -47,6 +47,7 @@ class ModelConfig(BaseModel):
labelmap_path: Optional[str] = Field( labelmap_path: Optional[str] = Field(
None, title="Label map for custom object detector." None, title="Label map for custom object detector."
) )
max_batch: int = Field(default=4, title="Max batch size.")
width: int = Field(default=320, title="Object detection model input width.") width: int = Field(default=320, title="Object detection model input width.")
height: int = Field(default=320, title="Object detection model input height.") height: int = Field(default=320, title="Object detection model input height.")
labelmap: Dict[int, str] = Field( labelmap: Dict[int, str] = Field(

View File

@ -18,7 +18,7 @@ from frigate.detectors.detector_config import (
InputTensorEnum, InputTensorEnum,
) )
from frigate.util.builtin import EventsPerSecond, load_labels from frigate.util.builtin import EventsPerSecond, load_labels
from frigate.util.image import SharedMemoryFrameManager, UntrackedSharedMemory from frigate.util.image import SharedMemoryFrameManager, SharedMemoryResultManager
from frigate.util.services import listen from frigate.util.services import listen
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -111,23 +111,19 @@ def run_detector(
signal.signal(signal.SIGINT, receiveSignal) signal.signal(signal.SIGINT, receiveSignal)
frame_manager = SharedMemoryFrameManager() frame_manager = SharedMemoryFrameManager()
result_manager = SharedMemoryResultManager()
object_detector = LocalObjectDetector(detector_config=detector_config) object_detector = LocalObjectDetector(detector_config=detector_config)
outputs = {}
for name in out_events.keys(): for name in out_events.keys():
out_shm = UntrackedSharedMemory(name=f"out-{name}", create=False) result_manager.create(name=f"out-{name}")
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
outputs[name] = {"shm": out_shm, "np": out_np}
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
connection_id = detection_queue.get(timeout=1) connection_id = detection_queue.get(timeout=1)
except queue.Empty: except queue.Empty:
continue continue
input_frame = frame_manager.get(
connection_id, input_frame = frame_manager.get_frame(connection_id)
(1, detector_config.model.height, detector_config.model.width, 3),
)
if input_frame is None: if input_frame is None:
logger.warning(f"Failed to get frame {connection_id} from SHM") logger.warning(f"Failed to get frame {connection_id} from SHM")
@ -138,7 +134,8 @@ def run_detector(
detections = object_detector.detect_raw(input_frame) detections = object_detector.detect_raw(input_frame)
duration = datetime.datetime.now().timestamp() - start.value duration = datetime.datetime.now().timestamp() - start.value
frame_manager.close(connection_id) frame_manager.close(connection_id)
outputs[connection_id]["np"][:] = detections[:]
result_manager.write_result(f"out-{connection_id}", detections)
out_events[connection_id].set() out_events[connection_id].set()
start.value = 0.0 start.value = 0.0
@ -198,21 +195,27 @@ class ObjectDetectProcess:
class RemoteObjectDetector: class RemoteObjectDetector:
def __init__(self, name, labels, detection_queue, event, model_config, stop_event): def __init__(
self,
name,
labels,
detection_queue,
event,
model_config,
stop_event,
frame_manager: SharedMemoryFrameManager,
):
self.labels = labels self.labels = labels
self.name = name self.name = name
self.fps = EventsPerSecond() self.fps = EventsPerSecond()
self.detection_queue = detection_queue self.detection_queue = detection_queue
self.event = event self.event = event
self.stop_event = stop_event self.stop_event = stop_event
self.shm = UntrackedSharedMemory(name=self.name, create=False) self.frame_manager = frame_manager
self.np_shm = np.ndarray(
(1, model_config.height, model_config.width, 3), self.result_manager = SharedMemoryResultManager(
dtype=np.uint8, max_frame=model_config.max_batch
buffer=self.shm.buf,
) )
self.out_shm = UntrackedSharedMemory(name=f"out-{self.name}", create=False)
self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf)
def detect(self, tensor_input, threshold=0.4): def detect(self, tensor_input, threshold=0.4):
detections = [] detections = []
@ -221,7 +224,8 @@ class RemoteObjectDetector:
return detections return detections
# copy input to shared memory # copy input to shared memory
self.np_shm[:] = tensor_input[:] self.frame_manager.write_frame(self.name, tensor_input)
self.event.clear() self.event.clear()
self.detection_queue.put(self.name) self.detection_queue.put(self.name)
result = self.event.wait(timeout=5.0) result = self.event.wait(timeout=5.0)
@ -230,15 +234,23 @@ class RemoteObjectDetector:
if result is None: if result is None:
return detections return detections
for d in self.out_np_shm: batch_result_np = self.result_manager.get_result(f"out-{self.name}")
if d[1] < threshold:
break if not isinstance(batch_result_np, np.ndarray):
detections.append( return detections
(self.labels[int(d[0])], float(d[1]), (d[2], d[3], d[4], d[5]))
) for result_np in batch_result_np:
tmp_detections = []
for d in result_np:
if d[1] < threshold:
break
tmp_detections.append(
(self.labels[int(d[0])], float(d[1]), (d[2], d[3], d[4], d[5]))
)
detections.append(tmp_detections)
self.fps.update() self.fps.update()
return detections return detections
def cleanup(self): def cleanup(self):
self.shm.unlink() self.result_manager.cleanup()
self.out_shm.unlink() self.frame_manager.cleanup()

View File

@ -9,6 +9,7 @@ import threading
import time import time
import cv2 import cv2
import numpy as np
from setproctitle import setproctitle from setproctitle import setproctitle
from frigate.camera import CameraMetrics, PTZMetrics from frigate.camera import CameraMetrics, PTZMetrics
@ -503,14 +504,20 @@ def track_camera(
name=config.name, name=config.name,
ptz_metrics=ptz_metrics, ptz_metrics=ptz_metrics,
) )
frame_manager = SharedMemoryFrameManager(max_frame=model_config.max_batch)
object_detector = RemoteObjectDetector( object_detector = RemoteObjectDetector(
name, labelmap, detection_queue, result_connection, model_config, stop_event name,
labelmap,
detection_queue,
result_connection,
model_config,
stop_event,
frame_manager,
) )
object_tracker = NorfairTracker(config, ptz_metrics) object_tracker = NorfairTracker(config, ptz_metrics)
frame_manager = SharedMemoryFrameManager()
# create communication for region grid updates # create communication for region grid updates
requestor = InterProcessRequestor() requestor = InterProcessRequestor()
@ -549,35 +556,53 @@ def detect(
object_detector, object_detector,
frame, frame,
model_config: ModelConfig, model_config: ModelConfig,
region, regions,
objects_to_track, objects_to_track,
object_filters, object_filters,
multi_batch: bool = False,
): ):
tensor_input = create_tensor_input(frame, model_config, region) if multi_batch:
tensor_list = []
for i, region in enumerate(regions):
if i > model_config.max_batch:
logger.info(f"batch is too large, skipping")
break
tensor_list.append(create_tensor_input(frame, model_config, region))
tensor_input = np.concatenate(tensor_list, axis=0)
else:
tensor_input = create_tensor_input(frame, model_config, regions)
region_detections_list = object_detector.detect(tensor_input)
detections = [] detections = []
region_detections = object_detector.detect(tensor_input) if not multi_batch:
for d in region_detections: region_detections_list = [region_detections_list]
box = d[2] regions = [regions]
size = region[2] - region[0]
x_min = int(max(0, (box[1] * size) + region[0]))
y_min = int(max(0, (box[0] * size) + region[1]))
x_max = int(min(detect_config.width - 1, (box[3] * size) + region[0]))
y_max = int(min(detect_config.height - 1, (box[2] * size) + region[1]))
# ignore objects that were detected outside the frame for region_detections, region in zip(region_detections_list, regions):
if (x_min >= detect_config.width - 1) or (y_min >= detect_config.height - 1): for d in region_detections:
continue box = d[2]
size = region[2] - region[0]
x_min = int(max(0, (box[1] * size) + region[0]))
y_min = int(max(0, (box[0] * size) + region[1]))
x_max = int(min(detect_config.width - 1, (box[3] * size) + region[0]))
y_max = int(min(detect_config.height - 1, (box[2] * size) + region[1]))
# ignore objects that were detected outside the frame
if (x_min >= detect_config.width - 1) or (
y_min >= detect_config.height - 1
):
continue
width = x_max - x_min
height = y_max - y_min
area = width * height
ratio = width / max(1, height)
det = (d[0], d[1], (x_min, y_min, x_max, y_max), area, ratio, region)
# apply object filters
if is_object_filtered(det, objects_to_track, object_filters):
continue
detections.append(det)
width = x_max - x_min
height = y_max - y_min
area = width * height
ratio = width / max(1, height)
det = (d[0], d[1], (x_min, y_min, x_max, y_max), area, ratio, region)
# apply object filters
if is_object_filtered(det, objects_to_track, object_filters):
continue
detections.append(det)
return detections return detections
@ -618,6 +643,8 @@ def process_frames(
attributes_map = model_config.attributes_map attributes_map = model_config.attributes_map
all_attributes = model_config.all_attributes all_attributes = model_config.all_attributes
multi_batch = model_config.max_batch > 1
# remove license_plate from attributes if this camera is a dedicated LPR cam # remove license_plate from attributes if this camera is a dedicated LPR cam
if camera_config.type == CameraTypeEnum.lpr: if camera_config.type == CameraTypeEnum.lpr:
modified_attributes_map = model_config.attributes_map.copy() modified_attributes_map = model_config.attributes_map.copy()
@ -813,18 +840,32 @@ def process_frames(
if obj["id"] in stationary_object_ids if obj["id"] in stationary_object_ids
] ]
for region in regions: if multi_batch and len(regions) > 0:
detections.extend( detections.extend(
detect( detect(
detect_config, detect_config,
object_detector, object_detector,
frame, frame,
model_config, model_config,
region, regions,
objects_to_track, objects_to_track,
object_filters, object_filters,
multi_batch,
) )
) )
else:
for region in regions:
detections.extend(
detect(
detect_config,
object_detector,
frame,
model_config,
region,
objects_to_track,
object_filters,
)
)
consolidated_detections = reduce_detections(frame_shape, detections) consolidated_detections = reduce_detections(frame_shape, detections)