mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
fixes for shm between containers
This commit is contained in:
parent
75f9dc09cb
commit
800105fbff
@ -4,7 +4,7 @@
|
||||
|
||||
set -o errexit -o nounset -o pipefail
|
||||
|
||||
dirs=(/dev/shm/logs/frigate /dev/shm/logs/go2rtc /dev/shm/logs/nginx)
|
||||
dirs=("/dev/shm/$HOSTNAME/logs/frigate" /dev/shm/$HOSTNAME/logs/go2rtc /dev/shm/$HOSTNAME/logs/nginx)
|
||||
|
||||
mkdir -p "${dirs[@]}"
|
||||
chown nobody:nogroup "${dirs[@]}"
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#!/command/with-contenv bash
|
||||
# shellcheck shell=bash
|
||||
|
||||
exec logutil-service /dev/shm/logs/frigate
|
||||
exec logutil-service /dev/shm/$HOSTNAME/logs/frigate
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#!/command/with-contenv bash
|
||||
# shellcheck shell=bash
|
||||
|
||||
exec logutil-service /dev/shm/logs/go2rtc
|
||||
exec logutil-service /dev/shm/$HOSTNAME/logs/go2rtc
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
#!/command/with-contenv bash
|
||||
# shellcheck shell=bash
|
||||
|
||||
exec logutil-service /dev/shm/logs/nginx
|
||||
exec logutil-service /dev/shm/$HOSTNAME/logs/nginx
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from multiprocessing.queues import Queue
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from multiprocessing.synchronize import Event as MpEvent
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from typing import Optional, Any
|
||||
from typing import Optional
|
||||
from types import FrameType
|
||||
|
||||
import traceback
|
||||
@ -23,7 +23,7 @@ from frigate.detectors import ObjectDetectProcess
|
||||
from frigate.events import EventCleanup, EventProcessor
|
||||
from frigate.http import create_app
|
||||
from frigate.log import log_process, root_configurer
|
||||
from frigate.majordomo import QueueBroker, BrokerWorker
|
||||
from frigate.majortomo import Broker, ServiceWorker
|
||||
from frigate.models import Event, Recordings
|
||||
from frigate.object_processing import TrackedObjectProcessor
|
||||
from frigate.output import output_frames
|
||||
@ -45,7 +45,7 @@ class FrigateApp:
|
||||
self.stop_event: MpEvent = mp.Event()
|
||||
self.detectors: dict[str, ObjectDetectProcess] = {}
|
||||
self.detection_out_events: dict[str, MpEvent] = {}
|
||||
self.detection_shms: dict[str, mp.shared_memory.SharedMemory] = {}
|
||||
self.detection_shms: dict[str, SharedMemory] = {}
|
||||
self.log_queue: Queue = mp.Queue()
|
||||
self.plus_api = PlusApi()
|
||||
self.camera_metrics: dict[str, CameraMetricsTypes] = {}
|
||||
@ -187,15 +187,15 @@ class FrigateApp:
|
||||
|
||||
def start_queue_broker(self) -> None:
|
||||
def detect_no_shm(
|
||||
worker: BrokerWorker, service_name: bytes, body: list[bytes]
|
||||
worker: ServiceWorker, service_name: bytes, body: list[bytes]
|
||||
) -> list[bytes]:
|
||||
in_shm = self.detection_shms[str(service_name, "ascii")]
|
||||
tensor_input = in_shm.buf
|
||||
body = body[0:2] + [tensor_input]
|
||||
return body
|
||||
|
||||
bind_urls = [self.config.broker.ipc] + self.config.broker.addresses
|
||||
self.queue_broker = QueueBroker(bind=bind_urls)
|
||||
bind_urls = [self.config.server.ipc] + self.config.server.addresses
|
||||
self.queue_broker = Broker(bind=bind_urls)
|
||||
self.queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm)
|
||||
self.queue_broker.start()
|
||||
|
||||
@ -208,20 +208,18 @@ class FrigateApp:
|
||||
for (name, det) in self.config.detectors.items()
|
||||
]
|
||||
)
|
||||
shm_in = mp.shared_memory.SharedMemory(
|
||||
shm_in = SharedMemory(
|
||||
name=name,
|
||||
create=True,
|
||||
size=largest_frame,
|
||||
)
|
||||
except FileExistsError:
|
||||
shm_in = mp.shared_memory.SharedMemory(name=name)
|
||||
shm_in = SharedMemory(name=name)
|
||||
|
||||
try:
|
||||
shm_out = mp.shared_memory.SharedMemory(
|
||||
name=f"out-{name}", create=True, size=20 * 6 * 4
|
||||
)
|
||||
shm_out = SharedMemory(name=f"out-{name}", create=True, size=20 * 6 * 4)
|
||||
except FileExistsError:
|
||||
shm_out = mp.shared_memory.SharedMemory(name=f"out-{name}")
|
||||
shm_out = SharedMemory(name=f"out-{name}")
|
||||
|
||||
self.detection_shms[name] = shm_in
|
||||
self.detection_shms[f"out-{name}"] = shm_out
|
||||
@ -274,6 +272,7 @@ class FrigateApp:
|
||||
camera_name,
|
||||
config,
|
||||
self.config.model,
|
||||
self.config.server,
|
||||
self.config.model.merged_labelmap,
|
||||
self.detected_frames_queue,
|
||||
self.camera_metrics[camera_name],
|
||||
@ -377,6 +376,7 @@ class FrigateApp:
|
||||
self.ensure_dirs()
|
||||
|
||||
if self.config.server.mode == ServerModeEnum.DetectionOnly:
|
||||
logger.info("Starting server in detection only mode.")
|
||||
self.start_detectors()
|
||||
self.start_watchdog()
|
||||
self.stop_event.wait()
|
||||
|
||||
@ -819,7 +819,7 @@ class ServerConfig(BaseModel):
|
||||
mode: ServerModeEnum = Field(default=ServerModeEnum.Full, title="Server mode")
|
||||
ipc: str = Field(default="ipc://queue_broker.ipc", title="Broker IPC path")
|
||||
addresses: List[str] = Field(
|
||||
default=["tcp://127.0.0.1:5555"], title="Broker TCP addresses"
|
||||
default=["tcp://0.0.0.0:5555"], title="Broker TCP addresses"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import multiprocessing as mp
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from majortomo import Client
|
||||
from frigate.majortomo import Client
|
||||
from frigate.util import EventsPerSecond
|
||||
from .detector_config import ModelConfig
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ import multiprocessing as mp
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import List
|
||||
|
||||
from frigate.majordomo import QueueWorker
|
||||
from frigate.majortomo import Worker
|
||||
from frigate.util import listen, EventsPerSecond, load_labels
|
||||
from .detector_config import InputTensorEnum, BaseDetectorConfig
|
||||
from .detector_types import create_detector
|
||||
@ -17,7 +17,7 @@ from setproctitle import setproctitle
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ObjectDetectionWorker(QueueWorker):
|
||||
class ObjectDetectionWorker(Worker):
|
||||
def __init__(
|
||||
self,
|
||||
detector_name: str,
|
||||
@ -94,7 +94,9 @@ class ObjectDetectionWorker(QueueWorker):
|
||||
if out_np is None:
|
||||
out_shm = self.detection_shms.get(f"out-{camera_name}", None)
|
||||
if out_shm is None:
|
||||
out_shm = SharedMemory(name=f"out-{camera_name}", create=False)
|
||||
out_shm = self.detection_shms[f"out-{camera_name}"] = SharedMemory(
|
||||
name=f"out-{camera_name}", create=False
|
||||
)
|
||||
out_np = self.detection_outputs[camera_name] = np.ndarray(
|
||||
(20, 6), dtype=np.float32, buffer=out_shm.buf
|
||||
)
|
||||
@ -183,7 +185,7 @@ class ObjectDetectProcess:
|
||||
|
||||
self.avg_inference_speed = mp.Value("d", 0.01)
|
||||
self.detection_start = mp.Value("d", 0.0)
|
||||
self.detect_process: mp.Process
|
||||
self.detect_process: mp.Process = None
|
||||
|
||||
self.start_or_restart()
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from .detection_api import DetectionApi
|
||||
from . import plugins
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("frigate.detectors")
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
|
||||
@ -1233,10 +1233,11 @@ def vainfo():
|
||||
|
||||
@bp.route("/logs/<service>", methods=["GET"])
|
||||
def logs(service: str):
|
||||
HOSTNAME = os.environ.get("HOSTNAME", "frigate")
|
||||
log_locations = {
|
||||
"frigate": "/dev/shm/logs/frigate/current",
|
||||
"go2rtc": "/dev/shm/logs/go2rtc/current",
|
||||
"nginx": "/dev/shm/logs/nginx/current",
|
||||
"frigate": f"/dev/shm/{HOSTNAME}/logs/frigate/current",
|
||||
"go2rtc": f"/dev/shm/{HOSTNAME}/logs/go2rtc/current",
|
||||
"nginx": f"/dev/shm/{HOSTNAME}/logs/nginx/current",
|
||||
}
|
||||
service_location = log_locations.get(service)
|
||||
|
||||
|
||||
@ -1,315 +0,0 @@
|
||||
"""Majordomo / `MDP/0.2 <https://rfc.zeromq.org/spec:18/MDP/>`_ broker implementation.
|
||||
|
||||
Extends the implementation of https://github.com/shoppimon/majortomo
|
||||
"""
|
||||
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
import zmq
|
||||
import multiprocessing as mp
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import (
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
) # noqa: F401
|
||||
from majortomo import error, protocol
|
||||
from majortomo.config import DEFAULT_BIND_URL
|
||||
from majortomo.broker import (
|
||||
Broker,
|
||||
Worker as MdpBrokerWorker,
|
||||
ServicesContainer,
|
||||
id_to_int,
|
||||
)
|
||||
from majortomo.util import TextOrBytes, text_to_ascii_bytes
|
||||
from majortomo.worker import DEFAULT_ZMQ_LINGER, Worker as MdpWorker
|
||||
|
||||
|
||||
class BrokerWorker(MdpBrokerWorker):
|
||||
"""Worker objects represent a connected / known MDP worker process"""
|
||||
|
||||
def __init__(
|
||||
self, worker_id: bytes, service: bytes, expire_at: float, next_heartbeat: float
|
||||
):
|
||||
super().__init__(worker_id, service, expire_at, next_heartbeat)
|
||||
self.request_handler: str
|
||||
self.request_params: list[str] = []
|
||||
|
||||
|
||||
class QueueServicesContainer(ServicesContainer):
|
||||
def __init__(self, busy_workers_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT):
|
||||
super().__init__(busy_workers_timeout)
|
||||
|
||||
def dequeue_pending(self):
|
||||
# type: () -> Generator[Tuple[List[bytes], BrokerWorker, bytes], None, None]
|
||||
"""Pop ready message-worker pairs from all service queues so they can be dispatched"""
|
||||
for service_name, service in self._services.items():
|
||||
for message, worker, client in service.dequeue_pending():
|
||||
yield message, worker, client, service_name
|
||||
|
||||
def add_worker(
|
||||
self, worker_id: bytes, service: bytes, expire_at: float, next_heartbeat: float
|
||||
):
|
||||
"""Add a worker to the list of available workers"""
|
||||
if worker_id in self._workers:
|
||||
worker = self._workers[worker_id]
|
||||
if service in worker.service:
|
||||
raise error.StateError(
|
||||
f"Worker '{id_to_int(worker_id)}' has already sent READY message for service '{service.decode('ascii')}'"
|
||||
)
|
||||
else:
|
||||
worker.service.add(service)
|
||||
else:
|
||||
worker = BrokerWorker(worker_id, set([service]), expire_at, next_heartbeat)
|
||||
self._workers[worker.id] = worker
|
||||
self._services[service].add_worker(worker)
|
||||
return worker
|
||||
|
||||
def set_worker_available(self, worker_id, expire_at, next_heartbeat):
|
||||
# type: (bytes, float, float) -> BrokerWorker
|
||||
"""Mark a worker that was busy processing a request as back and available again"""
|
||||
if worker_id not in self._busy_workers:
|
||||
raise error.StateError(
|
||||
"Worker id {} is not previously known or has expired".format(
|
||||
id_to_int(worker_id)
|
||||
)
|
||||
)
|
||||
worker = self._busy_workers.pop(worker_id)
|
||||
worker.is_busy = False
|
||||
worker.expire_at = expire_at
|
||||
worker.next_heartbeat = next_heartbeat
|
||||
self._workers[worker_id] = worker
|
||||
for service in worker.service:
|
||||
self._services[service].add_worker(worker)
|
||||
return worker
|
||||
|
||||
def remove_worker(self, worker_id):
|
||||
# type: (bytes) -> None
|
||||
"""Remove a worker from the list of known workers"""
|
||||
try:
|
||||
worker = self._workers[worker_id]
|
||||
for service_name in worker.service:
|
||||
service = self._services[service_name]
|
||||
if worker_id in service._workers:
|
||||
service.remove_worker(worker_id)
|
||||
del self._workers[worker_id]
|
||||
except KeyError:
|
||||
try:
|
||||
del self._busy_workers[worker_id]
|
||||
except KeyError:
|
||||
raise error.StateError(f"Worker id {id_to_int(worker_id)} is not known")
|
||||
|
||||
|
||||
class QueueBroker(Broker):
|
||||
def __init__(
|
||||
self,
|
||||
bind: Union[str, List[str]] = [DEFAULT_BIND_URL],
|
||||
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
|
||||
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
|
||||
busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT,
|
||||
zmq_context=None,
|
||||
):
|
||||
super().__init__(
|
||||
heartbeat_interval=heartbeat_interval,
|
||||
heartbeat_timeout=heartbeat_timeout,
|
||||
busy_worker_timeout=busy_worker_timeout,
|
||||
zmq_context=zmq_context,
|
||||
)
|
||||
self._services = QueueServicesContainer(busy_worker_timeout)
|
||||
self._bind_urls = [bind] if not isinstance(bind, list) else bind
|
||||
self.broker_thread: threading.Thread = None
|
||||
self.request_handlers: Dict[str, object] = {}
|
||||
|
||||
def bind(self):
|
||||
"""Bind the ZMQ socket"""
|
||||
if self._socket:
|
||||
raise error.StateError("Socket is already bound")
|
||||
|
||||
self._socket = self._context.socket(zmq.ROUTER)
|
||||
self._socket.rcvtimeo = int(self._heartbeat_interval * 1000)
|
||||
for bind_url in self._bind_urls:
|
||||
self._socket.bind(bind_url)
|
||||
self._log.info("Broker listening on %s", bind_url)
|
||||
|
||||
def close(self):
|
||||
if self._socket is None:
|
||||
return
|
||||
for bind_url in self._bind_urls:
|
||||
self._socket.disconnect(bind_url)
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
self._log.info("Broker socket closing")
|
||||
|
||||
def start(self):
|
||||
self.broker_thread = threading.Thread(target=self.run)
|
||||
self.broker_thread.start()
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
if self.broker_thread is not None:
|
||||
self.broker_thread.join()
|
||||
self.broker_thread = None
|
||||
|
||||
def _handle_worker_message(self, message):
|
||||
if message.command == protocol.READY:
|
||||
self._handle_worker_ready(message)
|
||||
else:
|
||||
super()._handle_worker_message(message)
|
||||
|
||||
def _handle_worker_ready(self, message):
|
||||
# type: (protocol.Message) -> BrokerWorker
|
||||
worker_id = message.client
|
||||
service = message.message[0]
|
||||
self._log.info("Got READY from worker: %d", id_to_int(worker_id))
|
||||
now = time.time()
|
||||
expire_at = now + self._heartbeat_timeout
|
||||
next_heartbeat = now + self._heartbeat_interval
|
||||
worker = self._services.add_worker(
|
||||
worker_id, service, expire_at, next_heartbeat
|
||||
)
|
||||
worker.request_handler = None
|
||||
worker.request_params = []
|
||||
if len(message.message) > 1:
|
||||
worker.request_handler = str(message.message[1], "ascii")
|
||||
if len(message.message) > 2:
|
||||
worker.request_params = [str(m, "ascii") for m in message.message[2:]]
|
||||
return worker
|
||||
|
||||
def _dispatch_queued_messages(self):
|
||||
"""Dispatch all queued messages to available workers"""
|
||||
expire_at = time.time() + self._busy_worker_timeout
|
||||
for message, worker, client, service_name in self._services.dequeue_pending():
|
||||
body = [client, b""] + message
|
||||
body = self.on_worker_request(worker, service_name, body)
|
||||
self._send_to_worker(worker.id, protocol.REQUEST, body)
|
||||
self._services.set_worker_busy(worker.id, expire_at=expire_at)
|
||||
|
||||
def register_request_handler(self, handler_name: str, handler):
|
||||
self.request_handlers[handler_name] = handler
|
||||
|
||||
def on_worker_request(
|
||||
self, worker: BrokerWorker, service_name: bytes, body: List[bytes]
|
||||
):
|
||||
if worker.request_handler in self.request_handlers:
|
||||
handler = self.request_handlers[worker.request_handler]
|
||||
body = handler(worker, service_name, body)
|
||||
return body
|
||||
|
||||
|
||||
class QueueWorker(MdpWorker):
|
||||
def __init__(
|
||||
self,
|
||||
broker_url: str,
|
||||
service_names: List[TextOrBytes],
|
||||
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
|
||||
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
|
||||
zmq_context=None,
|
||||
zmq_linger=DEFAULT_ZMQ_LINGER,
|
||||
handler_name: TextOrBytes = None,
|
||||
handler_params: List[TextOrBytes] = [],
|
||||
stop_event: mp.Event = None,
|
||||
):
|
||||
super().__init__(
|
||||
broker_url,
|
||||
b"",
|
||||
heartbeat_interval,
|
||||
heartbeat_timeout,
|
||||
zmq_context,
|
||||
zmq_linger,
|
||||
)
|
||||
|
||||
self.service_names = [
|
||||
text_to_ascii_bytes(service_name) for service_name in service_names
|
||||
]
|
||||
self.ready_params = [
|
||||
text_to_ascii_bytes(rp)
|
||||
for rp in (
|
||||
([handler_name] if handler_name is not None else []) + handler_params
|
||||
)
|
||||
]
|
||||
self.stop_event = stop_event or mp.Event()
|
||||
|
||||
def _send_ready(self):
|
||||
for service_name in self.service_names:
|
||||
self._send(protocol.READY, service_name, *self.ready_params)
|
||||
|
||||
def handle_request(client_id: bytes, request: List[bytes]) -> List[bytes]:
|
||||
return request
|
||||
|
||||
def start(self):
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
def signal_handler(sig_num, _):
|
||||
self.stop()
|
||||
|
||||
for sig_num in (signal.SIGINT, signal.SIGTERM):
|
||||
signal.signal(sig_num, signal_handler)
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
client_id, request = self.wait_for_request()
|
||||
reply = self.handle_request(client_id, request)
|
||||
self.send_reply_final(reply)
|
||||
except error.ProtocolError as e:
|
||||
self._log.warning("Protocol error: %s, dropping request", str(e))
|
||||
continue
|
||||
except error.Disconnected:
|
||||
self._log.info("Worker disconnected")
|
||||
break
|
||||
|
||||
self.close()
|
||||
|
||||
def stop(self):
|
||||
self.stop_event.set()
|
||||
|
||||
def _receive(self):
|
||||
# type: () -> Tuple[bytes, List[bytes]]
|
||||
"""Poll on the socket until a command is received
|
||||
|
||||
Will handle timeouts and heartbeats internally without returning
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
if self._socket is None:
|
||||
raise error.Disconnected("Worker is disconnected")
|
||||
|
||||
self._check_send_heartbeat()
|
||||
poll_timeout = self._get_poll_timeout()
|
||||
|
||||
try:
|
||||
socks = dict(self._poller.poll(timeout=poll_timeout))
|
||||
except zmq.error.ZMQError:
|
||||
# Probably connection was explicitly closed
|
||||
if self._socket is None:
|
||||
continue
|
||||
raise
|
||||
|
||||
if socks.get(self._socket) == zmq.POLLIN:
|
||||
message = self._socket.recv_multipart()
|
||||
self._log.debug("Got message of %d frames", len(message))
|
||||
else:
|
||||
self._log.debug("Receive timed out after %d ms", poll_timeout)
|
||||
if (time.time() - self._last_broker_hb) > self._heartbeat_timeout:
|
||||
# We're not connected anymore?
|
||||
self._log.info(
|
||||
"Got no heartbeat in %d sec, disconnecting and reconnecting socket",
|
||||
self._heartbeat_timeout,
|
||||
)
|
||||
self.connect(reconnect=True)
|
||||
continue
|
||||
|
||||
command, frames = self._verify_message(message)
|
||||
self._last_broker_hb = time.time()
|
||||
|
||||
if command == protocol.HEARTBEAT:
|
||||
self._log.debug("Got heartbeat message from broker")
|
||||
continue
|
||||
|
||||
return command, frames
|
||||
|
||||
return protocol.DISCONNECT, []
|
||||
19
frigate/majortomo/__init__.py
Normal file
19
frigate/majortomo/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2018 Shoppimon LTD
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .client import Client
|
||||
from .worker import Worker
|
||||
from .broker import Broker, ServiceWorker
|
||||
|
||||
__all__ = ["Client", "Worker", "Broker", "ServiceWorker"]
|
||||
526
frigate/majortomo/broker.py
Normal file
526
frigate/majortomo/broker.py
Normal file
@ -0,0 +1,526 @@
|
||||
"""Majordomo / `MDP/0.2 <https://rfc.zeromq.org/spec:18/MDP/>`_ broker implementation."""
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
) # noqa: F401
|
||||
|
||||
import zmq
|
||||
|
||||
from . import error, protocol
|
||||
from .util import id_to_int
|
||||
|
||||
DEFAULT_BIND_URL = "tcp://0.0.0.0:5555"
|
||||
|
||||
|
||||
class Broker:
|
||||
def __init__(
|
||||
self,
|
||||
bind: Union[str, List[str]] = [DEFAULT_BIND_URL],
|
||||
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
|
||||
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
|
||||
busy_worker_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT,
|
||||
zmq_context=None,
|
||||
):
|
||||
self._log = logging.getLogger(__name__)
|
||||
self._bind_urls = [bind] if not isinstance(bind, list) else bind
|
||||
self._heartbeat_interval = heartbeat_interval
|
||||
self._heartbeat_timeout = heartbeat_timeout
|
||||
self._busy_worker_timeout = busy_worker_timeout
|
||||
|
||||
self._context = zmq_context if zmq_context else zmq.Context.instance()
|
||||
self._socket = None # type: zmq.Socket
|
||||
self._services = ServicesContainer(busy_worker_timeout)
|
||||
self._stop = False
|
||||
|
||||
self.broker_thread: threading.Thread = None
|
||||
self.request_handlers: Dict[str, object] = {}
|
||||
|
||||
def run(self):
|
||||
"""Run in a main loop handling all incoming requests, until `stop()` is called"""
|
||||
self._log.info("MDP Broker starting up")
|
||||
self._stop = False
|
||||
self.bind()
|
||||
|
||||
try:
|
||||
while not self._stop:
|
||||
try:
|
||||
message = self.receive()
|
||||
if message:
|
||||
self.handle_message(message)
|
||||
except Exception:
|
||||
self._log.exception("Message handling failed")
|
||||
|
||||
finally:
|
||||
self.close()
|
||||
self._log.info("MDP Broker shutting down")
|
||||
|
||||
def start(self):
|
||||
self.broker_thread = threading.Thread(target=self.run)
|
||||
self.broker_thread.name = "zmq_majordomo_broker"
|
||||
self.broker_thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the broker's main loop"""
|
||||
self._stop = True
|
||||
if self.broker_thread is not None:
|
||||
self.broker_thread.join()
|
||||
self.broker_thread = None
|
||||
|
||||
def bind(self):
|
||||
"""Bind the ZMQ socket"""
|
||||
if self._socket:
|
||||
raise error.StateError("Socket is already bound")
|
||||
|
||||
self._socket = self._context.socket(zmq.ROUTER)
|
||||
self._socket.rcvtimeo = int(self._heartbeat_interval * 1000)
|
||||
for bind_url in self._bind_urls:
|
||||
self._socket.bind(bind_url)
|
||||
self._log.info("Broker listening on %s", bind_url)
|
||||
|
||||
def close(self):
|
||||
if self._socket is None:
|
||||
return
|
||||
for bind_url in self._bind_urls:
|
||||
self._socket.disconnect(bind_url)
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
self._log.info("Broker socket closing")
|
||||
|
||||
def receive(self):
|
||||
"""Run until a message is received"""
|
||||
while True:
|
||||
self._purge_expired_workers()
|
||||
self._send_heartbeat()
|
||||
try:
|
||||
frames = self._socket.recv_multipart()
|
||||
except zmq.error.Again:
|
||||
if self._socket is None or self._stop:
|
||||
self._log.debug(
|
||||
"Socket has been closed or broker has been shut down, breaking from recv loop"
|
||||
)
|
||||
break
|
||||
continue
|
||||
|
||||
self._log.debug("Got message of %d frames", len(frames))
|
||||
try:
|
||||
return self._parse_incoming_message(frames)
|
||||
except error.ProtocolError as e:
|
||||
self._log.warning(str(e))
|
||||
continue
|
||||
|
||||
def handle_message(self, message):
|
||||
# type: (protocol.Message) -> None
|
||||
"""Handle incoming message"""
|
||||
if message.header == protocol.WORKER_HEADER:
|
||||
self._handle_worker_message(message)
|
||||
elif message.header == protocol.CLIENT_HEADER:
|
||||
self._handle_client_message(message)
|
||||
else:
|
||||
raise error.ProtocolError(
|
||||
"Unexpected protocol header: {}".format(message.header.decode("utf8"))
|
||||
)
|
||||
|
||||
self._dispatch_queued_messages()
|
||||
|
||||
def _handle_worker_message(self, message):
|
||||
# type: (protocol.Message) -> None
|
||||
"""Handle message from a worker"""
|
||||
if message.command == protocol.HEARTBEAT:
|
||||
self._handle_worker_heartbeat(message.client)
|
||||
elif message.command == protocol.READY:
|
||||
self._handle_worker_ready(message)
|
||||
elif message.command == protocol.FINAL:
|
||||
self._handle_worker_final(
|
||||
message.client, message.message[0], message.message[2:]
|
||||
)
|
||||
elif message.command == protocol.PARTIAL:
|
||||
self._handle_worker_partial(
|
||||
message.client, message.message[0], message.message[2:]
|
||||
)
|
||||
elif message.command == protocol.DISCONNECT:
|
||||
self._handle_worker_disconnect(message.client)
|
||||
else:
|
||||
self._send_to_worker(message.client, protocol.DISCONNECT)
|
||||
raise error.ProtocolError(
|
||||
"Unexpected command from worker: {}".format(
|
||||
message.command.decode("utf8")
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_worker_heartbeat(self, worker_id):
|
||||
# type: (bytes) -> None
|
||||
"""Heartbeat from worker"""
|
||||
worker = self._services.get_worker(worker_id)
|
||||
if worker is None:
|
||||
if not self._services.is_busy_worker(worker_id):
|
||||
self._log.warning(
|
||||
"Got HEARTBEAT from unknown worker: %d", id_to_int(worker_id)
|
||||
)
|
||||
self._send_to_worker(worker_id, protocol.DISCONNECT)
|
||||
return
|
||||
|
||||
self._log.debug("Got HEARTBEAT from worker: %d", id_to_int(worker_id))
|
||||
worker.expire_at = time.time() + self._heartbeat_timeout
|
||||
|
||||
def _handle_worker_ready(self, message):
|
||||
# type: (protocol.Message) -> ServiceWorker
|
||||
worker_id = message.client
|
||||
service = message.message[0]
|
||||
self._log.debug("Got READY from worker: %d", id_to_int(worker_id))
|
||||
now = time.time()
|
||||
expire_at = now + self._heartbeat_timeout
|
||||
next_heartbeat = now + self._heartbeat_interval
|
||||
worker = self._services.add_worker(
|
||||
worker_id, service, expire_at, next_heartbeat
|
||||
)
|
||||
worker.request_handler = None
|
||||
worker.request_params = []
|
||||
if len(message.message) > 1:
|
||||
worker.request_handler = str(message.message[1], "ascii")
|
||||
if len(message.message) > 2:
|
||||
worker.request_params = [str(m, "ascii") for m in message.message[2:]]
|
||||
return worker
|
||||
|
||||
def _handle_worker_partial(self, worker_id, client_id, body):
|
||||
# type: (bytes, bytes, List[bytes]) -> None
|
||||
self._log.debug(
|
||||
"Got PARTIAL from worker: %d to client: %d",
|
||||
id_to_int(worker_id),
|
||||
id_to_int(client_id),
|
||||
)
|
||||
self._send_to_client(client_id, protocol.PARTIAL, body)
|
||||
|
||||
def _handle_worker_final(self, worker_id, client_id, body):
|
||||
# type: (bytes, bytes, List[bytes]) -> None
|
||||
self._log.debug(
|
||||
"Got FINAL from worker: %d to client: %d",
|
||||
id_to_int(worker_id),
|
||||
id_to_int(client_id),
|
||||
)
|
||||
self._send_to_client(client_id, protocol.FINAL, body)
|
||||
now = time.time()
|
||||
self._services.set_worker_available(
|
||||
worker_id, now + self._heartbeat_timeout, now + self._heartbeat_interval
|
||||
)
|
||||
|
||||
def _handle_worker_disconnect(self, worker_id):
|
||||
# type: (bytes) -> None
|
||||
self._log.info("Got DISCONNECT from worker: %d", id_to_int(worker_id))
|
||||
try:
|
||||
self._services.remove_worker(worker_id)
|
||||
except KeyError:
|
||||
self._log.info(
|
||||
"Got DISCONNECT from unknown worker: %d; ignoring", id_to_int(worker_id)
|
||||
)
|
||||
|
||||
def _handle_client_message(self, message):
|
||||
# type: (protocol.Message) -> None
|
||||
"""Handle message from a client"""
|
||||
assert message.command == protocol.REQUEST
|
||||
if len(message.message) < 2:
|
||||
raise error.ProtocolError(
|
||||
"Client REQUEST message is expected to be at least 2 frames long, got {}".format(
|
||||
len(message.message)
|
||||
)
|
||||
)
|
||||
|
||||
service_name = message.message[0]
|
||||
body = message.message[1:]
|
||||
|
||||
# TODO: Plug-in MMA handling
|
||||
|
||||
self._log.debug(
|
||||
"Queueing client request from %d to %s",
|
||||
id_to_int(message.client),
|
||||
service_name.decode("ascii"),
|
||||
)
|
||||
self._services.queue_client_request(message.client, service_name, body)
|
||||
|
||||
def _dispatch_queued_messages(self):
|
||||
"""Dispatch all queued messages to available workers"""
|
||||
expire_at = time.time() + self._busy_worker_timeout
|
||||
for message, worker, client, service_name in self._services.dequeue_pending():
|
||||
body = [client, b""] + message
|
||||
body = self.on_worker_request(worker, service_name, body)
|
||||
self._send_to_worker(worker.id, protocol.REQUEST, body)
|
||||
self._services.set_worker_busy(worker.id, expire_at=expire_at)
|
||||
|
||||
def register_request_handler(self, handler_name: str, handler):
|
||||
self.request_handlers[handler_name] = handler
|
||||
|
||||
def on_worker_request(self, worker, service_name, body):
|
||||
# type: (ServiceWorker, bytes, List[bytes]) -> List[bytes]
|
||||
if worker.request_handler in self.request_handlers:
|
||||
handler = self.request_handlers[worker.request_handler]
|
||||
body = handler(worker, service_name, body)
|
||||
return body
|
||||
|
||||
def _purge_expired_workers(self):
|
||||
"""Purge expired workers from the services container"""
|
||||
for worker in list(
|
||||
self._services.expired_workers()
|
||||
): # Copying to list as we are going to mutate
|
||||
self._log.debug(
|
||||
"Worker %d timed out after %0.2f sec, purging",
|
||||
id_to_int(worker.id),
|
||||
self._busy_worker_timeout
|
||||
if worker.is_busy
|
||||
else self._heartbeat_timeout,
|
||||
)
|
||||
self._send_to_worker(worker.id, protocol.DISCONNECT)
|
||||
self._services.remove_worker(worker.id)
|
||||
|
||||
def _send_heartbeat(self):
|
||||
"""Send heartbeat to all workers that didn't get any messages recently"""
|
||||
now = time.time()
|
||||
for worker in self._services.heartbeat_workers():
|
||||
self._log.debug("Sending heartbeat to worker: %d", id_to_int(worker.id))
|
||||
self._send_to_worker(worker.id, protocol.HEARTBEAT)
|
||||
worker.next_heartbeat = now + self._heartbeat_interval
|
||||
|
||||
def _send_to_worker(self, worker_id, command, body=None):
|
||||
# type: (bytes, bytes, Optional[List[bytes]]) -> None
|
||||
"""Send message to worker"""
|
||||
if body is None:
|
||||
body = []
|
||||
self._socket.send_multipart(
|
||||
[worker_id, b"", protocol.WORKER_HEADER, command] + body
|
||||
)
|
||||
|
||||
def _send_to_client(self, client_id, command, body):
|
||||
# type: (bytes, bytes, List[bytes]) -> None
|
||||
"""Send message to client"""
|
||||
self._socket.send_multipart(
|
||||
[client_id, b"", protocol.CLIENT_HEADER, command] + body
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_incoming_message(frames):
|
||||
# type: (List[bytes]) -> protocol.Message
|
||||
"""Parse and verify incoming message"""
|
||||
if len(frames) < 4:
|
||||
raise error.ProtocolError(
|
||||
"Unexpected message length: expecting at least 4 frames, got {}".format(
|
||||
len(frames)
|
||||
)
|
||||
)
|
||||
|
||||
if frames[1] != b"":
|
||||
raise error.ProtocolError(
|
||||
"Expecting empty frame 1, got {} bytes".format(len(frames[1]))
|
||||
)
|
||||
|
||||
return protocol.Message(
|
||||
client=frames[0], header=frames[2], command=frames[3], message=frames[4:]
|
||||
)
|
||||
|
||||
|
||||
class ServiceWorker:
|
||||
"""ServiceWorker objects represent a connected / known MDP worker process"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: bytes,
|
||||
service: set[bytes],
|
||||
expire_at: float,
|
||||
next_heartbeat: float,
|
||||
):
|
||||
self.id = worker_id
|
||||
self.service = service
|
||||
self.expire_at = expire_at
|
||||
self.next_heartbeat = next_heartbeat
|
||||
self.is_busy = False
|
||||
self.request_handler: str
|
||||
self.request_params: list[str] = []
|
||||
|
||||
def is_expired(self, now=None):
|
||||
# type: (Optional[float]) -> bool
|
||||
"""Check if worker is expired"""
|
||||
if now is None:
|
||||
now = time.time()
|
||||
return now >= self.expire_at
|
||||
|
||||
def is_heartbeat(self, now=None):
|
||||
# type: (Optional[float]) -> bool
|
||||
"""Check if worker is due for sending a heartbeat message"""
|
||||
if now is None:
|
||||
now = time.time()
|
||||
return now >= self.next_heartbeat
|
||||
|
||||
|
||||
class Service:
|
||||
"""Service objects manage all workers that can handle a specific service, as well as a queue of MDP Client
|
||||
requests to be handled by this service
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._queue = deque()
|
||||
self._workers = OrderedDict()
|
||||
|
||||
def queue_request(self, client, request_body):
|
||||
# type: (bytes, List[bytes]) -> None
|
||||
"""Queue a client request"""
|
||||
self._queue.append((client, request_body))
|
||||
|
||||
def add_worker(self, worker):
|
||||
# type: (ServiceWorker) -> None
|
||||
"""Add a ServiceWorker to the service"""
|
||||
self._workers[worker.id] = worker
|
||||
|
||||
def remove_worker(self, worker_id):
|
||||
# type: (bytes) -> None
|
||||
"""Remove a worker from the service"""
|
||||
if worker_id in self._workers:
|
||||
del self._workers[worker_id]
|
||||
|
||||
def dequeue_pending(self):
|
||||
# type: () -> Generator[Tuple[List[bytes], ServiceWorker, bytes], None, None]
|
||||
"""Dequeue pending workers and requests to be handled by them"""
|
||||
while len(self._queue) and len(self._workers):
|
||||
client, message = self._queue.popleft()
|
||||
_, worker = self._workers.popitem(last=False)
|
||||
yield message, worker, client
|
||||
|
||||
@property
|
||||
def queued_requests(self):
|
||||
# type: () -> int
|
||||
"""Number of queued requests for this service"""
|
||||
return len(self._queue)
|
||||
|
||||
@property
|
||||
def available_workers(self):
|
||||
# type: () -> int
|
||||
"""Number of available workers for this service"""
|
||||
return len(self._workers)
|
||||
|
||||
|
||||
class ServicesContainer:
|
||||
"""A container for all services managed by the broker"""
|
||||
|
||||
def __init__(self, busy_workers_timeout=protocol.DEFAULT_BUSY_WORKER_TIMEOUT):
|
||||
# type: (float) -> None
|
||||
self._busy_workers_timeout = busy_workers_timeout
|
||||
self._services = defaultdict(Service) # type: DefaultDict[bytes, Service]
|
||||
self._workers = OrderedDict() # type: OrderedDict[bytes, ServiceWorker]
|
||||
self._busy_workers = dict() # type: Dict[bytes, ServiceWorker]
|
||||
|
||||
def queue_client_request(self, client, service, body):
|
||||
# type: (bytes, bytes, List[bytes]) -> None
|
||||
"""Queue a request from a client"""
|
||||
self._services[service].queue_request(client, body)
|
||||
|
||||
def dequeue_pending(self):
|
||||
# type: () -> Generator[Tuple[List[bytes], ServiceWorker, bytes], None, None]
|
||||
"""Pop ready message-worker pairs from all service queues so they can be dispatched"""
|
||||
for service_name, service in self._services.items():
|
||||
for message, worker, client in service.dequeue_pending():
|
||||
for other_service in self._services.values():
|
||||
other_service.remove_worker(worker.id)
|
||||
yield message, worker, client, service_name
|
||||
|
||||
def get_worker(self, worker_id):
|
||||
# type: (bytes) -> Optional[ServiceWorker]
|
||||
"""Get a worker by ID if exists"""
|
||||
return self._workers.get(worker_id, None)
|
||||
|
||||
def add_worker(
|
||||
self, worker_id: bytes, service: bytes, expire_at: float, next_heartbeat: float
|
||||
):
|
||||
"""Add a worker to the list of available workers"""
|
||||
if worker_id in self._workers:
|
||||
worker = self._workers[worker_id]
|
||||
if service in worker.service:
|
||||
raise error.StateError(
|
||||
f"Worker '{id_to_int(worker_id)}' has already sent READY message for service '{service.decode('ascii')}'"
|
||||
)
|
||||
else:
|
||||
worker.service.add(service)
|
||||
else:
|
||||
worker = ServiceWorker(worker_id, set([service]), expire_at, next_heartbeat)
|
||||
self._workers[worker.id] = worker
|
||||
self._services[service].add_worker(worker)
|
||||
return worker
|
||||
|
||||
def set_worker_busy(self, worker_id, expire_at):
|
||||
# type: (bytes, float) -> ServiceWorker
|
||||
"""Mark a worker as busy - that is currently processing a request and not available for more work"""
|
||||
if worker_id not in self._workers:
|
||||
raise error.StateError(
|
||||
"Worker id {} is not in the list of available workers".format(
|
||||
id_to_int(worker_id)
|
||||
)
|
||||
)
|
||||
worker = self._workers.pop(worker_id)
|
||||
worker.is_busy = True
|
||||
worker.expire_at = expire_at
|
||||
self._busy_workers[worker.id] = worker
|
||||
return worker
|
||||
|
||||
def set_worker_available(self, worker_id, expire_at, next_heartbeat):
|
||||
# type: (bytes, float, float) -> ServiceWorker
|
||||
"""Mark a worker that was busy processing a request as back and available again"""
|
||||
if worker_id not in self._busy_workers:
|
||||
raise error.StateError(
|
||||
"Worker id {} is not previously known or has expired".format(
|
||||
id_to_int(worker_id)
|
||||
)
|
||||
)
|
||||
worker = self._busy_workers.pop(worker_id)
|
||||
worker.is_busy = False
|
||||
worker.expire_at = expire_at
|
||||
worker.next_heartbeat = next_heartbeat
|
||||
self._workers[worker_id] = worker
|
||||
for service in worker.service:
|
||||
self._services[service].add_worker(worker)
|
||||
return worker
|
||||
|
||||
def is_busy_worker(self, worker_id):
|
||||
# type: (bytes) -> bool
|
||||
"""Return True if the given worker_id is of a known busy worker"""
|
||||
return worker_id in self._busy_workers
|
||||
|
||||
def remove_worker(self, worker_id):
|
||||
# type: (bytes) -> None
|
||||
"""Remove a worker from the list of known workers"""
|
||||
try:
|
||||
worker = self._workers[worker_id]
|
||||
for service_name in worker.service:
|
||||
service = self._services[service_name]
|
||||
if worker_id in service._workers:
|
||||
service.remove_worker(worker_id)
|
||||
del self._workers[worker_id]
|
||||
except KeyError:
|
||||
try:
|
||||
del self._busy_workers[worker_id]
|
||||
except KeyError:
|
||||
raise error.StateError(f"Worker id {id_to_int(worker_id)} is not known")
|
||||
|
||||
def heartbeat_workers(self):
|
||||
# type: () -> Generator[ServiceWorker, None, None]
|
||||
"""Get iterator of workers waiting for a heartbeat"""
|
||||
now = time.time()
|
||||
return (w for w in self._workers.values() if w.is_heartbeat(now))
|
||||
|
||||
def expired_workers(self):
|
||||
# type: () -> Generator[ServiceWorker, None, None]
|
||||
"""Get iterator of workers that have expired (no heartbeat received in too long) and busy workers that
|
||||
have not returned in too long
|
||||
"""
|
||||
now = time.time()
|
||||
return (
|
||||
w
|
||||
for w in chain(self._workers.values(), self._busy_workers.values())
|
||||
if w.is_expired(now)
|
||||
)
|
||||
182
frigate/majortomo/client.py
Normal file
182
frigate/majortomo/client.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""MDP 0.2 Client implementation"""
|
||||
|
||||
# Copyright (c) 2018 Shoppimon LTD
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, Optional, Tuple # noqa: F401
|
||||
|
||||
import zmq
|
||||
|
||||
from . import error as e
|
||||
from . import protocol as p
|
||||
from .util import TextOrBytes, text_to_ascii_bytes
|
||||
|
||||
DEFAULT_ZMQ_LINGER = 2000
|
||||
|
||||
|
||||
class Client(object):
|
||||
"""MDP 0.2 Client implementation
|
||||
|
||||
:ivar _socket: zmq.Socket
|
||||
:type _socket: zmq.Socket
|
||||
"""
|
||||
|
||||
def __init__(self, broker_url, zmq_context=None, zmq_linger=DEFAULT_ZMQ_LINGER):
|
||||
# type: (str, Optional[zmq.Context], int) -> None
|
||||
self.broker_url = broker_url
|
||||
self._socket = None # type: zmq.Socket
|
||||
self._zmq_context = zmq_context if zmq_context else zmq.Context.instance()
|
||||
self._linger = zmq_linger
|
||||
self._log = logging.getLogger(__name__)
|
||||
self._expect_reply = False
|
||||
|
||||
def connect(self, reconnect=False):
|
||||
# type: (bool) -> None
|
||||
if self.is_connected():
|
||||
if not reconnect:
|
||||
return
|
||||
self._disconnect()
|
||||
|
||||
# Set up socket
|
||||
self._socket = self._zmq_context.socket(zmq.DEALER)
|
||||
self._socket.setsockopt(zmq.LINGER, self._linger)
|
||||
self._socket.connect(self.broker_url)
|
||||
self._log.debug(
|
||||
"Connected to broker on ZMQ DEALER socket at %s", self.broker_url
|
||||
)
|
||||
self._expect_reply = False
|
||||
|
||||
def close(self):
|
||||
if not self.is_connected():
|
||||
return
|
||||
self._disconnect()
|
||||
|
||||
def _disconnect(self):
|
||||
if not self.is_connected():
|
||||
return
|
||||
self._log.debug(
|
||||
"Disconnecting from broker on ZMQ DEALER socket at %s", self.broker_url
|
||||
)
|
||||
self._socket.setsockopt(zmq.LINGER, 0)
|
||||
self._socket.disconnect(self.broker_url)
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
|
||||
def is_connected(self):
|
||||
# type: () -> bool
|
||||
"""Tell whether we are currently connected"""
|
||||
return self._socket is not None
|
||||
|
||||
def send(self, service, *args):
|
||||
# type: (TextOrBytes, *bytes) -> None
|
||||
"""Send a REQUEST command to the broker to be passed to the given service.
|
||||
|
||||
Each additional argument will be sent as a request body frame.
|
||||
"""
|
||||
if self._expect_reply:
|
||||
raise e.StateError(
|
||||
"Still expecting reply from broker, cannot send new request"
|
||||
)
|
||||
|
||||
service = text_to_ascii_bytes(service)
|
||||
self._log.debug(
|
||||
"Sending REQUEST message to %s with %d frames in body", service, len(args)
|
||||
)
|
||||
self._socket.send_multipart((b"", p.CLIENT_HEADER, p.REQUEST, service) + args)
|
||||
self._expect_reply = True
|
||||
|
||||
def recv_part(self, timeout=None):
|
||||
# type: (Optional[float]) -> Optional[List[bytes]]
|
||||
"""Receive a single part of the reply, partial or final
|
||||
|
||||
Note that a "part" is actually a list in this case, as any reply part can contain multiple frames.
|
||||
|
||||
If there are no more parts to receive, will return None
|
||||
"""
|
||||
if not self._expect_reply:
|
||||
return None
|
||||
|
||||
timeout = int(timeout * 1000) if timeout else None
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(self._socket, zmq.POLLIN)
|
||||
|
||||
try:
|
||||
socks = dict(poller.poll(timeout=timeout))
|
||||
if socks.get(self._socket) == zmq.POLLIN:
|
||||
message = self._socket.recv_multipart()
|
||||
m_type, m_content = self._parse_message(message)
|
||||
if m_type == p.FINAL:
|
||||
self._expect_reply = False
|
||||
return m_content
|
||||
else:
|
||||
raise e.Timeout("Timed out waiting for reply from broker")
|
||||
finally:
|
||||
poller.unregister(self._socket)
|
||||
|
||||
def recv_all(self, timeout=None):
|
||||
# type: (Optional[float]) -> Iterable[List[bytes]]
|
||||
"""Return a generator allowing to iterate over all reply parts
|
||||
|
||||
Note that `timeout` applies to each part, not to the full list of parts
|
||||
"""
|
||||
while True:
|
||||
part = self.recv_part(timeout)
|
||||
if part is None:
|
||||
break
|
||||
yield part
|
||||
|
||||
def recv_all_as_list(self, timeout=None):
|
||||
# type: (Optional[float]) -> List[bytes]
|
||||
"""Return all reply parts as a single, flat list of frames"""
|
||||
return [frame for part in self.recv_all(timeout) for frame in part]
|
||||
|
||||
@staticmethod
|
||||
def _parse_message(message):
|
||||
# type: (List[bytes]) -> Tuple[bytes, List[bytes]]
|
||||
"""Parse and validate an incoming message"""
|
||||
if len(message) < 3:
|
||||
raise e.ProtocolError(
|
||||
"Unexpected message length, expecting at least 3 frames, got {}".format(
|
||||
len(message)
|
||||
)
|
||||
)
|
||||
|
||||
if message.pop(0) != b"":
|
||||
raise e.ProtocolError("Expecting first message frame to be empty")
|
||||
|
||||
if message[0] != p.CLIENT_HEADER:
|
||||
print(message)
|
||||
raise e.ProtocolError(
|
||||
"Unexpected protocol header [{}], expecting [{}]".format(
|
||||
message[0].decode("utf8"), p.WORKER_HEADER.decode("utf8")
|
||||
)
|
||||
)
|
||||
|
||||
if message[1] not in {p.PARTIAL, p.FINAL}:
|
||||
raise e.ProtocolError(
|
||||
"Unexpected message type [{}], expecting either PARTIAL or FINAL".format(
|
||||
message[1].decode("utf8")
|
||||
)
|
||||
)
|
||||
|
||||
return message[1], message[2:]
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
45
frigate/majortomo/error.py
Normal file
45
frigate/majortomo/error.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""ZeroMQ MDP Client / Worker Errors"""
|
||||
|
||||
# Copyright (c) 2018 Shoppimon LTD
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class Error(RuntimeError):
|
||||
"""Parent exception for all zmq_mdp errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProtocolError(Error):
|
||||
"""MDP 0.2 Protocol Mismatch"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Disconnected(Error):
|
||||
"""We are no longer connected"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StateError(Error):
|
||||
"""System is in an unexpected state"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Timeout(Error, TimeoutError):
|
||||
"""Operation timed out"""
|
||||
|
||||
pass
|
||||
64
frigate/majortomo/protocol.py
Normal file
64
frigate/majortomo/protocol.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""ZeroMQ MDP 0.2 Protocol Constants common for Worker and Client
|
||||
"""
|
||||
|
||||
# Copyright (c) 2018 Shoppimon LTD
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional # noqa: F401
|
||||
|
||||
from . import error
|
||||
|
||||
WORKER_HEADER = b"MDPW02"
|
||||
CLIENT_HEADER = b"MDPC02"
|
||||
|
||||
READY = b"\001"
|
||||
REQUEST = b"\002"
|
||||
PARTIAL = b"\003"
|
||||
FINAL = b"\004"
|
||||
HEARTBEAT = b"\005"
|
||||
DISCONNECT = b"\006"
|
||||
|
||||
DEFAULT_HEARTBEAT_INTERVAL = 2.500
|
||||
DEFAULT_HEARTBEAT_TIMEOUT = 10.000
|
||||
DEFAULT_BUSY_WORKER_TIMEOUT = 900.000
|
||||
|
||||
|
||||
class Message(object):
|
||||
"""Majordomo message container"""
|
||||
|
||||
ALLOWED_HEADERS = {WORKER_HEADER, CLIENT_HEADER}
|
||||
ALLOWED_COMMANDS = {
|
||||
WORKER_HEADER: {READY, PARTIAL, FINAL, HEARTBEAT, DISCONNECT},
|
||||
CLIENT_HEADER: {REQUEST},
|
||||
}
|
||||
|
||||
def __init__(self, client, header, command, message=None):
|
||||
# type: (bytes, bytes, bytes, Optional[List[bytes]]) -> None
|
||||
if header not in self.ALLOWED_HEADERS:
|
||||
raise error.ProtocolError(
|
||||
"Unexpected protocol header: {}".format(header.decode("utf8"))
|
||||
)
|
||||
|
||||
if command not in self.ALLOWED_COMMANDS[header]:
|
||||
raise error.ProtocolError(
|
||||
"Unexpected command: {}".format(command.decode("utf8"))
|
||||
)
|
||||
|
||||
if message is None:
|
||||
message = []
|
||||
|
||||
self.client = client
|
||||
self.header = header
|
||||
self.command = command
|
||||
self.message = message
|
||||
29
frigate/majortomo/util.py
Normal file
29
frigate/majortomo/util.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Utilities and helpers useful in other modules"""
|
||||
from typing import Text, Union
|
||||
|
||||
TextOrBytes = Union[Text, bytes]
|
||||
|
||||
|
||||
def text_to_ascii_bytes(text: TextOrBytes) -> bytes:
|
||||
"""Convert a text-or-bytes value to ASCII-encoded bytes
|
||||
|
||||
If the input is already `bytes`, we simply return it as is
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
return text.encode("ascii", "strict")
|
||||
return text
|
||||
|
||||
|
||||
def id_to_int(id_):
|
||||
# type: (Union[bytes, str]) -> int
|
||||
"""Convert a ZMQ client ID to printable integer
|
||||
|
||||
This is needed to log client IDs while maintaining Python cross-version compatibility (so we can't use bytes.hex()
|
||||
for example)
|
||||
"""
|
||||
i = 0
|
||||
for c in id_:
|
||||
if not isinstance(c, int):
|
||||
c = ord(c)
|
||||
i = (i << 8) + c
|
||||
return i
|
||||
312
frigate/majortomo/worker.py
Normal file
312
frigate/majortomo/worker.py
Normal file
@ -0,0 +1,312 @@
|
||||
"""MDP 0.2 Worker implementation"""
|
||||
|
||||
# Copyright (c) 2018 Shoppimon LTD
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import time
|
||||
import multiprocessing as mp
|
||||
from typing import Generator, Iterable, List, Optional, Tuple # noqa: F401
|
||||
|
||||
import zmq
|
||||
|
||||
from . import error, protocol
|
||||
from .util import TextOrBytes, text_to_ascii_bytes
|
||||
|
||||
DEFAULT_ZMQ_LINGER = 2500
|
||||
|
||||
|
||||
class Worker(object):
|
||||
"""MDP 0.2 Worker implementation"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
broker_url: str,
|
||||
service_names: List[TextOrBytes],
|
||||
heartbeat_interval=protocol.DEFAULT_HEARTBEAT_INTERVAL,
|
||||
heartbeat_timeout=protocol.DEFAULT_HEARTBEAT_TIMEOUT,
|
||||
zmq_context=None,
|
||||
zmq_linger=DEFAULT_ZMQ_LINGER,
|
||||
handler_name: TextOrBytes = None,
|
||||
handler_params: List[TextOrBytes] = [],
|
||||
stop_event: mp.Event = None,
|
||||
):
|
||||
self.broker_url = broker_url
|
||||
self.service_names = [
|
||||
text_to_ascii_bytes(service_name) for service_name in service_names
|
||||
]
|
||||
|
||||
self.heartbeat_interval = heartbeat_interval
|
||||
self._socket = None # type: zmq.Socket
|
||||
self._poller = None # type: zmq.Poller
|
||||
self._zmq_context = zmq_context if zmq_context else zmq.Context.instance()
|
||||
self._linger = zmq_linger
|
||||
self._log = logging.getLogger(__name__)
|
||||
self._heartbeat_timeout = heartbeat_timeout
|
||||
self._last_broker_hb = 0.0
|
||||
self._last_sent_message = 0.0
|
||||
|
||||
self.ready_params = [
|
||||
text_to_ascii_bytes(rp)
|
||||
for rp in (
|
||||
([handler_name] if handler_name is not None else []) + handler_params
|
||||
)
|
||||
]
|
||||
|
||||
self.stop_event = stop_event or mp.Event()
|
||||
|
||||
def connect(self, reconnect=False):
|
||||
# type: (bool) -> None
|
||||
if self.is_connected():
|
||||
if not reconnect:
|
||||
return
|
||||
self._disconnect()
|
||||
|
||||
# Set up socket
|
||||
self._socket = self._zmq_context.socket(zmq.DEALER)
|
||||
self._socket.setsockopt(zmq.LINGER, self._linger)
|
||||
self._socket.connect(self.broker_url)
|
||||
self._log.debug(
|
||||
"Connected to broker on ZMQ DEALER socket at %s", self.broker_url
|
||||
)
|
||||
|
||||
self._poller = zmq.Poller()
|
||||
self._poller.register(self._socket, zmq.POLLIN)
|
||||
|
||||
self._send_ready()
|
||||
self._last_broker_hb = time.time()
|
||||
|
||||
def wait_for_request(self):
|
||||
# type: () -> Tuple[bytes, List[bytes]]
|
||||
"""Wait for a REQUEST command from the broker and return the client address and message body frames.
|
||||
|
||||
Will internally handle timeouts, heartbeats and check for protocol errors and disconnect commands.
|
||||
"""
|
||||
command, frames = self._receive()
|
||||
|
||||
if command == protocol.DISCONNECT:
|
||||
self._log.debug("Got DISCONNECT from broker; Disconnecting")
|
||||
self._disconnect()
|
||||
raise error.Disconnected("Disconnected on message from broker")
|
||||
|
||||
elif command != protocol.REQUEST:
|
||||
raise error.ProtocolError(
|
||||
"Unexpected message type from broker: {}".format(command.decode("utf8"))
|
||||
)
|
||||
|
||||
if len(frames) < 3:
|
||||
raise error.ProtocolError(
|
||||
"Unexpected REQUEST message size, got {} frames, expecting at least 3".format(
|
||||
len(frames)
|
||||
)
|
||||
)
|
||||
|
||||
client_addr = frames[0]
|
||||
request = frames[2:]
|
||||
return client_addr, request
|
||||
|
||||
def send_reply_final(self, client, frames):
|
||||
# type: (bytes, List[bytes]) -> None
|
||||
"""Send final reply to client
|
||||
|
||||
FINAL reply means the client will not expect any additional parts to the reply. This should be used
|
||||
when the entire reply is ready to be delivered.
|
||||
"""
|
||||
self._send_to_client(client, protocol.FINAL, *frames)
|
||||
|
||||
def send_reply_partial(self, client, frames):
|
||||
# type: (bytes, List[bytes]) -> None
|
||||
"""Send the given set of frames as a partial reply to client
|
||||
|
||||
PARTIAL reply means the client will expect zero or more additional PARTIAL reply messages following
|
||||
this one, with exactly one terminating FINAL reply following. This should be used if parts of the
|
||||
reply are ready to be sent, and the client is capable of processing them while the worker is still
|
||||
at work on the rest of the reply.
|
||||
"""
|
||||
self._send_to_client(client, protocol.PARTIAL, *frames)
|
||||
|
||||
def send_reply_from_iterable(self, client, frames_iter, final=None):
|
||||
# type: (bytes, Iterable[List[bytes]], List[bytes]) -> None
|
||||
"""Send multiple partial replies from an iterator as PARTIAL replies to client.
|
||||
|
||||
If `final` is provided, it will be sent as the FINAL reply after all PARTIAL replies are sent.
|
||||
"""
|
||||
for part in frames_iter:
|
||||
self.send_reply_partial(client, part)
|
||||
if final:
|
||||
self.send_reply_final(client, final)
|
||||
|
||||
def close(self):
|
||||
if not self.is_connected():
|
||||
return
|
||||
self._send_disconnect()
|
||||
self._disconnect()
|
||||
|
||||
def is_connected(self):
|
||||
return self._socket is not None
|
||||
|
||||
def _disconnect(self):
|
||||
if not self.is_connected():
|
||||
return
|
||||
self._socket.disconnect(self.broker_url)
|
||||
self._socket.close()
|
||||
self._socket = None
|
||||
self._last_sent_message -= self.heartbeat_interval
|
||||
|
||||
def _receive(self):
|
||||
# type: () -> Tuple[bytes, List[bytes]]
|
||||
"""Poll on the socket until a command is received
|
||||
|
||||
Will handle timeouts and heartbeats internally without returning
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
if self._socket is None:
|
||||
raise error.Disconnected("Worker is disconnected")
|
||||
|
||||
self._check_send_heartbeat()
|
||||
poll_timeout = self._get_poll_timeout()
|
||||
|
||||
try:
|
||||
socks = dict(self._poller.poll(timeout=poll_timeout))
|
||||
except zmq.error.ZMQError:
|
||||
# Probably connection was explicitly closed
|
||||
if self._socket is None:
|
||||
continue
|
||||
raise
|
||||
|
||||
if socks.get(self._socket) == zmq.POLLIN:
|
||||
message = self._socket.recv_multipart()
|
||||
self._log.debug("Got message of %d frames", len(message))
|
||||
else:
|
||||
self._log.debug("Receive timed out after %d ms", poll_timeout)
|
||||
if (time.time() - self._last_broker_hb) > self._heartbeat_timeout:
|
||||
# We're not connected anymore?
|
||||
self._log.info(
|
||||
"Got no heartbeat in %d sec, disconnecting and reconnecting socket",
|
||||
self._heartbeat_timeout,
|
||||
)
|
||||
self.connect(reconnect=True)
|
||||
continue
|
||||
|
||||
command, frames = self._verify_message(message)
|
||||
self._last_broker_hb = time.time()
|
||||
|
||||
if command == protocol.HEARTBEAT:
|
||||
self._log.debug("Got heartbeat message from broker")
|
||||
continue
|
||||
|
||||
return command, frames
|
||||
|
||||
return protocol.DISCONNECT, []
|
||||
|
||||
def _send_ready(self):
|
||||
for service_name in self.service_names:
|
||||
self._send(protocol.READY, service_name, *self.ready_params)
|
||||
|
||||
def handle_request(client_id: bytes, request: List[bytes]) -> List[bytes]:
|
||||
return request
|
||||
|
||||
def _send_disconnect(self):
|
||||
self._send(protocol.DISCONNECT)
|
||||
|
||||
def _check_send_heartbeat(self):
|
||||
if time.time() - self._last_sent_message >= self.heartbeat_interval:
|
||||
self._log.debug("Sending HEARTBEAT to broker")
|
||||
self._send(protocol.HEARTBEAT)
|
||||
|
||||
def _send_to_client(self, client, message_type, *frames):
|
||||
self._send(message_type, client, b"", *frames)
|
||||
|
||||
def _send(self, message_type, *args):
|
||||
# type: (bytes, *bytes) -> None
|
||||
self._socket.send_multipart((b"", protocol.WORKER_HEADER, message_type) + args)
|
||||
self._last_sent_message = time.time()
|
||||
|
||||
def _get_poll_timeout(self):
|
||||
# type: () -> int
|
||||
"""Return the poll timeout for the current iteration in milliseconds"""
|
||||
return max(
|
||||
0,
|
||||
int(
|
||||
(time.time() - self._last_sent_message + self.heartbeat_interval) * 1000
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _verify_message(message):
|
||||
# type: (List[bytes]) -> Tuple[bytes, List[bytes]]
|
||||
if len(message) < 3:
|
||||
raise error.ProtocolError(
|
||||
"Unexpected message length, expecting at least 3 frames, got {}".format(
|
||||
len(message)
|
||||
)
|
||||
)
|
||||
|
||||
if message.pop(0) != b"":
|
||||
raise error.ProtocolError("Expecting first message frame to be empty")
|
||||
|
||||
if message[0] != protocol.WORKER_HEADER:
|
||||
print(message)
|
||||
raise error.ProtocolError(
|
||||
"Unexpected protocol header [{}], expecting [{}]".format(
|
||||
message[0].decode("utf8"), protocol.WORKER_HEADER.decode("utf8")
|
||||
)
|
||||
)
|
||||
|
||||
if message[1] not in {
|
||||
protocol.DISCONNECT,
|
||||
protocol.HEARTBEAT,
|
||||
protocol.REQUEST,
|
||||
}:
|
||||
raise error.ProtocolError(
|
||||
"Unexpected message type [{}], expecting either HEARTBEAT, REQUEST or "
|
||||
"DISCONNECT".format(message[1].decode("utf8"))
|
||||
)
|
||||
|
||||
return message[1], message[2:]
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def start(self):
|
||||
if not self.is_connected():
|
||||
self.connect()
|
||||
|
||||
def signal_handler(sig_num, _):
|
||||
self.stop()
|
||||
|
||||
for sig_num in (signal.SIGINT, signal.SIGTERM):
|
||||
signal.signal(sig_num, signal_handler)
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
client_id, request = self.wait_for_request()
|
||||
reply = self.handle_request(client_id, request)
|
||||
self.send_reply_final(client_id, reply)
|
||||
except error.ProtocolError as e:
|
||||
self._log.warning("Protocol error: %s, dropping request", str(e))
|
||||
continue
|
||||
except error.Disconnected:
|
||||
self._log.info("Worker disconnected")
|
||||
break
|
||||
|
||||
self.close()
|
||||
|
||||
def stop(self):
|
||||
self.stop_event.set()
|
||||
@ -13,7 +13,7 @@ from frigate.detectors import (
|
||||
ObjectDetectionClient,
|
||||
ObjectDetectionWorker,
|
||||
)
|
||||
from frigate.majordomo import QueueBroker
|
||||
from frigate.majortomo import Broker
|
||||
from frigate.util import deep_merge
|
||||
import frigate.detectors.detector_types as detectors
|
||||
|
||||
@ -43,7 +43,7 @@ def start_broker(ipc_address, tcp_address, camera_names):
|
||||
body = body[0:2] + [tensor_input]
|
||||
return body
|
||||
|
||||
queue_broker = QueueBroker(bind=[ipc_address, tcp_address])
|
||||
queue_broker = Broker(bind=[ipc_address, tcp_address])
|
||||
queue_broker.register_request_handler("DETECT_NO_SHM", detect_no_shm)
|
||||
queue_broker.start()
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ pydantic == 1.10.*
|
||||
PyYAML == 6.0
|
||||
pytz == 2022.6
|
||||
pyzmq == 24.0.1
|
||||
majortomo == 0.2.0
|
||||
tzlocal == 4.2
|
||||
types-PyYAML == 6.0.*
|
||||
requests == 2.28.*
|
||||
|
||||
Loading…
Reference in New Issue
Block a user