mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
added generic majordomo broker
This commit is contained in:
parent
9d21d71282
commit
03140c4687
@ -17,16 +17,13 @@ from playhouse.sqliteq import SqliteQueueDatabase
|
||||
from frigate.comms.dispatcher import Communicator, Dispatcher
|
||||
from frigate.comms.mqtt import MqttClient
|
||||
from frigate.comms.ws import WebSocketClient
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.config import FrigateConfig, ServerModeEnum
|
||||
from frigate.const import CACHE_DIR, CLIPS_DIR, RECORD_DIR
|
||||
from frigate.detectors import (
|
||||
ObjectDetectProcess,
|
||||
ObjectDetectionBroker,
|
||||
DetectionServerModeEnum,
|
||||
)
|
||||
from frigate.detectors import ObjectDetectProcess
|
||||
from frigate.events import EventCleanup, EventProcessor
|
||||
from frigate.http import create_app
|
||||
from frigate.log import log_process, root_configurer
|
||||
from frigate.majordomo import QueueBroker
|
||||
from frigate.models import Event, Recordings
|
||||
from frigate.object_processing import TrackedObjectProcessor
|
||||
from frigate.output import output_frames
|
||||
@ -84,7 +81,7 @@ class FrigateApp:
|
||||
user_config = FrigateConfig.parse_file(config_file)
|
||||
self.config = user_config.runtime_config
|
||||
|
||||
if self.config.server.mode == DetectionServerModeEnum.DetectionOnly:
|
||||
if self.config.server.mode == ServerModeEnum.DetectionOnly:
|
||||
return
|
||||
|
||||
for camera_name in self.config.cameras.keys():
|
||||
@ -188,12 +185,17 @@ class FrigateApp:
|
||||
comms.append(self.ws_client)
|
||||
self.dispatcher = Dispatcher(self.config, self.camera_metrics, comms)
|
||||
|
||||
def start_detection_broker(self) -> None:
|
||||
def start_queue_broker(self) -> None:
|
||||
def detect_no_shm(worker, service_name, body):
|
||||
in_shm = self.detection_shms[str(service_name, "ascii")]
|
||||
tensor_input = in_shm.buf
|
||||
body = body[0:2] + [tensor_input]
|
||||
return body
|
||||
|
||||
bind_urls = [self.config.broker.ipc] + self.config.broker.addresses
|
||||
self.detection_broker = ObjectDetectionBroker(
|
||||
bind=bind_urls, shms=self.detection_shms
|
||||
)
|
||||
self.detection_broker.start()
|
||||
self.queue_broker = QueueBroker(bind=bind_urls)
|
||||
self.queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm)
|
||||
self.queue_broker.start()
|
||||
|
||||
def start_detectors(self) -> None:
|
||||
for name in self.config.cameras.keys():
|
||||
@ -372,7 +374,7 @@ class FrigateApp:
|
||||
self.set_environment_vars()
|
||||
self.ensure_dirs()
|
||||
|
||||
if self.config.server.mode == DetectionServerModeEnum.DetectionOnly:
|
||||
if self.config.server.mode == ServerModeEnum.DetectionOnly:
|
||||
self.start_detectors()
|
||||
self.start_watchdog()
|
||||
self.stop_event.wait()
|
||||
@ -389,7 +391,7 @@ class FrigateApp:
|
||||
|
||||
self.init_restream()
|
||||
self.start_detectors()
|
||||
self.start_detection_broker()
|
||||
self.start_queue_broker()
|
||||
self.start_video_output_processor()
|
||||
self.start_detected_frames_processor()
|
||||
self.start_camera_processors()
|
||||
@ -416,7 +418,7 @@ class FrigateApp:
|
||||
logger.info(f"Stopping...")
|
||||
self.stop_event.set()
|
||||
|
||||
if self.config.server.mode != DetectionServerModeEnum.DetectionOnly:
|
||||
if self.config.server.mode != ServerModeEnum.DetectionOnly:
|
||||
self.ws_client.stop()
|
||||
self.detected_frames_processor.join()
|
||||
self.event_processor.join()
|
||||
@ -426,7 +428,7 @@ class FrigateApp:
|
||||
self.stats_emitter.join()
|
||||
self.frigate_watchdog.join()
|
||||
self.db.stop()
|
||||
self.detection_broker.stop()
|
||||
self.queue_broker.stop()
|
||||
|
||||
for detector in self.detectors.values():
|
||||
detector.stop()
|
||||
|
||||
@ -43,10 +43,8 @@ from frigate.ffmpeg_presets import (
|
||||
from frigate.detectors import (
|
||||
PixelFormatEnum,
|
||||
InputTensorEnum,
|
||||
DetectionServerConfig,
|
||||
ModelConfig,
|
||||
BaseDetectorConfig,
|
||||
DetectionServerModeEnum,
|
||||
)
|
||||
from frigate.version import VERSION
|
||||
|
||||
@ -812,9 +810,22 @@ def verify_zone_objects_are_tracked(camera_config: CameraConfig) -> None:
|
||||
)
|
||||
|
||||
|
||||
class ServerModeEnum(str, Enum):
|
||||
Full = "full"
|
||||
DetectionOnly = "detection_only"
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
mode: ServerModeEnum = Field(default=ServerModeEnum.Full, title="Server mode")
|
||||
ipc: str = Field(default="ipc://queue_broker.ipc", title="Broker IPC path")
|
||||
addresses: List[str] = Field(
|
||||
default=["tcp://127.0.0.1:5555"], title="Broker TCP addresses"
|
||||
)
|
||||
|
||||
|
||||
class FrigateConfig(FrigateBaseModel):
|
||||
server: DetectionServerConfig = Field(
|
||||
default_factory=DetectionServerConfig, title="Server configuration"
|
||||
server: ServerConfig = Field(
|
||||
default_factory=ServerConfig, title="Server configuration"
|
||||
)
|
||||
mqtt: Optional[MqttConfig] = Field(title="MQTT Configuration.")
|
||||
database: DatabaseConfig = Field(
|
||||
@ -913,7 +924,7 @@ class FrigateConfig(FrigateBaseModel):
|
||||
|
||||
config.detectors[key] = detector_config
|
||||
|
||||
if config.server.mode == DetectionServerModeEnum.DetectionOnly:
|
||||
if config.server.mode == ServerModeEnum.DetectionOnly:
|
||||
return config
|
||||
|
||||
# MQTT password substitution
|
||||
@ -1030,8 +1041,8 @@ class FrigateConfig(FrigateBaseModel):
|
||||
server_config = values.get("server", None)
|
||||
if (
|
||||
server_config is not None
|
||||
and server_config.get("mode", DetectionServerModeEnum.Full)
|
||||
== DetectionServerModeEnum.DetectionOnly
|
||||
and server_config.get("mode", ServerModeEnum.Full)
|
||||
== ServerModeEnum.DetectionOnly
|
||||
):
|
||||
return values
|
||||
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
from .detection_api import DetectionApi
|
||||
from .detection_broker import ObjectDetectionBroker
|
||||
from .detector_config import (
|
||||
PixelFormatEnum,
|
||||
InputTensorEnum,
|
||||
ModelConfig,
|
||||
BaseDetectorConfig,
|
||||
DetectionServerConfig,
|
||||
DetectionServerModeEnum,
|
||||
)
|
||||
from .detection_client import ObjectDetectionClient
|
||||
from .detector_types import DetectorTypeEnum, api_types, create_detector
|
||||
|
||||
@ -1,89 +0,0 @@
|
||||
import signal
|
||||
import threading
|
||||
import zmq
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Union, List
|
||||
from majortomo import error, protocol
|
||||
from majortomo.config import DEFAULT_BIND_URL
|
||||
from majortomo.broker import Broker
|
||||
|
||||
|
||||
READY_SHM = b"\007"
|
||||
|
||||
|
||||
class ObjectDetectionBroker(Broker):
|
||||
def __init__(
|
||||
self,
|
||||
bind: Union[str, List[str]] = DEFAULT_BIND_URL,
|
||||
shms: dict[str, SharedMemory] = {},
|
||||
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
|
||||
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
|
||||
busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT,
|
||||
zmq_context=None,
|
||||
):
|
||||
protocol.Message.ALLOWED_COMMANDS[protocol.WORKER_HEADER].add(READY_SHM)
|
||||
|
||||
super().__init__(
|
||||
self,
|
||||
heartbeat_interval=heartbeat_interval,
|
||||
heartbeat_timeout=heartbeat_timeout,
|
||||
busy_worker_timeout=busy_worker_timeout,
|
||||
zmq_context=zmq_context,
|
||||
)
|
||||
self.shm_workers = set()
|
||||
self.shms = shms
|
||||
self._bind_urls = [bind] if not isinstance(bind, list) else bind
|
||||
self.broker_thread: threading.Thread = None
|
||||
|
||||
def bind(self):
|
||||
"""Bind the ZMQ socket"""
|
||||
if self._socket:
|
||||
raise error.StateError("Socket is already bound")
|
||||
|
||||
self._socket = self._context.socket(zmq.ROUTER)
|
||||
self._socket.rcvtimeo = int(self._heartbeat_interval * 1000)
|
||||
for bind_url in self._bind_urls:
|
||||
self._socket.bind(bind_url)
|
||||
self._log.info("Broker listening on %s", bind_url)
|
||||
|
||||
def close(self):
|
||||
if self._socket is None:
|
||||
return
|
||||
for bind_url in self._bind_urls:
|
||||
self._socket.disconnect(bind_url)
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
self._log.info("Broker socket closing")
|
||||
|
||||
def _handle_worker_message(self, message):
|
||||
if message.command == READY_SHM:
|
||||
self.shm_workers.add(message.client)
|
||||
self._handle_worker_ready(message.client, message.message[0])
|
||||
else:
|
||||
super()._handle_worker_message(message)
|
||||
|
||||
def _purge_expired_workers(self):
|
||||
self.shm_workers.intersection_update(self._services._workers.keys())
|
||||
super()._purge_expired_workers()
|
||||
|
||||
def _send_to_worker(self, worker_id, command, body=None):
|
||||
if (
|
||||
worker_id not in self.shm_workers
|
||||
and command == protocol.REQUEST
|
||||
and body is not None
|
||||
):
|
||||
service_name = body[2]
|
||||
in_shm = self.shms[str(service_name, "ascii")]
|
||||
tensor_input = in_shm.buf
|
||||
body = body[0:2] + [tensor_input]
|
||||
super()._send_to_worker(worker_id, command, body)
|
||||
|
||||
def start(self):
|
||||
self.broker_thread = threading.Thread(target=self.run)
|
||||
self.broker_thread.start()
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
if self.broker_thread is not None:
|
||||
self.broker_thread.join()
|
||||
self.broker_thread = None
|
||||
@ -3,7 +3,7 @@ import multiprocessing as mp
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from majortomo import Client
|
||||
from frigate.util import EventsPerSecond
|
||||
from .detector_config import ModelConfig, DetectionServerConfig
|
||||
from .detector_config import ModelConfig
|
||||
|
||||
|
||||
class ObjectDetectionClient:
|
||||
@ -12,7 +12,7 @@ class ObjectDetectionClient:
|
||||
camera_name: str,
|
||||
labels,
|
||||
model_config: ModelConfig,
|
||||
server_config: DetectionServerConfig,
|
||||
server_address: str,
|
||||
timeout=None,
|
||||
):
|
||||
self.camera_name = camera_name
|
||||
@ -29,7 +29,7 @@ class ObjectDetectionClient:
|
||||
self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf)
|
||||
|
||||
self.timeout = timeout
|
||||
self.detection_client = Client(server_config.ipc)
|
||||
self.detection_client = Client(server_address)
|
||||
self.detection_client.connect()
|
||||
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
|
||||
@ -1,28 +1,23 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import numpy as np
|
||||
import multiprocessing as mp
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from majortomo import Worker, WorkerRequestsIterator, error, protocol
|
||||
from majortomo.util import TextOrBytes, text_to_ascii_bytes
|
||||
from typing import List
|
||||
|
||||
from frigate.majordomo import QueueWorker
|
||||
from frigate.util import listen, EventsPerSecond, load_labels
|
||||
from .detector_config import InputTensorEnum, BaseDetectorConfig
|
||||
from .detector_types import create_detector
|
||||
|
||||
from setproctitle import setproctitle
|
||||
|
||||
DEFAULT_ZMQ_LINGER = 2500
|
||||
READY_SHM = b"\007"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ObjectDetectionWorker(Worker):
|
||||
class ObjectDetectionWorker(QueueWorker):
|
||||
def __init__(
|
||||
self,
|
||||
detector_name: str,
|
||||
@ -30,23 +25,13 @@ class ObjectDetectionWorker(Worker):
|
||||
avg_inference_speed: mp.Value = mp.Value("d", 0.01),
|
||||
detection_start: mp.Value = mp.Value("d", 0.00),
|
||||
labels=None,
|
||||
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
|
||||
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
|
||||
zmq_context=None,
|
||||
zmq_linger=DEFAULT_ZMQ_LINGER,
|
||||
stop_event: mp.Event = None,
|
||||
):
|
||||
self.broker_url = detector_config.address
|
||||
self.service_names = [
|
||||
text_to_ascii_bytes(service_name)
|
||||
for service_name in detector_config.cameras
|
||||
]
|
||||
super().__init__(
|
||||
self.broker_url,
|
||||
b"",
|
||||
heartbeat_interval,
|
||||
heartbeat_timeout,
|
||||
zmq_context,
|
||||
zmq_linger,
|
||||
broker_url=detector_config.address,
|
||||
service_names=detector_config.cameras,
|
||||
handler_name="DETECT_NO_SHM" if not detector_config.shared_memory else None,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
self.detector_name = detector_name
|
||||
self.detector_config = detector_config
|
||||
@ -80,12 +65,7 @@ class ObjectDetectionWorker(Worker):
|
||||
|
||||
self.detect_api = create_detector(self.detector_config)
|
||||
|
||||
def _send_ready(self):
|
||||
command = READY_SHM if self.detector_config.shared_memory else protocol.READY
|
||||
for service_name in self.service_names:
|
||||
self._send(command, service_name)
|
||||
|
||||
def handle_request(self, request):
|
||||
def handle_request(self, client_id: bytes, request: List[bytes]):
|
||||
self.detection_start.value = datetime.datetime.now().timestamp()
|
||||
|
||||
# expected request format:
|
||||
@ -100,6 +80,7 @@ class ObjectDetectionWorker(Worker):
|
||||
self.detection_start.value = 0.0
|
||||
return frames
|
||||
frames.append(detections.tobytes())
|
||||
|
||||
elif len(request) == 3:
|
||||
camera_name = request[0].decode("ascii")
|
||||
shm_shape = (
|
||||
@ -186,17 +167,7 @@ def run_detector(
|
||||
detection_start,
|
||||
labels,
|
||||
)
|
||||
|
||||
def receiveSignal(signalNumber, frame):
|
||||
worker.close()
|
||||
|
||||
signal.signal(signal.SIGTERM, receiveSignal)
|
||||
signal.signal(signal.SIGINT, receiveSignal)
|
||||
|
||||
worker_iter = WorkerRequestsIterator(worker)
|
||||
for request in worker_iter:
|
||||
reply = worker.handle_request(request)
|
||||
worker_iter.send_reply_final(reply)
|
||||
worker.start()
|
||||
|
||||
|
||||
class ObjectDetectProcess:
|
||||
|
||||
@ -79,18 +79,3 @@ class BaseDetectorConfig(BaseModel):
|
||||
class Config:
|
||||
extra = Extra.allow
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class DetectionServerModeEnum(str, Enum):
|
||||
Full = "full"
|
||||
DetectionOnly = "detection_only"
|
||||
|
||||
|
||||
class DetectionServerConfig(BaseModel):
|
||||
mode: DetectionServerModeEnum = Field(
|
||||
default=DetectionServerModeEnum.Full, title="Server mode"
|
||||
)
|
||||
ipc: str = Field(default="ipc://detection_broker.ipc", title="Broker IPC path")
|
||||
addresses: List[str] = Field(
|
||||
default=["tcp://127.0.0.1:5555"], title="Broker TCP addresses"
|
||||
)
|
||||
|
||||
323
frigate/majordomo.py
Normal file
323
frigate/majordomo.py
Normal file
@ -0,0 +1,323 @@
|
||||
"""Majordomo / `MDP/0.2 <https://rfc.zeromq.org/spec:18/MDP/>`_ broker implementation.
|
||||
|
||||
Extends the implementation of https://github.com/shoppimon/majortomo
|
||||
"""
|
||||
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
import zmq
|
||||
import multiprocessing as mp
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import (
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
) # noqa: F401
|
||||
from majortomo import error, protocol
|
||||
from majortomo.config import DEFAULT_BIND_URL
|
||||
from majortomo.broker import (
|
||||
Broker,
|
||||
Worker as BrokerWorker,
|
||||
ServicesContainer,
|
||||
id_to_int,
|
||||
)
|
||||
from majortomo.util import TextOrBytes, text_to_ascii_bytes
|
||||
from majortomo.worker import DEFAULT_ZMQ_LINGER, Worker as MdpWorker
|
||||
|
||||
|
||||
class QueueServicesContainer(ServicesContainer):
|
||||
def __init__(self, busy_workers_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT):
|
||||
super().__init__(busy_workers_timeout)
|
||||
|
||||
def dequeue_pending(self):
|
||||
# type: () -> Generator[Tuple[List[bytes], BrokerWorker, bytes], None, None]
|
||||
"""Pop ready message-worker pairs from all service queues so they can be dispatched"""
|
||||
for service_name, service in self._services.items():
|
||||
for message, worker, client in service.dequeue_pending():
|
||||
yield message, worker, client, service_name
|
||||
|
||||
def add_worker(
|
||||
self, worker_id: bytes, service: bytes, expire_at: float, next_heartbeat: float
|
||||
):
|
||||
"""Add a worker to the list of available workers"""
|
||||
if worker_id in self._workers:
|
||||
worker = self._workers[worker_id]
|
||||
if service in worker.service:
|
||||
raise error.StateError(
|
||||
f"Worker '{id_to_int(worker_id)}' has already sent READY message for service '{service.decode('ascii')}'"
|
||||
)
|
||||
else:
|
||||
worker.service.add(service)
|
||||
else:
|
||||
worker = BrokerWorker(worker_id, set([service]), expire_at, next_heartbeat)
|
||||
self._workers[worker.id] = worker
|
||||
self._services[service].add_worker(worker)
|
||||
return worker
|
||||
|
||||
def set_worker_available(self, worker_id, expire_at, next_heartbeat):
|
||||
# type: (bytes, float, float) -> BrokerWorker
|
||||
"""Mark a worker that was busy processing a request as back and available again"""
|
||||
if worker_id not in self._busy_workers:
|
||||
raise error.StateError(
|
||||
"Worker id {} is not previously known or has expired".format(
|
||||
id_to_int(worker_id)
|
||||
)
|
||||
)
|
||||
worker = self._busy_workers.pop(worker_id)
|
||||
worker.is_busy = False
|
||||
worker.expire_at = expire_at
|
||||
worker.next_heartbeat = next_heartbeat
|
||||
self._workers[worker_id] = worker
|
||||
for service in worker.service:
|
||||
self._services[service].add_worker(worker)
|
||||
return worker
|
||||
|
||||
def remove_worker(self, worker_id):
|
||||
# type: (bytes) -> None
|
||||
"""Remove a worker from the list of known workers"""
|
||||
try:
|
||||
worker = self._workers[worker_id]
|
||||
for service_name in worker.service:
|
||||
service = self._services[service_name]
|
||||
if worker_id in service._workers:
|
||||
service.remove_worker(worker_id)
|
||||
del self._workers[worker_id]
|
||||
except KeyError:
|
||||
try:
|
||||
del self._busy_workers[worker_id]
|
||||
except KeyError:
|
||||
raise error.StateError(f"Worker id {id_to_int(worker_id)} is not known")
|
||||
|
||||
|
||||
class QueueBroker(Broker):
|
||||
def __init__(
|
||||
self,
|
||||
bind: Union[str, List[str]] = [DEFAULT_BIND_URL],
|
||||
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
|
||||
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
|
||||
busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT,
|
||||
zmq_context=None,
|
||||
):
|
||||
super().__init__(
|
||||
heartbeat_interval=heartbeat_interval,
|
||||
heartbeat_timeout=heartbeat_timeout,
|
||||
busy_worker_timeout=busy_worker_timeout,
|
||||
zmq_context=zmq_context,
|
||||
)
|
||||
self._services = QueueServicesContainer(busy_worker_timeout)
|
||||
self._bind_urls = [bind] if not isinstance(bind, list) else bind
|
||||
self.broker_thread: threading.Thread = None
|
||||
self.request_handlers: Dict[str, object] = {}
|
||||
|
||||
def bind(self):
|
||||
"""Bind the ZMQ socket"""
|
||||
if self._socket:
|
||||
raise error.StateError("Socket is already bound")
|
||||
|
||||
self._socket = self._context.socket(zmq.ROUTER)
|
||||
self._socket.rcvtimeo = int(self._heartbeat_interval * 1000)
|
||||
for bind_url in self._bind_urls:
|
||||
self._socket.bind(bind_url)
|
||||
self._log.info("Broker listening on %s", bind_url)
|
||||
|
||||
def close(self):
|
||||
if self._socket is None:
|
||||
return
|
||||
for bind_url in self._bind_urls:
|
||||
self._socket.disconnect(bind_url)
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
self._log.info("Broker socket closing")
|
||||
|
||||
def start(self):
|
||||
self.broker_thread = threading.Thread(target=self.run)
|
||||
self.broker_thread.start()
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
if self.broker_thread is not None:
|
||||
self.broker_thread.join()
|
||||
self.broker_thread = None
|
||||
|
||||
def _handle_worker_message(self, message):
|
||||
if message.command == protocol.READY:
|
||||
self._handle_worker_ready(message)
|
||||
else:
|
||||
super()._handle_worker_message(message)
|
||||
|
||||
def _handle_worker_ready(self, message):
|
||||
# type: (protocol.Message) -> BrokerWorker
|
||||
worker_id = message.client
|
||||
service = message.message[0]
|
||||
self._log.info("Got READY from worker: %d", id_to_int(worker_id))
|
||||
now = time.time()
|
||||
expire_at = now + self._heartbeat_timeout
|
||||
next_heartbeat = now + self._heartbeat_interval
|
||||
worker = self._services.add_worker(
|
||||
worker_id, service, expire_at, next_heartbeat
|
||||
)
|
||||
worker.request_handler = None
|
||||
worker.request_params = []
|
||||
if len(message.message) > 1:
|
||||
worker.request_handler = str(message.message[1], "ascii")
|
||||
if len(message.message) > 2:
|
||||
worker.request_params = [str(m, "ascii") for m in message.message[2:]]
|
||||
return worker
|
||||
|
||||
def _dispatch_queued_messages(self):
|
||||
"""Dispatch all queued messages to available workers"""
|
||||
expire_at = time.time() + self._busy_worker_timeout
|
||||
for message, worker, client, service_name in self._services.dequeue_pending():
|
||||
body = [client, b""] + message
|
||||
body = self.on_worker_request(worker, service_name, body)
|
||||
self._send_to_worker(worker.id, protocol.REQUEST, body)
|
||||
self._services.set_worker_busy(worker.id, expire_at=expire_at)
|
||||
|
||||
def register_request_handler(self, handler_name: str, handler):
|
||||
self.request_handlers[handler_name] = handler
|
||||
|
||||
def on_worker_request(
|
||||
self, worker: BrokerWorker, service_name: bytes, body: List[bytes]
|
||||
):
|
||||
if worker.request_handler in self.request_handlers:
|
||||
handler = self.request_handlers[worker.request_handler]
|
||||
body = handler(worker, service_name, body)
|
||||
return body
|
||||
|
||||
|
||||
class MultiBindableBroker:
|
||||
def __init__(
|
||||
self,
|
||||
bind: Union[str, List[str]] = [DEFAULT_BIND_URL],
|
||||
shms: dict[str, SharedMemory] = {},
|
||||
):
|
||||
super().__init__(bind)
|
||||
self.shms = shms
|
||||
|
||||
def on_worker_request(
|
||||
self, worker: BrokerWorker, service_name: bytes, body: List[bytes]
|
||||
) -> List[bytes]:
|
||||
if "DETECT_NO_SHM" == worker.request_handler:
|
||||
in_shm = self.shms[str(service_name, "ascii")]
|
||||
tensor_input = in_shm.buf
|
||||
body = body[0:2] + [tensor_input]
|
||||
return body
|
||||
|
||||
|
||||
class QueueWorker(MdpWorker):
|
||||
def __init__(
|
||||
self,
|
||||
broker_url: str,
|
||||
service_names: List[TextOrBytes],
|
||||
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
|
||||
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
|
||||
zmq_context=None,
|
||||
zmq_linger=DEFAULT_ZMQ_LINGER,
|
||||
handler_name: TextOrBytes = None,
|
||||
handler_params: List[TextOrBytes] = [],
|
||||
stop_event: mp.Event = None,
|
||||
):
|
||||
super().__init__(
|
||||
broker_url,
|
||||
b"",
|
||||
heartbeat_interval,
|
||||
heartbeat_timeout,
|
||||
zmq_context,
|
||||
zmq_linger,
|
||||
)
|
||||
|
||||
self.service_names = [
|
||||
text_to_ascii_bytes(service_name) for service_name in service_names
|
||||
]
|
||||
self.ready_params = [
|
||||
text_to_ascii_bytes(rp)
|
||||
for rp in (
|
||||
([handler_name] if handler_name is not None else []) + handler_params
|
||||
)
|
||||
]
|
||||
self.stop_event = stop_event or mp.Event()
|
||||
|
||||
def _send_ready(self):
|
||||
for service_name in self.service_names:
|
||||
self._send(protocol.READY, service_name, *self.ready_params)
|
||||
|
||||
def handle_request(client_id: bytes, request: List[bytes]) -> List[bytes]:
|
||||
return request
|
||||
|
||||
def start(self):
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
def signal_handler(sig_num, _):
|
||||
self.stop()
|
||||
|
||||
for sig_num in (signal.SIGINT, signal.SIGTERM):
|
||||
signal.signal(sig_num, signal_handler)
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
client_id, request = self.wait_for_request()
|
||||
reply = self.handle_request(client_id, request)
|
||||
self.send_reply_final(reply)
|
||||
except error.ProtocolError as e:
|
||||
self._log.warning("Protocol error: %s, dropping request", str(e))
|
||||
continue
|
||||
except error.Disconnected:
|
||||
self._log.info("Worker disconnected")
|
||||
break
|
||||
|
||||
self.close()
|
||||
|
||||
def stop(self):
|
||||
self.stop_event.set()
|
||||
|
||||
def _receive(self):
|
||||
# type: () -> Tuple[bytes, List[bytes]]
|
||||
"""Poll on the socket until a command is received
|
||||
|
||||
Will handle timeouts and heartbeats internally without returning
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
if self._socket is None:
|
||||
raise error.Disconnected("Worker is disconnected")
|
||||
|
||||
self._check_send_heartbeat()
|
||||
poll_timeout = self._get_poll_timeout()
|
||||
|
||||
try:
|
||||
socks = dict(self._poller.poll(timeout=poll_timeout))
|
||||
except zmq.error.ZMQError:
|
||||
# Probably connection was explicitly closed
|
||||
if self._socket is None:
|
||||
continue
|
||||
raise
|
||||
|
||||
if socks.get(self._socket) == zmq.POLLIN:
|
||||
message = self._socket.recv_multipart()
|
||||
self._log.debug("Got message of %d frames", len(message))
|
||||
else:
|
||||
self._log.debug("Receive timed out after %d ms", poll_timeout)
|
||||
if (time.time() - self._last_broker_hb) > self._heartbeat_timeout:
|
||||
# We're not connected anymore?
|
||||
self._log.info(
|
||||
"Got no heartbeat in %d sec, disconnecting and reconnecting socket",
|
||||
self._heartbeat_timeout,
|
||||
)
|
||||
self.connect(reconnect=True)
|
||||
continue
|
||||
|
||||
command, frames = self._verify_message(message)
|
||||
self._last_broker_hb = time.time()
|
||||
|
||||
if command == protocol.HEARTBEAT:
|
||||
self._log.debug("Got heartbeat message from broker")
|
||||
continue
|
||||
|
||||
return command, frames
|
||||
|
||||
return protocol.DISCONNECT, []
|
||||
@ -10,10 +10,10 @@ from pydantic import parse_obj_as
|
||||
from frigate.config import FrigateConfig, DetectorConfig, InputTensorEnum, ModelConfig
|
||||
from frigate.detectors import (
|
||||
DetectorTypeEnum,
|
||||
ObjectDetectionBroker,
|
||||
ObjectDetectionClient,
|
||||
ObjectDetectionWorker,
|
||||
)
|
||||
from frigate.majordomo import QueueBroker
|
||||
from frigate.util import deep_merge
|
||||
import frigate.detectors.detector_types as detectors
|
||||
|
||||
@ -34,6 +34,60 @@ def create_detector(det_type):
|
||||
return api
|
||||
|
||||
|
||||
def start_broker(ipc_address, tcp_address, camera_names):
|
||||
detection_shms: dict[str, SharedMemory] = {}
|
||||
|
||||
def detect_no_shm(worker, service_name, body):
|
||||
in_shm = detection_shms[str(service_name, "ascii")]
|
||||
tensor_input = in_shm.buf
|
||||
body = body[0:2] + [tensor_input]
|
||||
return body
|
||||
|
||||
queue_broker = QueueBroker(bind=[ipc_address, tcp_address])
|
||||
queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm)
|
||||
queue_broker.start()
|
||||
|
||||
for camera_name in camera_names:
|
||||
shm_name = camera_name
|
||||
out_shm_name = f"out-{camera_name}"
|
||||
try:
|
||||
shm = SharedMemory(name=shm_name, size=512 * 512 * 3, create=True)
|
||||
except FileExistsError:
|
||||
shm = SharedMemory(name=shm_name, create=False)
|
||||
detection_shms[shm_name] = shm
|
||||
try:
|
||||
out_shm = SharedMemory(name=out_shm_name, size=20 * 6 * 4, create=True)
|
||||
except FileExistsError:
|
||||
out_shm = SharedMemory(name=out_shm_name, create=False)
|
||||
detection_shms[out_shm_name] = out_shm
|
||||
|
||||
return queue_broker, detection_shms
|
||||
|
||||
|
||||
class WorkerTestThread(threading.Thread):
|
||||
def __init__(self, detector_name, detector_config, stop_event):
|
||||
super().__init__()
|
||||
self.detector_name = detector_name
|
||||
self.detector_config = detector_config
|
||||
self.stop_event = stop_event
|
||||
|
||||
def run(self):
|
||||
worker = ObjectDetectionWorker(
|
||||
self.detector_name,
|
||||
self.detector_config,
|
||||
mp.Value("d", 0.01),
|
||||
mp.Value("d", 0.0),
|
||||
None,
|
||||
self.stop_event,
|
||||
)
|
||||
worker.connect()
|
||||
if not self.stop_event.is_set():
|
||||
client_id, request = worker.wait_for_request()
|
||||
reply = worker.handle_request(client_id, request)
|
||||
worker.send_reply_final(client_id, reply)
|
||||
worker.close()
|
||||
|
||||
|
||||
class TestLocalObjectDetector(unittest.TestCase):
|
||||
@patch.dict(
|
||||
"frigate.detectors.detector_types.api_types",
|
||||
@ -41,7 +95,7 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
)
|
||||
def test_socket_client_broker_worker(self):
|
||||
detector_name = "cpu"
|
||||
ipc_address = "ipc://detection_broker.ipc"
|
||||
ipc_address = "ipc://queue_broker.ipc"
|
||||
tcp_address = "tcp://127.0.0.1:5555"
|
||||
|
||||
detector = {"type": "cpu"}
|
||||
@ -56,60 +110,11 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
"tcp_no_shm": {"address": tcp_address, "cameras": ["tcp_no_shm"]},
|
||||
}
|
||||
|
||||
class ClientTestThread(threading.Thread):
|
||||
def __init__(
|
||||
self,
|
||||
camera_name,
|
||||
labelmap,
|
||||
model_config,
|
||||
server_config,
|
||||
tensor_input,
|
||||
timeout,
|
||||
):
|
||||
super().__init__()
|
||||
self.camera_name = camera_name
|
||||
self.labelmap = labelmap
|
||||
self.model_config = model_config
|
||||
self.server_config = server_config
|
||||
self.tensor_input = tensor_input
|
||||
self.timeout = timeout
|
||||
|
||||
def run(self):
|
||||
object_detector = ObjectDetectionClient(
|
||||
self.camera_name,
|
||||
self.labelmap,
|
||||
self.model_config,
|
||||
self.server_config,
|
||||
timeout=self.timeout,
|
||||
try:
|
||||
queue_broker, detection_shms = None, None
|
||||
queue_broker, detection_shms = start_broker(
|
||||
ipc_address, tcp_address, test_cases.keys()
|
||||
)
|
||||
try:
|
||||
object_detector.detect(self.tensor_input)
|
||||
finally:
|
||||
object_detector.cleanup()
|
||||
|
||||
try:
|
||||
detection_shms: dict[str, SharedMemory] = {}
|
||||
for camera_name in test_cases.keys():
|
||||
shm_name = camera_name
|
||||
out_shm_name = f"out-{camera_name}"
|
||||
try:
|
||||
shm = SharedMemory(name=shm_name, size=512 * 512 * 3, create=True)
|
||||
except FileExistsError:
|
||||
shm = SharedMemory(name=shm_name)
|
||||
detection_shms[shm_name] = shm
|
||||
try:
|
||||
out_shm = SharedMemory(
|
||||
name=out_shm_name, size=20 * 6 * 4, create=True
|
||||
)
|
||||
except FileExistsError:
|
||||
out_shm = SharedMemory(name=out_shm_name)
|
||||
detection_shms[out_shm_name] = out_shm
|
||||
|
||||
self.detection_broker = ObjectDetectionBroker(
|
||||
bind=[ipc_address, tcp_address],
|
||||
shms=detection_shms,
|
||||
)
|
||||
self.detection_broker.start()
|
||||
|
||||
for test_case in test_cases.keys():
|
||||
with self.subTest(test_case=test_case):
|
||||
@ -136,6 +141,7 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
config = test_cfg.runtime_config
|
||||
detector_config = config.detectors[detector_name]
|
||||
model_config = detector_config.model
|
||||
stop_event = mp.Event()
|
||||
|
||||
tensor_input = np.ndarray(
|
||||
(1, config.model.height, config.model.width, 3),
|
||||
@ -146,33 +152,26 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
|
||||
|
||||
try:
|
||||
worker = ObjectDetectionWorker(
|
||||
detector_name,
|
||||
detector_config,
|
||||
mp.Value("d", 0.01),
|
||||
mp.Value("d", 0.0),
|
||||
None,
|
||||
client = None
|
||||
worker = WorkerTestThread(
|
||||
detector_name, detector_config, stop_event
|
||||
)
|
||||
worker.connect()
|
||||
worker.start()
|
||||
|
||||
client = ClientTestThread(
|
||||
client = ObjectDetectionClient(
|
||||
camera_name,
|
||||
test_cfg.model.merged_labelmap,
|
||||
model_config,
|
||||
config.server,
|
||||
tensor_input,
|
||||
config.server.ipc,
|
||||
timeout=10,
|
||||
)
|
||||
client.start()
|
||||
|
||||
client_id, request = worker.wait_for_request()
|
||||
reply = worker.handle_request(request)
|
||||
worker.send_reply_final(client_id, reply)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
client.detect(tensor_input)
|
||||
finally:
|
||||
client.join()
|
||||
worker.close()
|
||||
stop_event.set()
|
||||
if client is not None:
|
||||
client.cleanup()
|
||||
if worker is not None:
|
||||
worker.join()
|
||||
|
||||
self.assertIsNone(
|
||||
np.testing.assert_array_almost_equal(
|
||||
@ -180,7 +179,8 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
finally:
|
||||
self.detection_broker.stop()
|
||||
if queue_broker is not None:
|
||||
queue_broker.stop()
|
||||
for shm in detection_shms.values():
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
|
||||
@ -14,12 +14,7 @@ import numpy as np
|
||||
import cv2
|
||||
from setproctitle import setproctitle
|
||||
|
||||
from frigate.config import (
|
||||
CameraConfig,
|
||||
DetectConfig,
|
||||
PixelFormatEnum,
|
||||
DetectionServerConfig,
|
||||
)
|
||||
from frigate.config import CameraConfig, DetectConfig, PixelFormatEnum, ServerConfig
|
||||
from frigate.const import CACHE_DIR
|
||||
from frigate.detectors import ObjectDetectionClient
|
||||
from frigate.log import LogPipe
|
||||
@ -413,7 +408,7 @@ def track_camera(
|
||||
camera_name,
|
||||
config: CameraConfig,
|
||||
model_config,
|
||||
server_config: DetectionServerConfig,
|
||||
server_config: ServerConfig,
|
||||
labelmap,
|
||||
detected_objects_queue,
|
||||
process_info,
|
||||
@ -449,7 +444,7 @@ def track_camera(
|
||||
motion_contour_area,
|
||||
)
|
||||
object_detector = ObjectDetectionClient(
|
||||
camera_name, labelmap, model_config, server_config
|
||||
camera_name, labelmap, model_config, server_config.ipc
|
||||
)
|
||||
|
||||
object_tracker = ObjectTracker(config.detect)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user