Use separate zmq proxy for object detection

This commit is contained in:
Nicolas Mowen 2025-06-11 15:12:19 -06:00
parent 16ad0e5477
commit d3e34ca4fd
3 changed files with 83 additions and 10 deletions

View File

@ -23,6 +23,7 @@ from frigate.comms.dispatcher import Dispatcher
from frigate.comms.event_metadata_updater import EventMetadataPublisher
from frigate.comms.inter_process import InterProcessCommunicator
from frigate.comms.mqtt import MqttClient
from frigate.comms.object_detector_signaler import DetectorProxy
from frigate.comms.webpush import WebPushClient
from frigate.comms.ws import WebSocketClient
from frigate.comms.zmq_proxy import ZmqProxy
@ -330,6 +331,7 @@ class FrigateApp:
self.inter_config_updater = CameraConfigUpdatePublisher()
self.event_metadata_updater = EventMetadataPublisher()
self.inter_zmq_proxy = ZmqProxy()
self.detection_proxy = DetectorProxy()
def init_onvif(self) -> None:
self.onvif_controller = OnvifController(self.config, self.ptz_metrics)
@ -661,6 +663,7 @@ class FrigateApp:
self.inter_config_updater.stop()
self.event_metadata_updater.stop()
self.inter_zmq_proxy.stop()
self.detection_proxy.stop()
while len(self.detection_shms) > 0:
shm = self.detection_shms.pop()

View File

@ -1,21 +1,92 @@
"""Facilitates communication between processes for object detection signals."""
from .zmq_proxy import Publisher, Subscriber
import threading
import zmq
SOCKET_PUB = "ipc:///tmp/cache/detector_pub"
SOCKET_SUB = "ipc:///tmp/cache/detector_sub"
class ObjectDetectorPublisher(Publisher):
class ZmqProxyRunner(threading.Thread):
def __init__(self, context: zmq.Context[zmq.Socket]) -> None:
super().__init__(name="detector_proxy")
self.context = context
def run(self) -> None:
"""Run the proxy."""
incoming = self.context.socket(zmq.XSUB)
incoming.bind(SOCKET_PUB)
outgoing = self.context.socket(zmq.XPUB)
outgoing.bind(SOCKET_SUB)
# Blocking: This will unblock (via exception) when we destroy the context
# The incoming and outgoing sockets will be closed automatically
# when the context is destroyed as well.
try:
zmq.proxy(incoming, outgoing)
except zmq.ZMQError:
pass
class DetectorProxy:
"""Proxies object detection signals."""
def __init__(self) -> None:
self.context = zmq.Context()
self.runner = ZmqProxyRunner(self.context)
self.runner.start()
def stop(self) -> None:
# destroying the context will tell the proxy to stop
self.context.destroy()
self.runner.join()
class ObjectDetectorPublisher:
"""Publishes signal for object detection to different processes."""
topic_base = "object_detector/"
def __init__(self, topic: str = "") -> None:
self.topic = f"{self.topic_base}{topic}"
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.connect(SOCKET_PUB)
class ObjectDetectorSubscriber(Subscriber):
def publish(self, sub_topic: str = "") -> None:
"""Publish message."""
self.socket.send_string(f"{self.topic}{sub_topic}/")
def stop(self) -> None:
self.socket.close()
self.context.destroy()
class ObjectDetectorSubscriber:
"""Simplifies receiving a signal for object detection."""
topic_base = "object_detector/"
def __init__(self, topic: str) -> None:
super().__init__(topic)
def __init__(self, topic: str = "") -> None:
self.topic = f"{self.topic_base}{topic}/"
self.context = zmq.Context()
self.socket = self.context.socket(zmq.SUB)
self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic)
self.socket.connect(SOCKET_SUB)
def check_for_update(self):
return super().check_for_update(timeout=5)
def check_for_update(self, timeout: float = 5) -> str | None:
"""Returns message or None if no update."""
try:
has_update, _, _ = zmq.select([self.socket], [], [], timeout)
if has_update:
return self.socket.recv_string(flags=zmq.NOBLOCK)
except zmq.ZMQError:
pass
return None
def stop(self) -> None:
self.socket.close()
self.context.destroy()

View File

@ -149,8 +149,7 @@ def run_detector(
create_output_shm(connection_id)
outputs[connection_id]["np"][:] = detections[:]
signal_id = f"{connection_id}/update"
detector_publisher.publish(signal_id, signal_id)
detector_publisher.publish(connection_id)
start.value = 0.0
avg_speed.value = (avg_speed.value * 9 + duration) / 10
@ -231,7 +230,7 @@ class RemoteObjectDetector:
)
self.out_shm = UntrackedSharedMemory(name=f"out-{self.name}", create=False)
self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf)
self.detector_subscriber = ObjectDetectorSubscriber(f"{name}/update")
self.detector_subscriber = ObjectDetectorSubscriber(name)
def detect(self, tensor_input, threshold=0.4):
detections = []