added generic majordomo broker

This commit is contained in:
Dennis George 2022-12-27 15:57:23 -06:00
parent 9d21d71282
commit 03140c4687
10 changed files with 452 additions and 257 deletions

View File

@ -17,16 +17,13 @@ from playhouse.sqliteq import SqliteQueueDatabase
from frigate.comms.dispatcher import Communicator, Dispatcher
from frigate.comms.mqtt import MqttClient
from frigate.comms.ws import WebSocketClient
from frigate.config import FrigateConfig
from frigate.config import FrigateConfig, ServerModeEnum
from frigate.const import CACHE_DIR, CLIPS_DIR, RECORD_DIR
from frigate.detectors import (
ObjectDetectProcess,
ObjectDetectionBroker,
DetectionServerModeEnum,
)
from frigate.detectors import ObjectDetectProcess
from frigate.events import EventCleanup, EventProcessor
from frigate.http import create_app
from frigate.log import log_process, root_configurer
from frigate.majordomo import QueueBroker
from frigate.models import Event, Recordings
from frigate.object_processing import TrackedObjectProcessor
from frigate.output import output_frames
@ -84,7 +81,7 @@ class FrigateApp:
user_config = FrigateConfig.parse_file(config_file)
self.config = user_config.runtime_config
if self.config.server.mode == DetectionServerModeEnum.DetectionOnly:
if self.config.server.mode == ServerModeEnum.DetectionOnly:
return
for camera_name in self.config.cameras.keys():
@ -188,12 +185,17 @@ class FrigateApp:
comms.append(self.ws_client)
self.dispatcher = Dispatcher(self.config, self.camera_metrics, comms)
def start_detection_broker(self) -> None:
def start_queue_broker(self) -> None:
def detect_no_shm(worker, service_name, body):
in_shm = self.detection_shms[str(service_name, "ascii")]
tensor_input = in_shm.buf
body = body[0:2] + [tensor_input]
return body
bind_urls = [self.config.broker.ipc] + self.config.broker.addresses
self.detection_broker = ObjectDetectionBroker(
bind=bind_urls, shms=self.detection_shms
)
self.detection_broker.start()
self.queue_broker = QueueBroker(bind=bind_urls)
self.queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm)
self.queue_broker.start()
def start_detectors(self) -> None:
for name in self.config.cameras.keys():
@ -372,7 +374,7 @@ class FrigateApp:
self.set_environment_vars()
self.ensure_dirs()
if self.config.server.mode == DetectionServerModeEnum.DetectionOnly:
if self.config.server.mode == ServerModeEnum.DetectionOnly:
self.start_detectors()
self.start_watchdog()
self.stop_event.wait()
@ -389,7 +391,7 @@ class FrigateApp:
self.init_restream()
self.start_detectors()
self.start_detection_broker()
self.start_queue_broker()
self.start_video_output_processor()
self.start_detected_frames_processor()
self.start_camera_processors()
@ -416,7 +418,7 @@ class FrigateApp:
logger.info(f"Stopping...")
self.stop_event.set()
if self.config.server.mode != DetectionServerModeEnum.DetectionOnly:
if self.config.server.mode != ServerModeEnum.DetectionOnly:
self.ws_client.stop()
self.detected_frames_processor.join()
self.event_processor.join()
@ -426,7 +428,7 @@ class FrigateApp:
self.stats_emitter.join()
self.frigate_watchdog.join()
self.db.stop()
self.detection_broker.stop()
self.queue_broker.stop()
for detector in self.detectors.values():
detector.stop()

View File

@ -43,10 +43,8 @@ from frigate.ffmpeg_presets import (
from frigate.detectors import (
PixelFormatEnum,
InputTensorEnum,
DetectionServerConfig,
ModelConfig,
BaseDetectorConfig,
DetectionServerModeEnum,
)
from frigate.version import VERSION
@ -812,9 +810,22 @@ def verify_zone_objects_are_tracked(camera_config: CameraConfig) -> None:
)
class ServerModeEnum(str, Enum):
Full = "full"
DetectionOnly = "detection_only"
class ServerConfig(BaseModel):
mode: ServerModeEnum = Field(default=ServerModeEnum.Full, title="Server mode")
ipc: str = Field(default="ipc://queue_broker.ipc", title="Broker IPC path")
addresses: List[str] = Field(
default=["tcp://127.0.0.1:5555"], title="Broker TCP addresses"
)
class FrigateConfig(FrigateBaseModel):
server: DetectionServerConfig = Field(
default_factory=DetectionServerConfig, title="Server configuration"
server: ServerConfig = Field(
default_factory=ServerConfig, title="Server configuration"
)
mqtt: Optional[MqttConfig] = Field(title="MQTT Configuration.")
database: DatabaseConfig = Field(
@ -913,7 +924,7 @@ class FrigateConfig(FrigateBaseModel):
config.detectors[key] = detector_config
if config.server.mode == DetectionServerModeEnum.DetectionOnly:
if config.server.mode == ServerModeEnum.DetectionOnly:
return config
# MQTT password substitution
@ -1030,8 +1041,8 @@ class FrigateConfig(FrigateBaseModel):
server_config = values.get("server", None)
if (
server_config is not None
and server_config.get("mode", DetectionServerModeEnum.Full)
== DetectionServerModeEnum.DetectionOnly
and server_config.get("mode", ServerModeEnum.Full)
== ServerModeEnum.DetectionOnly
):
return values

View File

@ -1,12 +1,9 @@
from .detection_api import DetectionApi
from .detection_broker import ObjectDetectionBroker
from .detector_config import (
PixelFormatEnum,
InputTensorEnum,
ModelConfig,
BaseDetectorConfig,
DetectionServerConfig,
DetectionServerModeEnum,
)
from .detection_client import ObjectDetectionClient
from .detector_types import DetectorTypeEnum, api_types, create_detector

View File

@ -1,89 +0,0 @@
import signal
import threading
import zmq
from multiprocessing.shared_memory import SharedMemory
from typing import Union, List
from majortomo import error, protocol
from majortomo.config import DEFAULT_BIND_URL
from majortomo.broker import Broker
READY_SHM = b"\007"
class ObjectDetectionBroker(Broker):
def __init__(
self,
bind: Union[str, List[str]] = DEFAULT_BIND_URL,
shms: dict[str, SharedMemory] = {},
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT,
zmq_context=None,
):
protocol.Message.ALLOWED_COMMANDS[protocol.WORKER_HEADER].add(READY_SHM)
super().__init__(
self,
heartbeat_interval=heartbeat_interval,
heartbeat_timeout=heartbeat_timeout,
busy_worker_timeout=busy_worker_timeout,
zmq_context=zmq_context,
)
self.shm_workers = set()
self.shms = shms
self._bind_urls = [bind] if not isinstance(bind, list) else bind
self.broker_thread: threading.Thread = None
def bind(self):
"""Bind the ZMQ socket"""
if self._socket:
raise error.StateError("Socket is already bound")
self._socket = self._context.socket(zmq.ROUTER)
self._socket.rcvtimeo = int(self._heartbeat_interval * 1000)
for bind_url in self._bind_urls:
self._socket.bind(bind_url)
self._log.info("Broker listening on %s", bind_url)
def close(self):
if self._socket is None:
return
for bind_url in self._bind_urls:
self._socket.disconnect(bind_url)
self._socket.close()
self._socket = None
self._log.info("Broker socket closing")
def _handle_worker_message(self, message):
if message.command == READY_SHM:
self.shm_workers.add(message.client)
self._handle_worker_ready(message.client, message.message[0])
else:
super()._handle_worker_message(message)
def _purge_expired_workers(self):
self.shm_workers.intersection_update(self._services._workers.keys())
super()._purge_expired_workers()
def _send_to_worker(self, worker_id, command, body=None):
if (
worker_id not in self.shm_workers
and command == protocol.REQUEST
and body is not None
):
service_name = body[2]
in_shm = self.shms[str(service_name, "ascii")]
tensor_input = in_shm.buf
body = body[0:2] + [tensor_input]
super()._send_to_worker(worker_id, command, body)
def start(self):
self.broker_thread = threading.Thread(target=self.run)
self.broker_thread.start()
def stop(self):
super().stop()
if self.broker_thread is not None:
self.broker_thread.join()
self.broker_thread = None

View File

@ -3,7 +3,7 @@ import multiprocessing as mp
from multiprocessing.shared_memory import SharedMemory
from majortomo import Client
from frigate.util import EventsPerSecond
from .detector_config import ModelConfig, DetectionServerConfig
from .detector_config import ModelConfig
class ObjectDetectionClient:
@ -12,7 +12,7 @@ class ObjectDetectionClient:
camera_name: str,
labels,
model_config: ModelConfig,
server_config: DetectionServerConfig,
server_address: str,
timeout=None,
):
self.camera_name = camera_name
@ -29,7 +29,7 @@ class ObjectDetectionClient:
self.out_np_shm = np.ndarray((20, 6), dtype=np.float32, buffer=self.out_shm.buf)
self.timeout = timeout
self.detection_client = Client(server_config.ipc)
self.detection_client = Client(server_address)
self.detection_client.connect()
def detect(self, tensor_input, threshold=0.4):

View File

@ -1,28 +1,23 @@
import datetime
import logging
import os
import signal
import threading
import numpy as np
import multiprocessing as mp
from multiprocessing.shared_memory import SharedMemory
from majortomo import Worker, WorkerRequestsIterator, error, protocol
from majortomo.util import TextOrBytes, text_to_ascii_bytes
from typing import List
from frigate.majordomo import QueueWorker
from frigate.util import listen, EventsPerSecond, load_labels
from .detector_config import InputTensorEnum, BaseDetectorConfig
from .detector_types import create_detector
from setproctitle import setproctitle
DEFAULT_ZMQ_LINGER = 2500
READY_SHM = b"\007"
logger = logging.getLogger(__name__)
class ObjectDetectionWorker(Worker):
class ObjectDetectionWorker(QueueWorker):
def __init__(
self,
detector_name: str,
@ -30,23 +25,13 @@ class ObjectDetectionWorker(Worker):
avg_inference_speed: mp.Value = mp.Value("d", 0.01),
detection_start: mp.Value = mp.Value("d", 0.00),
labels=None,
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
zmq_context=None,
zmq_linger=DEFAULT_ZMQ_LINGER,
stop_event: mp.Event = None,
):
self.broker_url = detector_config.address
self.service_names = [
text_to_ascii_bytes(service_name)
for service_name in detector_config.cameras
]
super().__init__(
self.broker_url,
b"",
heartbeat_interval,
heartbeat_timeout,
zmq_context,
zmq_linger,
broker_url=detector_config.address,
service_names=detector_config.cameras,
handler_name="DETECT_NO_SHM" if not detector_config.shared_memory else None,
stop_event=stop_event,
)
self.detector_name = detector_name
self.detector_config = detector_config
@ -80,12 +65,7 @@ class ObjectDetectionWorker(Worker):
self.detect_api = create_detector(self.detector_config)
def _send_ready(self):
command = READY_SHM if self.detector_config.shared_memory else protocol.READY
for service_name in self.service_names:
self._send(command, service_name)
def handle_request(self, request):
def handle_request(self, client_id: bytes, request: List[bytes]):
self.detection_start.value = datetime.datetime.now().timestamp()
# expected request format:
@ -100,6 +80,7 @@ class ObjectDetectionWorker(Worker):
self.detection_start.value = 0.0
return frames
frames.append(detections.tobytes())
elif len(request) == 3:
camera_name = request[0].decode("ascii")
shm_shape = (
@ -186,17 +167,7 @@ def run_detector(
detection_start,
labels,
)
def receiveSignal(signalNumber, frame):
worker.close()
signal.signal(signal.SIGTERM, receiveSignal)
signal.signal(signal.SIGINT, receiveSignal)
worker_iter = WorkerRequestsIterator(worker)
for request in worker_iter:
reply = worker.handle_request(request)
worker_iter.send_reply_final(reply)
worker.start()
class ObjectDetectProcess:

View File

@ -79,18 +79,3 @@ class BaseDetectorConfig(BaseModel):
class Config:
extra = Extra.allow
arbitrary_types_allowed = True
class DetectionServerModeEnum(str, Enum):
Full = "full"
DetectionOnly = "detection_only"
class DetectionServerConfig(BaseModel):
mode: DetectionServerModeEnum = Field(
default=DetectionServerModeEnum.Full, title="Server mode"
)
ipc: str = Field(default="ipc://detection_broker.ipc", title="Broker IPC path")
addresses: List[str] = Field(
default=["tcp://127.0.0.1:5555"], title="Broker TCP addresses"
)

323
frigate/majordomo.py Normal file
View File

@ -0,0 +1,323 @@
"""Majordomo / `MDP/0.2 <https://rfc.zeromq.org/spec:18/MDP/>`_ broker implementation.
Extends the implementation of https://github.com/shoppimon/majortomo
"""
import signal
import threading
import time
import zmq
import multiprocessing as mp
from multiprocessing.shared_memory import SharedMemory
from typing import (
DefaultDict,
Dict,
Generator,
List,
Optional,
Tuple,
Union,
) # noqa: F401
from majortomo import error, protocol
from majortomo.config import DEFAULT_BIND_URL
from majortomo.broker import (
Broker,
Worker as BrokerWorker,
ServicesContainer,
id_to_int,
)
from majortomo.util import TextOrBytes, text_to_ascii_bytes
from majortomo.worker import DEFAULT_ZMQ_LINGER, Worker as MdpWorker
class QueueServicesContainer(ServicesContainer):
def __init__(self, busy_workers_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT):
super().__init__(busy_workers_timeout)
def dequeue_pending(self):
# type: () -> Generator[Tuple[List[bytes], BrokerWorker, bytes], None, None]
"""Pop ready message-worker pairs from all service queues so they can be dispatched"""
for service_name, service in self._services.items():
for message, worker, client in service.dequeue_pending():
yield message, worker, client, service_name
def add_worker(
self, worker_id: bytes, service: bytes, expire_at: float, next_heartbeat: float
):
"""Add a worker to the list of available workers"""
if worker_id in self._workers:
worker = self._workers[worker_id]
if service in worker.service:
raise error.StateError(
f"Worker '{id_to_int(worker_id)}' has already sent READY message for service '{service.decode('ascii')}'"
)
else:
worker.service.add(service)
else:
worker = BrokerWorker(worker_id, set([service]), expire_at, next_heartbeat)
self._workers[worker.id] = worker
self._services[service].add_worker(worker)
return worker
def set_worker_available(self, worker_id, expire_at, next_heartbeat):
# type: (bytes, float, float) -> BrokerWorker
"""Mark a worker that was busy processing a request as back and available again"""
if worker_id not in self._busy_workers:
raise error.StateError(
"Worker id {} is not previously known or has expired".format(
id_to_int(worker_id)
)
)
worker = self._busy_workers.pop(worker_id)
worker.is_busy = False
worker.expire_at = expire_at
worker.next_heartbeat = next_heartbeat
self._workers[worker_id] = worker
for service in worker.service:
self._services[service].add_worker(worker)
return worker
def remove_worker(self, worker_id):
# type: (bytes) -> None
"""Remove a worker from the list of known workers"""
try:
worker = self._workers[worker_id]
for service_name in worker.service:
service = self._services[service_name]
if worker_id in service._workers:
service.remove_worker(worker_id)
del self._workers[worker_id]
except KeyError:
try:
del self._busy_workers[worker_id]
except KeyError:
raise error.StateError(f"Worker id {id_to_int(worker_id)} is not known")
class QueueBroker(Broker):
def __init__(
self,
bind: Union[str, List[str]] = [DEFAULT_BIND_URL],
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT,
zmq_context=None,
):
super().__init__(
heartbeat_interval=heartbeat_interval,
heartbeat_timeout=heartbeat_timeout,
busy_worker_timeout=busy_worker_timeout,
zmq_context=zmq_context,
)
self._services = QueueServicesContainer(busy_worker_timeout)
self._bind_urls = [bind] if not isinstance(bind, list) else bind
self.broker_thread: threading.Thread = None
self.request_handlers: Dict[str, object] = {}
def bind(self):
"""Bind the ZMQ socket"""
if self._socket:
raise error.StateError("Socket is already bound")
self._socket = self._context.socket(zmq.ROUTER)
self._socket.rcvtimeo = int(self._heartbeat_interval * 1000)
for bind_url in self._bind_urls:
self._socket.bind(bind_url)
self._log.info("Broker listening on %s", bind_url)
def close(self):
if self._socket is None:
return
for bind_url in self._bind_urls:
self._socket.disconnect(bind_url)
self._socket.close()
self._socket = None
self._log.info("Broker socket closing")
def start(self):
self.broker_thread = threading.Thread(target=self.run)
self.broker_thread.start()
def stop(self):
super().stop()
if self.broker_thread is not None:
self.broker_thread.join()
self.broker_thread = None
def _handle_worker_message(self, message):
if message.command == protocol.READY:
self._handle_worker_ready(message)
else:
super()._handle_worker_message(message)
def _handle_worker_ready(self, message):
# type: (protocol.Message) -> BrokerWorker
worker_id = message.client
service = message.message[0]
self._log.info("Got READY from worker: %d", id_to_int(worker_id))
now = time.time()
expire_at = now + self._heartbeat_timeout
next_heartbeat = now + self._heartbeat_interval
worker = self._services.add_worker(
worker_id, service, expire_at, next_heartbeat
)
worker.request_handler = None
worker.request_params = []
if len(message.message) > 1:
worker.request_handler = str(message.message[1], "ascii")
if len(message.message) > 2:
worker.request_params = [str(m, "ascii") for m in message.message[2:]]
return worker
def _dispatch_queued_messages(self):
"""Dispatch all queued messages to available workers"""
expire_at = time.time() + self._busy_worker_timeout
for message, worker, client, service_name in self._services.dequeue_pending():
body = [client, b""] + message
body = self.on_worker_request(worker, service_name, body)
self._send_to_worker(worker.id, protocol.REQUEST, body)
self._services.set_worker_busy(worker.id, expire_at=expire_at)
def register_request_handler(self, handler_name: str, handler):
self.request_handlers[handler_name] = handler
def on_worker_request(
self, worker: BrokerWorker, service_name: bytes, body: List[bytes]
):
if worker.request_handler in self.request_handlers:
handler = self.request_handlers[worker.request_handler]
body = handler(worker, service_name, body)
return body
class MultiBindableBroker:
def __init__(
self,
bind: Union[str, List[str]] = [DEFAULT_BIND_URL],
shms: dict[str, SharedMemory] = {},
):
super().__init__(bind)
self.shms = shms
def on_worker_request(
self, worker: BrokerWorker, service_name: bytes, body: List[bytes]
) -> List[bytes]:
if "DETECT_NO_SHM" == worker.request_handler:
in_shm = self.shms[str(service_name, "ascii")]
tensor_input = in_shm.buf
body = body[0:2] + [tensor_input]
return body
class QueueWorker(MdpWorker):
def __init__(
self,
broker_url: str,
service_names: List[TextOrBytes],
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
zmq_context=None,
zmq_linger=DEFAULT_ZMQ_LINGER,
handler_name: TextOrBytes = None,
handler_params: List[TextOrBytes] = [],
stop_event: mp.Event = None,
):
super().__init__(
broker_url,
b"",
heartbeat_interval,
heartbeat_timeout,
zmq_context,
zmq_linger,
)
self.service_names = [
text_to_ascii_bytes(service_name) for service_name in service_names
]
self.ready_params = [
text_to_ascii_bytes(rp)
for rp in (
([handler_name] if handler_name is not None else []) + handler_params
)
]
self.stop_event = stop_event or mp.Event()
def _send_ready(self):
for service_name in self.service_names:
self._send(protocol.READY, service_name, *self.ready_params)
def handle_request(client_id: bytes, request: List[bytes]) -> List[bytes]:
return request
def start(self):
if not self.is_connected:
self.connect()
def signal_handler(sig_num, _):
self.stop()
for sig_num in (signal.SIGINT, signal.SIGTERM):
signal.signal(sig_num, signal_handler)
while not self.stop_event.is_set():
try:
client_id, request = self.wait_for_request()
reply = self.handle_request(client_id, request)
self.send_reply_final(reply)
except error.ProtocolError as e:
self._log.warning("Protocol error: %s, dropping request", str(e))
continue
except error.Disconnected:
self._log.info("Worker disconnected")
break
self.close()
def stop(self):
self.stop_event.set()
def _receive(self):
# type: () -> Tuple[bytes, List[bytes]]
"""Poll on the socket until a command is received
Will handle timeouts and heartbeats internally without returning
"""
while not self.stop_event.is_set():
if self._socket is None:
raise error.Disconnected("Worker is disconnected")
self._check_send_heartbeat()
poll_timeout = self._get_poll_timeout()
try:
socks = dict(self._poller.poll(timeout=poll_timeout))
except zmq.error.ZMQError:
# Probably connection was explicitly closed
if self._socket is None:
continue
raise
if socks.get(self._socket) == zmq.POLLIN:
message = self._socket.recv_multipart()
self._log.debug("Got message of %d frames", len(message))
else:
self._log.debug("Receive timed out after %d ms", poll_timeout)
if (time.time() - self._last_broker_hb) > self._heartbeat_timeout:
# We're not connected anymore?
self._log.info(
"Got no heartbeat in %d sec, disconnecting and reconnecting socket",
self._heartbeat_timeout,
)
self.connect(reconnect=True)
continue
command, frames = self._verify_message(message)
self._last_broker_hb = time.time()
if command == protocol.HEARTBEAT:
self._log.debug("Got heartbeat message from broker")
continue
return command, frames
return protocol.DISCONNECT, []

View File

@ -10,10 +10,10 @@ from pydantic import parse_obj_as
from frigate.config import FrigateConfig, DetectorConfig, InputTensorEnum, ModelConfig
from frigate.detectors import (
DetectorTypeEnum,
ObjectDetectionBroker,
ObjectDetectionClient,
ObjectDetectionWorker,
)
from frigate.majordomo import QueueBroker
from frigate.util import deep_merge
import frigate.detectors.detector_types as detectors
@ -34,6 +34,60 @@ def create_detector(det_type):
return api
def start_broker(ipc_address, tcp_address, camera_names):
detection_shms: dict[str, SharedMemory] = {}
def detect_no_shm(worker, service_name, body):
in_shm = detection_shms[str(service_name, "ascii")]
tensor_input = in_shm.buf
body = body[0:2] + [tensor_input]
return body
queue_broker = QueueBroker(bind=[ipc_address, tcp_address])
queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm)
queue_broker.start()
for camera_name in camera_names:
shm_name = camera_name
out_shm_name = f"out-{camera_name}"
try:
shm = SharedMemory(name=shm_name, size=512 * 512 * 3, create=True)
except FileExistsError:
shm = SharedMemory(name=shm_name, create=False)
detection_shms[shm_name] = shm
try:
out_shm = SharedMemory(name=out_shm_name, size=20 * 6 * 4, create=True)
except FileExistsError:
out_shm = SharedMemory(name=out_shm_name, create=False)
detection_shms[out_shm_name] = out_shm
return queue_broker, detection_shms
class WorkerTestThread(threading.Thread):
def __init__(self, detector_name, detector_config, stop_event):
super().__init__()
self.detector_name = detector_name
self.detector_config = detector_config
self.stop_event = stop_event
def run(self):
worker = ObjectDetectionWorker(
self.detector_name,
self.detector_config,
mp.Value("d", 0.01),
mp.Value("d", 0.0),
None,
self.stop_event,
)
worker.connect()
if not self.stop_event.is_set():
client_id, request = worker.wait_for_request()
reply = worker.handle_request(client_id, request)
worker.send_reply_final(client_id, reply)
worker.close()
class TestLocalObjectDetector(unittest.TestCase):
@patch.dict(
"frigate.detectors.detector_types.api_types",
@ -41,7 +95,7 @@ class TestLocalObjectDetector(unittest.TestCase):
)
def test_socket_client_broker_worker(self):
detector_name = "cpu"
ipc_address = "ipc://detection_broker.ipc"
ipc_address = "ipc://queue_broker.ipc"
tcp_address = "tcp://127.0.0.1:5555"
detector = {"type": "cpu"}
@ -56,60 +110,11 @@ class TestLocalObjectDetector(unittest.TestCase):
"tcp_no_shm": {"address": tcp_address, "cameras": ["tcp_no_shm"]},
}
class ClientTestThread(threading.Thread):
def __init__(
self,
camera_name,
labelmap,
model_config,
server_config,
tensor_input,
timeout,
):
super().__init__()
self.camera_name = camera_name
self.labelmap = labelmap
self.model_config = model_config
self.server_config = server_config
self.tensor_input = tensor_input
self.timeout = timeout
def run(self):
object_detector = ObjectDetectionClient(
self.camera_name,
self.labelmap,
self.model_config,
self.server_config,
timeout=self.timeout,
try:
queue_broker, detection_shms = None, None
queue_broker, detection_shms = start_broker(
ipc_address, tcp_address, test_cases.keys()
)
try:
object_detector.detect(self.tensor_input)
finally:
object_detector.cleanup()
try:
detection_shms: dict[str, SharedMemory] = {}
for camera_name in test_cases.keys():
shm_name = camera_name
out_shm_name = f"out-{camera_name}"
try:
shm = SharedMemory(name=shm_name, size=512 * 512 * 3, create=True)
except FileExistsError:
shm = SharedMemory(name=shm_name)
detection_shms[shm_name] = shm
try:
out_shm = SharedMemory(
name=out_shm_name, size=20 * 6 * 4, create=True
)
except FileExistsError:
out_shm = SharedMemory(name=out_shm_name)
detection_shms[out_shm_name] = out_shm
self.detection_broker = ObjectDetectionBroker(
bind=[ipc_address, tcp_address],
shms=detection_shms,
)
self.detection_broker.start()
for test_case in test_cases.keys():
with self.subTest(test_case=test_case):
@ -136,6 +141,7 @@ class TestLocalObjectDetector(unittest.TestCase):
config = test_cfg.runtime_config
detector_config = config.detectors[detector_name]
model_config = detector_config.model
stop_event = mp.Event()
tensor_input = np.ndarray(
(1, config.model.height, config.model.width, 3),
@ -146,33 +152,26 @@ class TestLocalObjectDetector(unittest.TestCase):
out_np = np.ndarray((20, 6), dtype=np.float32, buffer=out_shm.buf)
try:
worker = ObjectDetectionWorker(
detector_name,
detector_config,
mp.Value("d", 0.01),
mp.Value("d", 0.0),
None,
client = None
worker = WorkerTestThread(
detector_name, detector_config, stop_event
)
worker.connect()
worker.start()
client = ClientTestThread(
client = ObjectDetectionClient(
camera_name,
test_cfg.model.merged_labelmap,
model_config,
config.server,
tensor_input,
config.server.ipc,
timeout=10,
)
client.start()
client_id, request = worker.wait_for_request()
reply = worker.handle_request(request)
worker.send_reply_final(client_id, reply)
except Exception as ex:
print(ex)
client.detect(tensor_input)
finally:
client.join()
worker.close()
stop_event.set()
if client is not None:
client.cleanup()
if worker is not None:
worker.join()
self.assertIsNone(
np.testing.assert_array_almost_equal(
@ -180,7 +179,8 @@ class TestLocalObjectDetector(unittest.TestCase):
)
)
finally:
self.detection_broker.stop()
if queue_broker is not None:
queue_broker.stop()
for shm in detection_shms.values():
shm.close()
shm.unlink()

View File

@ -14,12 +14,7 @@ import numpy as np
import cv2
from setproctitle import setproctitle
from frigate.config import (
CameraConfig,
DetectConfig,
PixelFormatEnum,
DetectionServerConfig,
)
from frigate.config import CameraConfig, DetectConfig, PixelFormatEnum, ServerConfig
from frigate.const import CACHE_DIR
from frigate.detectors import ObjectDetectionClient
from frigate.log import LogPipe
@ -413,7 +408,7 @@ def track_camera(
camera_name,
config: CameraConfig,
model_config,
server_config: DetectionServerConfig,
server_config: ServerConfig,
labelmap,
detected_objects_queue,
process_info,
@ -449,7 +444,7 @@ def track_camera(
motion_contour_area,
)
object_detector = ObjectDetectionClient(
camera_name, labelmap, model_config, server_config
camera_name, labelmap, model_config, server_config.ipc
)
object_tracker = ObjectTracker(config.detect)