added generic majordomo broker

This commit is contained in:
Dennis George 2022-12-27 15:57:23 -06:00
parent 9d21d71282
commit 03140c4687
10 changed files with 452 additions and 257 deletions

View File

@ -17,16 +17,13 @@ from playhouse.sqliteq import SqliteQueueDatabase
from frigate.comms.dispatcher import Communicator, Dispatcher from frigate.comms.dispatcher import Communicator, Dispatcher
from frigate.comms.mqtt import MqttClient from frigate.comms.mqtt import MqttClient
from frigate.comms.ws import WebSocketClient from frigate.comms.ws import WebSocketClient
from frigate.config import FrigateConfig from frigate.config import FrigateConfig, ServerModeEnum
from frigate.const import CACHE_DIR, CLIPS_DIR, RECORD_DIR from frigate.const import CACHE_DIR, CLIPS_DIR, RECORD_DIR
from frigate.detectors import ( from frigate.detectors import ObjectDetectProcess
ObjectDetectProcess,
ObjectDetectionBroker,
DetectionServerModeEnum,
)
from frigate.events import EventCleanup, EventProcessor from frigate.events import EventCleanup, EventProcessor
from frigate.http import create_app from frigate.http import create_app
from frigate.log import log_process, root_configurer from frigate.log import log_process, root_configurer
from frigate.majordomo import QueueBroker
from frigate.models import Event, Recordings from frigate.models import Event, Recordings
from frigate.object_processing import TrackedObjectProcessor from frigate.object_processing import TrackedObjectProcessor
from frigate.output import output_frames from frigate.output import output_frames
@ -84,7 +81,7 @@ class FrigateApp:
user_config = FrigateConfig.parse_file(config_file) user_config = FrigateConfig.parse_file(config_file)
self.config = user_config.runtime_config self.config = user_config.runtime_config
if self.config.server.mode == DetectionServerModeEnum.DetectionOnly: if self.config.server.mode == ServerModeEnum.DetectionOnly:
return return
for camera_name in self.config.cameras.keys(): for camera_name in self.config.cameras.keys():
@ -188,12 +185,17 @@ class FrigateApp:
comms.append(self.ws_client) comms.append(self.ws_client)
self.dispatcher = Dispatcher(self.config, self.camera_metrics, comms) self.dispatcher = Dispatcher(self.config, self.camera_metrics, comms)
def start_detection_broker(self) -> None: def start_queue_broker(self) -> None:
def detect_no_shm(worker, service_name, body):
in_shm = self.detection_shms[str(service_name, "ascii")]
tensor_input = in_shm.buf
body = body[0:2] + [tensor_input]
return body
bind_urls = [self.config.broker.ipc] + self.config.broker.addresses bind_urls = [self.config.broker.ipc] + self.config.broker.addresses
self.detection_broker = ObjectDetectionBroker( self.queue_broker = QueueBroker(bind=bind_urls)
bind=bind_urls, shms=self.detection_shms self.queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm)
) self.queue_broker.start()
self.detection_broker.start()
def start_detectors(self) -> None: def start_detectors(self) -> None:
for name in self.config.cameras.keys(): for name in self.config.cameras.keys():
@ -372,7 +374,7 @@ class FrigateApp:
self.set_environment_vars() self.set_environment_vars()
self.ensure_dirs() self.ensure_dirs()
if self.config.server.mode == DetectionServerModeEnum.DetectionOnly: if self.config.server.mode == ServerModeEnum.DetectionOnly:
self.start_detectors() self.start_detectors()
self.start_watchdog() self.start_watchdog()
self.stop_event.wait() self.stop_event.wait()
@ -389,7 +391,7 @@ class FrigateApp:
self.init_restream() self.init_restream()
self.start_detectors() self.start_detectors()
self.start_detection_broker() self.start_queue_broker()
self.start_video_output_processor() self.start_video_output_processor()
self.start_detected_frames_processor() self.start_detected_frames_processor()
self.start_camera_processors() self.start_camera_processors()
@ -416,7 +418,7 @@ class FrigateApp:
logger.info(f"Stopping...") logger.info(f"Stopping...")
self.stop_event.set() self.stop_event.set()
if self.config.server.mode != DetectionServerModeEnum.DetectionOnly: if self.config.server.mode != ServerModeEnum.DetectionOnly:
self.ws_client.stop() self.ws_client.stop()
self.detected_frames_processor.join() self.detected_frames_processor.join()
self.event_processor.join() self.event_processor.join()
@ -426,7 +428,7 @@ class FrigateApp:
self.stats_emitter.join() self.stats_emitter.join()
self.frigate_watchdog.join() self.frigate_watchdog.join()
self.db.stop() self.db.stop()
self.detection_broker.stop() self.queue_broker.stop()
for detector in self.detectors.values(): for detector in self.detectors.values():
detector.stop() detector.stop()

View File

@ -43,10 +43,8 @@ from frigate.ffmpeg_presets import (
from frigate.detectors import ( from frigate.detectors import (
PixelFormatEnum, PixelFormatEnum,
InputTensorEnum, InputTensorEnum,
DetectionServerConfig,
ModelConfig, ModelConfig,
BaseDetectorConfig, BaseDetectorConfig,
DetectionServerModeEnum,
) )
from frigate.version import VERSION from frigate.version import VERSION
@ -812,9 +810,22 @@ def verify_zone_objects_are_tracked(camera_config: CameraConfig) -> None:
) )
class ServerModeEnum(str, Enum):
Full = "full"
DetectionOnly = "detection_only"
class ServerConfig(BaseModel):
mode: ServerModeEnum = Field(default=ServerModeEnum.Full, title="Server mode")
ipc: str = Field(default="ipc://queue_broker.ipc", title="Broker IPC path")
addresses: List[str] = Field(
default=["tcp://127.0.0.1:5555"], title="Broker TCP addresses"
)
class FrigateConfig(FrigateBaseModel): class FrigateConfig(FrigateBaseModel):
server: DetectionServerConfig = Field( server: ServerConfig = Field(
default_factory=DetectionServerConfig, title="Server configuration" default_factory=ServerConfig, title="Server configuration"
) )
mqtt: Optional[MqttConfig] = Field(title="MQTT Configuration.") mqtt: Optional[MqttConfig] = Field(title="MQTT Configuration.")
database: DatabaseConfig = Field( database: DatabaseConfig = Field(
@ -913,7 +924,7 @@ class FrigateConfig(FrigateBaseModel):
config.detectors[key] = detector_config config.detectors[key] = detector_config
if config.server.mode == DetectionServerModeEnum.DetectionOnly: if config.server.mode == ServerModeEnum.DetectionOnly:
return config return config
# MQTT password substitution # MQTT password substitution
@ -1030,8 +1041,8 @@ class FrigateConfig(FrigateBaseModel):
server_config = values.get("server", None) server_config = values.get("server", None)
if ( if (
server_config is not None server_config is not None
and server_config.get("mode", DetectionServerModeEnum.Full) and server_config.get("mode", ServerModeEnum.Full)
== DetectionServerModeEnum.DetectionOnly == ServerModeEnum.DetectionOnly
): ):
return values return values

View File

@ -1,12 +1,9 @@
from .detection_api import DetectionApi from .detection_api import DetectionApi
from .detection_broker import ObjectDetectionBroker
from .detector_config import ( from .detector_config import (
PixelFormatEnum, PixelFormatEnum,
InputTensorEnum, InputTensorEnum,
ModelConfig, ModelConfig,
BaseDetectorConfig, BaseDetectorConfig,
DetectionServerConfig,
DetectionServerModeEnum,
) )
from .detection_client import ObjectDetectionClient from .detection_client import ObjectDetectionClient
from .detector_types import DetectorTypeEnum, api_types, create_detector from .detector_types import DetectorTypeEnum, api_types, create_detector

View File

@ -1,89 +0,0 @@
import signal
import threading
import zmq
from multiprocessing.shared_memory import SharedMemory
from typing import Union, List
from majortomo import error, protocol
from majortomo.config import DEFAULT_BIND_URL
from majortomo.broker import Broker
READY_SHM = b"\007"
class ObjectDetectionBroker(Broker):
def __init__(
self,
bind: Union[str, List[str]] = DEFAULT_BIND_URL,
shms: dict[str, SharedMemory] = {},
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT,
zmq_context=None,
):
protocol.Message.ALLOWED_COMMANDS[protocol.WORKER_HEADER].add(READY_SHM)
super().__init__(
self,
heartbeat_interval=heartbeat_interval,
heartbeat_timeout=heartbeat_timeout,
busy_worker_timeout=busy_worker_timeout,
zmq_context=zmq_context,
)
self.shm_workers = set()
self.shms = shms
self._bind_urls = [bind] if not isinstance(bind, list) else bind
self.broker_thread: threading.Thread = None
def bind(self):
"""Bind the ZMQ socket"""
if self._socket:
raise error.StateError("Socket is already bound")
self._socket = self._context.socket(zmq.ROUTER)
self._socket.rcvtimeo = int(self._heartbeat_interval * 1000)
for bind_url in self._bind_urls:
self._socket.bind(bind_url)
self._log.info("Broker listening on %s", bind_url)
def close(self):
if self._socket is None:
return
for bind_url in self._bind_urls:
self._socket.disconnect(bind_url)
self._socket.close()
self._socket = None
self._log.info("Broker socket closing")
def _handle_worker_message(self, message):
if message.command == READY_SHM:
self.shm_workers.add(message.client)
self._handle_worker_ready(message.client, message.message[0])
else:
super()._handle_worker_message(message)
def _purge_expired_workers(self):
self.shm_workers.intersection_update(self._services._workers.keys())
super()._purge_expired_workers()
def _send_to_worker(self, worker_id, command, body=None):
if (
worker_id not in self.shm_workers
and command == protocol.REQUEST
and body is not None
):
service_name = body[2]
in_shm = self.shms[str(service_name, "ascii")]
tensor_input = in_shm.buf
body = body[0:2] + [tensor_input]
super()._send_to_worker(worker_id, command, body)
def start(self):
self.broker_thread = threading.Thread(target=self.run)
self.broker_thread.start()
def stop(self):
super().stop()
if self.broker_thread is not None:
self.broker_thread.join()
self.broker_thread = None

View File

@ -3,7 +3,7 @@ import multiprocessing as mp
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
from majortomo import Client from majortomo import Client
from frigate.util import EventsPerSecond from frigate.util import EventsPerSecond
from .detector_config import ModelConfig, DetectionServerConfig from .detector_config import ModelConfig
class ObjectDetectionClient: class ObjectDetectionClient:
@ -12,7 +12,7 @@ class ObjectDetectionClient:
camera_name: str, camera_name: str,
labels, labels,
model_config: ModelConfig, model_config: ModelConfig,
server_config: DetectionServerConfig, server_address: str,
timeout=None, timeout=None,
): ):
self.camera_name = camera_name self.camera_name = camera_name
@ -29,7 +29,7 @@ class ObjectDetectionClient:
self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf) self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf)
self.timeout = timeout self.timeout = timeout
self.detection_client = Client(server_config.ipc) self.detection_client = Client(server_address)
self.detection_client.connect() self.detection_client.connect()
def detect(self, tensor_input, threshold=0.4): def detect(self, tensor_input, threshold=0.4):

View File

@ -1,28 +1,23 @@
import datetime import datetime
import logging import logging
import os import os
import signal
import threading import threading
import numpy as np import numpy as np
import multiprocessing as mp import multiprocessing as mp
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
from majortomo import Worker, WorkerRequestsIterator, error, protocol
from majortomo.util import TextOrBytes, text_to_ascii_bytes
from typing import List from typing import List
from frigate.majordomo import QueueWorker
from frigate.util import listen, EventsPerSecond, load_labels from frigate.util import listen, EventsPerSecond, load_labels
from .detector_config import InputTensorEnum, BaseDetectorConfig from .detector_config import InputTensorEnum, BaseDetectorConfig
from .detector_types import create_detector from .detector_types import create_detector
from setproctitle import setproctitle from setproctitle import setproctitle
DEFAULT_ZMQ_LINGER = 2500
READY_SHM = b"\007"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ObjectDetectionWorker(Worker): class ObjectDetectionWorker(QueueWorker):
def __init__( def __init__(
self, self,
detector_name: str, detector_name: str,
@ -30,23 +25,13 @@ class ObjectDetectionWorker(Worker):
avg_inference_speed: mp.Value = mp.Value("d", 0.01), avg_inference_speed: mp.Value = mp.Value("d", 0.01),
detection_start: mp.Value = mp.Value("d", 0.00), detection_start: mp.Value = mp.Value("d", 0.00),
labels=None, labels=None,
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL, stop_event: mp.Event = None,
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
zmq_context=None,
zmq_linger=DEFAULT_ZMQ_LINGER,
): ):
self.broker_url = detector_config.address
self.service_names = [
text_to_ascii_bytes(service_name)
for service_name in detector_config.cameras
]
super().__init__( super().__init__(
self.broker_url, broker_url=detector_config.address,
b"", service_names=detector_config.cameras,
heartbeat_interval, handler_name="DETECT_NO_SHM" if not detector_config.shared_memory else None,
heartbeat_timeout, stop_event=stop_event,
zmq_context,
zmq_linger,
) )
self.detector_name = detector_name self.detector_name = detector_name
self.detector_config = detector_config self.detector_config = detector_config
@ -80,12 +65,7 @@ class ObjectDetectionWorker(Worker):
self.detect_api = create_detector(self.detector_config) self.detect_api = create_detector(self.detector_config)
def _send_ready(self): def handle_request(self, client_id: bytes, request: List[bytes]):
command = READY_SHM if self.detector_config.shared_memory else protocol.READY
for service_name in self.service_names:
self._send(command, service_name)
def handle_request(self, request):
self.detection_start.value = datetime.datetime.now().timestamp() self.detection_start.value = datetime.datetime.now().timestamp()
# expected request format: # expected request format:
@ -100,6 +80,7 @@ class ObjectDetectionWorker(Worker):
self.detection_start.value = 0.0 self.detection_start.value = 0.0
return frames return frames
frames.append(detections.tobytes()) frames.append(detections.tobytes())
elif len(request) == 3: elif len(request) == 3:
camera_name = request[0].decode("ascii") camera_name = request[0].decode("ascii")
shm_shape = ( shm_shape = (
@ -186,17 +167,7 @@ def run_detector(
detection_start, detection_start,
labels, labels,
) )
worker.start()
def receiveSignal(signalNumber, frame):
worker.close()
signal.signal(signal.SIGTERM, receiveSignal)
signal.signal(signal.SIGINT, receiveSignal)
worker_iter = WorkerRequestsIterator(worker)
for request in worker_iter:
reply = worker.handle_request(request)
worker_iter.send_reply_final(reply)
class ObjectDetectProcess: class ObjectDetectProcess:

View File

@ -79,18 +79,3 @@ class BaseDetectorConfig(BaseModel):
class Config: class Config:
extra = Extra.allow extra = Extra.allow
arbitrary_types_allowed = True arbitrary_types_allowed = True
class DetectionServerModeEnum(str, Enum):
Full = "full"
DetectionOnly = "detection_only"
class DetectionServerConfig(BaseModel):
mode: DetectionServerModeEnum = Field(
default=DetectionServerModeEnum.Full, title="Server mode"
)
ipc: str = Field(default="ipc://detection_broker.ipc", title="Broker IPC path")
addresses: List[str] = Field(
default=["tcp://127.0.0.1:5555"], title="Broker TCP addresses"
)

323
frigate/majordomo.py Normal file
View File

@ -0,0 +1,323 @@
"""Majordomo / `MDP/0.2 <https://rfc.zeromq.org/spec:18/MDP/>`_ broker implementation.
Extends the implementation of https://github.com/shoppimon/majortomo
"""
import signal
import threading
import time
import zmq
import multiprocessing as mp
from multiprocessing.shared_memory import SharedMemory
from typing import (
DefaultDict,
Dict,
Generator,
List,
Optional,
Tuple,
Union,
) # noqa: F401
from majortomo import error, protocol
from majortomo.config import DEFAULT_BIND_URL
from majortomo.broker import (
Broker,
Worker as BrokerWorker,
ServicesContainer,
id_to_int,
)
from majortomo.util import TextOrBytes, text_to_ascii_bytes
from majortomo.worker import DEFAULT_ZMQ_LINGER, Worker as MdpWorker
class QueueServicesContainer(ServicesContainer):
def __init__(self, busy_workers_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT):
super().__init__(busy_workers_timeout)
def dequeue_pending(self):
# type: () -> Generator[Tuple[List[bytes], BrokerWorker, bytes], None, None]
"""Pop ready message-worker pairs from all service queues so they can be dispatched"""
for service_name, service in self._services.items():
for message, worker, client in service.dequeue_pending():
yield message, worker, client, service_name
def add_worker(
self, worker_id: bytes, service: bytes, expire_at: float, next_heartbeat: float
):
"""Add a worker to the list of available workers"""
if worker_id in self._workers:
worker = self._workers[worker_id]
if service in worker.service:
raise error.StateError(
f"Worker '{id_to_int(worker_id)}' has already sent READY message for service '{service.decode('ascii')}'"
)
else:
worker.service.add(service)
else:
worker = BrokerWorker(worker_id, set([service]), expire_at, next_heartbeat)
self._workers[worker.id] = worker
self._services[service].add_worker(worker)
return worker
def set_worker_available(self, worker_id, expire_at, next_heartbeat):
# type: (bytes, float, float) -> BrokerWorker
"""Mark a worker that was busy processing a request as back and available again"""
if worker_id not in self._busy_workers:
raise error.StateError(
"Worker id {} is not previously known or has expired".format(
id_to_int(worker_id)
)
)
worker = self._busy_workers.pop(worker_id)
worker.is_busy = False
worker.expire_at = expire_at
worker.next_heartbeat = next_heartbeat
self._workers[worker_id] = worker
for service in worker.service:
self._services[service].add_worker(worker)
return worker
def remove_worker(self, worker_id):
# type: (bytes) -> None
"""Remove a worker from the list of known workers"""
try:
worker = self._workers[worker_id]
for service_name in worker.service:
service = self._services[service_name]
if worker_id in service._workers:
service.remove_worker(worker_id)
del self._workers[worker_id]
except KeyError:
try:
del self._busy_workers[worker_id]
except KeyError:
raise error.StateError(f"Worker id {id_to_int(worker_id)} is not known")
class QueueBroker(Broker):
def __init__(
self,
bind: Union[str, List[str]] = [DEFAULT_BIND_URL],
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT,
zmq_context=None,
):
super().__init__(
heartbeat_interval=heartbeat_interval,
heartbeat_timeout=heartbeat_timeout,
busy_worker_timeout=busy_worker_timeout,
zmq_context=zmq_context,
)
self._services = QueueServicesContainer(busy_worker_timeout)
self._bind_urls = [bind] if not isinstance(bind, list) else bind
self.broker_thread: threading.Thread = None
self.request_handlers: Dict[str, object] = {}
def bind(self):
"""Bind the ZMQ socket"""
if self._socket:
raise error.StateError("Socket is already bound")
self._socket = self._context.socket(zmq.ROUTER)
self._socket.rcvtimeo = int(self._heartbeat_interval * 1000)
for bind_url in self._bind_urls:
self._socket.bind(bind_url)
self._log.info("Broker listening on %s", bind_url)
def close(self):
if self._socket is None:
return
for bind_url in self._bind_urls:
self._socket.disconnect(bind_url)
self._socket.close()
self._socket = None
self._log.info("Broker socket closing")
def start(self):
self.broker_thread = threading.Thread(target=self.run)
self.broker_thread.start()
def stop(self):
super().stop()
if self.broker_thread is not None:
self.broker_thread.join()
self.broker_thread = None
def _handle_worker_message(self, message):
if message.command == protocol.READY:
self._handle_worker_ready(message)
else:
super()._handle_worker_message(message)
def _handle_worker_ready(self, message):
# type: (protocol.Message) -> BrokerWorker
worker_id = message.client
service = message.message[0]
self._log.info("Got READY from worker: %d", id_to_int(worker_id))
now = time.time()
expire_at = now + self._heartbeat_timeout
next_heartbeat = now + self._heartbeat_interval
worker = self._services.add_worker(
worker_id, service, expire_at, next_heartbeat
)
worker.request_handler = None
worker.request_params = []
if len(message.message) > 1:
worker.request_handler = str(message.message[1], "ascii")
if len(message.message) > 2:
worker.request_params = [str(m, "ascii") for m in message.message[2:]]
return worker
def _dispatch_queued_messages(self):
"""Dispatch all queued messages to available workers"""
expire_at = time.time() + self._busy_worker_timeout
for message, worker, client, service_name in self._services.dequeue_pending():
body = [client, b""] + message
body = self.on_worker_request(worker, service_name, body)
self._send_to_worker(worker.id, protocol.REQUEST, body)
self._services.set_worker_busy(worker.id, expire_at=expire_at)
def register_request_handler(self, handler_name: str, handler):
self.request_handlers[handler_name] = handler
def on_worker_request(
self, worker: BrokerWorker, service_name: bytes, body: List[bytes]
):
if worker.request_handler in self.request_handlers:
handler = self.request_handlers[worker.request_handler]
body = handler(worker, service_name, body)
return body
class MultiBindableBroker:
def __init__(
self,
bind: Union[str, List[str]] = [DEFAULT_BIND_URL],
shms: dict[str, SharedMemory] = {},
):
super().__init__(bind)
self.shms = shms
def on_worker_request(
self, worker: BrokerWorker, service_name: bytes, body: List[bytes]
) -> List[bytes]:
if "DETECT_NO_SHM" == worker.request_handler:
in_shm = self.shms[str(service_name, "ascii")]
tensor_input = in_shm.buf
body = body[0:2] + [tensor_input]
return body
class QueueWorker(MdpWorker):
def __init__(
self,
broker_url: str,
service_names: List[TextOrBytes],
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
zmq_context=None,
zmq_linger=DEFAULT_ZMQ_LINGER,
handler_name: TextOrBytes = None,
handler_params: List[TextOrBytes] = [],
stop_event: mp.Event = None,
):
super().__init__(
broker_url,
b"",
heartbeat_interval,
heartbeat_timeout,
zmq_context,
zmq_linger,
)
self.service_names = [
text_to_ascii_bytes(service_name) for service_name in service_names
]
self.ready_params = [
text_to_ascii_bytes(rp)
for rp in (
([handler_name] if handler_name is not None else []) + handler_params
)
]
self.stop_event = stop_event or mp.Event()
def _send_ready(self):
for service_name in self.service_names:
self._send(protocol.READY, service_name, *self.ready_params)
def handle_request(client_id: bytes, request: List[bytes]) -> List[bytes]:
return request
def start(self):
if not self.is_connected:
self.connect()
def signal_handler(sig_num, _):
self.stop()
for sig_num in (signal.SIGINT, signal.SIGTERM):
signal.signal(sig_num, signal_handler)
while not self.stop_event.is_set():
try:
client_id, request = self.wait_for_request()
reply = self.handle_request(client_id, request)
self.send_reply_final(reply)
except error.ProtocolError as e:
self._log.warning("Protocol error: %s, dropping request", str(e))
continue
except error.Disconnected:
self._log.info("Worker disconnected")
break
self.close()
def stop(self):
self.stop_event.set()
def _receive(self):
# type: () -> Tuple[bytes, List[bytes]]
"""Poll on the socket until a command is received
Will handle timeouts and heartbeats internally without returning
"""
while not self.stop_event.is_set():
if self._socket is None:
raise error.Disconnected("Worker is disconnected")
self._check_send_heartbeat()
poll_timeout = self._get_poll_timeout()
try:
socks = dict(self._poller.poll(timeout=poll_timeout))
except zmq.error.ZMQError:
# Probably connection was explicitly closed
if self._socket is None:
continue
raise
if socks.get(self._socket) == zmq.POLLIN:
message = self._socket.recv_multipart()
self._log.debug("Got message of %d frames", len(message))
else:
self._log.debug("Receive timed out after %d ms", poll_timeout)
if (time.time() - self._last_broker_hb) > self._heartbeat_timeout:
# We're not connected anymore?
self._log.info(
"Got no heartbeat in %d sec, disconnecting and reconnecting socket",
self._heartbeat_timeout,
)
self.connect(reconnect=True)
continue
command, frames = self._verify_message(message)
self._last_broker_hb = time.time()
if command == protocol.HEARTBEAT:
self._log.debug("Got heartbeat message from broker")
continue
return command, frames
return protocol.DISCONNECT, []

View File

@ -10,10 +10,10 @@ from pydantic import parse_obj_as
from frigate.config import FrigateConfig, DetectorConfig, InputTensorEnum, ModelConfig from frigate.config import FrigateConfig, DetectorConfig, InputTensorEnum, ModelConfig
from frigate.detectors import ( from frigate.detectors import (
DetectorTypeEnum, DetectorTypeEnum,
ObjectDetectionBroker,
ObjectDetectionClient, ObjectDetectionClient,
ObjectDetectionWorker, ObjectDetectionWorker,
) )
from frigate.majordomo import QueueBroker
from frigate.util import deep_merge from frigate.util import deep_merge
import frigate.detectors.detector_types as detectors import frigate.detectors.detector_types as detectors
@ -34,6 +34,60 @@ def create_detector(det_type):
return api return api
def start_broker(ipc_address, tcp_address, camera_names):
detection_shms: dict[str, SharedMemory] = {}
def detect_no_shm(worker, service_name, body):
in_shm = detection_shms[str(service_name, "ascii")]
tensor_input = in_shm.buf
body = body[0:2] + [tensor_input]
return body
queue_broker = QueueBroker(bind=[ipc_address, tcp_address])
queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm)
queue_broker.start()
for camera_name in camera_names:
shm_name = camera_name
out_shm_name = f"out-{camera_name}"
try:
shm = SharedMemory(name=shm_name, size=512 * 512 * 3, create=True)
except FileExistsError:
shm = SharedMemory(name=shm_name, create=False)
detection_shms[shm_name] = shm
try:
out_shm = SharedMemory(name=out_shm_name, size=20 * 6 * 4, create=True)
except FileExistsError:
out_shm = SharedMemory(name=out_shm_name, create=False)
detection_shms[out_shm_name] = out_shm
return queue_broker, detection_shms
class WorkerTestThread(threading.Thread):
def __init__(self, detector_name, detector_config, stop_event):
super().__init__()
self.detector_name = detector_name
self.detector_config = detector_config
self.stop_event = stop_event
def run(self):
worker = ObjectDetectionWorker(
self.detector_name,
self.detector_config,
mp.Value("d", 0.01),
mp.Value("d", 0.0),
None,
self.stop_event,
)
worker.connect()
if not self.stop_event.is_set():
client_id, request = worker.wait_for_request()
reply = worker.handle_request(client_id, request)
worker.send_reply_final(client_id, reply)
worker.close()
class TestLocalObjectDetector(unittest.TestCase): class TestLocalObjectDetector(unittest.TestCase):
@patch.dict( @patch.dict(
"frigate.detectors.detector_types.api_types", "frigate.detectors.detector_types.api_types",
@ -41,7 +95,7 @@ class TestLocalObjectDetector(unittest.TestCase):
) )
def test_socket_client_broker_worker(self): def test_socket_client_broker_worker(self):
detector_name = "cpu" detector_name = "cpu"
ipc_address = "ipc://detection_broker.ipc" ipc_address = "ipc://queue_broker.ipc"
tcp_address = "tcp://127.0.0.1:5555" tcp_address = "tcp://127.0.0.1:5555"
detector = {"type": "cpu"} detector = {"type": "cpu"}
@ -56,60 +110,11 @@ class TestLocalObjectDetector(unittest.TestCase):
"tcp_no_shm": {"address": tcp_address, "cameras": ["tcp_no_shm"]}, "tcp_no_shm": {"address": tcp_address, "cameras": ["tcp_no_shm"]},
} }
class ClientTestThread(threading.Thread): try:
def __init__( queue_broker, detection_shms = None, None
self, queue_broker, detection_shms = start_broker(
camera_name, ipc_address, tcp_address, test_cases.keys()
labelmap,
model_config,
server_config,
tensor_input,
timeout,
):
super().__init__()
self.camera_name = camera_name
self.labelmap = labelmap
self.model_config = model_config
self.server_config = server_config
self.tensor_input = tensor_input
self.timeout = timeout
def run(self):
object_detector = ObjectDetectionClient(
self.camera_name,
self.labelmap,
self.model_config,
self.server_config,
timeout=self.timeout,
) )
try:
object_detector.detect(self.tensor_input)
finally:
object_detector.cleanup()
try:
detection_shms: dict[str, SharedMemory] = {}
for camera_name in test_cases.keys():
shm_name = camera_name
out_shm_name = f"out-{camera_name}"
try:
shm = SharedMemory(name=shm_name, size=512 * 512 * 3, create=True)
except FileExistsError:
shm = SharedMemory(name=shm_name)
detection_shms[shm_name] = shm
try:
out_shm = SharedMemory(
name=out_shm_name, size=20 * 6 * 4, create=True
)
except FileExistsError:
out_shm = SharedMemory(name=out_shm_name)
detection_shms[out_shm_name] = out_shm
self.detection_broker = ObjectDetectionBroker(
bind=[ipc_address, tcp_address],
shms=detection_shms,
)
self.detection_broker.start()
for test_case in test_cases.keys(): for test_case in test_cases.keys():
with self.subTest(test_case=test_case): with self.subTest(test_case=test_case):
@ -136,6 +141,7 @@ class TestLocalObjectDetector(unittest.TestCase):
config = test_cfg.runtime_config config = test_cfg.runtime_config
detector_config = config.detectors[detector_name] detector_config = config.detectors[detector_name]
model_config = detector_config.model model_config = detector_config.model
stop_event = mp.Event()
tensor_input = np.ndarray( tensor_input = np.ndarray(
(1, config.model.height, config.model.width, 3), (1, config.model.height, config.model.width, 3),
@ -146,33 +152,26 @@ class TestLocalObjectDetector(unittest.TestCase):
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf) out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
try: try:
worker = ObjectDetectionWorker( client = None
detector_name, worker = WorkerTestThread(
detector_config, detector_name, detector_config, stop_event
mp.Value("d", 0.01),
mp.Value("d", 0.0),
None,
) )
worker.connect() worker.start()
client = ClientTestThread( client = ObjectDetectionClient(
camera_name, camera_name,
test_cfg.model.merged_labelmap, test_cfg.model.merged_labelmap,
model_config, model_config,
config.server, config.server.ipc,
tensor_input,
timeout=10, timeout=10,
) )
client.start() client.detect(tensor_input)
client_id, request = worker.wait_for_request()
reply = worker.handle_request(request)
worker.send_reply_final(client_id, reply)
except Exception as ex:
print(ex)
finally: finally:
client.join() stop_event.set()
worker.close() if client is not None:
client.cleanup()
if worker is not None:
worker.join()
self.assertIsNone( self.assertIsNone(
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
@ -180,7 +179,8 @@ class TestLocalObjectDetector(unittest.TestCase):
) )
) )
finally: finally:
self.detection_broker.stop() if queue_broker is not None:
queue_broker.stop()
for shm in detection_shms.values(): for shm in detection_shms.values():
shm.close() shm.close()
shm.unlink() shm.unlink()

View File

@ -14,12 +14,7 @@ import numpy as np
import cv2 import cv2
from setproctitle import setproctitle from setproctitle import setproctitle
from frigate.config import ( from frigate.config import CameraConfig, DetectConfig, PixelFormatEnum, ServerConfig
CameraConfig,
DetectConfig,
PixelFormatEnum,
DetectionServerConfig,
)
from frigate.const import CACHE_DIR from frigate.const import CACHE_DIR
from frigate.detectors import ObjectDetectionClient from frigate.detectors import ObjectDetectionClient
from frigate.log import LogPipe from frigate.log import LogPipe
@ -413,7 +408,7 @@ def track_camera(
camera_name, camera_name,
config: CameraConfig, config: CameraConfig,
model_config, model_config,
server_config: DetectionServerConfig, server_config: ServerConfig,
labelmap, labelmap,
detected_objects_queue, detected_objects_queue,
process_info, process_info,
@ -449,7 +444,7 @@ def track_camera(
motion_contour_area, motion_contour_area,
) )
object_detector = ObjectDetectionClient( object_detector = ObjectDetectionClient(
camera_name, labelmap, model_config, server_config camera_name, labelmap, model_config, server_config.ipc
) )
object_tracker = ObjectTracker(config.detect) object_tracker = ObjectTracker(config.detect)