This commit is contained in:
JoshADC 2026-03-06 10:09:01 +08:00 committed by GitHub
commit ca4485754e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 482 additions and 0 deletions

View File

@ -1,5 +1,6 @@
"""Base runner implementation for ONNX models."""
import json
import logging
import os
import platform
@ -10,6 +11,11 @@ from typing import Any
import numpy as np
import onnxruntime as ort
try:
import zmq as _zmq
except ImportError:
_zmq = None
from frigate.util.model import get_ort_providers
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
@ -546,12 +552,213 @@ class RKNNModelRunner(BaseModelRunner):
pass
class ZmqEmbeddingRunner(BaseModelRunner):
"""Send preprocessed embedding tensors over ZMQ to an external inference service.
This enables offloading ONNX embedding inference (e.g. ArcFace face recognition,
Jina semantic search) to a native host process that has access to hardware
acceleration unavailable inside Docker, such as CoreML/ANE on Apple Silicon.
Protocol:
- Request is a multipart message: [ header_json_bytes, tensor_bytes ]
where header is:
{
"shape": List[int], # e.g. [1, 3, 112, 112]
"dtype": str, # numpy dtype, e.g. "float32"
"model_type": str, # e.g. "arcface"
}
tensor_bytes are the raw C-order bytes of the input tensor.
- Response is either:
a) Multipart [ header_json_bytes, embedding_bytes ] with header specifying
shape and dtype of the returned embedding; or
b) Single frame of raw float32 bytes (embedding vector, batch-first).
On timeout or error, a zero embedding is returned so the caller can degrade
gracefully (the face will simply not be recognized for that frame).
Configuration example (face_recognition.device):
face_recognition:
enabled: true
model_size: large
device: "zmq://host.docker.internal:5556"
"""
# Model type → primary input name (used to answer get_input_names())
_INPUT_NAMES: dict[str, list[str]] = {}
# Model type → model input spatial width
_INPUT_WIDTHS: dict[str, int] = {}
# Model type → embedding output dimensionality (used for zero-fallback shape)
_OUTPUT_DIMS: dict[str, int] = {}
@classmethod
def _init_model_maps(cls) -> None:
"""Populate the model maps lazily to avoid circular imports at module load."""
if cls._INPUT_NAMES:
return
from frigate.embeddings.types import EnrichmentModelTypeEnum
cls._INPUT_NAMES = {
EnrichmentModelTypeEnum.arcface.value: ["data"],
EnrichmentModelTypeEnum.facenet.value: ["data"],
EnrichmentModelTypeEnum.jina_v1.value: ["pixel_values"],
EnrichmentModelTypeEnum.jina_v2.value: ["pixel_values"],
}
cls._INPUT_WIDTHS = {
EnrichmentModelTypeEnum.arcface.value: 112,
EnrichmentModelTypeEnum.facenet.value: 160,
EnrichmentModelTypeEnum.jina_v1.value: 224,
EnrichmentModelTypeEnum.jina_v2.value: 224,
}
cls._OUTPUT_DIMS = {
EnrichmentModelTypeEnum.arcface.value: 512,
EnrichmentModelTypeEnum.facenet.value: 128,
EnrichmentModelTypeEnum.jina_v1.value: 768,
EnrichmentModelTypeEnum.jina_v2.value: 768,
}
def __init__(
self,
endpoint: str,
model_type: str,
request_timeout_ms: int = 60000,
linger_ms: int = 0,
):
if _zmq is None:
raise ImportError(
"pyzmq is required for ZmqEmbeddingRunner. Install it with: pip install pyzmq"
)
self._init_model_maps()
# "zmq://host:port" is the Frigate config sentinel; ZMQ sockets need "tcp://host:port"
self._endpoint = endpoint.replace("zmq://", "tcp://", 1)
self._model_type = model_type
self._request_timeout_ms = request_timeout_ms
self._linger_ms = linger_ms
self._context = _zmq.Context()
self._socket = None
self._needs_reset = False
self._lock = threading.Lock()
self._create_socket()
logger.info(
f"ZmqEmbeddingRunner({model_type}): connected to {endpoint}"
)
def _create_socket(self) -> None:
if self._socket is not None:
try:
self._socket.close(linger=self._linger_ms)
except Exception:
pass
self._socket = self._context.socket(_zmq.REQ)
self._socket.setsockopt(_zmq.RCVTIMEO, self._request_timeout_ms)
self._socket.setsockopt(_zmq.SNDTIMEO, self._request_timeout_ms)
self._socket.setsockopt(_zmq.LINGER, self._linger_ms)
self._socket.connect(self._endpoint)
def get_input_names(self) -> list[str]:
return self._INPUT_NAMES.get(self._model_type, ["data"])
def get_input_width(self) -> int:
return self._INPUT_WIDTHS.get(self._model_type, -1)
def run(self, inputs: dict[str, Any]) -> list[np.ndarray]:
"""Send the primary input tensor over ZMQ and return the embedding.
For single-input models (ArcFace, FaceNet) the entire inputs dict maps to
one tensor. For multi-input models only the first tensor is sent; those
models are not yet supported for ZMQ offload.
"""
tensor_input = np.ascontiguousarray(next(iter(inputs.values())))
batch_size = tensor_input.shape[0]
with self._lock:
# Lazy reset: if a previous call errored, reset the socket now — before any
# ZMQ operations — so we don't manipulate sockets inside an error handler where
# Frigate's own ZMQ threads may be polling and could hit a libzmq assertion.
# The lock ensures only one thread touches the socket at a time (ZMQ REQ
# sockets are not thread-safe; concurrent calls from the reindex thread and
# the normal embedding maintainer thread would corrupt the socket state).
if self._needs_reset:
self._reset_socket()
self._needs_reset = False
try:
header = {
"shape": list(tensor_input.shape),
"dtype": str(tensor_input.dtype.name),
"model_type": self._model_type,
}
header_bytes = json.dumps(header).encode("utf-8")
payload_bytes = memoryview(tensor_input.tobytes(order="C"))
self._socket.send_multipart([header_bytes, payload_bytes])
reply_frames = self._socket.recv_multipart()
return self._decode_response(reply_frames)
except _zmq.Again:
logger.warning(
f"ZmqEmbeddingRunner({self._model_type}): request timed out, will reset socket before next call"
)
self._needs_reset = True
return [np.zeros((batch_size, self._get_output_dim()), dtype=np.float32)]
except _zmq.ZMQError as exc:
logger.error(f"ZmqEmbeddingRunner({self._model_type}) ZMQError: {exc}, will reset socket before next call")
self._needs_reset = True
return [np.zeros((batch_size, self._get_output_dim()), dtype=np.float32)]
except Exception as exc:
logger.error(f"ZmqEmbeddingRunner({self._model_type}) unexpected error: {exc}")
return [np.zeros((batch_size, self._get_output_dim()), dtype=np.float32)]
def _reset_socket(self) -> None:
try:
self._create_socket()
except Exception:
pass
def _decode_response(self, frames: list[bytes]) -> list[np.ndarray]:
try:
if len(frames) >= 2:
header = json.loads(frames[0].decode("utf-8"))
shape = tuple(header.get("shape", []))
dtype = np.dtype(header.get("dtype", "float32"))
return [np.frombuffer(frames[1], dtype=dtype).reshape(shape)]
elif len(frames) == 1:
# Raw float32 bytes — reshape to (1, embedding_dim)
arr = np.frombuffer(frames[0], dtype=np.float32)
return [arr.reshape((1, -1))]
else:
logger.warning(f"ZmqEmbeddingRunner({self._model_type}): empty reply")
return [np.zeros((1, self._get_output_dim()), dtype=np.float32)]
except Exception as exc:
logger.error(
f"ZmqEmbeddingRunner({self._model_type}): failed to decode response: {exc}"
)
return [np.zeros((1, self._get_output_dim()), dtype=np.float32)]
def _get_output_dim(self) -> int:
return self._OUTPUT_DIMS.get(self._model_type, 512)
def __del__(self) -> None:
try:
if self._socket is not None:
self._socket.close(linger=self._linger_ms)
except Exception:
pass
def get_optimized_runner(
model_path: str, device: str | None, model_type: str, **kwargs
) -> BaseModelRunner:
"""Get an optimized runner for the hardware."""
device = device or "AUTO"
# ZMQ embedding runner — offloads ONNX inference to a native host process.
# Triggered when device is a ZMQ endpoint, e.g. "zmq://host.docker.internal:5556".
if device.startswith("zmq://"):
return ZmqEmbeddingRunner(endpoint=device, model_type=model_type)
if device != "CPU" and is_rknn_compatible(model_path):
rknn_path = auto_convert_model(model_path)

View File

@ -0,0 +1,275 @@
"""ZMQ Embedding Server — native Mac (Apple Silicon) inference service.
Runs ONNX models using hardware acceleration unavailable inside Docker on macOS,
specifically CoreML and the Apple Neural Engine. Frigate's Docker container
connects to this server over ZMQ TCP, sends preprocessed tensors, and receives
embedding vectors back.
Supported models:
- ArcFace (face recognition, 512-dim output)
- FaceNet (face recognition, 128-dim output)
- Jina V1/V2 vision (semantic search, 768-dim output)
Requirements (install outside Docker, on the Mac host):
pip install onnxruntime pyzmq numpy
Usage:
# ArcFace face recognition (port 5556):
python tools/zmq_embedding_server.py \\
--model /config/model_cache/facedet/arcface.onnx \\
--model-type arcface \\
--port 5556
# Jina V1 vision semantic search (port 5557):
python tools/zmq_embedding_server.py \\
--model /config/model_cache/jinaai/jina-clip-v1/vision_model_quantized.onnx \\
--model-type jina_v1 \\
--port 5557
Frigate config (docker-compose / config.yaml):
face_recognition:
enabled: true
model_size: large
device: "zmq://host.docker.internal:5556"
semantic_search:
enabled: true
model_size: small
device: "zmq://host.docker.internal:5557"
Protocol (REQ/REP):
Request: multipart [ header_json_bytes, tensor_bytes ]
header = {
"shape": [batch, channels, height, width], # e.g. [1, 3, 112, 112]
"dtype": "float32",
"model_type": "arcface",
}
tensor_bytes = raw C-order numpy bytes
Response: multipart [ header_json_bytes, embedding_bytes ]
header = {
"shape": [batch, embedding_dim], # e.g. [1, 512]
"dtype": "float32",
}
embedding_bytes = raw C-order numpy bytes
"""
import argparse
import json
import logging
import os
import signal
import sys
import time
import numpy as np
import zmq
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logger = logging.getLogger("zmq_embedding_server")
# Models that require ORT_ENABLE_BASIC optimization to avoid graph fusion issues
# (e.g. SimplifiedLayerNormFusion creates nodes that some providers can't handle).
_COMPLEX_MODELS = {"jina_v1", "jina_v2"}
# ---------------------------------------------------------------------------
# ONNX Runtime session (CoreML preferred on Apple Silicon)
# ---------------------------------------------------------------------------
def build_ort_session(model_path: str, model_type: str = ""):
"""Create an ONNX Runtime InferenceSession, preferring CoreML on macOS.
Jina V1/V2 models use ORT_ENABLE_BASIC graph optimization to avoid
fusion passes (e.g. SimplifiedLayerNormFusion) that produce unsupported
nodes. All other models use the default ORT_ENABLE_ALL.
"""
import onnxruntime as ort
available = ort.get_available_providers()
logger.info(f"Available ORT providers: {available}")
# Prefer CoreMLExecutionProvider on Apple Silicon for ANE/GPU acceleration.
# Falls back automatically to CPUExecutionProvider if CoreML is unavailable.
preferred = []
if "CoreMLExecutionProvider" in available:
preferred.append("CoreMLExecutionProvider")
logger.info("Using CoreMLExecutionProvider (Apple Neural Engine / GPU)")
else:
logger.warning(
"CoreMLExecutionProvider not available — falling back to CPU. "
"Install onnxruntime-silicon or a CoreML-enabled onnxruntime build."
)
preferred.append("CPUExecutionProvider")
sess_options = None
if model_type in _COMPLEX_MODELS:
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
)
logger.info(f"Using ORT_ENABLE_BASIC optimization for {model_type}")
session = ort.InferenceSession(model_path, sess_options=sess_options, providers=preferred)
input_names = [inp.name for inp in session.get_inputs()]
output_names = [out.name for out in session.get_outputs()]
logger.info(f"Model loaded: inputs={input_names}, outputs={output_names}")
return session
# ---------------------------------------------------------------------------
# Inference helpers
# ---------------------------------------------------------------------------
def run_arcface(session, tensor: np.ndarray) -> np.ndarray:
"""Run ArcFace — input (1, 3, 112, 112) float32, output (1, 512) float32."""
outputs = session.run(None, {"data": tensor})
return outputs[0] # shape (1, 512)
def run_generic(session, tensor: np.ndarray) -> np.ndarray:
"""Generic single-input ONNX model runner."""
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: tensor})
return outputs[0]
_RUNNERS = {
"arcface": run_arcface,
"facenet": run_generic,
"jina_v1": run_generic,
"jina_v2": run_generic,
}
# Model type → input shape for warmup inference (triggers CoreML JIT compilation
# before the first real request arrives, avoiding a ZMQ timeout on cold start).
_WARMUP_SHAPES = {
"arcface": (1, 3, 112, 112),
"facenet": (1, 3, 160, 160),
"jina_v1": (1, 3, 224, 224),
"jina_v2": (1, 3, 224, 224),
}
def warmup(session, model_type: str) -> None:
"""Run a dummy inference to trigger CoreML JIT compilation."""
shape = _WARMUP_SHAPES.get(model_type)
if shape is None:
return
logger.info(f"Warming up CoreML model ({model_type})…")
dummy = np.zeros(shape, dtype=np.float32)
try:
runner = _RUNNERS.get(model_type, run_generic)
runner(session, dummy)
logger.info("Warmup complete")
except Exception as exc:
logger.warning(f"Warmup failed (non-fatal): {exc}")
# ---------------------------------------------------------------------------
# ZMQ server loop
# ---------------------------------------------------------------------------
def serve(session, port: int, model_type: str) -> None:
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://0.0.0.0:{port}")
logger.info(f"Listening on tcp://0.0.0.0:{port} (model_type={model_type})")
runner = _RUNNERS.get(model_type, run_generic)
def _shutdown(sig, frame):
logger.info("Shutting down…")
socket.close(linger=0)
context.term()
sys.exit(0)
signal.signal(signal.SIGINT, _shutdown)
signal.signal(signal.SIGTERM, _shutdown)
while True:
try:
frames = socket.recv_multipart()
except zmq.ZMQError as exc:
logger.error(f"recv error: {exc}")
continue
if len(frames) < 2:
logger.warning(f"Received unexpected frame count: {len(frames)}, ignoring")
socket.send_multipart([b"{}"])
continue
try:
header = json.loads(frames[0].decode("utf-8"))
shape = tuple(header["shape"])
dtype = np.dtype(header.get("dtype", "float32"))
tensor = np.frombuffer(frames[1], dtype=dtype).reshape(shape)
except Exception as exc:
logger.error(f"Failed to decode request: {exc}")
socket.send_multipart([b"{}"])
continue
try:
t0 = time.monotonic()
embedding = runner(session, tensor)
elapsed_ms = (time.monotonic() - t0) * 1000
if elapsed_ms > 2000:
logger.warning(f"slow inference {elapsed_ms:.1f}ms shape={shape}")
resp_header = json.dumps(
{"shape": list(embedding.shape), "dtype": str(embedding.dtype.name)}
).encode("utf-8")
resp_payload = memoryview(np.ascontiguousarray(embedding).tobytes())
socket.send_multipart([resp_header, resp_payload])
except Exception as exc:
logger.error(f"Inference error: {exc}")
# Return a zero embedding so the client can degrade gracefully
zero = np.zeros((1, 512), dtype=np.float32)
resp_header = json.dumps(
{"shape": list(zero.shape), "dtype": "float32"}
).encode("utf-8")
socket.send_multipart([resp_header, memoryview(zero.tobytes())])
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="ZMQ Embedding Server for Frigate")
parser.add_argument(
"--model",
required=True,
help="Path to the ONNX model file (e.g. /config/model_cache/facedet/arcface.onnx)",
)
parser.add_argument(
"--model-type",
default="arcface",
choices=list(_RUNNERS.keys()),
help="Model type key (default: arcface)",
)
parser.add_argument(
"--port",
type=int,
default=5556,
help="TCP port to listen on (default: 5556)",
)
args = parser.parse_args()
if not os.path.exists(args.model):
logger.error(f"Model file not found: {args.model}")
sys.exit(1)
logger.info(f"Loading model: {args.model}")
session = build_ort_session(args.model, model_type=args.model_type)
warmup(session, model_type=args.model_type)
serve(session, port=args.port, model_type=args.model_type)
if __name__ == "__main__":
main()