diff --git a/docker/rootfs/etc/cont-init.d/prepare-logs.sh b/docker/rootfs/etc/cont-init.d/prepare-logs.sh index 0d8d73ce2..20186d01b 100755 --- a/docker/rootfs/etc/cont-init.d/prepare-logs.sh +++ b/docker/rootfs/etc/cont-init.d/prepare-logs.sh @@ -4,7 +4,7 @@ set -o errexit -o nounset -o pipefail -dirs=(/dev/shm/logs/frigate /dev/shm/logs/go2rtc /dev/shm/logs/nginx) +dirs=("/dev/shm/$HOSTNAME/logs/frigate" /dev/shm/$HOSTNAME/logs/go2rtc /dev/shm/$HOSTNAME/logs/nginx) mkdir -p "${dirs[@]}" chown nobody:nogroup "${dirs[@]}" diff --git a/docker/rootfs/etc/services.d/frigate/log/run b/docker/rootfs/etc/services.d/frigate/log/run index c10284862..387dd96f2 100755 --- a/docker/rootfs/etc/services.d/frigate/log/run +++ b/docker/rootfs/etc/services.d/frigate/log/run @@ -1,4 +1,4 @@ #!/command/with-contenv bash # shellcheck shell=bash -exec logutil-service /dev/shm/logs/frigate +exec logutil-service /dev/shm/$HOSTNAME/logs/frigate diff --git a/docker/rootfs/etc/services.d/go2rtc/log/run b/docker/rootfs/etc/services.d/go2rtc/log/run index 96a204b9d..9a787361f 100755 --- a/docker/rootfs/etc/services.d/go2rtc/log/run +++ b/docker/rootfs/etc/services.d/go2rtc/log/run @@ -1,4 +1,4 @@ #!/command/with-contenv bash # shellcheck shell=bash -exec logutil-service /dev/shm/logs/go2rtc +exec logutil-service /dev/shm/$HOSTNAME/logs/go2rtc diff --git a/docker/rootfs/etc/services.d/nginx/log/run b/docker/rootfs/etc/services.d/nginx/log/run index 50057d1d7..5270c639d 100755 --- a/docker/rootfs/etc/services.d/nginx/log/run +++ b/docker/rootfs/etc/services.d/nginx/log/run @@ -1,4 +1,4 @@ #!/command/with-contenv bash # shellcheck shell=bash -exec logutil-service /dev/shm/logs/nginx +exec logutil-service /dev/shm/$HOSTNAME/logs/nginx diff --git a/frigate/app.py b/frigate/app.py index 8a2ad1c56..02f2c181c 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -1,12 +1,12 @@ import logging import multiprocessing as mp from multiprocessing.queues import Queue +from multiprocessing.shared_memory import SharedMemory from multiprocessing.synchronize import Event as MpEvent import os import signal import sys -import threading -from typing import Optional, Any +from typing import Optional from types import FrameType import traceback @@ -23,7 +23,7 @@ from frigate.detectors import ObjectDetectProcess from frigate.events import EventCleanup, EventProcessor from frigate.http import create_app from frigate.log import log_process, root_configurer -from frigate.majordomo import QueueBroker, BrokerWorker +from frigate.majortomo import Broker, ServiceWorker from frigate.models import Event, Recordings from frigate.object_processing import TrackedObjectProcessor from frigate.output import output_frames @@ -45,7 +45,7 @@ class FrigateApp: self.stop_event: MpEvent = mp.Event() self.detectors: dict[str, ObjectDetectProcess] = {} self.detection_out_events: dict[str, MpEvent] = {} - self.detection_shms: dict[str, mp.shared_memory.SharedMemory] = {} + self.detection_shms: dict[str, SharedMemory] = {} self.log_queue: Queue = mp.Queue() self.plus_api = PlusApi() self.camera_metrics: dict[str, CameraMetricsTypes] = {} @@ -187,15 +187,15 @@ class FrigateApp: def start_queue_broker(self) -> None: def detect_no_shm( - worker: BrokerWorker, service_name: bytes, body: list[bytes] + worker: ServiceWorker, service_name: bytes, body: list[bytes] ) -> list[bytes]: in_shm = self.detection_shms[str(service_name, "ascii")] tensor_input = in_shm.buf body = body[0:2] + [tensor_input] return body - bind_urls = [self.config.broker.ipc] + self.config.broker.addresses - self.queue_broker = QueueBroker(bind=bind_urls) + bind_urls = [self.config.server.ipc] + self.config.server.addresses + self.queue_broker = Broker(bind=bind_urls) self.queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm) self.queue_broker.start() @@ -208,20 +208,18 @@ class FrigateApp: for (name, det) in self.config.detectors.items() ] ) - shm_in = mp.shared_memory.SharedMemory( + shm_in = SharedMemory( name=name, create=True, size=largest_frame, ) except FileExistsError: - shm_in = mp.shared_memory.SharedMemory(name=name) + shm_in = SharedMemory(name=name) try: - shm_out = mp.shared_memory.SharedMemory( - name=f"out-{name}", create=True, size=20 * 6 * 4 - ) + shm_out = SharedMemory(name=f"out-{name}", create=True, size=20 * 6 * 4) except FileExistsError: - shm_out = mp.shared_memory.SharedMemory(name=f"out-{name}") + shm_out = SharedMemory(name=f"out-{name}") self.detection_shms[name] = shm_in self.detection_shms[f"out-{name}"] = shm_out @@ -274,6 +272,7 @@ class FrigateApp: camera_name, config, self.config.model, + self.config.server, self.config.model.merged_labelmap, self.detected_frames_queue, self.camera_metrics[camera_name], @@ -377,6 +376,7 @@ class FrigateApp: self.ensure_dirs() if self.config.server.mode == ServerModeEnum.DetectionOnly: + logger.info("Starting server in detection only mode.") self.start_detectors() self.start_watchdog() self.stop_event.wait() diff --git a/frigate/config.py b/frigate/config.py index 0234152a0..dcb833e7e 100644 --- a/frigate/config.py +++ b/frigate/config.py @@ -819,7 +819,7 @@ class ServerConfig(BaseModel): mode: ServerModeEnum = Field(default=ServerModeEnum.Full, title="Server mode") ipc: str = Field(default="ipc://queue_broker.ipc", title="Broker IPC path") addresses: List[str] = Field( - default=["tcp://127.0.0.1:5555"], title="Broker TCP addresses" + default=["tcp://0.0.0.0:5555"], title="Broker TCP addresses" ) diff --git a/frigate/detectors/detection_client.py b/frigate/detectors/detection_client.py index 73d909e81..8c01ba88b 100644 --- a/frigate/detectors/detection_client.py +++ b/frigate/detectors/detection_client.py @@ -1,7 +1,7 @@ import numpy as np import multiprocessing as mp from multiprocessing.shared_memory import SharedMemory -from majortomo import Client +from frigate.majortomo import Client from frigate.util import EventsPerSecond from .detector_config import ModelConfig diff --git a/frigate/detectors/detection_worker.py b/frigate/detectors/detection_worker.py index bd3d4f5f7..e9f85e9c0 100644 --- a/frigate/detectors/detection_worker.py +++ b/frigate/detectors/detection_worker.py @@ -7,7 +7,7 @@ import multiprocessing as mp from multiprocessing.shared_memory import SharedMemory from typing import List -from frigate.majordomo import QueueWorker +from frigate.majortomo import Worker from frigate.util import listen, EventsPerSecond, load_labels from .detector_config import InputTensorEnum, BaseDetectorConfig from .detector_types import create_detector @@ -17,7 +17,7 @@ from setproctitle import setproctitle logger = logging.getLogger(__name__) -class ObjectDetectionWorker(QueueWorker): +class ObjectDetectionWorker(Worker): def __init__( self, detector_name: str, @@ -94,7 +94,9 @@ class ObjectDetectionWorker(QueueWorker): if out_np is None: out_shm = self.detection_shms.get(f"out-{camera_name}", None) if out_shm is None: - out_shm = SharedMemory(name=f"out-{camera_name}", create=False) + out_shm = self.detection_shms[f"out-{camera_name}"] = SharedMemory( + name=f"out-{camera_name}", create=False + ) out_np = self.detection_outputs[camera_name] = np.ndarray( (20, 6), dtype=np.float32, buffer=out_shm.buf ) @@ -183,7 +185,7 @@ class ObjectDetectProcess: self.avg_inference_speed = mp.Value("d", 0.01) self.detection_start = mp.Value("d", 0.0) - self.detect_process: mp.Process + self.detect_process: mp.Process = None self.start_or_restart() diff --git a/frigate/detectors/detector_types.py b/frigate/detectors/detector_types.py index b099c5a37..583bb99b4 100644 --- a/frigate/detectors/detector_types.py +++ b/frigate/detectors/detector_types.py @@ -7,7 +7,7 @@ from .detection_api import DetectionApi from . import plugins -logger = logging.getLogger(__name__) +logger = logging.getLogger("frigate.detectors") class StrEnum(str, Enum): diff --git a/frigate/http.py b/frigate/http.py index 901097679..4cb90c1ea 100644 --- a/frigate/http.py +++ b/frigate/http.py @@ -1233,10 +1233,11 @@ def vainfo(): @bp.route("/logs/", methods=["GET"]) def logs(service: str): + HOSTNAME = os.environ.get("HOSTNAME", "frigate") log_locations = { - "frigate": "/dev/shm/logs/frigate/current", - "go2rtc": "/dev/shm/logs/go2rtc/current", - "nginx": "/dev/shm/logs/nginx/current", + "frigate": f"/dev/shm/{HOSTNAME}/logs/frigate/current", + "go2rtc": f"/dev/shm/{HOSTNAME}/logs/go2rtc/current", + "nginx": f"/dev/shm/{HOSTNAME}/logs/nginx/current", } service_location = log_locations.get(service) diff --git a/frigate/majordomo.py b/frigate/majordomo.py deleted file mode 100644 index 813324f46..000000000 --- a/frigate/majordomo.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Majordomo / `MDP/0.2 `_ broker implementation. - -Extends the implementation of https://github.com/shoppimon/majortomo -""" - -import signal -import threading -import time -import zmq -import multiprocessing as mp -from multiprocessing.shared_memory import SharedMemory -from typing import ( - DefaultDict, - Dict, - Generator, - List, - Optional, - Tuple, - Union, -) # noqa: F401 -from majortomo import error, protocol -from majortomo.config import DEFAULT_BIND_URL -from majortomo.broker import ( - Broker, - Worker as MdpBrokerWorker, - ServicesContainer, - id_to_int, -) -from majortomo.util import TextOrBytes, text_to_ascii_bytes -from majortomo.worker import DEFAULT_ZMQ_LINGER, Worker as MdpWorker - - -class BrokerWorker(MdpBrokerWorker): - """Worker objects represent a connected / known MDP worker process""" - - def __init__( - self, worker_id: bytes, service: bytes, expire_at: float, next_heartbeat: float - ): - super().__init__(worker_id, service, expire_at, next_heartbeat) - self.request_handler: str - self.request_params: list[str] = [] - - -class QueueServicesContainer(ServicesContainer): - def __init__(self, busy_workers_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT): - super().__init__(busy_workers_timeout) - - def dequeue_pending(self): - # type: () -> Generator[Tuple[List[bytes], BrokerWorker, bytes], None, None] - """Pop ready message-worker pairs from all service queues so they can be dispatched""" - for service_name, service in self._services.items(): - for message, worker, client in service.dequeue_pending(): - yield message, worker, client, service_name - - def add_worker( - self, worker_id: bytes, service: bytes, expire_at: float, next_heartbeat: float - ): - """Add a worker to the list of available workers""" - if worker_id in self._workers: - worker = self._workers[worker_id] - if service in worker.service: - raise error.StateError( - f"Worker '{id_to_int(worker_id)}' has already sent READY message for service '{service.decode('ascii')}'" - ) - else: - worker.service.add(service) - else: - worker = BrokerWorker(worker_id, set([service]), expire_at, next_heartbeat) - self._workers[worker.id] = worker - self._services[service].add_worker(worker) - return worker - - def set_worker_available(self, worker_id, expire_at, next_heartbeat): - # type: (bytes, float, float) -> BrokerWorker - """Mark a worker that was busy processing a request as back and available again""" - if worker_id not in self._busy_workers: - raise error.StateError( - "Worker id {} is not previously known or has expired".format( - id_to_int(worker_id) - ) - ) - worker = self._busy_workers.pop(worker_id) - worker.is_busy = False - worker.expire_at = expire_at - worker.next_heartbeat = next_heartbeat - self._workers[worker_id] = worker - for service in worker.service: - self._services[service].add_worker(worker) - return worker - - def remove_worker(self, worker_id): - # type: (bytes) -> None - """Remove a worker from the list of known workers""" - try: - worker = self._workers[worker_id] - for service_name in worker.service: - service = self._services[service_name] - if worker_id in service._workers: - service.remove_worker(worker_id) - del self._workers[worker_id] - except KeyError: - try: - del self._busy_workers[worker_id] - except KeyError: - raise error.StateError(f"Worker id {id_to_int(worker_id)} is not known") - - -class QueueBroker(Broker): - def __init__( - self, - bind: Union[str, List[str]] = [DEFAULT_BIND_URL], - heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL, - heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT, - busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT, - zmq_context=None, - ): - super().__init__( - heartbeat_interval=heartbeat_interval, - heartbeat_timeout=heartbeat_timeout, - busy_worker_timeout=busy_worker_timeout, - zmq_context=zmq_context, - ) - self._services = QueueServicesContainer(busy_worker_timeout) - self._bind_urls = [bind] if not isinstance(bind, list) else bind - self.broker_thread: threading.Thread = None - self.request_handlers: Dict[str, object] = {} - - def bind(self): - """Bind the ZMQ socket""" - if self._socket: - raise error.StateError("Socket is already bound") - - self._socket = self._context.socket(zmq.ROUTER) - self._socket.rcvtimeo = int(self._heartbeat_interval * 1000) - for bind_url in self._bind_urls: - self._socket.bind(bind_url) - self._log.info("Broker listening on %s", bind_url) - - def close(self): - if self._socket is None: - return - for bind_url in self._bind_urls: - self._socket.disconnect(bind_url) - self._socket.close() - self._socket = None - self._log.info("Broker socket closing") - - def start(self): - self.broker_thread = threading.Thread(target=self.run) - self.broker_thread.start() - - def stop(self): - super().stop() - if self.broker_thread is not None: - self.broker_thread.join() - self.broker_thread = None - - def _handle_worker_message(self, message): - if message.command == protocol.READY: - self._handle_worker_ready(message) - else: - super()._handle_worker_message(message) - - def _handle_worker_ready(self, message): - # type: (protocol.Message) -> BrokerWorker - worker_id = message.client - service = message.message[0] - self._log.info("Got READY from worker: %d", id_to_int(worker_id)) - now = time.time() - expire_at = now + self._heartbeat_timeout - next_heartbeat = now + self._heartbeat_interval - worker = self._services.add_worker( - worker_id, service, expire_at, next_heartbeat - ) - worker.request_handler = None - worker.request_params = [] - if len(message.message) > 1: - worker.request_handler = str(message.message[1], "ascii") - if len(message.message) > 2: - worker.request_params = [str(m, "ascii") for m in message.message[2:]] - return worker - - def _dispatch_queued_messages(self): - """Dispatch all queued messages to available workers""" - expire_at = time.time() + self._busy_worker_timeout - for message, worker, client, service_name in self._services.dequeue_pending(): - body = [client, b""] + message - body = self.on_worker_request(worker, service_name, body) - self._send_to_worker(worker.id, protocol.REQUEST, body) - self._services.set_worker_busy(worker.id, expire_at=expire_at) - - def register_request_handler(self, handler_name: str, handler): - self.request_handlers[handler_name] = handler - - def on_worker_request( - self, worker: BrokerWorker, service_name: bytes, body: List[bytes] - ): - if worker.request_handler in self.request_handlers: - handler = self.request_handlers[worker.request_handler] - body = handler(worker, service_name, body) - return body - - -class QueueWorker(MdpWorker): - def __init__( - self, - broker_url: str, - service_names: List[TextOrBytes], - heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL, - heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT, - zmq_context=None, - zmq_linger=DEFAULT_ZMQ_LINGER, - handler_name: TextOrBytes = None, - handler_params: List[TextOrBytes] = [], - stop_event: mp.Event = None, - ): - super().__init__( - broker_url, - b"", - heartbeat_interval, - heartbeat_timeout, - zmq_context, - zmq_linger, - ) - - self.service_names = [ - text_to_ascii_bytes(service_name) for service_name in service_names - ] - self.ready_params = [ - text_to_ascii_bytes(rp) - for rp in ( - ([handler_name] if handler_name is not None else []) + handler_params - ) - ] - self.stop_event = stop_event or mp.Event() - - def _send_ready(self): - for service_name in self.service_names: - self._send(protocol.READY, service_name, *self.ready_params) - - def handle_request(client_id: bytes, request: List[bytes]) -> List[bytes]: - return request - - def start(self): - if not self.is_connected: - self.connect() - - def signal_handler(sig_num, _): - self.stop() - - for sig_num in (signal.SIGINT, signal.SIGTERM): - signal.signal(sig_num, signal_handler) - - while not self.stop_event.is_set(): - try: - client_id, request = self.wait_for_request() - reply = self.handle_request(client_id, request) - self.send_reply_final(reply) - except error.ProtocolError as e: - self._log.warning("Protocol error: %s, dropping request", str(e)) - continue - except error.Disconnected: - self._log.info("Worker disconnected") - break - - self.close() - - def stop(self): - self.stop_event.set() - - def _receive(self): - # type: () -> Tuple[bytes, List[bytes]] - """Poll on the socket until a command is received - - Will handle timeouts and heartbeats internally without returning - """ - while not self.stop_event.is_set(): - if self._socket is None: - raise error.Disconnected("Worker is disconnected") - - self._check_send_heartbeat() - poll_timeout = self._get_poll_timeout() - - try: - socks = dict(self._poller.poll(timeout=poll_timeout)) - except zmq.error.ZMQError: - # Probably connection was explicitly closed - if self._socket is None: - continue - raise - - if socks.get(self._socket) == zmq.POLLIN: - message = self._socket.recv_multipart() - self._log.debug("Got message of %d frames", len(message)) - else: - self._log.debug("Receive timed out after %d ms", poll_timeout) - if (time.time() - self._last_broker_hb) > self._heartbeat_timeout: - # We're not connected anymore? - self._log.info( - "Got no heartbeat in %d sec, disconnecting and reconnecting socket", - self._heartbeat_timeout, - ) - self.connect(reconnect=True) - continue - - command, frames = self._verify_message(message) - self._last_broker_hb = time.time() - - if command == protocol.HEARTBEAT: - self._log.debug("Got heartbeat message from broker") - continue - - return command, frames - - return protocol.DISCONNECT, [] diff --git a/frigate/majortomo/__init__.py b/frigate/majortomo/__init__.py new file mode 100644 index 000000000..cce804736 --- /dev/null +++ b/frigate/majortomo/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2018 Shoppimon LTD +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .client import Client +from .worker import Worker +from .broker import Broker, ServiceWorker + +__all__ = ["Client", "Worker", "Broker", "ServiceWorker"] diff --git a/frigate/majortomo/broker.py b/frigate/majortomo/broker.py new file mode 100644 index 000000000..5aa9e5d6c --- /dev/null +++ b/frigate/majortomo/broker.py @@ -0,0 +1,526 @@ +"""Majordomo / `MDP/0.2 `_ broker implementation.""" + +import logging +import logging.config +import threading +import time +from collections import OrderedDict, defaultdict, deque +from itertools import chain +from typing import ( + DefaultDict, + Dict, + Generator, + List, + Optional, + Tuple, + Union, +) # noqa: F401 + +import zmq + +from . import error, protocol +from .util import id_to_int + +DEFAULT_BIND_URL = "tcp://0.0.0.0:5555" + + +class Broker: + def __init__( + self, + bind: Union[str, List[str]] = [DEFAULT_BIND_URL], + heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL, + heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT, + busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT, + zmq_context=None, + ): + self._log = logging.getLogger(__name__) + self._bind_urls = [bind] if not isinstance(bind, list) else bind + self._heartbeat_interval = heartbeat_interval + self._heartbeat_timeout = heartbeat_timeout + self._busy_worker_timeout = busy_worker_timeout + + self._context = zmq_context if zmq_context else zmq.Context.instance() + self._socket = None # type: zmq.Socket + self._services = ServicesContainer(busy_worker_timeout) + self._stop = False + + self.broker_thread: threading.Thread = None + self.request_handlers: Dict[str, object] = {} + + def run(self): + """Run in a main loop handling all incoming requests, until `stop()` is called""" + self._log.info("MDP Broker starting up") + self._stop = False + self.bind() + + try: + while not self._stop: + try: + message = self.receive() + if message: + self.handle_message(message) + except Exception: + self._log.exception("Message handling failed") + + finally: + self.close() + self._log.info("MDP Broker shutting down") + + def start(self): + self.broker_thread = threading.Thread(target=self.run) + self.broker_thread.name = "zmq_majordomo_broker" + self.broker_thread.start() + + def stop(self): + """Stop the broker's main loop""" + self._stop = True + if self.broker_thread is not None: + self.broker_thread.join() + self.broker_thread = None + + def bind(self): + """Bind the ZMQ socket""" + if self._socket: + raise error.StateError("Socket is already bound") + + self._socket = self._context.socket(zmq.ROUTER) + self._socket.rcvtimeo = int(self._heartbeat_interval * 1000) + for bind_url in self._bind_urls: + self._socket.bind(bind_url) + self._log.info("Broker listening on %s", bind_url) + + def close(self): + if self._socket is None: + return + for bind_url in self._bind_urls: + self._socket.disconnect(bind_url) + self._socket.close() + self._socket = None + self._log.info("Broker socket closing") + + def receive(self): + """Run until a message is received""" + while True: + self._purge_expired_workers() + self._send_heartbeat() + try: + frames = self._socket.recv_multipart() + except zmq.error.Again: + if self._socket is None or self._stop: + self._log.debug( + "Socket has been closed or broker has been shut down, breaking from recv loop" + ) + break + continue + + self._log.debug("Got message of %d frames", len(frames)) + try: + return self._parse_incoming_message(frames) + except error.ProtocolError as e: + self._log.warning(str(e)) + continue + + def handle_message(self, message): + # type: (protocol.Message) -> None + """Handle incoming message""" + if message.header == protocol.WORKER_HEADER: + self._handle_worker_message(message) + elif message.header == protocol.CLIENT_HEADER: + self._handle_client_message(message) + else: + raise error.ProtocolError( + "Unexpected protocol header: {}".format(message.header.decode("utf8")) + ) + + self._dispatch_queued_messages() + + def _handle_worker_message(self, message): + # type: (protocol.Message) -> None + """Handle message from a worker""" + if message.command == protocol.HEARTBEAT: + self._handle_worker_heartbeat(message.client) + elif message.command == protocol.READY: + self._handle_worker_ready(message) + elif message.command == protocol.FINAL: + self._handle_worker_final( + message.client, message.message[0], message.message[2:] + ) + elif message.command == protocol.PARTIAL: + self._handle_worker_partial( + message.client, message.message[0], message.message[2:] + ) + elif message.command == protocol.DISCONNECT: + self._handle_worker_disconnect(message.client) + else: + self._send_to_worker(message.client, protocol.DISCONNECT) + raise error.ProtocolError( + "Unexpected command from worker: {}".format( + message.command.decode("utf8") + ) + ) + + def _handle_worker_heartbeat(self, worker_id): + # type: (bytes) -> None + """Heartbeat from worker""" + worker = self._services.get_worker(worker_id) + if worker is None: + if not self._services.is_busy_worker(worker_id): + self._log.warning( + "Got HEARTBEAT from unknown worker: %d", id_to_int(worker_id) + ) + self._send_to_worker(worker_id, protocol.DISCONNECT) + return + + self._log.debug("Got HEARTBEAT from worker: %d", id_to_int(worker_id)) + worker.expire_at = time.time() + self._heartbeat_timeout + + def _handle_worker_ready(self, message): + # type: (protocol.Message) -> ServiceWorker + worker_id = message.client + service = message.message[0] + self._log.debug("Got READY from worker: %d", id_to_int(worker_id)) + now = time.time() + expire_at = now + self._heartbeat_timeout + next_heartbeat = now + self._heartbeat_interval + worker = self._services.add_worker( + worker_id, service, expire_at, next_heartbeat + ) + worker.request_handler = None + worker.request_params = [] + if len(message.message) > 1: + worker.request_handler = str(message.message[1], "ascii") + if len(message.message) > 2: + worker.request_params = [str(m, "ascii") for m in message.message[2:]] + return worker + + def _handle_worker_partial(self, worker_id, client_id, body): + # type: (bytes, bytes, List[bytes]) -> None + self._log.debug( + "Got PARTIAL from worker: %d to client: %d", + id_to_int(worker_id), + id_to_int(client_id), + ) + self._send_to_client(client_id, protocol.PARTIAL, body) + + def _handle_worker_final(self, worker_id, client_id, body): + # type: (bytes, bytes, List[bytes]) -> None + self._log.debug( + "Got FINAL from worker: %d to client: %d", + id_to_int(worker_id), + id_to_int(client_id), + ) + self._send_to_client(client_id, protocol.FINAL, body) + now = time.time() + self._services.set_worker_available( + worker_id, now + self._heartbeat_timeout, now + self._heartbeat_interval + ) + + def _handle_worker_disconnect(self, worker_id): + # type: (bytes) -> None + self._log.info("Got DISCONNECT from worker: %d", id_to_int(worker_id)) + try: + self._services.remove_worker(worker_id) + except KeyError: + self._log.info( + "Got DISCONNECT from unknown worker: %d; ignoring", id_to_int(worker_id) + ) + + def _handle_client_message(self, message): + # type: (protocol.Message) -> None + """Handle message from a client""" + assert message.command == protocol.REQUEST + if len(message.message) < 2: + raise error.ProtocolError( + "Client REQUEST message is expected to be at least 2 frames long, got {}".format( + len(message.message) + ) + ) + + service_name = message.message[0] + body = message.message[1:] + + # TODO: Plug-in MMA handling + + self._log.debug( + "Queueing client request from %d to %s", + id_to_int(message.client), + service_name.decode("ascii"), + ) + self._services.queue_client_request(message.client, service_name, body) + + def _dispatch_queued_messages(self): + """Dispatch all queued messages to available workers""" + expire_at = time.time() + self._busy_worker_timeout + for message, worker, client, service_name in self._services.dequeue_pending(): + body = [client, b""] + message + body = self.on_worker_request(worker, service_name, body) + self._send_to_worker(worker.id, protocol.REQUEST, body) + self._services.set_worker_busy(worker.id, expire_at=expire_at) + + def register_request_handler(self, handler_name: str, handler): + self.request_handlers[handler_name] = handler + + def on_worker_request(self, worker, service_name, body): + # type: (ServiceWorker, bytes, List[bytes]) -> List[bytes] + if worker.request_handler in self.request_handlers: + handler = self.request_handlers[worker.request_handler] + body = handler(worker, service_name, body) + return body + + def _purge_expired_workers(self): + """Purge expired workers from the services container""" + for worker in list( + self._services.expired_workers() + ): # Copying to list as we are going to mutate + self._log.debug( + "Worker %d timed out after %0.2f sec, purging", + id_to_int(worker.id), + self._busy_worker_timeout + if worker.is_busy + else self._heartbeat_timeout, + ) + self._send_to_worker(worker.id, protocol.DISCONNECT) + self._services.remove_worker(worker.id) + + def _send_heartbeat(self): + """Send heartbeat to all workers that didn't get any messages recently""" + now = time.time() + for worker in self._services.heartbeat_workers(): + self._log.debug("Sending heartbeat to worker: %d", id_to_int(worker.id)) + self._send_to_worker(worker.id, protocol.HEARTBEAT) + worker.next_heartbeat = now + self._heartbeat_interval + + def _send_to_worker(self, worker_id, command, body=None): + # type: (bytes, bytes, Optional[List[bytes]]) -> None + """Send message to worker""" + if body is None: + body = [] + self._socket.send_multipart( + [worker_id, b"", protocol.WORKER_HEADER, command] + body + ) + + def _send_to_client(self, client_id, command, body): + # type: (bytes, bytes, List[bytes]) -> None + """Send message to client""" + self._socket.send_multipart( + [client_id, b"", protocol.CLIENT_HEADER, command] + body + ) + + @staticmethod + def _parse_incoming_message(frames): + # type: (List[bytes]) -> protocol.Message + """Parse and verify incoming message""" + if len(frames) < 4: + raise error.ProtocolError( + "Unexpected message length: expecting at least 4 frames, got {}".format( + len(frames) + ) + ) + + if frames[1] != b"": + raise error.ProtocolError( + "Expecting empty frame 1, got {} bytes".format(len(frames[1])) + ) + + return protocol.Message( + client=frames[0], header=frames[2], command=frames[3], message=frames[4:] + ) + + +class ServiceWorker: + """ServiceWorker objects represent a connected / known MDP worker process""" + + def __init__( + self, + worker_id: bytes, + service: set[bytes], + expire_at: float, + next_heartbeat: float, + ): + self.id = worker_id + self.service = service + self.expire_at = expire_at + self.next_heartbeat = next_heartbeat + self.is_busy = False + self.request_handler: str + self.request_params: list[str] = [] + + def is_expired(self, now=None): + # type: (Optional[float]) -> bool + """Check if worker is expired""" + if now is None: + now = time.time() + return now >= self.expire_at + + def is_heartbeat(self, now=None): + # type: (Optional[float]) -> bool + """Check if worker is due for sending a heartbeat message""" + if now is None: + now = time.time() + return now >= self.next_heartbeat + + +class Service: + """Service objects manage all workers that can handle a specific service, as well as a queue of MDP Client + requests to be handled by this service + """ + + def __init__(self): + self._queue = deque() + self._workers = OrderedDict() + + def queue_request(self, client, request_body): + # type: (bytes, List[bytes]) -> None + """Queue a client request""" + self._queue.append((client, request_body)) + + def add_worker(self, worker): + # type: (ServiceWorker) -> None + """Add a ServiceWorker to the service""" + self._workers[worker.id] = worker + + def remove_worker(self, worker_id): + # type: (bytes) -> None + """Remove a worker from the service""" + if worker_id in self._workers: + del self._workers[worker_id] + + def dequeue_pending(self): + # type: () -> Generator[Tuple[List[bytes], ServiceWorker, bytes], None, None] + """Dequeue pending workers and requests to be handled by them""" + while len(self._queue) and len(self._workers): + client, message = self._queue.popleft() + _, worker = self._workers.popitem(last=False) + yield message, worker, client + + @property + def queued_requests(self): + # type: () -> int + """Number of queued requests for this service""" + return len(self._queue) + + @property + def available_workers(self): + # type: () -> int + """Number of available workers for this service""" + return len(self._workers) + + +class ServicesContainer: + """A container for all services managed by the broker""" + + def __init__(self, busy_workers_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT): + # type: (float) -> None + self._busy_workers_timeout = busy_workers_timeout + self._services = defaultdict(Service) # type: DefaultDict[bytes, Service] + self._workers = OrderedDict() # type: OrderedDict[bytes, ServiceWorker] + self._busy_workers = dict() # type: Dict[bytes, ServiceWorker] + + def queue_client_request(self, client, service, body): + # type: (bytes, bytes, List[bytes]) -> None + """Queue a request from a client""" + self._services[service].queue_request(client, body) + + def dequeue_pending(self): + # type: () -> Generator[Tuple[List[bytes], ServiceWorker, bytes], None, None] + """Pop ready message-worker pairs from all service queues so they can be dispatched""" + for service_name, service in self._services.items(): + for message, worker, client in service.dequeue_pending(): + for other_service in self._services.values(): + other_service.remove_worker(worker.id) + yield message, worker, client, service_name + + def get_worker(self, worker_id): + # type: (bytes) -> Optional[ServiceWorker] + """Get a worker by ID if exists""" + return self._workers.get(worker_id, None) + + def add_worker( + self, worker_id: bytes, service: bytes, expire_at: float, next_heartbeat: float + ): + """Add a worker to the list of available workers""" + if worker_id in self._workers: + worker = self._workers[worker_id] + if service in worker.service: + raise error.StateError( + f"Worker '{id_to_int(worker_id)}' has already sent READY message for service '{service.decode('ascii')}'" + ) + else: + worker.service.add(service) + else: + worker = ServiceWorker(worker_id, set([service]), expire_at, next_heartbeat) + self._workers[worker.id] = worker + self._services[service].add_worker(worker) + return worker + + def set_worker_busy(self, worker_id, expire_at): + # type: (bytes, float) -> ServiceWorker + """Mark a worker as busy - that is currently processing a request and not available for more work""" + if worker_id not in self._workers: + raise error.StateError( + "Worker id {} is not in the list of available workers".format( + id_to_int(worker_id) + ) + ) + worker = self._workers.pop(worker_id) + worker.is_busy = True + worker.expire_at = expire_at + self._busy_workers[worker.id] = worker + return worker + + def set_worker_available(self, worker_id, expire_at, next_heartbeat): + # type: (bytes, float, float) -> ServiceWorker + """Mark a worker that was busy processing a request as back and available again""" + if worker_id not in self._busy_workers: + raise error.StateError( + "Worker id {} is not previously known or has expired".format( + id_to_int(worker_id) + ) + ) + worker = self._busy_workers.pop(worker_id) + worker.is_busy = False + worker.expire_at = expire_at + worker.next_heartbeat = next_heartbeat + self._workers[worker_id] = worker + for service in worker.service: + self._services[service].add_worker(worker) + return worker + + def is_busy_worker(self, worker_id): + # type: (bytes) -> bool + """Return True if the given worker_id is of a known busy worker""" + return worker_id in self._busy_workers + + def remove_worker(self, worker_id): + # type: (bytes) -> None + """Remove a worker from the list of known workers""" + try: + worker = self._workers[worker_id] + for service_name in worker.service: + service = self._services[service_name] + if worker_id in service._workers: + service.remove_worker(worker_id) + del self._workers[worker_id] + except KeyError: + try: + del self._busy_workers[worker_id] + except KeyError: + raise error.StateError(f"Worker id {id_to_int(worker_id)} is not known") + + def heartbeat_workers(self): + # type: () -> Generator[ServiceWorker, None, None] + """Get iterator of workers waiting for a heartbeat""" + now = time.time() + return (w for w in self._workers.values() if w.is_heartbeat(now)) + + def expired_workers(self): + # type: () -> Generator[ServiceWorker, None, None] + """Get iterator of workers that have expired (no heartbeat received in too long) and busy workers that + have not returned in too long + """ + now = time.time() + return ( + w + for w in chain(self._workers.values(), self._busy_workers.values()) + if w.is_expired(now) + ) diff --git a/frigate/majortomo/client.py b/frigate/majortomo/client.py new file mode 100644 index 000000000..d44e1e5c7 --- /dev/null +++ b/frigate/majortomo/client.py @@ -0,0 +1,182 @@ +"""MDP 0.2 Client implementation""" + +# Copyright (c) 2018 Shoppimon LTD +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Iterable, List, Optional, Tuple # noqa: F401 + +import zmq + +from . import error as e +from . import protocol as p +from .util import TextOrBytes, text_to_ascii_bytes + +DEFAULT_ZMQ_LINGER = 2000 + + +class Client(object): + """MDP 0.2 Client implementation + + :ivar _socket: zmq.Socket + :type _socket: zmq.Socket + """ + + def __init__(self, broker_url, zmq_context=None, zmq_linger=DEFAULT_ZMQ_LINGER): + # type: (str, Optional[zmq.Context], int) -> None + self.broker_url = broker_url + self._socket = None # type: zmq.Socket + self._zmq_context = zmq_context if zmq_context else zmq.Context.instance() + self._linger = zmq_linger + self._log = logging.getLogger(__name__) + self._expect_reply = False + + def connect(self, reconnect=False): + # type: (bool) -> None + if self.is_connected(): + if not reconnect: + return + self._disconnect() + + # Set up socket + self._socket = self._zmq_context.socket(zmq.DEALER) + self._socket.setsockopt(zmq.LINGER, self._linger) + self._socket.connect(self.broker_url) + self._log.debug( + "Connected to broker on ZMQ DEALER socket at %s", self.broker_url + ) + self._expect_reply = False + + def close(self): + if not self.is_connected(): + return + self._disconnect() + + def _disconnect(self): + if not self.is_connected(): + return + self._log.debug( + "Disconnecting from broker on ZMQ DEALER socket at %s", self.broker_url + ) + self._socket.setsockopt(zmq.LINGER, 0) + self._socket.disconnect(self.broker_url) + self._socket.close() + self._socket = None + + def is_connected(self): + # type: () -> bool + """Tell whether we are currently connected""" + return self._socket is not None + + def send(self, service, *args): + # type: (TextOrBytes, *bytes) -> None + """Send a REQUEST command to the broker to be passed to the given service. + + Each additional argument will be sent as a request body frame. + """ + if self._expect_reply: + raise e.StateError( + "Still expecting reply from broker, cannot send new request" + ) + + service = text_to_ascii_bytes(service) + self._log.debug( + "Sending REQUEST message to %s with %d frames in body", service, len(args) + ) + self._socket.send_multipart((b"", p.CLIENT_HEADER, p.REQUEST, service) + args) + self._expect_reply = True + + def recv_part(self, timeout=None): + # type: (Optional[float]) -> Optional[List[bytes]] + """Receive a single part of the reply, partial or final + + Note that a "part" is actually a list in this case, as any reply part can contain multiple frames. + + If there are no more parts to receive, will return None + """ + if not self._expect_reply: + return None + + timeout = int(timeout * 1000) if timeout else None + + poller = zmq.Poller() + poller.register(self._socket, zmq.POLLIN) + + try: + socks = dict(poller.poll(timeout=timeout)) + if socks.get(self._socket) == zmq.POLLIN: + message = self._socket.recv_multipart() + m_type, m_content = self._parse_message(message) + if m_type == p.FINAL: + self._expect_reply = False + return m_content + else: + raise e.Timeout("Timed out waiting for reply from broker") + finally: + poller.unregister(self._socket) + + def recv_all(self, timeout=None): + # type: (Optional[float]) -> Iterable[List[bytes]] + """Return a generator allowing to iterate over all reply parts + + Note that `timeout` applies to each part, not to the full list of parts + """ + while True: + part = self.recv_part(timeout) + if part is None: + break + yield part + + def recv_all_as_list(self, timeout=None): + # type: (Optional[float]) -> List[bytes] + """Return all reply parts as a single, flat list of frames""" + return [frame for part in self.recv_all(timeout) for frame in part] + + @staticmethod + def _parse_message(message): + # type: (List[bytes]) -> Tuple[bytes, List[bytes]] + """Parse and validate an incoming message""" + if len(message) < 3: + raise e.ProtocolError( + "Unexpected message length, expecting at least 3 frames, got {}".format( + len(message) + ) + ) + + if message.pop(0) != b"": + raise e.ProtocolError("Expecting first message frame to be empty") + + if message[0] != p.CLIENT_HEADER: + print(message) + raise e.ProtocolError( + "Unexpected protocol header [{}], expecting [{}]".format( + message[0].decode("utf8"), p.WORKER_HEADER.decode("utf8") + ) + ) + + if message[1] not in {p.PARTIAL, p.FINAL}: + raise e.ProtocolError( + "Unexpected message type [{}], expecting either PARTIAL or FINAL".format( + message[1].decode("utf8") + ) + ) + + return message[1], message[2:] + + def __enter__(self): + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() diff --git a/frigate/majortomo/error.py b/frigate/majortomo/error.py new file mode 100644 index 000000000..0f13f0c58 --- /dev/null +++ b/frigate/majortomo/error.py @@ -0,0 +1,45 @@ +"""ZeroMQ MDP Client / Worker Errors""" + +# Copyright (c) 2018 Shoppimon LTD +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Error(RuntimeError): + """Parent exception for all zmq_mdp errors""" + + pass + + +class ProtocolError(Error): + """MDP 0.2 Protocol Mismatch""" + + pass + + +class Disconnected(Error): + """We are no longer connected""" + + pass + + +class StateError(Error): + """System is in an unexpected state""" + + pass + + +class Timeout(Error, TimeoutError): + """Operation timed out""" + + pass diff --git a/frigate/majortomo/protocol.py b/frigate/majortomo/protocol.py new file mode 100644 index 000000000..623f959ea --- /dev/null +++ b/frigate/majortomo/protocol.py @@ -0,0 +1,64 @@ +"""ZeroMQ MDP 0.2 Protocol Constants common for Worker and Client +""" + +# Copyright (c) 2018 Shoppimon LTD +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional # noqa: F401 + +from . import error + +WORKER_HEADER = b"MDPW02" +CLIENT_HEADER = b"MDPC02" + +READY = b"\001" +REQUEST = b"\002" +PARTIAL = b"\003" +FINAL = b"\004" +HEARTBEAT = b"\005" +DISCONNECT = b"\006" + +DEFAULT_HEARTBEAT_INTERVAL = 2.500 +DEFAULT_HEARTBEAT_TIMEOUT = 10.000 +DEFAULT_BUSY_WORKER_TIMEOUT = 900.000 + + +class Message(object): + """Majordomo message container""" + + ALLOWED_HEADERS = {WORKER_HEADER, CLIENT_HEADER} + ALLOWED_COMMANDS = { + WORKER_HEADER: {READY, PARTIAL, FINAL, HEARTBEAT, DISCONNECT}, + CLIENT_HEADER: {REQUEST}, + } + + def __init__(self, client, header, command, message=None): + # type: (bytes, bytes, bytes, Optional[List[bytes]]) -> None + if header not in self.ALLOWED_HEADERS: + raise error.ProtocolError( + "Unexpected protocol header: {}".format(header.decode("utf8")) + ) + + if command not in self.ALLOWED_COMMANDS[header]: + raise error.ProtocolError( + "Unexpected command: {}".format(command.decode("utf8")) + ) + + if message is None: + message = [] + + self.client = client + self.header = header + self.command = command + self.message = message diff --git a/frigate/majortomo/util.py b/frigate/majortomo/util.py new file mode 100644 index 000000000..a81bb7651 --- /dev/null +++ b/frigate/majortomo/util.py @@ -0,0 +1,29 @@ +"""Utilities and helpers useful in other modules""" +from typing import Text, Union + +TextOrBytes = Union[Text, bytes] + + +def text_to_ascii_bytes(text: TextOrBytes) -> bytes: + """Convert a text-or-bytes value to ASCII-encoded bytes + + If the input is already `bytes`, we simply return it as is + """ + if isinstance(text, str): + return text.encode("ascii", "strict") + return text + + +def id_to_int(id_): + # type: (Union[bytes, str]) -> int + """Convert a ZMQ client ID to printable integer + + This is needed to log client IDs while maintaining Python cross-version compatibility (so we can't use bytes.hex() + for example) + """ + i = 0 + for c in id_: + if not isinstance(c, int): + c = ord(c) + i = (i << 8) + c + return i diff --git a/frigate/majortomo/worker.py b/frigate/majortomo/worker.py new file mode 100644 index 000000000..6659a267f --- /dev/null +++ b/frigate/majortomo/worker.py @@ -0,0 +1,312 @@ +"""MDP 0.2 Worker implementation""" + +# Copyright (c) 2018 Shoppimon LTD +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import signal +import time +import multiprocessing as mp +from typing import Generator, Iterable, List, Optional, Tuple # noqa: F401 + +import zmq + +from . import error, protocol +from .util import TextOrBytes, text_to_ascii_bytes + +DEFAULT_ZMQ_LINGER = 2500 + + +class Worker(object): + """MDP 0.2 Worker implementation""" + + def __init__( + self, + broker_url: str, + service_names: List[TextOrBytes], + heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL, + heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT, + zmq_context=None, + zmq_linger=DEFAULT_ZMQ_LINGER, + handler_name: TextOrBytes = None, + handler_params: List[TextOrBytes] = [], + stop_event: mp.Event = None, + ): + self.broker_url = broker_url + self.service_names = [ + text_to_ascii_bytes(service_name) for service_name in service_names + ] + + self.heartbeat_interval = heartbeat_interval + self._socket = None # type: zmq.Socket + self._poller = None # type: zmq.Poller + self._zmq_context = zmq_context if zmq_context else zmq.Context.instance() + self._linger = zmq_linger + self._log = logging.getLogger(__name__) + self._heartbeat_timeout = heartbeat_timeout + self._last_broker_hb = 0.0 + self._last_sent_message = 0.0 + + self.ready_params = [ + text_to_ascii_bytes(rp) + for rp in ( + ([handler_name] if handler_name is not None else []) + handler_params + ) + ] + + self.stop_event = stop_event or mp.Event() + + def connect(self, reconnect=False): + # type: (bool) -> None + if self.is_connected(): + if not reconnect: + return + self._disconnect() + + # Set up socket + self._socket = self._zmq_context.socket(zmq.DEALER) + self._socket.setsockopt(zmq.LINGER, self._linger) + self._socket.connect(self.broker_url) + self._log.debug( + "Connected to broker on ZMQ DEALER socket at %s", self.broker_url + ) + + self._poller = zmq.Poller() + self._poller.register(self._socket, zmq.POLLIN) + + self._send_ready() + self._last_broker_hb = time.time() + + def wait_for_request(self): + # type: () -> Tuple[bytes, List[bytes]] + """Wait for a REQUEST command from the broker and return the client address and message body frames. + + Will internally handle timeouts, heartbeats and check for protocol errors and disconnect commands. + """ + command, frames = self._receive() + + if command == protocol.DISCONNECT: + self._log.debug("Got DISCONNECT from broker; Disconnecting") + self._disconnect() + raise error.Disconnected("Disconnected on message from broker") + + elif command != protocol.REQUEST: + raise error.ProtocolError( + "Unexpected message type from broker: {}".format(command.decode("utf8")) + ) + + if len(frames) < 3: + raise error.ProtocolError( + "Unexpected REQUEST message size, got {} frames, expecting at least 3".format( + len(frames) + ) + ) + + client_addr = frames[0] + request = frames[2:] + return client_addr, request + + def send_reply_final(self, client, frames): + # type: (bytes, List[bytes]) -> None + """Send final reply to client + + FINAL reply means the client will not expect any additional parts to the reply. This should be used + when the entire reply is ready to be delivered. + """ + self._send_to_client(client, protocol.FINAL, *frames) + + def send_reply_partial(self, client, frames): + # type: (bytes, List[bytes]) -> None + """Send the given set of frames as a partial reply to client + + PARTIAL reply means the client will expect zero or more additional PARTIAL reply messages following + this one, with exactly one terminating FINAL reply following. This should be used if parts of the + reply are ready to be sent, and the client is capable of processing them while the worker is still + at work on the rest of the reply. + """ + self._send_to_client(client, protocol.PARTIAL, *frames) + + def send_reply_from_iterable(self, client, frames_iter, final=None): + # type: (bytes, Iterable[List[bytes]], List[bytes]) -> None + """Send multiple partial replies from an iterator as PARTIAL replies to client. + + If `final` is provided, it will be sent as the FINAL reply after all PARTIAL replies are sent. + """ + for part in frames_iter: + self.send_reply_partial(client, part) + if final: + self.send_reply_final(client, final) + + def close(self): + if not self.is_connected(): + return + self._send_disconnect() + self._disconnect() + + def is_connected(self): + return self._socket is not None + + def _disconnect(self): + if not self.is_connected(): + return + self._socket.disconnect(self.broker_url) + self._socket.close() + self._socket = None + self._last_sent_message -= self.heartbeat_interval + + def _receive(self): + # type: () -> Tuple[bytes, List[bytes]] + """Poll on the socket until a command is received + + Will handle timeouts and heartbeats internally without returning + """ + while not self.stop_event.is_set(): + if self._socket is None: + raise error.Disconnected("Worker is disconnected") + + self._check_send_heartbeat() + poll_timeout = self._get_poll_timeout() + + try: + socks = dict(self._poller.poll(timeout=poll_timeout)) + except zmq.error.ZMQError: + # Probably connection was explicitly closed + if self._socket is None: + continue + raise + + if socks.get(self._socket) == zmq.POLLIN: + message = self._socket.recv_multipart() + self._log.debug("Got message of %d frames", len(message)) + else: + self._log.debug("Receive timed out after %d ms", poll_timeout) + if (time.time() - self._last_broker_hb) > self._heartbeat_timeout: + # We're not connected anymore? + self._log.info( + "Got no heartbeat in %d sec, disconnecting and reconnecting socket", + self._heartbeat_timeout, + ) + self.connect(reconnect=True) + continue + + command, frames = self._verify_message(message) + self._last_broker_hb = time.time() + + if command == protocol.HEARTBEAT: + self._log.debug("Got heartbeat message from broker") + continue + + return command, frames + + return protocol.DISCONNECT, [] + + def _send_ready(self): + for service_name in self.service_names: + self._send(protocol.READY, service_name, *self.ready_params) + + def handle_request(client_id: bytes, request: List[bytes]) -> List[bytes]: + return request + + def _send_disconnect(self): + self._send(protocol.DISCONNECT) + + def _check_send_heartbeat(self): + if time.time() - self._last_sent_message >= self.heartbeat_interval: + self._log.debug("Sending HEARTBEAT to broker") + self._send(protocol.HEARTBEAT) + + def _send_to_client(self, client, message_type, *frames): + self._send(message_type, client, b"", *frames) + + def _send(self, message_type, *args): + # type: (bytes, *bytes) -> None + self._socket.send_multipart((b"", protocol.WORKER_HEADER, message_type) + args) + self._last_sent_message = time.time() + + def _get_poll_timeout(self): + # type: () -> int + """Return the poll timeout for the current iteration in milliseconds""" + return max( + 0, + int( + (time.time() - self._last_sent_message + self.heartbeat_interval) * 1000 + ), + ) + + @staticmethod + def _verify_message(message): + # type: (List[bytes]) -> Tuple[bytes, List[bytes]] + if len(message) < 3: + raise error.ProtocolError( + "Unexpected message length, expecting at least 3 frames, got {}".format( + len(message) + ) + ) + + if message.pop(0) != b"": + raise error.ProtocolError("Expecting first message frame to be empty") + + if message[0] != protocol.WORKER_HEADER: + print(message) + raise error.ProtocolError( + "Unexpected protocol header [{}], expecting [{}]".format( + message[0].decode("utf8"), protocol.WORKER_HEADER.decode("utf8") + ) + ) + + if message[1] not in { + protocol.DISCONNECT, + protocol.HEARTBEAT, + protocol.REQUEST, + }: + raise error.ProtocolError( + "Unexpected message type [{}], expecting either HEARTBEAT, REQUEST or " + "DISCONNECT".format(message[1].decode("utf8")) + ) + + return message[1], message[2:] + + def __enter__(self): + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def start(self): + if not self.is_connected(): + self.connect() + + def signal_handler(sig_num, _): + self.stop() + + for sig_num in (signal.SIGINT, signal.SIGTERM): + signal.signal(sig_num, signal_handler) + + while not self.stop_event.is_set(): + try: + client_id, request = self.wait_for_request() + reply = self.handle_request(client_id, request) + self.send_reply_final(client_id, reply) + except error.ProtocolError as e: + self._log.warning("Protocol error: %s, dropping request", str(e)) + continue + except error.Disconnected: + self._log.info("Worker disconnected") + break + + self.close() + + def stop(self): + self.stop_event.set() diff --git a/frigate/test/test_object_detector.py b/frigate/test/test_object_detector.py index 274fad2df..e106a0910 100644 --- a/frigate/test/test_object_detector.py +++ b/frigate/test/test_object_detector.py @@ -13,7 +13,7 @@ from frigate.detectors import ( ObjectDetectionClient, ObjectDetectionWorker, ) -from frigate.majordomo import QueueBroker +from frigate.majortomo import Broker from frigate.util import deep_merge import frigate.detectors.detector_types as detectors @@ -43,7 +43,7 @@ def start_broker(ipc_address, tcp_address, camera_names): body = body[0:2] + [tensor_input] return body - queue_broker = QueueBroker(bind=[ipc_address, tcp_address]) + queue_broker = Broker(bind=[ipc_address, tcp_address]) queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm) queue_broker.start() diff --git a/requirements-wheels.txt b/requirements-wheels.txt index ba81ffd13..232100f3f 100644 --- a/requirements-wheels.txt +++ b/requirements-wheels.txt @@ -13,7 +13,6 @@ pydantic == 1.10.* PyYAML == 6.0 pytz == 2022.6 pyzmq == 24.0.1 -majortomo == 0.2.0 tzlocal == 4.2 types-PyYAML == 6.0.* requests == 2.28.*