diff --git a/frigate/app.py b/frigate/app.py index b015260c4..3417bea0f 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -17,16 +17,13 @@ from playhouse.sqliteq import SqliteQueueDatabase from frigate.comms.dispatcher import Communicator, Dispatcher from frigate.comms.mqtt import MqttClient from frigate.comms.ws import WebSocketClient -from frigate.config import FrigateConfig +from frigate.config import FrigateConfig, ServerModeEnum from frigate.const import CACHE_DIR, CLIPS_DIR, RECORD_DIR -from frigate.detectors import ( - ObjectDetectProcess, - ObjectDetectionBroker, - DetectionServerModeEnum, -) +from frigate.detectors import ObjectDetectProcess from frigate.events import EventCleanup, EventProcessor from frigate.http import create_app from frigate.log import log_process, root_configurer +from frigate.majordomo import QueueBroker from frigate.models import Event, Recordings from frigate.object_processing import TrackedObjectProcessor from frigate.output import output_frames @@ -84,7 +81,7 @@ class FrigateApp: user_config = FrigateConfig.parse_file(config_file) self.config = user_config.runtime_config - if self.config.server.mode == DetectionServerModeEnum.DetectionOnly: + if self.config.server.mode == ServerModeEnum.DetectionOnly: return for camera_name in self.config.cameras.keys(): @@ -188,12 +185,17 @@ class FrigateApp: comms.append(self.ws_client) self.dispatcher = Dispatcher(self.config, self.camera_metrics, comms) - def start_detection_broker(self) -> None: + def start_queue_broker(self) -> None: + def detect_no_shm(worker, service_name, body): + in_shm = self.detection_shms[str(service_name, "ascii")] + tensor_input = in_shm.buf + body = body[0:2] + [tensor_input] + return body + bind_urls = [self.config.broker.ipc] + self.config.broker.addresses - self.detection_broker = ObjectDetectionBroker( - bind=bind_urls, shms=self.detection_shms - ) - self.detection_broker.start() + self.queue_broker = QueueBroker(bind=bind_urls) + self.queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm) + self.queue_broker.start() def start_detectors(self) -> None: for name in self.config.cameras.keys(): @@ -372,7 +374,7 @@ class FrigateApp: self.set_environment_vars() self.ensure_dirs() - if self.config.server.mode == DetectionServerModeEnum.DetectionOnly: + if self.config.server.mode == ServerModeEnum.DetectionOnly: self.start_detectors() self.start_watchdog() self.stop_event.wait() @@ -389,7 +391,7 @@ class FrigateApp: self.init_restream() self.start_detectors() - self.start_detection_broker() + self.start_queue_broker() self.start_video_output_processor() self.start_detected_frames_processor() self.start_camera_processors() @@ -416,7 +418,7 @@ class FrigateApp: logger.info(f"Stopping...") self.stop_event.set() - if self.config.server.mode != DetectionServerModeEnum.DetectionOnly: + if self.config.server.mode != ServerModeEnum.DetectionOnly: self.ws_client.stop() self.detected_frames_processor.join() self.event_processor.join() @@ -426,7 +428,7 @@ class FrigateApp: self.stats_emitter.join() self.frigate_watchdog.join() self.db.stop() - self.detection_broker.stop() + self.queue_broker.stop() for detector in self.detectors.values(): detector.stop() diff --git a/frigate/config.py b/frigate/config.py index 578bc0e41..56d5cead5 100644 --- a/frigate/config.py +++ b/frigate/config.py @@ -43,10 +43,8 @@ from frigate.ffmpeg_presets import ( from frigate.detectors import ( PixelFormatEnum, InputTensorEnum, - DetectionServerConfig, ModelConfig, BaseDetectorConfig, - DetectionServerModeEnum, ) from frigate.version import VERSION @@ -812,9 +810,22 @@ def verify_zone_objects_are_tracked(camera_config: CameraConfig) -> None: ) +class ServerModeEnum(str, Enum): + Full = "full" + DetectionOnly = "detection_only" + + +class ServerConfig(BaseModel): + mode: ServerModeEnum = Field(default=ServerModeEnum.Full, title="Server mode") + ipc: str = Field(default="ipc://queue_broker.ipc", title="Broker IPC path") + addresses: List[str] = Field( + default=["tcp://127.0.0.1:5555"], title="Broker TCP addresses" + ) + + class FrigateConfig(FrigateBaseModel): - server: DetectionServerConfig = Field( - default_factory=DetectionServerConfig, title="Server configuration" + server: ServerConfig = Field( + default_factory=ServerConfig, title="Server configuration" ) mqtt: Optional[MqttConfig] = Field(title="MQTT Configuration.") database: DatabaseConfig = Field( @@ -913,7 +924,7 @@ class FrigateConfig(FrigateBaseModel): config.detectors[key] = detector_config - if config.server.mode == DetectionServerModeEnum.DetectionOnly: + if config.server.mode == ServerModeEnum.DetectionOnly: return config # MQTT password substitution @@ -1030,8 +1041,8 @@ class FrigateConfig(FrigateBaseModel): server_config = values.get("server", None) if ( server_config is not None - and server_config.get("mode", DetectionServerModeEnum.Full) - == DetectionServerModeEnum.DetectionOnly + and server_config.get("mode", ServerModeEnum.Full) + == ServerModeEnum.DetectionOnly ): return values diff --git a/frigate/detectors/__init__.py b/frigate/detectors/__init__.py index b7277bf4a..0015a8eab 100644 --- a/frigate/detectors/__init__.py +++ b/frigate/detectors/__init__.py @@ -1,12 +1,9 @@ from .detection_api import DetectionApi -from .detection_broker import ObjectDetectionBroker from .detector_config import ( PixelFormatEnum, InputTensorEnum, ModelConfig, BaseDetectorConfig, - DetectionServerConfig, - DetectionServerModeEnum, ) from .detection_client import ObjectDetectionClient from .detector_types import DetectorTypeEnum, api_types, create_detector diff --git a/frigate/detectors/detection_broker.py b/frigate/detectors/detection_broker.py deleted file mode 100644 index bbd79dc16..000000000 --- a/frigate/detectors/detection_broker.py +++ /dev/null @@ -1,89 +0,0 @@ -import signal -import threading -import zmq -from multiprocessing.shared_memory import SharedMemory -from typing import Union, List -from majortomo import error, protocol -from majortomo.config import DEFAULT_BIND_URL -from majortomo.broker import Broker - - -READY_SHM = b"\007" - - -class ObjectDetectionBroker(Broker): - def __init__( - self, - bind: Union[str, List[str]] = DEFAULT_BIND_URL, - shms: dict[str, SharedMemory] = {}, - heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL, - heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT, - busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT, - zmq_context=None, - ): - protocol.Message.ALLOWED_COMMANDS[protocol.WORKER_HEADER].add(READY_SHM) - - super().__init__( - self, - heartbeat_interval=heartbeat_interval, - heartbeat_timeout=heartbeat_timeout, - busy_worker_timeout=busy_worker_timeout, - zmq_context=zmq_context, - ) - self.shm_workers = set() - self.shms = shms - self._bind_urls = [bind] if not isinstance(bind, list) else bind - self.broker_thread: threading.Thread = None - - def bind(self): - """Bind the ZMQ socket""" - if self._socket: - raise error.StateError("Socket is already bound") - - self._socket = self._context.socket(zmq.ROUTER) - self._socket.rcvtimeo = int(self._heartbeat_interval * 1000) - for bind_url in self._bind_urls: - self._socket.bind(bind_url) - self._log.info("Broker listening on %s", bind_url) - - def close(self): - if self._socket is None: - return - for bind_url in self._bind_urls: - self._socket.disconnect(bind_url) - self._socket.close() - self._socket = None - self._log.info("Broker socket closing") - - def _handle_worker_message(self, message): - if message.command == READY_SHM: - self.shm_workers.add(message.client) - self._handle_worker_ready(message.client, message.message[0]) - else: - super()._handle_worker_message(message) - - def _purge_expired_workers(self): - self.shm_workers.intersection_update(self._services._workers.keys()) - super()._purge_expired_workers() - - def _send_to_worker(self, worker_id, command, body=None): - if ( - worker_id not in self.shm_workers - and command == protocol.REQUEST - and body is not None - ): - service_name = body[2] - in_shm = self.shms[str(service_name, "ascii")] - tensor_input = in_shm.buf - body = body[0:2] + [tensor_input] - super()._send_to_worker(worker_id, command, body) - - def start(self): - self.broker_thread = threading.Thread(target=self.run) - self.broker_thread.start() - - def stop(self): - super().stop() - if self.broker_thread is not None: - self.broker_thread.join() - self.broker_thread = None diff --git a/frigate/detectors/detection_client.py b/frigate/detectors/detection_client.py index fa6ca2902..73d909e81 100644 --- a/frigate/detectors/detection_client.py +++ b/frigate/detectors/detection_client.py @@ -3,7 +3,7 @@ import multiprocessing as mp from multiprocessing.shared_memory import SharedMemory from majortomo import Client from frigate.util import EventsPerSecond -from .detector_config import ModelConfig, DetectionServerConfig +from .detector_config import ModelConfig class ObjectDetectionClient: @@ -12,7 +12,7 @@ class ObjectDetectionClient: camera_name: str, labels, model_config: ModelConfig, - server_config: DetectionServerConfig, + server_address: str, timeout=None, ): self.camera_name = camera_name @@ -29,7 +29,7 @@ class ObjectDetectionClient: self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf) self.timeout = timeout - self.detection_client = Client(server_config.ipc) + self.detection_client = Client(server_address) self.detection_client.connect() def detect(self, tensor_input, threshold=0.4): diff --git a/frigate/detectors/detection_worker.py b/frigate/detectors/detection_worker.py index 4b767fc64..d806cca89 100644 --- a/frigate/detectors/detection_worker.py +++ b/frigate/detectors/detection_worker.py @@ -1,28 +1,23 @@ import datetime import logging import os -import signal import threading import numpy as np import multiprocessing as mp from multiprocessing.shared_memory import SharedMemory -from majortomo import Worker, WorkerRequestsIterator, error, protocol -from majortomo.util import TextOrBytes, text_to_ascii_bytes from typing import List +from frigate.majordomo import QueueWorker from frigate.util import listen, EventsPerSecond, load_labels from .detector_config import InputTensorEnum, BaseDetectorConfig from .detector_types import create_detector from setproctitle import setproctitle -DEFAULT_ZMQ_LINGER = 2500 -READY_SHM = b"\007" - logger = logging.getLogger(__name__) -class ObjectDetectionWorker(Worker): +class ObjectDetectionWorker(QueueWorker): def __init__( self, detector_name: str, @@ -30,23 +25,13 @@ class ObjectDetectionWorker(Worker): avg_inference_speed: mp.Value = mp.Value("d", 0.01), detection_start: mp.Value = mp.Value("d", 0.00), labels=None, - heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL, - heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT, - zmq_context=None, - zmq_linger=DEFAULT_ZMQ_LINGER, + stop_event: mp.Event = None, ): - self.broker_url = detector_config.address - self.service_names = [ - text_to_ascii_bytes(service_name) - for service_name in detector_config.cameras - ] super().__init__( - self.broker_url, - b"", - heartbeat_interval, - heartbeat_timeout, - zmq_context, - zmq_linger, + broker_url=detector_config.address, + service_names=detector_config.cameras, + handler_name="DETECT_NO_SHM" if not detector_config.shared_memory else None, + stop_event=stop_event, ) self.detector_name = detector_name self.detector_config = detector_config @@ -80,12 +65,7 @@ class ObjectDetectionWorker(Worker): self.detect_api = create_detector(self.detector_config) - def _send_ready(self): - command = READY_SHM if self.detector_config.shared_memory else protocol.READY - for service_name in self.service_names: - self._send(command, service_name) - - def handle_request(self, request): + def handle_request(self, client_id: bytes, request: List[bytes]): self.detection_start.value = datetime.datetime.now().timestamp() # expected request format: @@ -100,6 +80,7 @@ class ObjectDetectionWorker(Worker): self.detection_start.value = 0.0 return frames frames.append(detections.tobytes()) + elif len(request) == 3: camera_name = request[0].decode("ascii") shm_shape = ( @@ -186,17 +167,7 @@ def run_detector( detection_start, labels, ) - - def receiveSignal(signalNumber, frame): - worker.close() - - signal.signal(signal.SIGTERM, receiveSignal) - signal.signal(signal.SIGINT, receiveSignal) - - worker_iter = WorkerRequestsIterator(worker) - for request in worker_iter: - reply = worker.handle_request(request) - worker_iter.send_reply_final(reply) + worker.start() class ObjectDetectProcess: diff --git a/frigate/detectors/detector_config.py b/frigate/detectors/detector_config.py index 1630fc057..f22a10a45 100644 --- a/frigate/detectors/detector_config.py +++ b/frigate/detectors/detector_config.py @@ -79,18 +79,3 @@ class BaseDetectorConfig(BaseModel): class Config: extra = Extra.allow arbitrary_types_allowed = True - - -class DetectionServerModeEnum(str, Enum): - Full = "full" - DetectionOnly = "detection_only" - - -class DetectionServerConfig(BaseModel): - mode: DetectionServerModeEnum = Field( - default=DetectionServerModeEnum.Full, title="Server mode" - ) - ipc: str = Field(default="ipc://detection_broker.ipc", title="Broker IPC path") - addresses: List[str] = Field( - default=["tcp://127.0.0.1:5555"], title="Broker TCP addresses" - ) diff --git a/frigate/majordomo.py b/frigate/majordomo.py new file mode 100644 index 000000000..b842f28ac --- /dev/null +++ b/frigate/majordomo.py @@ -0,0 +1,323 @@ +"""Majordomo / `MDP/0.2 `_ broker implementation. + +Extends the implementation of https://github.com/shoppimon/majortomo +""" + +import signal +import threading +import time +import zmq +import multiprocessing as mp +from multiprocessing.shared_memory import SharedMemory +from typing import ( + DefaultDict, + Dict, + Generator, + List, + Optional, + Tuple, + Union, +) # noqa: F401 +from majortomo import error, protocol +from majortomo.config import DEFAULT_BIND_URL +from majortomo.broker import ( + Broker, + Worker as BrokerWorker, + ServicesContainer, + id_to_int, +) +from majortomo.util import TextOrBytes, text_to_ascii_bytes +from majortomo.worker import DEFAULT_ZMQ_LINGER, Worker as MdpWorker + + +class QueueServicesContainer(ServicesContainer): + def __init__(self, busy_workers_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT): + super().__init__(busy_workers_timeout) + + def dequeue_pending(self): + # type: () -> Generator[Tuple[List[bytes], BrokerWorker, bytes], None, None] + """Pop ready message-worker pairs from all service queues so they can be dispatched""" + for service_name, service in self._services.items(): + for message, worker, client in service.dequeue_pending(): + yield message, worker, client, service_name + + def add_worker( + self, worker_id: bytes, service: bytes, expire_at: float, next_heartbeat: float + ): + """Add a worker to the list of available workers""" + if worker_id in self._workers: + worker = self._workers[worker_id] + if service in worker.service: + raise error.StateError( + f"Worker '{id_to_int(worker_id)}' has already sent READY message for service '{service.decode('ascii')}'" + ) + else: + worker.service.add(service) + else: + worker = BrokerWorker(worker_id, set([service]), expire_at, next_heartbeat) + self._workers[worker.id] = worker + self._services[service].add_worker(worker) + return worker + + def set_worker_available(self, worker_id, expire_at, next_heartbeat): + # type: (bytes, float, float) -> BrokerWorker + """Mark a worker that was busy processing a request as back and available again""" + if worker_id not in self._busy_workers: + raise error.StateError( + "Worker id {} is not previously known or has expired".format( + id_to_int(worker_id) + ) + ) + worker = self._busy_workers.pop(worker_id) + worker.is_busy = False + worker.expire_at = expire_at + worker.next_heartbeat = next_heartbeat + self._workers[worker_id] = worker + for service in worker.service: + self._services[service].add_worker(worker) + return worker + + def remove_worker(self, worker_id): + # type: (bytes) -> None + """Remove a worker from the list of known workers""" + try: + worker = self._workers[worker_id] + for service_name in worker.service: + service = self._services[service_name] + if worker_id in service._workers: + service.remove_worker(worker_id) + del self._workers[worker_id] + except KeyError: + try: + del self._busy_workers[worker_id] + except KeyError: + raise error.StateError(f"Worker id {id_to_int(worker_id)} is not known") + + +class QueueBroker(Broker): + def __init__( + self, + bind: Union[str, List[str]] = [DEFAULT_BIND_URL], + heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL, + heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT, + busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT, + zmq_context=None, + ): + super().__init__( + heartbeat_interval=heartbeat_interval, + heartbeat_timeout=heartbeat_timeout, + busy_worker_timeout=busy_worker_timeout, + zmq_context=zmq_context, + ) + self._services = QueueServicesContainer(busy_worker_timeout) + self._bind_urls = [bind] if not isinstance(bind, list) else bind + self.broker_thread: threading.Thread = None + self.request_handlers: Dict[str, object] = {} + + def bind(self): + """Bind the ZMQ socket""" + if self._socket: + raise error.StateError("Socket is already bound") + + self._socket = self._context.socket(zmq.ROUTER) + self._socket.rcvtimeo = int(self._heartbeat_interval * 1000) + for bind_url in self._bind_urls: + self._socket.bind(bind_url) + self._log.info("Broker listening on %s", bind_url) + + def close(self): + if self._socket is None: + return + for bind_url in self._bind_urls: + self._socket.disconnect(bind_url) + self._socket.close() + self._socket = None + self._log.info("Broker socket closing") + + def start(self): + self.broker_thread = threading.Thread(target=self.run) + self.broker_thread.start() + + def stop(self): + super().stop() + if self.broker_thread is not None: + self.broker_thread.join() + self.broker_thread = None + + def _handle_worker_message(self, message): + if message.command == protocol.READY: + self._handle_worker_ready(message) + else: + super()._handle_worker_message(message) + + def _handle_worker_ready(self, message): + # type: (protocol.Message) -> BrokerWorker + worker_id = message.client + service = message.message[0] + self._log.info("Got READY from worker: %d", id_to_int(worker_id)) + now = time.time() + expire_at = now + self._heartbeat_timeout + next_heartbeat = now + self._heartbeat_interval + worker = self._services.add_worker( + worker_id, service, expire_at, next_heartbeat + ) + worker.request_handler = None + worker.request_params = [] + if len(message.message) > 1: + worker.request_handler = str(message.message[1], "ascii") + if len(message.message) > 2: + worker.request_params = [str(m, "ascii") for m in message.message[2:]] + return worker + + def _dispatch_queued_messages(self): + """Dispatch all queued messages to available workers""" + expire_at = time.time() + self._busy_worker_timeout + for message, worker, client, service_name in self._services.dequeue_pending(): + body = [client, b""] + message + body = self.on_worker_request(worker, service_name, body) + self._send_to_worker(worker.id, protocol.REQUEST, body) + self._services.set_worker_busy(worker.id, expire_at=expire_at) + + def register_request_handler(self, handler_name: str, handler): + self.request_handlers[handler_name] = handler + + def on_worker_request( + self, worker: BrokerWorker, service_name: bytes, body: List[bytes] + ): + if worker.request_handler in self.request_handlers: + handler = self.request_handlers[worker.request_handler] + body = handler(worker, service_name, body) + return body + + +class MultiBindableBroker: + def __init__( + self, + bind: Union[str, List[str]] = [DEFAULT_BIND_URL], + shms: dict[str, SharedMemory] = {}, + ): + super().__init__(bind) + self.shms = shms + + def on_worker_request( + self, worker: BrokerWorker, service_name: bytes, body: List[bytes] + ) -> List[bytes]: + if "DETECT_NO_SHM" == worker.request_handler: + in_shm = self.shms[str(service_name, "ascii")] + tensor_input = in_shm.buf + body = body[0:2] + [tensor_input] + return body + + +class QueueWorker(MdpWorker): + def __init__( + self, + broker_url: str, + service_names: List[TextOrBytes], + heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL, + heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT, + zmq_context=None, + zmq_linger=DEFAULT_ZMQ_LINGER, + handler_name: TextOrBytes = None, + handler_params: List[TextOrBytes] = [], + stop_event: mp.Event = None, + ): + super().__init__( + broker_url, + b"", + heartbeat_interval, + heartbeat_timeout, + zmq_context, + zmq_linger, + ) + + self.service_names = [ + text_to_ascii_bytes(service_name) for service_name in service_names + ] + self.ready_params = [ + text_to_ascii_bytes(rp) + for rp in ( + ([handler_name] if handler_name is not None else []) + handler_params + ) + ] + self.stop_event = stop_event or mp.Event() + + def _send_ready(self): + for service_name in self.service_names: + self._send(protocol.READY, service_name, *self.ready_params) + + def handle_request(client_id: bytes, request: List[bytes]) -> List[bytes]: + return request + + def start(self): + if not self.is_connected: + self.connect() + + def signal_handler(sig_num, _): + self.stop() + + for sig_num in (signal.SIGINT, signal.SIGTERM): + signal.signal(sig_num, signal_handler) + + while not self.stop_event.is_set(): + try: + client_id, request = self.wait_for_request() + reply = self.handle_request(client_id, request) + self.send_reply_final(reply) + except error.ProtocolError as e: + self._log.warning("Protocol error: %s, dropping request", str(e)) + continue + except error.Disconnected: + self._log.info("Worker disconnected") + break + + self.close() + + def stop(self): + self.stop_event.set() + + def _receive(self): + # type: () -> Tuple[bytes, List[bytes]] + """Poll on the socket until a command is received + + Will handle timeouts and heartbeats internally without returning + """ + while not self.stop_event.is_set(): + if self._socket is None: + raise error.Disconnected("Worker is disconnected") + + self._check_send_heartbeat() + poll_timeout = self._get_poll_timeout() + + try: + socks = dict(self._poller.poll(timeout=poll_timeout)) + except zmq.error.ZMQError: + # Probably connection was explicitly closed + if self._socket is None: + continue + raise + + if socks.get(self._socket) == zmq.POLLIN: + message = self._socket.recv_multipart() + self._log.debug("Got message of %d frames", len(message)) + else: + self._log.debug("Receive timed out after %d ms", poll_timeout) + if (time.time() - self._last_broker_hb) > self._heartbeat_timeout: + # We're not connected anymore? + self._log.info( + "Got no heartbeat in %d sec, disconnecting and reconnecting socket", + self._heartbeat_timeout, + ) + self.connect(reconnect=True) + continue + + command, frames = self._verify_message(message) + self._last_broker_hb = time.time() + + if command == protocol.HEARTBEAT: + self._log.debug("Got heartbeat message from broker") + continue + + return command, frames + + return protocol.DISCONNECT, [] diff --git a/frigate/test/test_object_detector.py b/frigate/test/test_object_detector.py index 4a4be93c2..274fad2df 100644 --- a/frigate/test/test_object_detector.py +++ b/frigate/test/test_object_detector.py @@ -10,10 +10,10 @@ from pydantic import parse_obj_as from frigate.config import FrigateConfig, DetectorConfig, InputTensorEnum, ModelConfig from frigate.detectors import ( DetectorTypeEnum, - ObjectDetectionBroker, ObjectDetectionClient, ObjectDetectionWorker, ) +from frigate.majordomo import QueueBroker from frigate.util import deep_merge import frigate.detectors.detector_types as detectors @@ -34,6 +34,60 @@ def create_detector(det_type): return api +def start_broker(ipc_address, tcp_address, camera_names): + detection_shms: dict[str, SharedMemory] = {} + + def detect_no_shm(worker, service_name, body): + in_shm = detection_shms[str(service_name, "ascii")] + tensor_input = in_shm.buf + body = body[0:2] + [tensor_input] + return body + + queue_broker = QueueBroker(bind=[ipc_address, tcp_address]) + queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm) + queue_broker.start() + + for camera_name in camera_names: + shm_name = camera_name + out_shm_name = f"out-{camera_name}" + try: + shm = SharedMemory(name=shm_name, size=512 * 512 * 3, create=True) + except FileExistsError: + shm = SharedMemory(name=shm_name, create=False) + detection_shms[shm_name] = shm + try: + out_shm = SharedMemory(name=out_shm_name, size=20 * 6 * 4, create=True) + except FileExistsError: + out_shm = SharedMemory(name=out_shm_name, create=False) + detection_shms[out_shm_name] = out_shm + + return queue_broker, detection_shms + + +class WorkerTestThread(threading.Thread): + def __init__(self, detector_name, detector_config, stop_event): + super().__init__() + self.detector_name = detector_name + self.detector_config = detector_config + self.stop_event = stop_event + + def run(self): + worker = ObjectDetectionWorker( + self.detector_name, + self.detector_config, + mp.Value("d", 0.01), + mp.Value("d", 0.0), + None, + self.stop_event, + ) + worker.connect() + if not self.stop_event.is_set(): + client_id, request = worker.wait_for_request() + reply = worker.handle_request(client_id, request) + worker.send_reply_final(client_id, reply) + worker.close() + + class TestLocalObjectDetector(unittest.TestCase): @patch.dict( "frigate.detectors.detector_types.api_types", @@ -41,7 +95,7 @@ class TestLocalObjectDetector(unittest.TestCase): ) def test_socket_client_broker_worker(self): detector_name = "cpu" - ipc_address = "ipc://detection_broker.ipc" + ipc_address = "ipc://queue_broker.ipc" tcp_address = "tcp://127.0.0.1:5555" detector = {"type": "cpu"} @@ -56,60 +110,11 @@ class TestLocalObjectDetector(unittest.TestCase): "tcp_no_shm": {"address": tcp_address, "cameras": ["tcp_no_shm"]}, } - class ClientTestThread(threading.Thread): - def __init__( - self, - camera_name, - labelmap, - model_config, - server_config, - tensor_input, - timeout, - ): - super().__init__() - self.camera_name = camera_name - self.labelmap = labelmap - self.model_config = model_config - self.server_config = server_config - self.tensor_input = tensor_input - self.timeout = timeout - - def run(self): - object_detector = ObjectDetectionClient( - self.camera_name, - self.labelmap, - self.model_config, - self.server_config, - timeout=self.timeout, - ) - try: - object_detector.detect(self.tensor_input) - finally: - object_detector.cleanup() - try: - detection_shms: dict[str, SharedMemory] = {} - for camera_name in test_cases.keys(): - shm_name = camera_name - out_shm_name = f"out-{camera_name}" - try: - shm = SharedMemory(name=shm_name, size=512 * 512 * 3, create=True) - except FileExistsError: - shm = SharedMemory(name=shm_name) - detection_shms[shm_name] = shm - try: - out_shm = SharedMemory( - name=out_shm_name, size=20 * 6 * 4, create=True - ) - except FileExistsError: - out_shm = SharedMemory(name=out_shm_name) - detection_shms[out_shm_name] = out_shm - - self.detection_broker = ObjectDetectionBroker( - bind=[ipc_address, tcp_address], - shms=detection_shms, + queue_broker, detection_shms = None, None + queue_broker, detection_shms = start_broker( + ipc_address, tcp_address, test_cases.keys() ) - self.detection_broker.start() for test_case in test_cases.keys(): with self.subTest(test_case=test_case): @@ -136,6 +141,7 @@ class TestLocalObjectDetector(unittest.TestCase): config = test_cfg.runtime_config detector_config = config.detectors[detector_name] model_config = detector_config.model + stop_event = mp.Event() tensor_input = np.ndarray( (1, config.model.height, config.model.width, 3), @@ -146,33 +152,26 @@ class TestLocalObjectDetector(unittest.TestCase): out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf) try: - worker = ObjectDetectionWorker( - detector_name, - detector_config, - mp.Value("d", 0.01), - mp.Value("d", 0.0), - None, + client = None + worker = WorkerTestThread( + detector_name, detector_config, stop_event ) - worker.connect() + worker.start() - client = ClientTestThread( + client = ObjectDetectionClient( camera_name, test_cfg.model.merged_labelmap, model_config, - config.server, - tensor_input, + config.server.ipc, timeout=10, ) - client.start() - - client_id, request = worker.wait_for_request() - reply = worker.handle_request(request) - worker.send_reply_final(client_id, reply) - except Exception as ex: - print(ex) + client.detect(tensor_input) finally: - client.join() - worker.close() + stop_event.set() + if client is not None: + client.cleanup() + if worker is not None: + worker.join() self.assertIsNone( np.testing.assert_array_almost_equal( @@ -180,10 +179,11 @@ class TestLocalObjectDetector(unittest.TestCase): ) ) finally: - self.detection_broker.stop() - for shm in detection_shms.values(): - shm.close() - shm.unlink() + if queue_broker is not None: + queue_broker.stop() + for shm in detection_shms.values(): + shm.close() + shm.unlink() def test_localdetectorprocess_should_only_create_specified_detector_type(self): for det_type in detectors.api_types: diff --git a/frigate/video.py b/frigate/video.py index 2b5c919d2..67845b5b5 100755 --- a/frigate/video.py +++ b/frigate/video.py @@ -14,12 +14,7 @@ import numpy as np import cv2 from setproctitle import setproctitle -from frigate.config import ( - CameraConfig, - DetectConfig, - PixelFormatEnum, - DetectionServerConfig, -) +from frigate.config import CameraConfig, DetectConfig, PixelFormatEnum, ServerConfig from frigate.const import CACHE_DIR from frigate.detectors import ObjectDetectionClient from frigate.log import LogPipe @@ -413,7 +408,7 @@ def track_camera( camera_name, config: CameraConfig, model_config, - server_config: DetectionServerConfig, + server_config: ServerConfig, labelmap, detected_objects_queue, process_info, @@ -449,7 +444,7 @@ def track_camera( motion_contour_area, ) object_detector = ObjectDetectionClient( - camera_name, labelmap, model_config, server_config + camera_name, labelmap, model_config, server_config.ipc ) object_tracker = ObjectTracker(config.detect)