mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-10 10:33:11 +03:00
Extends the ZMQ split-detector pattern (apple-silicon-detector) to cover ONNX embedding models — ArcFace face recognition and Jina semantic search. On macOS, Docker has no access to CoreML or the Apple Neural Engine, so embedding inference is forced to CPU (~200ms/face for ArcFace). This adds a ZmqEmbeddingRunner that sends preprocessed tensors to a native host process over ZMQ TCP and receives embeddings back, enabling CoreML/ANE acceleration outside the container. Files changed: - frigate/detectors/detection_runners.py: add ZmqEmbeddingRunner class and hook into get_optimized_runner() via "zmq://" device prefix - tools/zmq_embedding_server.py: new host-side server script Tested on Mac Mini M4, 24h soak test, ~5000 object reindex.
276 lines
9.3 KiB
Python
276 lines
9.3 KiB
Python
"""ZMQ Embedding Server — native Mac (Apple Silicon) inference service.
|
|
|
|
Runs ONNX models using hardware acceleration unavailable inside Docker on macOS,
|
|
specifically CoreML and the Apple Neural Engine. Frigate's Docker container
|
|
connects to this server over ZMQ TCP, sends preprocessed tensors, and receives
|
|
embedding vectors back.
|
|
|
|
Supported models:
|
|
- ArcFace (face recognition, 512-dim output)
|
|
- FaceNet (face recognition, 128-dim output)
|
|
- Jina V1/V2 vision (semantic search, 768-dim output)
|
|
|
|
Requirements (install outside Docker, on the Mac host):
|
|
pip install onnxruntime pyzmq numpy
|
|
|
|
Usage:
|
|
# ArcFace face recognition (port 5556):
|
|
python tools/zmq_embedding_server.py \\
|
|
--model /config/model_cache/facedet/arcface.onnx \\
|
|
--model-type arcface \\
|
|
--port 5556
|
|
|
|
# Jina V1 vision semantic search (port 5557):
|
|
python tools/zmq_embedding_server.py \\
|
|
--model /config/model_cache/jinaai/jina-clip-v1/vision_model_quantized.onnx \\
|
|
--model-type jina_v1 \\
|
|
--port 5557
|
|
|
|
Frigate config (docker-compose / config.yaml):
|
|
face_recognition:
|
|
enabled: true
|
|
model_size: large
|
|
device: "zmq://host.docker.internal:5556"
|
|
|
|
semantic_search:
|
|
enabled: true
|
|
model_size: small
|
|
device: "zmq://host.docker.internal:5557"
|
|
|
|
Protocol (REQ/REP):
|
|
Request: multipart [ header_json_bytes, tensor_bytes ]
|
|
header = {
|
|
"shape": [batch, channels, height, width], # e.g. [1, 3, 112, 112]
|
|
"dtype": "float32",
|
|
"model_type": "arcface",
|
|
}
|
|
tensor_bytes = raw C-order numpy bytes
|
|
|
|
Response: multipart [ header_json_bytes, embedding_bytes ]
|
|
header = {
|
|
"shape": [batch, embedding_dim], # e.g. [1, 512]
|
|
"dtype": "float32",
|
|
}
|
|
embedding_bytes = raw C-order numpy bytes
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import time
|
|
|
|
import numpy as np
|
|
import zmq
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
logger = logging.getLogger("zmq_embedding_server")
|
|
|
|
|
|
# Models that require ORT_ENABLE_BASIC optimization to avoid graph fusion issues
|
|
# (e.g. SimplifiedLayerNormFusion creates nodes that some providers can't handle).
|
|
_COMPLEX_MODELS = {"jina_v1", "jina_v2"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ONNX Runtime session (CoreML preferred on Apple Silicon)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def build_ort_session(model_path: str, model_type: str = ""):
|
|
"""Create an ONNX Runtime InferenceSession, preferring CoreML on macOS.
|
|
|
|
Jina V1/V2 models use ORT_ENABLE_BASIC graph optimization to avoid
|
|
fusion passes (e.g. SimplifiedLayerNormFusion) that produce unsupported
|
|
nodes. All other models use the default ORT_ENABLE_ALL.
|
|
"""
|
|
import onnxruntime as ort
|
|
|
|
available = ort.get_available_providers()
|
|
logger.info(f"Available ORT providers: {available}")
|
|
|
|
# Prefer CoreMLExecutionProvider on Apple Silicon for ANE/GPU acceleration.
|
|
# Falls back automatically to CPUExecutionProvider if CoreML is unavailable.
|
|
preferred = []
|
|
if "CoreMLExecutionProvider" in available:
|
|
preferred.append("CoreMLExecutionProvider")
|
|
logger.info("Using CoreMLExecutionProvider (Apple Neural Engine / GPU)")
|
|
else:
|
|
logger.warning(
|
|
"CoreMLExecutionProvider not available — falling back to CPU. "
|
|
"Install onnxruntime-silicon or a CoreML-enabled onnxruntime build."
|
|
)
|
|
|
|
preferred.append("CPUExecutionProvider")
|
|
|
|
sess_options = None
|
|
if model_type in _COMPLEX_MODELS:
|
|
sess_options = ort.SessionOptions()
|
|
sess_options.graph_optimization_level = (
|
|
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
)
|
|
logger.info(f"Using ORT_ENABLE_BASIC optimization for {model_type}")
|
|
|
|
session = ort.InferenceSession(model_path, sess_options=sess_options, providers=preferred)
|
|
|
|
input_names = [inp.name for inp in session.get_inputs()]
|
|
output_names = [out.name for out in session.get_outputs()]
|
|
logger.info(f"Model loaded: inputs={input_names}, outputs={output_names}")
|
|
return session
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Inference helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def run_arcface(session, tensor: np.ndarray) -> np.ndarray:
|
|
"""Run ArcFace — input (1, 3, 112, 112) float32, output (1, 512) float32."""
|
|
outputs = session.run(None, {"data": tensor})
|
|
return outputs[0] # shape (1, 512)
|
|
|
|
|
|
def run_generic(session, tensor: np.ndarray) -> np.ndarray:
|
|
"""Generic single-input ONNX model runner."""
|
|
input_name = session.get_inputs()[0].name
|
|
outputs = session.run(None, {input_name: tensor})
|
|
return outputs[0]
|
|
|
|
|
|
_RUNNERS = {
|
|
"arcface": run_arcface,
|
|
"facenet": run_generic,
|
|
"jina_v1": run_generic,
|
|
"jina_v2": run_generic,
|
|
}
|
|
|
|
# Model type → input shape for warmup inference (triggers CoreML JIT compilation
|
|
# before the first real request arrives, avoiding a ZMQ timeout on cold start).
|
|
_WARMUP_SHAPES = {
|
|
"arcface": (1, 3, 112, 112),
|
|
"facenet": (1, 3, 160, 160),
|
|
"jina_v1": (1, 3, 224, 224),
|
|
"jina_v2": (1, 3, 224, 224),
|
|
}
|
|
|
|
|
|
def warmup(session, model_type: str) -> None:
|
|
"""Run a dummy inference to trigger CoreML JIT compilation."""
|
|
shape = _WARMUP_SHAPES.get(model_type)
|
|
if shape is None:
|
|
return
|
|
logger.info(f"Warming up CoreML model ({model_type})…")
|
|
dummy = np.zeros(shape, dtype=np.float32)
|
|
try:
|
|
runner = _RUNNERS.get(model_type, run_generic)
|
|
runner(session, dummy)
|
|
logger.info("Warmup complete")
|
|
except Exception as exc:
|
|
logger.warning(f"Warmup failed (non-fatal): {exc}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ZMQ server loop
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def serve(session, port: int, model_type: str) -> None:
|
|
context = zmq.Context()
|
|
socket = context.socket(zmq.REP)
|
|
socket.bind(f"tcp://0.0.0.0:{port}")
|
|
logger.info(f"Listening on tcp://0.0.0.0:{port} (model_type={model_type})")
|
|
|
|
runner = _RUNNERS.get(model_type, run_generic)
|
|
|
|
def _shutdown(sig, frame):
|
|
logger.info("Shutting down…")
|
|
socket.close(linger=0)
|
|
context.term()
|
|
sys.exit(0)
|
|
|
|
signal.signal(signal.SIGINT, _shutdown)
|
|
signal.signal(signal.SIGTERM, _shutdown)
|
|
|
|
while True:
|
|
try:
|
|
frames = socket.recv_multipart()
|
|
except zmq.ZMQError as exc:
|
|
logger.error(f"recv error: {exc}")
|
|
continue
|
|
|
|
if len(frames) < 2:
|
|
logger.warning(f"Received unexpected frame count: {len(frames)}, ignoring")
|
|
socket.send_multipart([b"{}"])
|
|
continue
|
|
|
|
try:
|
|
header = json.loads(frames[0].decode("utf-8"))
|
|
shape = tuple(header["shape"])
|
|
dtype = np.dtype(header.get("dtype", "float32"))
|
|
tensor = np.frombuffer(frames[1], dtype=dtype).reshape(shape)
|
|
except Exception as exc:
|
|
logger.error(f"Failed to decode request: {exc}")
|
|
socket.send_multipart([b"{}"])
|
|
continue
|
|
|
|
try:
|
|
t0 = time.monotonic()
|
|
embedding = runner(session, tensor)
|
|
elapsed_ms = (time.monotonic() - t0) * 1000
|
|
if elapsed_ms > 2000:
|
|
logger.warning(f"slow inference {elapsed_ms:.1f}ms shape={shape}")
|
|
resp_header = json.dumps(
|
|
{"shape": list(embedding.shape), "dtype": str(embedding.dtype.name)}
|
|
).encode("utf-8")
|
|
resp_payload = memoryview(np.ascontiguousarray(embedding).tobytes())
|
|
socket.send_multipart([resp_header, resp_payload])
|
|
except Exception as exc:
|
|
logger.error(f"Inference error: {exc}")
|
|
# Return a zero embedding so the client can degrade gracefully
|
|
zero = np.zeros((1, 512), dtype=np.float32)
|
|
resp_header = json.dumps(
|
|
{"shape": list(zero.shape), "dtype": "float32"}
|
|
).encode("utf-8")
|
|
socket.send_multipart([resp_header, memoryview(zero.tobytes())])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="ZMQ Embedding Server for Frigate")
|
|
parser.add_argument(
|
|
"--model",
|
|
required=True,
|
|
help="Path to the ONNX model file (e.g. /config/model_cache/facedet/arcface.onnx)",
|
|
)
|
|
parser.add_argument(
|
|
"--model-type",
|
|
default="arcface",
|
|
choices=list(_RUNNERS.keys()),
|
|
help="Model type key (default: arcface)",
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=5556,
|
|
help="TCP port to listen on (default: 5556)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.exists(args.model):
|
|
logger.error(f"Model file not found: {args.model}")
|
|
sys.exit(1)
|
|
|
|
logger.info(f"Loading model: {args.model}")
|
|
session = build_ort_session(args.model, model_type=args.model_type)
|
|
warmup(session, model_type=args.model_type)
|
|
serve(session, port=args.port, model_type=args.model_type)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|