mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
remote detection using zeromq
This commit is contained in:
parent
6e519e0071
commit
9d21d71282
@ -5,6 +5,7 @@ from multiprocessing.synchronize import Event as MpEvent
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Optional
|
||||
from types import FrameType
|
||||
|
||||
@ -18,7 +19,11 @@ from frigate.comms.mqtt import MqttClient
|
||||
from frigate.comms.ws import WebSocketClient
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.const import CACHE_DIR, CLIPS_DIR, RECORD_DIR
|
||||
from frigate.object_detection import ObjectDetectProcess
|
||||
from frigate.detectors import (
|
||||
ObjectDetectProcess,
|
||||
ObjectDetectionBroker,
|
||||
DetectionServerModeEnum,
|
||||
)
|
||||
from frigate.events import EventCleanup, EventProcessor
|
||||
from frigate.http import create_app
|
||||
from frigate.log import log_process, root_configurer
|
||||
@ -41,10 +46,9 @@ logger = logging.getLogger(__name__)
|
||||
class FrigateApp:
|
||||
def __init__(self) -> None:
|
||||
self.stop_event: MpEvent = mp.Event()
|
||||
self.detection_queue: Queue = mp.Queue()
|
||||
self.detectors: dict[str, ObjectDetectProcess] = {}
|
||||
self.detection_out_events: dict[str, MpEvent] = {}
|
||||
self.detection_shms: list[mp.shared_memory.SharedMemory] = []
|
||||
self.detection_shms: dict[str, mp.shared_memory.SharedMemory] = {}
|
||||
self.log_queue: Queue = mp.Queue()
|
||||
self.plus_api = PlusApi()
|
||||
self.camera_metrics: dict[str, CameraMetricsTypes] = {}
|
||||
@ -80,6 +84,9 @@ class FrigateApp:
|
||||
user_config = FrigateConfig.parse_file(config_file)
|
||||
self.config = user_config.runtime_config
|
||||
|
||||
if self.config.server.mode == DetectionServerModeEnum.DetectionOnly:
|
||||
return
|
||||
|
||||
for camera_name in self.config.cameras.keys():
|
||||
# create camera_metrics
|
||||
self.camera_metrics[camera_name] = {
|
||||
@ -181,10 +188,15 @@ class FrigateApp:
|
||||
comms.append(self.ws_client)
|
||||
self.dispatcher = Dispatcher(self.config, self.camera_metrics, comms)
|
||||
|
||||
def start_detection_broker(self) -> None:
|
||||
bind_urls = [self.config.broker.ipc] + self.config.broker.addresses
|
||||
self.detection_broker = ObjectDetectionBroker(
|
||||
bind=bind_urls, shms=self.detection_shms
|
||||
)
|
||||
self.detection_broker.start()
|
||||
|
||||
def start_detectors(self) -> None:
|
||||
for name in self.config.cameras.keys():
|
||||
self.detection_out_events[name] = mp.Event()
|
||||
|
||||
try:
|
||||
largest_frame = max(
|
||||
[
|
||||
@ -207,14 +219,12 @@ class FrigateApp:
|
||||
except FileExistsError:
|
||||
shm_out = mp.shared_memory.SharedMemory(name=f"out-{name}")
|
||||
|
||||
self.detection_shms.append(shm_in)
|
||||
self.detection_shms.append(shm_out)
|
||||
self.detection_shms[name] = shm_in
|
||||
self.detection_shms[f"out-{name}"] = shm_out
|
||||
|
||||
for name, detector_config in self.config.detectors.items():
|
||||
self.detectors[name] = ObjectDetectProcess(
|
||||
name,
|
||||
self.detection_queue,
|
||||
self.detection_out_events,
|
||||
detector_config,
|
||||
)
|
||||
|
||||
@ -246,29 +256,31 @@ class FrigateApp:
|
||||
logger.info(f"Output process started: {output_processor.pid}")
|
||||
|
||||
def start_camera_processors(self) -> None:
|
||||
for name, config in self.config.cameras.items():
|
||||
if not self.config.cameras[name].enabled:
|
||||
logger.info(f"Camera processor not started for disabled camera {name}")
|
||||
for camera_name, config in self.config.cameras.items():
|
||||
if not self.config.cameras[camera_name].enabled:
|
||||
logger.info(
|
||||
f"Camera processor not started for disabled camera {camera_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
camera_process = mp.Process(
|
||||
target=track_camera,
|
||||
name=f"camera_processor:{name}",
|
||||
name=f"camera_processor:{camera_name}",
|
||||
args=(
|
||||
name,
|
||||
camera_name,
|
||||
config,
|
||||
self.config.model,
|
||||
self.config.model.merged_labelmap,
|
||||
self.detection_queue,
|
||||
self.detection_out_events[name],
|
||||
self.detected_frames_queue,
|
||||
self.camera_metrics[name],
|
||||
self.camera_metrics[camera_name],
|
||||
),
|
||||
)
|
||||
camera_process.daemon = True
|
||||
self.camera_metrics[name]["process"] = camera_process
|
||||
self.camera_metrics[camera_name]["process"] = camera_process
|
||||
camera_process.start()
|
||||
logger.info(f"Camera processor started for {name}: {camera_process.pid}")
|
||||
logger.info(
|
||||
f"Camera processor started for {camera_name}: {camera_process.pid}"
|
||||
)
|
||||
|
||||
def start_camera_capture_processes(self) -> None:
|
||||
for name, config in self.config.cameras.items():
|
||||
@ -330,6 +342,13 @@ class FrigateApp:
|
||||
def start(self) -> None:
|
||||
self.init_logger()
|
||||
logger.info(f"Starting Frigate ({VERSION})")
|
||||
|
||||
def receiveSignal(signalNumber: int, frame: Optional[FrameType]) -> None:
|
||||
self.stop()
|
||||
sys.exit()
|
||||
|
||||
signal.signal(signal.SIGTERM, receiveSignal)
|
||||
|
||||
try:
|
||||
try:
|
||||
self.init_config()
|
||||
@ -352,6 +371,13 @@ class FrigateApp:
|
||||
sys.exit(1)
|
||||
self.set_environment_vars()
|
||||
self.ensure_dirs()
|
||||
|
||||
if self.config.server.mode == DetectionServerModeEnum.DetectionOnly:
|
||||
self.start_detectors()
|
||||
self.start_watchdog()
|
||||
self.stop_event.wait()
|
||||
sys.exit()
|
||||
|
||||
self.set_log_levels()
|
||||
self.init_queues()
|
||||
self.init_database()
|
||||
@ -360,8 +386,10 @@ class FrigateApp:
|
||||
print(e)
|
||||
self.log_process.terminate()
|
||||
sys.exit(1)
|
||||
|
||||
self.init_restream()
|
||||
self.start_detectors()
|
||||
self.start_detection_broker()
|
||||
self.start_video_output_processor()
|
||||
self.start_detected_frames_processor()
|
||||
self.start_camera_processors()
|
||||
@ -377,12 +405,6 @@ class FrigateApp:
|
||||
self.start_watchdog()
|
||||
# self.zeroconf = broadcast_zeroconf(self.config.mqtt.client_id)
|
||||
|
||||
def receiveSignal(signalNumber: int, frame: Optional[FrameType]) -> None:
|
||||
self.stop()
|
||||
sys.exit()
|
||||
|
||||
signal.signal(signal.SIGTERM, receiveSignal)
|
||||
|
||||
try:
|
||||
self.flask_app.run(host="127.0.0.1", port=5001, debug=False)
|
||||
except KeyboardInterrupt:
|
||||
@ -394,20 +416,22 @@ class FrigateApp:
|
||||
logger.info(f"Stopping...")
|
||||
self.stop_event.set()
|
||||
|
||||
self.ws_client.stop()
|
||||
self.detected_frames_processor.join()
|
||||
self.event_processor.join()
|
||||
self.event_cleanup.join()
|
||||
self.recording_maintainer.join()
|
||||
self.recording_cleanup.join()
|
||||
self.stats_emitter.join()
|
||||
self.frigate_watchdog.join()
|
||||
self.db.stop()
|
||||
if self.config.server.mode != DetectionServerModeEnum.DetectionOnly:
|
||||
self.ws_client.stop()
|
||||
self.detected_frames_processor.join()
|
||||
self.event_processor.join()
|
||||
self.event_cleanup.join()
|
||||
self.recording_maintainer.join()
|
||||
self.recording_cleanup.join()
|
||||
self.stats_emitter.join()
|
||||
self.frigate_watchdog.join()
|
||||
self.db.stop()
|
||||
self.detection_broker.stop()
|
||||
|
||||
for detector in self.detectors.values():
|
||||
detector.stop()
|
||||
|
||||
while len(self.detection_shms) > 0:
|
||||
shm = self.detection_shms.pop()
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
shm = self.detection_shms.popitem()
|
||||
shm[1].close()
|
||||
shm[1].unlink()
|
||||
|
||||
@ -4,12 +4,20 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union, Annotated
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import yaml
|
||||
from pydantic import BaseModel, Extra, Field, validator, parse_obj_as
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
validator,
|
||||
root_validator,
|
||||
ValidationError,
|
||||
parse_obj_as,
|
||||
)
|
||||
from pydantic.fields import PrivateAttr
|
||||
|
||||
from frigate.const import (
|
||||
@ -35,8 +43,10 @@ from frigate.ffmpeg_presets import (
|
||||
from frigate.detectors import (
|
||||
PixelFormatEnum,
|
||||
InputTensorEnum,
|
||||
DetectionServerConfig,
|
||||
ModelConfig,
|
||||
DetectorConfig,
|
||||
BaseDetectorConfig,
|
||||
DetectionServerModeEnum,
|
||||
)
|
||||
from frigate.version import VERSION
|
||||
|
||||
@ -325,6 +335,12 @@ class ObjectConfig(FrigateBaseModel):
|
||||
mask: Union[str, List[str]] = Field(default="", title="Object mask.")
|
||||
|
||||
|
||||
DetectorConfig = Annotated[
|
||||
Union[tuple(BaseDetectorConfig.__subclasses__())],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class BirdseyeModeEnum(str, Enum):
|
||||
objects = "objects"
|
||||
motion = "motion"
|
||||
@ -797,7 +813,10 @@ def verify_zone_objects_are_tracked(camera_config: CameraConfig) -> None:
|
||||
|
||||
|
||||
class FrigateConfig(FrigateBaseModel):
|
||||
mqtt: MqttConfig = Field(title="MQTT Configuration.")
|
||||
server: DetectionServerConfig = Field(
|
||||
default_factory=DetectionServerConfig, title="Server configuration"
|
||||
)
|
||||
mqtt: Optional[MqttConfig] = Field(title="MQTT Configuration.")
|
||||
database: DatabaseConfig = Field(
|
||||
default_factory=DatabaseConfig, title="Database configuration."
|
||||
)
|
||||
@ -842,7 +861,7 @@ class FrigateConfig(FrigateBaseModel):
|
||||
detect: DetectConfig = Field(
|
||||
default_factory=DetectConfig, title="Global object tracking configuration."
|
||||
)
|
||||
cameras: Dict[str, CameraConfig] = Field(title="Camera configuration.")
|
||||
cameras: Optional[Dict[str, CameraConfig]] = Field(title="Camera configuration.")
|
||||
timestamp_style: TimestampStyleConfig = Field(
|
||||
default_factory=TimestampStyleConfig,
|
||||
title="Global timestamp style configuration.",
|
||||
@ -853,6 +872,50 @@ class FrigateConfig(FrigateBaseModel):
|
||||
"""Merge camera config with globals."""
|
||||
config = self.copy(deep=True)
|
||||
|
||||
for key, detector in config.detectors.items():
|
||||
detector_config: BaseDetectorConfig = parse_obj_as(DetectorConfig, detector)
|
||||
|
||||
if detector_config.cameras is None:
|
||||
detector_config.cameras = (
|
||||
list(config.cameras.keys()) if config.cameras is not None else []
|
||||
)
|
||||
|
||||
if detector_config.address is None:
|
||||
detector_config.address = config.server.ipc
|
||||
|
||||
if detector_config.shared_memory is None:
|
||||
detector_config.shared_memory = (
|
||||
detector_config.address == config.server.ipc
|
||||
)
|
||||
|
||||
if detector_config.model is None:
|
||||
detector_config.model = config.model
|
||||
else:
|
||||
model = detector_config.model
|
||||
schema = ModelConfig.schema()["properties"]
|
||||
if (
|
||||
model.width != schema["width"]["default"]
|
||||
or model.height != schema["height"]["default"]
|
||||
or model.labelmap_path is not None
|
||||
or model.labelmap is not {}
|
||||
or model.input_tensor != schema["input_tensor"]["default"]
|
||||
or model.input_pixel_format
|
||||
!= schema["input_pixel_format"]["default"]
|
||||
):
|
||||
logger.warning(
|
||||
"Customizing more than a detector model path is unsupported."
|
||||
)
|
||||
merged_model = deep_merge(
|
||||
detector_config.model.dict(exclude_unset=True),
|
||||
config.model.dict(exclude_unset=True),
|
||||
)
|
||||
detector_config.model = ModelConfig.parse_obj(merged_model)
|
||||
|
||||
config.detectors[key] = detector_config
|
||||
|
||||
if config.server.mode == DetectionServerModeEnum.DetectionOnly:
|
||||
return config
|
||||
|
||||
# MQTT password substitution
|
||||
if config.mqtt.password:
|
||||
config.mqtt.password = config.mqtt.password.format(**FRIGATE_ENV_VARS)
|
||||
@ -952,32 +1015,6 @@ class FrigateConfig(FrigateBaseModel):
|
||||
camera_config.create_ffmpeg_cmds()
|
||||
config.cameras[name] = camera_config
|
||||
|
||||
for key, detector in config.detectors.items():
|
||||
detector_config: DetectorConfig = parse_obj_as(DetectorConfig, detector)
|
||||
if detector_config.model is None:
|
||||
detector_config.model = config.model
|
||||
else:
|
||||
model = detector_config.model
|
||||
schema = ModelConfig.schema()["properties"]
|
||||
if (
|
||||
model.width != schema["width"]["default"]
|
||||
or model.height != schema["height"]["default"]
|
||||
or model.labelmap_path is not None
|
||||
or model.labelmap is not {}
|
||||
or model.input_tensor != schema["input_tensor"]["default"]
|
||||
or model.input_pixel_format
|
||||
!= schema["input_pixel_format"]["default"]
|
||||
):
|
||||
logger.warning(
|
||||
"Customizing more than a detector model path is unsupported."
|
||||
)
|
||||
merged_model = deep_merge(
|
||||
detector_config.model.dict(exclude_unset=True),
|
||||
config.model.dict(exclude_unset=True),
|
||||
)
|
||||
detector_config.model = ModelConfig.parse_obj(merged_model)
|
||||
config.detectors[key] = detector_config
|
||||
|
||||
return config
|
||||
|
||||
@validator("cameras")
|
||||
@ -988,6 +1025,33 @@ class FrigateConfig(FrigateBaseModel):
|
||||
raise ValueError("Zones cannot share names with cameras")
|
||||
return v
|
||||
|
||||
@root_validator(pre=True)
|
||||
def ensure_cameras_mqtt_defined(cls, values):
|
||||
server_config = values.get("server", None)
|
||||
if (
|
||||
server_config is not None
|
||||
and server_config.get("mode", DetectionServerModeEnum.Full)
|
||||
== DetectionServerModeEnum.DetectionOnly
|
||||
):
|
||||
return values
|
||||
|
||||
if values.get("cameras", None) is None:
|
||||
raise ValueError("cameras: field required")
|
||||
if values.get("mqtt", None) is None:
|
||||
raise ValueError("mqtt: field required")
|
||||
return values
|
||||
|
||||
@validator("detectors")
|
||||
def ensure_detectors_have_cameras(cls, v: Dict[str, BaseDetectorConfig], values):
|
||||
for detector in v.values():
|
||||
if values.get("cameras", None) is None and (
|
||||
detector.cameras is None or len(detector.cameras) == 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Detectors must specify at least one camera name to process"
|
||||
)
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def parse_file(cls, config_file):
|
||||
with open(config_file) as f:
|
||||
|
||||
@ -1,24 +1,13 @@
|
||||
import logging
|
||||
|
||||
from .detection_api import DetectionApi
|
||||
from .detection_broker import ObjectDetectionBroker
|
||||
from .detector_config import (
|
||||
PixelFormatEnum,
|
||||
InputTensorEnum,
|
||||
ModelConfig,
|
||||
BaseDetectorConfig,
|
||||
DetectionServerConfig,
|
||||
DetectionServerModeEnum,
|
||||
)
|
||||
from .detector_types import DetectorTypeEnum, api_types, DetectorConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_detector(detector_config):
|
||||
if detector_config.type == DetectorTypeEnum.cpu:
|
||||
logger.warning(
|
||||
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
|
||||
)
|
||||
|
||||
api = api_types.get(detector_config.type)
|
||||
if not api:
|
||||
raise ValueError(detector_config.type)
|
||||
return api(detector_config)
|
||||
from .detection_client import ObjectDetectionClient
|
||||
from .detector_types import DetectorTypeEnum, api_types, create_detector
|
||||
from .detection_worker import ObjectDetectionWorker, ObjectDetectProcess
|
||||
|
||||
89
frigate/detectors/detection_broker.py
Normal file
89
frigate/detectors/detection_broker.py
Normal file
@ -0,0 +1,89 @@
|
||||
import signal
|
||||
import threading
|
||||
import zmq
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Union, List
|
||||
from majortomo import error, protocol
|
||||
from majortomo.config import DEFAULT_BIND_URL
|
||||
from majortomo.broker import Broker
|
||||
|
||||
|
||||
READY_SHM = b"\007"
|
||||
|
||||
|
||||
class ObjectDetectionBroker(Broker):
|
||||
def __init__(
|
||||
self,
|
||||
bind: Union[str, List[str]] = DEFAULT_BIND_URL,
|
||||
shms: dict[str, SharedMemory] = {},
|
||||
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
|
||||
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
|
||||
busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT,
|
||||
zmq_context=None,
|
||||
):
|
||||
protocol.Message.ALLOWED_COMMANDS[protocol.WORKER_HEADER].add(READY_SHM)
|
||||
|
||||
super().__init__(
|
||||
self,
|
||||
heartbeat_interval=heartbeat_interval,
|
||||
heartbeat_timeout=heartbeat_timeout,
|
||||
busy_worker_timeout=busy_worker_timeout,
|
||||
zmq_context=zmq_context,
|
||||
)
|
||||
self.shm_workers = set()
|
||||
self.shms = shms
|
||||
self._bind_urls = [bind] if not isinstance(bind, list) else bind
|
||||
self.broker_thread: threading.Thread = None
|
||||
|
||||
def bind(self):
|
||||
"""Bind the ZMQ socket"""
|
||||
if self._socket:
|
||||
raise error.StateError("Socket is already bound")
|
||||
|
||||
self._socket = self._context.socket(zmq.ROUTER)
|
||||
self._socket.rcvtimeo = int(self._heartbeat_interval * 1000)
|
||||
for bind_url in self._bind_urls:
|
||||
self._socket.bind(bind_url)
|
||||
self._log.info("Broker listening on %s", bind_url)
|
||||
|
||||
def close(self):
|
||||
if self._socket is None:
|
||||
return
|
||||
for bind_url in self._bind_urls:
|
||||
self._socket.disconnect(bind_url)
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
self._log.info("Broker socket closing")
|
||||
|
||||
def _handle_worker_message(self, message):
|
||||
if message.command == READY_SHM:
|
||||
self.shm_workers.add(message.client)
|
||||
self._handle_worker_ready(message.client, message.message[0])
|
||||
else:
|
||||
super()._handle_worker_message(message)
|
||||
|
||||
def _purge_expired_workers(self):
|
||||
self.shm_workers.intersection_update(self._services._workers.keys())
|
||||
super()._purge_expired_workers()
|
||||
|
||||
def _send_to_worker(self, worker_id, command, body=None):
|
||||
if (
|
||||
worker_id not in self.shm_workers
|
||||
and command == protocol.REQUEST
|
||||
and body is not None
|
||||
):
|
||||
service_name = body[2]
|
||||
in_shm = self.shms[str(service_name, "ascii")]
|
||||
tensor_input = in_shm.buf
|
||||
body = body[0:2] + [tensor_input]
|
||||
super()._send_to_worker(worker_id, command, body)
|
||||
|
||||
def start(self):
|
||||
self.broker_thread = threading.Thread(target=self.run)
|
||||
self.broker_thread.start()
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
if self.broker_thread is not None:
|
||||
self.broker_thread.join()
|
||||
self.broker_thread = None
|
||||
67
frigate/detectors/detection_client.py
Normal file
67
frigate/detectors/detection_client.py
Normal file
@ -0,0 +1,67 @@
|
||||
import numpy as np
|
||||
import multiprocessing as mp
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from majortomo import Client
|
||||
from frigate.util import EventsPerSecond
|
||||
from .detector_config import ModelConfig, DetectionServerConfig
|
||||
|
||||
|
||||
class ObjectDetectionClient:
|
||||
def __init__(
|
||||
self,
|
||||
camera_name: str,
|
||||
labels,
|
||||
model_config: ModelConfig,
|
||||
server_config: DetectionServerConfig,
|
||||
timeout=None,
|
||||
):
|
||||
self.camera_name = camera_name
|
||||
self.labels = labels
|
||||
self.model_config = model_config
|
||||
self.fps = EventsPerSecond()
|
||||
self.in_shm = SharedMemory(name=self.camera_name, create=False)
|
||||
self.in_np_shm = np.ndarray(
|
||||
(1, model_config.height, model_config.width, 3),
|
||||
dtype=np.uint8,
|
||||
buffer=self.in_shm.buf,
|
||||
)
|
||||
self.out_shm = SharedMemory(name=f"out-{self.camera_name}", create=False)
|
||||
self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf)
|
||||
|
||||
self.timeout = timeout
|
||||
self.detection_client = Client(server_config.ipc)
|
||||
self.detection_client.connect()
|
||||
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
detections = []
|
||||
# copy input to shared memory
|
||||
self.in_np_shm[:] = tensor_input[:]
|
||||
|
||||
try:
|
||||
self.detection_client.send(
|
||||
f"{self.camera_name}".encode("ascii"),
|
||||
f"{self.camera_name}".encode("ascii"),
|
||||
self.model_config.height.to_bytes(4, "little"),
|
||||
self.model_config.width.to_bytes(4, "little"),
|
||||
)
|
||||
result = self.detection_client.recv_all_as_list(timeout=self.timeout)
|
||||
if len(result) == 1:
|
||||
# output came back in the reply rather than direct to SHM
|
||||
output = np.frombuffer(result[0], dtype=np.float32)
|
||||
self.out_np_shm[:] = np.reshape(output, newshape=(20, 6))[:]
|
||||
except TimeoutError:
|
||||
return detections
|
||||
|
||||
for d in self.out_np_shm:
|
||||
if d[1] < threshold:
|
||||
break
|
||||
detections.append(
|
||||
(self.labels[int(d[0])], float(d[1]), (d[2], d[3], d[4], d[5]))
|
||||
)
|
||||
self.fps.update()
|
||||
return detections
|
||||
|
||||
def cleanup(self):
|
||||
self.detection_client.close()
|
||||
self.in_shm.close()
|
||||
self.out_shm.close()
|
||||
244
frigate/detectors/detection_worker.py
Normal file
244
frigate/detectors/detection_worker.py
Normal file
@ -0,0 +1,244 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import numpy as np
|
||||
import multiprocessing as mp
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from majortomo import Worker, WorkerRequestsIterator, error, protocol
|
||||
from majortomo.util import TextOrBytes, text_to_ascii_bytes
|
||||
from typing import List
|
||||
|
||||
from frigate.util import listen, EventsPerSecond, load_labels
|
||||
from .detector_config import InputTensorEnum, BaseDetectorConfig
|
||||
from .detector_types import create_detector
|
||||
|
||||
from setproctitle import setproctitle
|
||||
|
||||
DEFAULT_ZMQ_LINGER = 2500
|
||||
READY_SHM = b"\007"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ObjectDetectionWorker(Worker):
|
||||
def __init__(
|
||||
self,
|
||||
detector_name: str,
|
||||
detector_config: BaseDetectorConfig,
|
||||
avg_inference_speed: mp.Value = mp.Value("d", 0.01),
|
||||
detection_start: mp.Value = mp.Value("d", 0.00),
|
||||
labels=None,
|
||||
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
|
||||
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
|
||||
zmq_context=None,
|
||||
zmq_linger=DEFAULT_ZMQ_LINGER,
|
||||
):
|
||||
self.broker_url = detector_config.address
|
||||
self.service_names = [
|
||||
text_to_ascii_bytes(service_name)
|
||||
for service_name in detector_config.cameras
|
||||
]
|
||||
super().__init__(
|
||||
self.broker_url,
|
||||
b"",
|
||||
heartbeat_interval,
|
||||
heartbeat_timeout,
|
||||
zmq_context,
|
||||
zmq_linger,
|
||||
)
|
||||
self.detector_name = detector_name
|
||||
self.detector_config = detector_config
|
||||
self.avg_inference_speed = avg_inference_speed
|
||||
self.detection_start = detection_start
|
||||
self.detection_shms: dict[str, SharedMemory] = {}
|
||||
self.detection_outputs = {}
|
||||
|
||||
self.fps = EventsPerSecond()
|
||||
self.shm_shape = (
|
||||
1,
|
||||
self.detector_config.model.height,
|
||||
self.detector_config.model.width,
|
||||
3,
|
||||
)
|
||||
self.out_shm = None
|
||||
self.out_np = None
|
||||
|
||||
self.labels = labels
|
||||
if self.labels is None:
|
||||
self.labels = {}
|
||||
else:
|
||||
self.labels = load_labels(self.labels)
|
||||
|
||||
if self.detector_config:
|
||||
self.input_transform = self.tensor_transform(
|
||||
self.detector_config.model.input_tensor
|
||||
)
|
||||
else:
|
||||
self.input_transform = None
|
||||
|
||||
self.detect_api = create_detector(self.detector_config)
|
||||
|
||||
def _send_ready(self):
|
||||
command = READY_SHM if self.detector_config.shared_memory else protocol.READY
|
||||
for service_name in self.service_names:
|
||||
self._send(command, service_name)
|
||||
|
||||
def handle_request(self, request):
|
||||
self.detection_start.value = datetime.datetime.now().timestamp()
|
||||
|
||||
# expected request format:
|
||||
# if SHM: [camera_name, model.height, model.width]
|
||||
# else: [tensor_input]
|
||||
# detect and send the output
|
||||
detections = None
|
||||
frames = []
|
||||
if len(request) == 1:
|
||||
detections = self.detect_raw(request[0])
|
||||
if detections is None:
|
||||
self.detection_start.value = 0.0
|
||||
return frames
|
||||
frames.append(detections.tobytes())
|
||||
elif len(request) == 3:
|
||||
camera_name = request[0].decode("ascii")
|
||||
shm_shape = (
|
||||
1,
|
||||
int.from_bytes(request[1], byteorder="little"),
|
||||
int.from_bytes(request[2], byteorder="little"),
|
||||
3,
|
||||
)
|
||||
detections = self.detect_shm(camera_name, shm_shape)
|
||||
out_np = self.detection_outputs.get(camera_name, None)
|
||||
if out_np is None:
|
||||
out_shm = self.detection_shms.get(f"out-{camera_name}", None)
|
||||
if out_shm is None:
|
||||
out_shm = SharedMemory(name=f"out-{camera_name}", create=False)
|
||||
out_np = self.detection_outputs[camera_name] = np.ndarray(
|
||||
(20, 6), dtype=np.float32, buffer=out_shm.buf
|
||||
)
|
||||
out_np[:] = detections[:]
|
||||
|
||||
duration = datetime.datetime.now().timestamp() - self.detection_start.value
|
||||
self.detection_start.value = 0.0
|
||||
self.avg_inference_speed.value = (
|
||||
self.avg_inference_speed.value * 9 + duration
|
||||
) / 10
|
||||
return frames
|
||||
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
detections = []
|
||||
|
||||
raw_detections = self.detect_raw(tensor_input)
|
||||
|
||||
for d in raw_detections:
|
||||
if d[1] < threshold:
|
||||
break
|
||||
detections.append(
|
||||
(self.labels[int(d[0])], float(d[1]), (d[2], d[3], d[4], d[5]))
|
||||
)
|
||||
self.fps.update()
|
||||
return detections
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
if self.input_transform:
|
||||
tensor_input = np.transpose(tensor_input, self.input_transform)
|
||||
detections = self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||
return detections
|
||||
|
||||
def detect_shm(self, camera_name, shm_shape):
|
||||
in_shm = self.detection_shms.get(camera_name)
|
||||
if in_shm is None:
|
||||
in_shm = self.detection_shms[camera_name] = SharedMemory(camera_name)
|
||||
|
||||
tensor_input = np.ndarray(shm_shape, dtype=np.uint8, buffer=in_shm.buf)
|
||||
detections = self.detect_raw(tensor_input=tensor_input)
|
||||
return detections
|
||||
|
||||
def tensor_transform(self, desired_shape):
|
||||
# Currently this function only supports BHWC permutations
|
||||
if desired_shape == InputTensorEnum.nhwc:
|
||||
return None
|
||||
elif desired_shape == InputTensorEnum.nchw:
|
||||
return (0, 3, 1, 2)
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
self.detection_outputs = {}
|
||||
while len(self.detection_shms) > 0:
|
||||
shm = self.detection_shms.popitem()
|
||||
shm[1].close()
|
||||
|
||||
|
||||
def run_detector(
|
||||
detector_name, detector_config, avg_inference_speed, detection_start, labels
|
||||
):
|
||||
threading.current_thread().name = f"detector:{detector_name}"
|
||||
logger = logging.getLogger(f"detector.{detector_name}")
|
||||
logger.info(f"Starting detection process: {os.getpid()}")
|
||||
setproctitle(f"frigate.detector.{detector_name}")
|
||||
listen()
|
||||
|
||||
worker = ObjectDetectionWorker(
|
||||
detector_name,
|
||||
detector_config,
|
||||
avg_inference_speed,
|
||||
detection_start,
|
||||
labels,
|
||||
)
|
||||
|
||||
def receiveSignal(signalNumber, frame):
|
||||
worker.close()
|
||||
|
||||
signal.signal(signal.SIGTERM, receiveSignal)
|
||||
signal.signal(signal.SIGINT, receiveSignal)
|
||||
|
||||
worker_iter = WorkerRequestsIterator(worker)
|
||||
for request in worker_iter:
|
||||
reply = worker.handle_request(request)
|
||||
worker_iter.send_reply_final(reply)
|
||||
|
||||
|
||||
class ObjectDetectProcess:
|
||||
def __init__(
|
||||
self,
|
||||
detector_name: str,
|
||||
detector_config: BaseDetectorConfig,
|
||||
labels=None,
|
||||
):
|
||||
self.detector_name = detector_name
|
||||
self.detector_config = detector_config
|
||||
self.labels = labels
|
||||
|
||||
self.avg_inference_speed = mp.Value("d", 0.01)
|
||||
self.detection_start = mp.Value("d", 0.0)
|
||||
self.detect_process = None
|
||||
|
||||
self.start_or_restart()
|
||||
|
||||
def stop(self):
|
||||
self.detect_process.terminate()
|
||||
logging.info("Waiting for detection process to exit gracefully...")
|
||||
self.detect_process.join(timeout=30)
|
||||
if self.detect_process.exitcode is None:
|
||||
logging.info("Detection process didn't exit. Force killing...")
|
||||
self.detect_process.kill()
|
||||
self.detect_process.join()
|
||||
|
||||
def start_or_restart(self):
|
||||
self.detection_start.value = 0.0
|
||||
if (not self.detect_process is None) and self.detect_process.is_alive():
|
||||
self.stop()
|
||||
self.detect_process = mp.Process(
|
||||
target=run_detector,
|
||||
name=f"detector:{self.detector_name}",
|
||||
args=(
|
||||
self.detector_name,
|
||||
self.detector_config,
|
||||
self.avg_inference_speed,
|
||||
self.detection_start,
|
||||
self.labels,
|
||||
),
|
||||
)
|
||||
self.detect_process.daemon = True
|
||||
self.detect_process.start()
|
||||
@ -69,6 +69,9 @@ class ModelConfig(BaseModel):
|
||||
class BaseDetectorConfig(BaseModel):
|
||||
# the type field must be defined in all subclasses
|
||||
type: str = Field(default="cpu", title="Detector Type")
|
||||
cameras: List[str] = Field(default=None, title="Cameras to track")
|
||||
address: str = Field(default=None, title="Frigate Detection Queue Server Address")
|
||||
shared_memory: Union[bool, None] = Field(default=None, title="Use Shared Memory")
|
||||
model: ModelConfig = Field(
|
||||
default=None, title="Detector specific model configuration."
|
||||
)
|
||||
@ -76,3 +79,18 @@ class BaseDetectorConfig(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class DetectionServerModeEnum(str, Enum):
|
||||
Full = "full"
|
||||
DetectionOnly = "detection_only"
|
||||
|
||||
|
||||
class DetectionServerConfig(BaseModel):
|
||||
mode: DetectionServerModeEnum = Field(
|
||||
default=DetectionServerModeEnum.Full, title="Server mode"
|
||||
)
|
||||
ipc: str = Field(default="ipc://detection_broker.ipc", title="Broker IPC path")
|
||||
addresses: List[str] = Field(
|
||||
default=["tcp://127.0.0.1:5555"], title="Broker TCP addresses"
|
||||
)
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
import logging
|
||||
import importlib
|
||||
import pkgutil
|
||||
from typing import Union
|
||||
from typing_extensions import Annotated
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
|
||||
from . import plugins
|
||||
from .detection_api import DetectionApi
|
||||
from .detector_config import BaseDetectorConfig
|
||||
from . import plugins
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
pass
|
||||
|
||||
|
||||
plugin_modules = [
|
||||
importlib.import_module(name)
|
||||
for finder, name, ispkg in pkgutil.iter_modules(
|
||||
@ -21,15 +22,16 @@ plugin_modules = [
|
||||
]
|
||||
|
||||
api_types = {det.type_key: det for det in DetectionApi.__subclasses__()}
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
pass
|
||||
|
||||
|
||||
DetectorTypeEnum = StrEnum("DetectorTypeEnum", {k: k for k in api_types})
|
||||
|
||||
DetectorConfig = Annotated[
|
||||
Union[tuple(BaseDetectorConfig.__subclasses__())],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
def create_detector(detector_config):
|
||||
if detector_config.type == DetectorTypeEnum.cpu:
|
||||
logger.warning(
|
||||
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
|
||||
)
|
||||
|
||||
api = api_types.get(detector_config.type)
|
||||
if not api:
|
||||
raise ValueError(detector_config.type)
|
||||
return api(detector_config)
|
||||
|
||||
@ -1,217 +0,0 @@
|
||||
import datetime
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
from setproctitle import setproctitle
|
||||
|
||||
from frigate.config import InputTensorEnum
|
||||
from frigate.detectors import create_detector
|
||||
|
||||
from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ObjectDetector(ABC):
|
||||
@abstractmethod
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
pass
|
||||
|
||||
|
||||
def tensor_transform(desired_shape):
|
||||
# Currently this function only supports BHWC permutations
|
||||
if desired_shape == InputTensorEnum.nhwc:
|
||||
return None
|
||||
elif desired_shape == InputTensorEnum.nchw:
|
||||
return (0, 3, 1, 2)
|
||||
|
||||
|
||||
class LocalObjectDetector(ObjectDetector):
|
||||
def __init__(
|
||||
self,
|
||||
detector_config=None,
|
||||
labels=None,
|
||||
):
|
||||
self.fps = EventsPerSecond()
|
||||
if labels is None:
|
||||
self.labels = {}
|
||||
else:
|
||||
self.labels = load_labels(labels)
|
||||
|
||||
if detector_config:
|
||||
self.input_transform = tensor_transform(detector_config.model.input_tensor)
|
||||
else:
|
||||
self.input_transform = None
|
||||
|
||||
self.detect_api = create_detector(detector_config)
|
||||
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
detections = []
|
||||
|
||||
raw_detections = self.detect_raw(tensor_input)
|
||||
|
||||
for d in raw_detections:
|
||||
if d[1] < threshold:
|
||||
break
|
||||
detections.append(
|
||||
(self.labels[int(d[0])], float(d[1]), (d[2], d[3], d[4], d[5]))
|
||||
)
|
||||
self.fps.update()
|
||||
return detections
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
if self.input_transform:
|
||||
tensor_input = np.transpose(tensor_input, self.input_transform)
|
||||
return self.detect_api.detect_raw(tensor_input=tensor_input)
|
||||
|
||||
|
||||
def run_detector(
|
||||
name: str,
|
||||
detection_queue: mp.Queue,
|
||||
out_events: dict[str, mp.Event],
|
||||
avg_speed,
|
||||
start,
|
||||
detector_config,
|
||||
):
|
||||
threading.current_thread().name = f"detector:{name}"
|
||||
logger = logging.getLogger(f"detector.{name}")
|
||||
logger.info(f"Starting detection process: {os.getpid()}")
|
||||
setproctitle(f"frigate.detector.{name}")
|
||||
listen()
|
||||
|
||||
stop_event = mp.Event()
|
||||
|
||||
def receiveSignal(signalNumber, frame):
|
||||
stop_event.set()
|
||||
|
||||
signal.signal(signal.SIGTERM, receiveSignal)
|
||||
signal.signal(signal.SIGINT, receiveSignal)
|
||||
|
||||
frame_manager = SharedMemoryFrameManager()
|
||||
object_detector = LocalObjectDetector(detector_config=detector_config)
|
||||
|
||||
outputs = {}
|
||||
for name in out_events.keys():
|
||||
out_shm = mp.shared_memory.SharedMemory(name=f"out-{name}", create=False)
|
||||
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
||||
outputs[name] = {"shm": out_shm, "np": out_np}
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
connection_id = detection_queue.get(timeout=5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
input_frame = frame_manager.get(
|
||||
connection_id,
|
||||
(1, detector_config.model.height, detector_config.model.width, 3),
|
||||
)
|
||||
|
||||
if input_frame is None:
|
||||
continue
|
||||
|
||||
# detect and send the output
|
||||
start.value = datetime.datetime.now().timestamp()
|
||||
detections = object_detector.detect_raw(input_frame)
|
||||
duration = datetime.datetime.now().timestamp() - start.value
|
||||
outputs[connection_id]["np"][:] = detections[:]
|
||||
out_events[connection_id].set()
|
||||
start.value = 0.0
|
||||
|
||||
avg_speed.value = (avg_speed.value * 9 + duration) / 10
|
||||
|
||||
|
||||
class ObjectDetectProcess:
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
detection_queue,
|
||||
out_events,
|
||||
detector_config,
|
||||
):
|
||||
self.name = name
|
||||
self.out_events = out_events
|
||||
self.detection_queue = detection_queue
|
||||
self.avg_inference_speed = mp.Value("d", 0.01)
|
||||
self.detection_start = mp.Value("d", 0.0)
|
||||
self.detect_process = None
|
||||
self.detector_config = detector_config
|
||||
self.start_or_restart()
|
||||
|
||||
def stop(self):
|
||||
self.detect_process.terminate()
|
||||
logging.info("Waiting for detection process to exit gracefully...")
|
||||
self.detect_process.join(timeout=30)
|
||||
if self.detect_process.exitcode is None:
|
||||
logging.info("Detection process didnt exit. Force killing...")
|
||||
self.detect_process.kill()
|
||||
self.detect_process.join()
|
||||
|
||||
def start_or_restart(self):
|
||||
self.detection_start.value = 0.0
|
||||
if (not self.detect_process is None) and self.detect_process.is_alive():
|
||||
self.stop()
|
||||
self.detect_process = mp.Process(
|
||||
target=run_detector,
|
||||
name=f"detector:{self.name}",
|
||||
args=(
|
||||
self.name,
|
||||
self.detection_queue,
|
||||
self.out_events,
|
||||
self.avg_inference_speed,
|
||||
self.detection_start,
|
||||
self.detector_config,
|
||||
),
|
||||
)
|
||||
self.detect_process.daemon = True
|
||||
self.detect_process.start()
|
||||
|
||||
|
||||
class RemoteObjectDetector:
|
||||
def __init__(self, name, labels, detection_queue, event, model_config):
|
||||
self.labels = labels
|
||||
self.name = name
|
||||
self.fps = EventsPerSecond()
|
||||
self.detection_queue = detection_queue
|
||||
self.event = event
|
||||
self.shm = mp.shared_memory.SharedMemory(name=self.name, create=False)
|
||||
self.np_shm = np.ndarray(
|
||||
(1, model_config.height, model_config.width, 3),
|
||||
dtype=np.uint8,
|
||||
buffer=self.shm.buf,
|
||||
)
|
||||
self.out_shm = mp.shared_memory.SharedMemory(
|
||||
name=f"out-{self.name}", create=False
|
||||
)
|
||||
self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf)
|
||||
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
detections = []
|
||||
|
||||
# copy input to shared memory
|
||||
self.np_shm[:] = tensor_input[:]
|
||||
self.event.clear()
|
||||
self.detection_queue.put(self.name)
|
||||
result = self.event.wait(timeout=10.0)
|
||||
|
||||
# if it timed out
|
||||
if result is None:
|
||||
return detections
|
||||
|
||||
for d in self.out_np_shm:
|
||||
if d[1] < threshold:
|
||||
break
|
||||
detections.append(
|
||||
(self.labels[int(d[0])], float(d[1]), (d[2], d[3], d[4], d[5]))
|
||||
)
|
||||
self.fps.update()
|
||||
return detections
|
||||
|
||||
def cleanup(self):
|
||||
self.shm.unlink()
|
||||
self.out_shm.unlink()
|
||||
@ -17,7 +17,7 @@ from frigate.types import StatsTrackingTypes, CameraMetricsTypes
|
||||
from frigate.util import get_amd_gpu_stats, get_intel_gpu_stats, get_nvidia_gpu_stats
|
||||
from frigate.version import VERSION
|
||||
from frigate.util import get_cpu_stats
|
||||
from frigate.object_detection import ObjectDetectProcess
|
||||
from frigate.detectors import ObjectDetectProcess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -45,14 +45,14 @@ class TestConfig(unittest.TestCase):
|
||||
"cpu": {
|
||||
"type": "cpu",
|
||||
"model": {"path": "/cpu_model.tflite"},
|
||||
"cameras": ["test"],
|
||||
},
|
||||
"edgetpu": {
|
||||
"type": "edgetpu",
|
||||
"model": {"path": "/edgetpu_model.tflite", "width": 160},
|
||||
"cameras": ["test"],
|
||||
},
|
||||
"openvino": {
|
||||
"type": "openvino",
|
||||
},
|
||||
"openvino": {"type": "openvino", "cameras": ["test"]},
|
||||
},
|
||||
"model": {"path": "/default.tflite", "width": 512},
|
||||
}
|
||||
|
||||
@ -1,29 +1,207 @@
|
||||
import functools
|
||||
import threading
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import multiprocessing as mp
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
import numpy as np
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from frigate.config import DetectorConfig, InputTensorEnum, ModelConfig
|
||||
from frigate.detectors import DetectorTypeEnum
|
||||
import frigate.detectors as detectors
|
||||
import frigate.object_detection
|
||||
from frigate.config import FrigateConfig, DetectorConfig, InputTensorEnum, ModelConfig
|
||||
from frigate.detectors import (
|
||||
DetectorTypeEnum,
|
||||
ObjectDetectionBroker,
|
||||
ObjectDetectionClient,
|
||||
ObjectDetectionWorker,
|
||||
)
|
||||
from frigate.util import deep_merge
|
||||
import frigate.detectors.detector_types as detectors
|
||||
|
||||
|
||||
test_tensor_input = np.random.randint(
|
||||
np.iinfo(np.uint8).min,
|
||||
np.iinfo(np.uint8).max,
|
||||
(1, 320, 320, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
test_detection_output = np.random.rand(20, 6).astype("f")
|
||||
|
||||
|
||||
def create_detector(det_type):
|
||||
api = Mock()
|
||||
api.return_value.detect_raw = Mock(return_value=test_detection_output)
|
||||
return api
|
||||
|
||||
|
||||
class TestLocalObjectDetector(unittest.TestCase):
|
||||
@patch.dict(
|
||||
"frigate.detectors.detector_types.api_types",
|
||||
{det_type: create_detector(det_type) for det_type in DetectorTypeEnum},
|
||||
)
|
||||
def test_socket_client_broker_worker(self):
|
||||
detector_name = "cpu"
|
||||
ipc_address = "ipc://detection_broker.ipc"
|
||||
tcp_address = "tcp://127.0.0.1:5555"
|
||||
|
||||
detector = {"type": "cpu"}
|
||||
test_cases = {
|
||||
"ipc_shm": {"cameras": ["ipc_shm"]},
|
||||
"ipc_no_shm": {"shared_memory": False, "cameras": ["ipc_no_shm"]},
|
||||
"tcp_shm": {
|
||||
"address": tcp_address,
|
||||
"shared_memory": True,
|
||||
"cameras": ["tcp_shm"],
|
||||
},
|
||||
"tcp_no_shm": {"address": tcp_address, "cameras": ["tcp_no_shm"]},
|
||||
}
|
||||
|
||||
class ClientTestThread(threading.Thread):
|
||||
def __init__(
|
||||
self,
|
||||
camera_name,
|
||||
labelmap,
|
||||
model_config,
|
||||
server_config,
|
||||
tensor_input,
|
||||
timeout,
|
||||
):
|
||||
super().__init__()
|
||||
self.camera_name = camera_name
|
||||
self.labelmap = labelmap
|
||||
self.model_config = model_config
|
||||
self.server_config = server_config
|
||||
self.tensor_input = tensor_input
|
||||
self.timeout = timeout
|
||||
|
||||
def run(self):
|
||||
object_detector = ObjectDetectionClient(
|
||||
self.camera_name,
|
||||
self.labelmap,
|
||||
self.model_config,
|
||||
self.server_config,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
try:
|
||||
object_detector.detect(self.tensor_input)
|
||||
finally:
|
||||
object_detector.cleanup()
|
||||
|
||||
try:
|
||||
detection_shms: dict[str, SharedMemory] = {}
|
||||
for camera_name in test_cases.keys():
|
||||
shm_name = camera_name
|
||||
out_shm_name = f"out-{camera_name}"
|
||||
try:
|
||||
shm = SharedMemory(name=shm_name, size=512 * 512 * 3, create=True)
|
||||
except FileExistsError:
|
||||
shm = SharedMemory(name=shm_name)
|
||||
detection_shms[shm_name] = shm
|
||||
try:
|
||||
out_shm = SharedMemory(
|
||||
name=out_shm_name, size=20 * 6 * 4, create=True
|
||||
)
|
||||
except FileExistsError:
|
||||
out_shm = SharedMemory(name=out_shm_name)
|
||||
detection_shms[out_shm_name] = out_shm
|
||||
|
||||
self.detection_broker = ObjectDetectionBroker(
|
||||
bind=[ipc_address, tcp_address],
|
||||
shms=detection_shms,
|
||||
)
|
||||
self.detection_broker.start()
|
||||
|
||||
for test_case in test_cases.keys():
|
||||
with self.subTest(test_case=test_case):
|
||||
camera_name = test_case
|
||||
shm_name = camera_name
|
||||
shm = detection_shms[shm_name]
|
||||
out_shm_name = f"out-{camera_name}"
|
||||
out_shm = detection_shms[out_shm_name]
|
||||
|
||||
test_cfg = FrigateConfig.parse_obj(
|
||||
{
|
||||
"server": {
|
||||
"mode": "detection_only",
|
||||
"ipc": ipc_address,
|
||||
"addresses": [tcp_address],
|
||||
},
|
||||
"detectors": {
|
||||
detector_name: deep_merge(
|
||||
detector, test_cases[test_case]
|
||||
)
|
||||
},
|
||||
}
|
||||
)
|
||||
config = test_cfg.runtime_config
|
||||
detector_config = config.detectors[detector_name]
|
||||
model_config = detector_config.model
|
||||
|
||||
tensor_input = np.ndarray(
|
||||
(1, config.model.height, config.model.width, 3),
|
||||
dtype=np.uint8,
|
||||
buffer=shm.buf,
|
||||
)
|
||||
tensor_input[:] = test_tensor_input[:]
|
||||
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
||||
|
||||
try:
|
||||
worker = ObjectDetectionWorker(
|
||||
detector_name,
|
||||
detector_config,
|
||||
mp.Value("d", 0.01),
|
||||
mp.Value("d", 0.0),
|
||||
None,
|
||||
)
|
||||
worker.connect()
|
||||
|
||||
client = ClientTestThread(
|
||||
camera_name,
|
||||
test_cfg.model.merged_labelmap,
|
||||
model_config,
|
||||
config.server,
|
||||
tensor_input,
|
||||
timeout=10,
|
||||
)
|
||||
client.start()
|
||||
|
||||
client_id, request = worker.wait_for_request()
|
||||
reply = worker.handle_request(request)
|
||||
worker.send_reply_final(client_id, reply)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
finally:
|
||||
client.join()
|
||||
worker.close()
|
||||
|
||||
self.assertIsNone(
|
||||
np.testing.assert_array_almost_equal(
|
||||
out_np, test_detection_output
|
||||
)
|
||||
)
|
||||
finally:
|
||||
self.detection_broker.stop()
|
||||
for shm in detection_shms.values():
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
|
||||
def test_localdetectorprocess_should_only_create_specified_detector_type(self):
|
||||
for det_type in detectors.api_types:
|
||||
with self.subTest(det_type=det_type):
|
||||
with patch.dict(
|
||||
"frigate.detectors.api_types",
|
||||
{det_type: Mock() for det_type in DetectorTypeEnum},
|
||||
"frigate.detectors.detector_types.api_types",
|
||||
{
|
||||
det_type: create_detector(det_type)
|
||||
for det_type in DetectorTypeEnum
|
||||
},
|
||||
):
|
||||
test_cfg = parse_obj_as(
|
||||
DetectorConfig, ({"type": det_type, "model": {}})
|
||||
DetectorConfig,
|
||||
({"type": det_type, "model": {}, "cameras": ["test"]}),
|
||||
)
|
||||
test_cfg.model.path = "/test/modelpath"
|
||||
test_obj = frigate.object_detection.LocalObjectDetector(
|
||||
detector_config=test_cfg
|
||||
test_obj = ObjectDetectionWorker(
|
||||
detector_name="test", detector_config=test_cfg
|
||||
)
|
||||
|
||||
assert test_obj is not None
|
||||
@ -34,7 +212,7 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
mock_detector.assert_not_called()
|
||||
|
||||
@patch.dict(
|
||||
"frigate.detectors.api_types",
|
||||
"frigate.detectors.detector_types.api_types",
|
||||
{det_type: Mock() for det_type in DetectorTypeEnum},
|
||||
)
|
||||
def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result(self):
|
||||
@ -42,8 +220,11 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
|
||||
TEST_DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32])
|
||||
test_obj_detect = frigate.object_detection.LocalObjectDetector(
|
||||
detector_config=parse_obj_as(DetectorConfig, {"type": "cpu", "model": {}})
|
||||
test_obj_detect = ObjectDetectionWorker(
|
||||
detector_name="test",
|
||||
detector_config=parse_obj_as(
|
||||
DetectorConfig, {"type": "cpu", "model": {}, "cameras": ["test"]}
|
||||
),
|
||||
)
|
||||
|
||||
mock_det_api = mock_cputfl.return_value
|
||||
@ -55,7 +236,7 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
assert test_result is mock_det_api.detect_raw.return_value
|
||||
|
||||
@patch.dict(
|
||||
"frigate.detectors.api_types",
|
||||
"frigate.detectors.detector_types.api_types",
|
||||
{det_type: Mock() for det_type in DetectorTypeEnum},
|
||||
)
|
||||
def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor(
|
||||
@ -66,11 +247,13 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
TEST_DATA = np.zeros((1, 32, 32, 3), np.uint8)
|
||||
TEST_DETECT_RESULT = np.ndarray([1, 2, 4, 8, 16, 32])
|
||||
|
||||
test_cfg = parse_obj_as(DetectorConfig, {"type": "cpu", "model": {}})
|
||||
test_cfg = parse_obj_as(
|
||||
DetectorConfig, {"type": "cpu", "model": {}, "cameras": ["test"]}
|
||||
)
|
||||
test_cfg.model.input_tensor = InputTensorEnum.nchw
|
||||
|
||||
test_obj_detect = frigate.object_detection.LocalObjectDetector(
|
||||
detector_config=test_cfg
|
||||
test_obj_detect = ObjectDetectionWorker(
|
||||
detector_name="test", detector_config=test_cfg
|
||||
)
|
||||
|
||||
mock_det_api = mock_cputfl.return_value
|
||||
@ -87,10 +270,10 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
assert test_result is mock_det_api.detect_raw.return_value
|
||||
|
||||
@patch.dict(
|
||||
"frigate.detectors.api_types",
|
||||
"frigate.detectors.detector_types.api_types",
|
||||
{det_type: Mock() for det_type in DetectorTypeEnum},
|
||||
)
|
||||
@patch("frigate.object_detection.load_labels")
|
||||
@patch("frigate.detectors.detection_worker.load_labels")
|
||||
def test_detect_given_tensor_input_should_return_lfiltered_detections(
|
||||
self, mock_load_labels
|
||||
):
|
||||
@ -115,9 +298,12 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
"label-5",
|
||||
]
|
||||
|
||||
test_cfg = parse_obj_as(DetectorConfig, {"type": "cpu", "model": {}})
|
||||
test_cfg = parse_obj_as(
|
||||
DetectorConfig, {"type": "cpu", "model": {}, "cameras": ["test"]}
|
||||
)
|
||||
test_cfg.model = ModelConfig()
|
||||
test_obj_detect = frigate.object_detection.LocalObjectDetector(
|
||||
test_obj_detect = ObjectDetectionWorker(
|
||||
detector_name="test",
|
||||
detector_config=test_cfg,
|
||||
labels=TEST_LABEL_FILE,
|
||||
)
|
||||
|
||||
@ -3,7 +3,7 @@ from multiprocessing.queues import Queue
|
||||
from multiprocessing.sharedctypes import Synchronized
|
||||
from multiprocessing.context import Process
|
||||
|
||||
from frigate.object_detection import ObjectDetectProcess
|
||||
from frigate.detectors import ObjectDetectProcess
|
||||
|
||||
|
||||
class CameraMetricsTypes(TypedDict):
|
||||
|
||||
@ -14,9 +14,14 @@ import numpy as np
|
||||
import cv2
|
||||
from setproctitle import setproctitle
|
||||
|
||||
from frigate.config import CameraConfig, DetectConfig, PixelFormatEnum
|
||||
from frigate.config import (
|
||||
CameraConfig,
|
||||
DetectConfig,
|
||||
PixelFormatEnum,
|
||||
DetectionServerConfig,
|
||||
)
|
||||
from frigate.const import CACHE_DIR
|
||||
from frigate.object_detection import RemoteObjectDetector
|
||||
from frigate.detectors import ObjectDetectionClient
|
||||
from frigate.log import LogPipe
|
||||
from frigate.motion import MotionDetector
|
||||
from frigate.objects import ObjectTracker
|
||||
@ -405,12 +410,11 @@ def capture_camera(name, config: CameraConfig, process_info):
|
||||
|
||||
|
||||
def track_camera(
|
||||
name,
|
||||
camera_name,
|
||||
config: CameraConfig,
|
||||
model_config,
|
||||
server_config: DetectionServerConfig,
|
||||
labelmap,
|
||||
detection_queue,
|
||||
result_connection,
|
||||
detected_objects_queue,
|
||||
process_info,
|
||||
):
|
||||
@ -422,8 +426,8 @@ def track_camera(
|
||||
signal.signal(signal.SIGTERM, receiveSignal)
|
||||
signal.signal(signal.SIGINT, receiveSignal)
|
||||
|
||||
threading.current_thread().name = f"process:{name}"
|
||||
setproctitle(f"frigate.process:{name}")
|
||||
threading.current_thread().name = f"process:{camera_name}"
|
||||
setproctitle(f"frigate.process:{camera_name}")
|
||||
listen()
|
||||
|
||||
frame_queue = process_info["frame_queue"]
|
||||
@ -444,8 +448,8 @@ def track_camera(
|
||||
motion_threshold,
|
||||
motion_contour_area,
|
||||
)
|
||||
object_detector = RemoteObjectDetector(
|
||||
name, labelmap, detection_queue, result_connection, model_config
|
||||
object_detector = ObjectDetectionClient(
|
||||
camera_name, labelmap, model_config, server_config
|
||||
)
|
||||
|
||||
object_tracker = ObjectTracker(config.detect)
|
||||
@ -453,7 +457,7 @@ def track_camera(
|
||||
frame_manager = SharedMemoryFrameManager()
|
||||
|
||||
process_frames(
|
||||
name,
|
||||
camera_name,
|
||||
frame_queue,
|
||||
frame_shape,
|
||||
model_config,
|
||||
@ -471,7 +475,9 @@ def track_camera(
|
||||
stop_event,
|
||||
)
|
||||
|
||||
logger.info(f"{name}: exiting subprocess")
|
||||
object_detector.cleanup()
|
||||
|
||||
logger.info(f"{camera_name}: exiting subprocess")
|
||||
|
||||
|
||||
def box_overlaps(b1, b2):
|
||||
@ -558,7 +564,7 @@ def process_frames(
|
||||
detect_config: DetectConfig,
|
||||
frame_manager: FrameManager,
|
||||
motion_detector: MotionDetector,
|
||||
object_detector: RemoteObjectDetector,
|
||||
object_detector: ObjectDetectionClient,
|
||||
object_tracker: ObjectTracker,
|
||||
detected_objects_queue: mp.Queue,
|
||||
process_info: dict,
|
||||
|
||||
@ -2,10 +2,8 @@ import datetime
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import os
|
||||
import signal
|
||||
|
||||
from frigate.object_detection import ObjectDetectProcess
|
||||
from frigate.detectors import ObjectDetectProcess
|
||||
from frigate.util import restart_frigate
|
||||
from multiprocessing.synchronize import Event as MpEvent
|
||||
|
||||
|
||||
@ -12,6 +12,8 @@ psutil == 5.9.*
|
||||
pydantic == 1.10.*
|
||||
PyYAML == 6.0
|
||||
pytz == 2022.6
|
||||
pyzmq == 24.0.1
|
||||
majortomo == 0.2.0
|
||||
tzlocal == 4.2
|
||||
types-PyYAML == 6.0.*
|
||||
requests == 2.28.*
|
||||
|
||||
Loading…
Reference in New Issue
Block a user