frigate/tools/zmq_embedding_server.py
Josh Casada a2c43ad8bb feat: ZMQ embedding runner for offloading ONNX inference to native host
Extends the ZMQ split-detector pattern (apple-silicon-detector) to cover
ONNX embedding models — ArcFace face recognition and Jina semantic search.

On macOS, Docker has no access to CoreML or the Apple Neural Engine, so
embedding inference is forced to CPU (~200ms/face for ArcFace). This adds
a ZmqEmbeddingRunner that sends preprocessed tensors to a native host
process over ZMQ TCP and receives embeddings back, enabling CoreML/ANE
acceleration outside the container.

Files changed:
- frigate/detectors/detection_runners.py: add ZmqEmbeddingRunner class
  and hook into get_optimized_runner() via "zmq://" device prefix
- tools/zmq_embedding_server.py: new host-side server script

Tested on Mac Mini M4, 24h soak test, ~5000 object reindex.
2026-02-21 12:44:42 -05:00

276 lines
9.3 KiB
Python

"""ZMQ Embedding Server — native Mac (Apple Silicon) inference service.
Runs ONNX models using hardware acceleration unavailable inside Docker on macOS,
specifically CoreML and the Apple Neural Engine. Frigate's Docker container
connects to this server over ZMQ TCP, sends preprocessed tensors, and receives
embedding vectors back.
Supported models:
- ArcFace (face recognition, 512-dim output)
- FaceNet (face recognition, 128-dim output)
- Jina V1/V2 vision (semantic search, 768-dim output)
Requirements (install outside Docker, on the Mac host):
pip install onnxruntime pyzmq numpy
Usage:
# ArcFace face recognition (port 5556):
python tools/zmq_embedding_server.py \\
--model /config/model_cache/facedet/arcface.onnx \\
--model-type arcface \\
--port 5556
# Jina V1 vision semantic search (port 5557):
python tools/zmq_embedding_server.py \\
--model /config/model_cache/jinaai/jina-clip-v1/vision_model_quantized.onnx \\
--model-type jina_v1 \\
--port 5557
Frigate config (docker-compose / config.yaml):
face_recognition:
enabled: true
model_size: large
device: "zmq://host.docker.internal:5556"
semantic_search:
enabled: true
model_size: small
device: "zmq://host.docker.internal:5557"
Protocol (REQ/REP):
Request: multipart [ header_json_bytes, tensor_bytes ]
header = {
"shape": [batch, channels, height, width], # e.g. [1, 3, 112, 112]
"dtype": "float32",
"model_type": "arcface",
}
tensor_bytes = raw C-order numpy bytes
Response: multipart [ header_json_bytes, embedding_bytes ]
header = {
"shape": [batch, embedding_dim], # e.g. [1, 512]
"dtype": "float32",
}
embedding_bytes = raw C-order numpy bytes
"""
import argparse
import json
import logging
import os
import signal
import sys
import time
import numpy as np
import zmq
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger("zmq_embedding_server")
# Models that require ORT_ENABLE_BASIC optimization to avoid graph fusion issues
# (e.g. SimplifiedLayerNormFusion creates nodes that some providers can't handle).
_COMPLEX_MODELS = {"jina_v1", "jina_v2"}
# ---------------------------------------------------------------------------
# ONNX Runtime session (CoreML preferred on Apple Silicon)
# ---------------------------------------------------------------------------
def build_ort_session(model_path: str, model_type: str = ""):
"""Create an ONNX Runtime InferenceSession, preferring CoreML on macOS.
Jina V1/V2 models use ORT_ENABLE_BASIC graph optimization to avoid
fusion passes (e.g. SimplifiedLayerNormFusion) that produce unsupported
nodes. All other models use the default ORT_ENABLE_ALL.
"""
import onnxruntime as ort
available = ort.get_available_providers()
logger.info(f"Available ORT providers: {available}")
# Prefer CoreMLExecutionProvider on Apple Silicon for ANE/GPU acceleration.
# Falls back automatically to CPUExecutionProvider if CoreML is unavailable.
preferred = []
if "CoreMLExecutionProvider" in available:
preferred.append("CoreMLExecutionProvider")
logger.info("Using CoreMLExecutionProvider (Apple Neural Engine / GPU)")
else:
logger.warning(
"CoreMLExecutionProvider not available — falling back to CPU. "
"Install onnxruntime-silicon or a CoreML-enabled onnxruntime build."
)
preferred.append("CPUExecutionProvider")
sess_options = None
if model_type in _COMPLEX_MODELS:
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
)
logger.info(f"Using ORT_ENABLE_BASIC optimization for {model_type}")
session = ort.InferenceSession(model_path, sess_options=sess_options, providers=preferred)
input_names = [inp.name for inp in session.get_inputs()]
output_names = [out.name for out in session.get_outputs()]
logger.info(f"Model loaded: inputs={input_names}, outputs={output_names}")
return session
# ---------------------------------------------------------------------------
# Inference helpers
# ---------------------------------------------------------------------------
def run_arcface(session, tensor: np.ndarray) -> np.ndarray:
"""Run ArcFace — input (1, 3, 112, 112) float32, output (1, 512) float32."""
outputs = session.run(None, {"data": tensor})
return outputs[0] # shape (1, 512)
def run_generic(session, tensor: np.ndarray) -> np.ndarray:
"""Generic single-input ONNX model runner."""
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: tensor})
return outputs[0]
_RUNNERS = {
"arcface": run_arcface,
"facenet": run_generic,
"jina_v1": run_generic,
"jina_v2": run_generic,
}
# Model type → input shape for warmup inference (triggers CoreML JIT compilation
# before the first real request arrives, avoiding a ZMQ timeout on cold start).
_WARMUP_SHAPES = {
"arcface": (1, 3, 112, 112),
"facenet": (1, 3, 160, 160),
"jina_v1": (1, 3, 224, 224),
"jina_v2": (1, 3, 224, 224),
}
def warmup(session, model_type: str) -> None:
"""Run a dummy inference to trigger CoreML JIT compilation."""
shape = _WARMUP_SHAPES.get(model_type)
if shape is None:
return
logger.info(f"Warming up CoreML model ({model_type})…")
dummy = np.zeros(shape, dtype=np.float32)
try:
runner = _RUNNERS.get(model_type, run_generic)
runner(session, dummy)
logger.info("Warmup complete")
except Exception as exc:
logger.warning(f"Warmup failed (non-fatal): {exc}")
# ---------------------------------------------------------------------------
# ZMQ server loop
# ---------------------------------------------------------------------------
def serve(session, port: int, model_type: str) -> None:
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://0.0.0.0:{port}")
logger.info(f"Listening on tcp://0.0.0.0:{port} (model_type={model_type})")
runner = _RUNNERS.get(model_type, run_generic)
def _shutdown(sig, frame):
logger.info("Shutting down…")
socket.close(linger=0)
context.term()
sys.exit(0)
signal.signal(signal.SIGINT, _shutdown)
signal.signal(signal.SIGTERM, _shutdown)
while True:
try:
frames = socket.recv_multipart()
except zmq.ZMQError as exc:
logger.error(f"recv error: {exc}")
continue
if len(frames) < 2:
logger.warning(f"Received unexpected frame count: {len(frames)}, ignoring")
socket.send_multipart([b"{}"])
continue
try:
header = json.loads(frames[0].decode("utf-8"))
shape = tuple(header["shape"])
dtype = np.dtype(header.get("dtype", "float32"))
tensor = np.frombuffer(frames[1], dtype=dtype).reshape(shape)
except Exception as exc:
logger.error(f"Failed to decode request: {exc}")
socket.send_multipart([b"{}"])
continue
try:
t0 = time.monotonic()
embedding = runner(session, tensor)
elapsed_ms = (time.monotonic() - t0) * 1000
if elapsed_ms > 2000:
logger.warning(f"slow inference {elapsed_ms:.1f}ms shape={shape}")
resp_header = json.dumps(
{"shape": list(embedding.shape), "dtype": str(embedding.dtype.name)}
).encode("utf-8")
resp_payload = memoryview(np.ascontiguousarray(embedding).tobytes())
socket.send_multipart([resp_header, resp_payload])
except Exception as exc:
logger.error(f"Inference error: {exc}")
# Return a zero embedding so the client can degrade gracefully
zero = np.zeros((1, 512), dtype=np.float32)
resp_header = json.dumps(
{"shape": list(zero.shape), "dtype": "float32"}
).encode("utf-8")
socket.send_multipart([resp_header, memoryview(zero.tobytes())])
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="ZMQ Embedding Server for Frigate")
parser.add_argument(
"--model",
required=True,
help="Path to the ONNX model file (e.g. /config/model_cache/facedet/arcface.onnx)",
)
parser.add_argument(
"--model-type",
default="arcface",
choices=list(_RUNNERS.keys()),
help="Model type key (default: arcface)",
)
parser.add_argument(
"--port",
type=int,
default=5556,
help="TCP port to listen on (default: 5556)",
)
args = parser.parse_args()
if not os.path.exists(args.model):
logger.error(f"Model file not found: {args.model}")
sys.exit(1)
logger.info(f"Loading model: {args.model}")
session = build_ort_session(args.model, model_type=args.model_type)
warmup(session, model_type=args.model_type)
serve(session, port=args.port, model_type=args.model_type)
if __name__ == "__main__":
main()