mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-10 02:29:19 +03:00
feat: ZMQ embedding runner for offloading ONNX inference to native host
Extends the ZMQ split-detector pattern (apple-silicon-detector) to cover ONNX embedding models — ArcFace face recognition and Jina semantic search. On macOS, Docker has no access to CoreML or the Apple Neural Engine, so embedding inference is forced to CPU (~200ms/face for ArcFace). This adds a ZmqEmbeddingRunner that sends preprocessed tensors to a native host process over ZMQ TCP and receives embeddings back, enabling CoreML/ANE acceleration outside the container. Files changed: - frigate/detectors/detection_runners.py: add ZmqEmbeddingRunner class and hook into get_optimized_runner() via "zmq://" device prefix - tools/zmq_embedding_server.py: new host-side server script Tested on Mac Mini M4, 24h soak test, ~5000 object reindex.
This commit is contained in:
parent
4dcd2968b3
commit
a2c43ad8bb
@ -1,5 +1,6 @@
|
|||||||
"""Base runner implementation for ONNX models."""
|
"""Base runner implementation for ONNX models."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@ -10,6 +11,11 @@ from typing import Any
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
try:
|
||||||
|
import zmq as _zmq
|
||||||
|
except ImportError:
|
||||||
|
_zmq = None
|
||||||
|
|
||||||
from frigate.util.model import get_ort_providers
|
from frigate.util.model import get_ort_providers
|
||||||
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
|
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
|
||||||
|
|
||||||
@ -548,12 +554,213 @@ class RKNNModelRunner(BaseModelRunner):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ZmqEmbeddingRunner(BaseModelRunner):
|
||||||
|
"""Send preprocessed embedding tensors over ZMQ to an external inference service.
|
||||||
|
|
||||||
|
This enables offloading ONNX embedding inference (e.g. ArcFace face recognition,
|
||||||
|
Jina semantic search) to a native host process that has access to hardware
|
||||||
|
acceleration unavailable inside Docker, such as CoreML/ANE on Apple Silicon.
|
||||||
|
|
||||||
|
Protocol:
|
||||||
|
- Request is a multipart message: [ header_json_bytes, tensor_bytes ]
|
||||||
|
where header is:
|
||||||
|
{
|
||||||
|
"shape": List[int], # e.g. [1, 3, 112, 112]
|
||||||
|
"dtype": str, # numpy dtype, e.g. "float32"
|
||||||
|
"model_type": str, # e.g. "arcface"
|
||||||
|
}
|
||||||
|
tensor_bytes are the raw C-order bytes of the input tensor.
|
||||||
|
|
||||||
|
- Response is either:
|
||||||
|
a) Multipart [ header_json_bytes, embedding_bytes ] with header specifying
|
||||||
|
shape and dtype of the returned embedding; or
|
||||||
|
b) Single frame of raw float32 bytes (embedding vector, batch-first).
|
||||||
|
|
||||||
|
On timeout or error, a zero embedding is returned so the caller can degrade
|
||||||
|
gracefully (the face will simply not be recognized for that frame).
|
||||||
|
|
||||||
|
Configuration example (face_recognition.device):
|
||||||
|
face_recognition:
|
||||||
|
enabled: true
|
||||||
|
model_size: large
|
||||||
|
device: "zmq://host.docker.internal:5556"
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model type → primary input name (used to answer get_input_names())
|
||||||
|
_INPUT_NAMES: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
# Model type → model input spatial width
|
||||||
|
_INPUT_WIDTHS: dict[str, int] = {}
|
||||||
|
|
||||||
|
# Model type → embedding output dimensionality (used for zero-fallback shape)
|
||||||
|
_OUTPUT_DIMS: dict[str, int] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _init_model_maps(cls) -> None:
|
||||||
|
"""Populate the model maps lazily to avoid circular imports at module load."""
|
||||||
|
if cls._INPUT_NAMES:
|
||||||
|
return
|
||||||
|
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||||
|
|
||||||
|
cls._INPUT_NAMES = {
|
||||||
|
EnrichmentModelTypeEnum.arcface.value: ["data"],
|
||||||
|
EnrichmentModelTypeEnum.facenet.value: ["data"],
|
||||||
|
EnrichmentModelTypeEnum.jina_v1.value: ["pixel_values"],
|
||||||
|
EnrichmentModelTypeEnum.jina_v2.value: ["pixel_values"],
|
||||||
|
}
|
||||||
|
cls._INPUT_WIDTHS = {
|
||||||
|
EnrichmentModelTypeEnum.arcface.value: 112,
|
||||||
|
EnrichmentModelTypeEnum.facenet.value: 160,
|
||||||
|
EnrichmentModelTypeEnum.jina_v1.value: 224,
|
||||||
|
EnrichmentModelTypeEnum.jina_v2.value: 224,
|
||||||
|
}
|
||||||
|
cls._OUTPUT_DIMS = {
|
||||||
|
EnrichmentModelTypeEnum.arcface.value: 512,
|
||||||
|
EnrichmentModelTypeEnum.facenet.value: 128,
|
||||||
|
EnrichmentModelTypeEnum.jina_v1.value: 768,
|
||||||
|
EnrichmentModelTypeEnum.jina_v2.value: 768,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
model_type: str,
|
||||||
|
request_timeout_ms: int = 60000,
|
||||||
|
linger_ms: int = 0,
|
||||||
|
):
|
||||||
|
if _zmq is None:
|
||||||
|
raise ImportError(
|
||||||
|
"pyzmq is required for ZmqEmbeddingRunner. Install it with: pip install pyzmq"
|
||||||
|
)
|
||||||
|
self._init_model_maps()
|
||||||
|
# "zmq://host:port" is the Frigate config sentinel; ZMQ sockets need "tcp://host:port"
|
||||||
|
self._endpoint = endpoint.replace("zmq://", "tcp://", 1)
|
||||||
|
self._model_type = model_type
|
||||||
|
self._request_timeout_ms = request_timeout_ms
|
||||||
|
self._linger_ms = linger_ms
|
||||||
|
self._context = _zmq.Context()
|
||||||
|
self._socket = None
|
||||||
|
self._needs_reset = False
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._create_socket()
|
||||||
|
logger.info(
|
||||||
|
f"ZmqEmbeddingRunner({model_type}): connected to {endpoint}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_socket(self) -> None:
|
||||||
|
if self._socket is not None:
|
||||||
|
try:
|
||||||
|
self._socket.close(linger=self._linger_ms)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._socket = self._context.socket(_zmq.REQ)
|
||||||
|
self._socket.setsockopt(_zmq.RCVTIMEO, self._request_timeout_ms)
|
||||||
|
self._socket.setsockopt(_zmq.SNDTIMEO, self._request_timeout_ms)
|
||||||
|
self._socket.setsockopt(_zmq.LINGER, self._linger_ms)
|
||||||
|
self._socket.connect(self._endpoint)
|
||||||
|
|
||||||
|
def get_input_names(self) -> list[str]:
|
||||||
|
return self._INPUT_NAMES.get(self._model_type, ["data"])
|
||||||
|
|
||||||
|
def get_input_width(self) -> int:
|
||||||
|
return self._INPUT_WIDTHS.get(self._model_type, -1)
|
||||||
|
|
||||||
|
def run(self, inputs: dict[str, Any]) -> list[np.ndarray]:
|
||||||
|
"""Send the primary input tensor over ZMQ and return the embedding.
|
||||||
|
|
||||||
|
For single-input models (ArcFace, FaceNet) the entire inputs dict maps to
|
||||||
|
one tensor. For multi-input models only the first tensor is sent; those
|
||||||
|
models are not yet supported for ZMQ offload.
|
||||||
|
"""
|
||||||
|
tensor_input = np.ascontiguousarray(next(iter(inputs.values())))
|
||||||
|
batch_size = tensor_input.shape[0]
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# Lazy reset: if a previous call errored, reset the socket now — before any
|
||||||
|
# ZMQ operations — so we don't manipulate sockets inside an error handler where
|
||||||
|
# Frigate's own ZMQ threads may be polling and could hit a libzmq assertion.
|
||||||
|
# The lock ensures only one thread touches the socket at a time (ZMQ REQ
|
||||||
|
# sockets are not thread-safe; concurrent calls from the reindex thread and
|
||||||
|
# the normal embedding maintainer thread would corrupt the socket state).
|
||||||
|
if self._needs_reset:
|
||||||
|
self._reset_socket()
|
||||||
|
self._needs_reset = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
header = {
|
||||||
|
"shape": list(tensor_input.shape),
|
||||||
|
"dtype": str(tensor_input.dtype.name),
|
||||||
|
"model_type": self._model_type,
|
||||||
|
}
|
||||||
|
header_bytes = json.dumps(header).encode("utf-8")
|
||||||
|
payload_bytes = memoryview(tensor_input.tobytes(order="C"))
|
||||||
|
|
||||||
|
self._socket.send_multipart([header_bytes, payload_bytes])
|
||||||
|
reply_frames = self._socket.recv_multipart()
|
||||||
|
return self._decode_response(reply_frames)
|
||||||
|
|
||||||
|
except _zmq.Again:
|
||||||
|
logger.warning(
|
||||||
|
f"ZmqEmbeddingRunner({self._model_type}): request timed out, will reset socket before next call"
|
||||||
|
)
|
||||||
|
self._needs_reset = True
|
||||||
|
return [np.zeros((batch_size, self._get_output_dim()), dtype=np.float32)]
|
||||||
|
except _zmq.ZMQError as exc:
|
||||||
|
logger.error(f"ZmqEmbeddingRunner({self._model_type}) ZMQError: {exc}, will reset socket before next call")
|
||||||
|
self._needs_reset = True
|
||||||
|
return [np.zeros((batch_size, self._get_output_dim()), dtype=np.float32)]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"ZmqEmbeddingRunner({self._model_type}) unexpected error: {exc}")
|
||||||
|
return [np.zeros((batch_size, self._get_output_dim()), dtype=np.float32)]
|
||||||
|
|
||||||
|
def _reset_socket(self) -> None:
|
||||||
|
try:
|
||||||
|
self._create_socket()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _decode_response(self, frames: list[bytes]) -> list[np.ndarray]:
|
||||||
|
try:
|
||||||
|
if len(frames) >= 2:
|
||||||
|
header = json.loads(frames[0].decode("utf-8"))
|
||||||
|
shape = tuple(header.get("shape", []))
|
||||||
|
dtype = np.dtype(header.get("dtype", "float32"))
|
||||||
|
return [np.frombuffer(frames[1], dtype=dtype).reshape(shape)]
|
||||||
|
elif len(frames) == 1:
|
||||||
|
# Raw float32 bytes — reshape to (1, embedding_dim)
|
||||||
|
arr = np.frombuffer(frames[0], dtype=np.float32)
|
||||||
|
return [arr.reshape((1, -1))]
|
||||||
|
else:
|
||||||
|
logger.warning(f"ZmqEmbeddingRunner({self._model_type}): empty reply")
|
||||||
|
return [np.zeros((1, self._get_output_dim()), dtype=np.float32)]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
f"ZmqEmbeddingRunner({self._model_type}): failed to decode response: {exc}"
|
||||||
|
)
|
||||||
|
return [np.zeros((1, self._get_output_dim()), dtype=np.float32)]
|
||||||
|
|
||||||
|
def _get_output_dim(self) -> int:
|
||||||
|
return self._OUTPUT_DIMS.get(self._model_type, 512)
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
try:
|
||||||
|
if self._socket is not None:
|
||||||
|
self._socket.close(linger=self._linger_ms)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_optimized_runner(
|
def get_optimized_runner(
|
||||||
model_path: str, device: str | None, model_type: str, **kwargs
|
model_path: str, device: str | None, model_type: str, **kwargs
|
||||||
) -> BaseModelRunner:
|
) -> BaseModelRunner:
|
||||||
"""Get an optimized runner for the hardware."""
|
"""Get an optimized runner for the hardware."""
|
||||||
device = device or "AUTO"
|
device = device or "AUTO"
|
||||||
|
|
||||||
|
# ZMQ embedding runner — offloads ONNX inference to a native host process.
|
||||||
|
# Triggered when device is a ZMQ endpoint, e.g. "zmq://host.docker.internal:5556".
|
||||||
|
if device.startswith("zmq://"):
|
||||||
|
return ZmqEmbeddingRunner(endpoint=device, model_type=model_type)
|
||||||
|
|
||||||
if device != "CPU" and is_rknn_compatible(model_path):
|
if device != "CPU" and is_rknn_compatible(model_path):
|
||||||
rknn_path = auto_convert_model(model_path)
|
rknn_path = auto_convert_model(model_path)
|
||||||
|
|
||||||
|
|||||||
275
tools/zmq_embedding_server.py
Normal file
275
tools/zmq_embedding_server.py
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
"""ZMQ Embedding Server — native Mac (Apple Silicon) inference service.
|
||||||
|
|
||||||
|
Runs ONNX models using hardware acceleration unavailable inside Docker on macOS,
|
||||||
|
specifically CoreML and the Apple Neural Engine. Frigate's Docker container
|
||||||
|
connects to this server over ZMQ TCP, sends preprocessed tensors, and receives
|
||||||
|
embedding vectors back.
|
||||||
|
|
||||||
|
Supported models:
|
||||||
|
- ArcFace (face recognition, 512-dim output)
|
||||||
|
- FaceNet (face recognition, 128-dim output)
|
||||||
|
- Jina V1/V2 vision (semantic search, 768-dim output)
|
||||||
|
|
||||||
|
Requirements (install outside Docker, on the Mac host):
|
||||||
|
pip install onnxruntime pyzmq numpy
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# ArcFace face recognition (port 5556):
|
||||||
|
python tools/zmq_embedding_server.py \\
|
||||||
|
--model /config/model_cache/facedet/arcface.onnx \\
|
||||||
|
--model-type arcface \\
|
||||||
|
--port 5556
|
||||||
|
|
||||||
|
# Jina V1 vision semantic search (port 5557):
|
||||||
|
python tools/zmq_embedding_server.py \\
|
||||||
|
--model /config/model_cache/jinaai/jina-clip-v1/vision_model_quantized.onnx \\
|
||||||
|
--model-type jina_v1 \\
|
||||||
|
--port 5557
|
||||||
|
|
||||||
|
Frigate config (docker-compose / config.yaml):
|
||||||
|
face_recognition:
|
||||||
|
enabled: true
|
||||||
|
model_size: large
|
||||||
|
device: "zmq://host.docker.internal:5556"
|
||||||
|
|
||||||
|
semantic_search:
|
||||||
|
enabled: true
|
||||||
|
model_size: small
|
||||||
|
device: "zmq://host.docker.internal:5557"
|
||||||
|
|
||||||
|
Protocol (REQ/REP):
|
||||||
|
Request: multipart [ header_json_bytes, tensor_bytes ]
|
||||||
|
header = {
|
||||||
|
"shape": [batch, channels, height, width], # e.g. [1, 3, 112, 112]
|
||||||
|
"dtype": "float32",
|
||||||
|
"model_type": "arcface",
|
||||||
|
}
|
||||||
|
tensor_bytes = raw C-order numpy bytes
|
||||||
|
|
||||||
|
Response: multipart [ header_json_bytes, embedding_bytes ]
|
||||||
|
header = {
|
||||||
|
"shape": [batch, embedding_dim], # e.g. [1, 512]
|
||||||
|
"dtype": "float32",
|
||||||
|
}
|
||||||
|
embedding_bytes = raw C-order numpy bytes
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("zmq_embedding_server")
|
||||||
|
|
||||||
|
|
||||||
|
# Models that require ORT_ENABLE_BASIC optimization to avoid graph fusion issues
|
||||||
|
# (e.g. SimplifiedLayerNormFusion creates nodes that some providers can't handle).
|
||||||
|
_COMPLEX_MODELS = {"jina_v1", "jina_v2"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ONNX Runtime session (CoreML preferred on Apple Silicon)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def build_ort_session(model_path: str, model_type: str = ""):
|
||||||
|
"""Create an ONNX Runtime InferenceSession, preferring CoreML on macOS.
|
||||||
|
|
||||||
|
Jina V1/V2 models use ORT_ENABLE_BASIC graph optimization to avoid
|
||||||
|
fusion passes (e.g. SimplifiedLayerNormFusion) that produce unsupported
|
||||||
|
nodes. All other models use the default ORT_ENABLE_ALL.
|
||||||
|
"""
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
available = ort.get_available_providers()
|
||||||
|
logger.info(f"Available ORT providers: {available}")
|
||||||
|
|
||||||
|
# Prefer CoreMLExecutionProvider on Apple Silicon for ANE/GPU acceleration.
|
||||||
|
# Falls back automatically to CPUExecutionProvider if CoreML is unavailable.
|
||||||
|
preferred = []
|
||||||
|
if "CoreMLExecutionProvider" in available:
|
||||||
|
preferred.append("CoreMLExecutionProvider")
|
||||||
|
logger.info("Using CoreMLExecutionProvider (Apple Neural Engine / GPU)")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"CoreMLExecutionProvider not available — falling back to CPU. "
|
||||||
|
"Install onnxruntime-silicon or a CoreML-enabled onnxruntime build."
|
||||||
|
)
|
||||||
|
|
||||||
|
preferred.append("CPUExecutionProvider")
|
||||||
|
|
||||||
|
sess_options = None
|
||||||
|
if model_type in _COMPLEX_MODELS:
|
||||||
|
sess_options = ort.SessionOptions()
|
||||||
|
sess_options.graph_optimization_level = (
|
||||||
|
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||||
|
)
|
||||||
|
logger.info(f"Using ORT_ENABLE_BASIC optimization for {model_type}")
|
||||||
|
|
||||||
|
session = ort.InferenceSession(model_path, sess_options=sess_options, providers=preferred)
|
||||||
|
|
||||||
|
input_names = [inp.name for inp in session.get_inputs()]
|
||||||
|
output_names = [out.name for out in session.get_outputs()]
|
||||||
|
logger.info(f"Model loaded: inputs={input_names}, outputs={output_names}")
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Inference helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def run_arcface(session, tensor: np.ndarray) -> np.ndarray:
|
||||||
|
"""Run ArcFace — input (1, 3, 112, 112) float32, output (1, 512) float32."""
|
||||||
|
outputs = session.run(None, {"data": tensor})
|
||||||
|
return outputs[0] # shape (1, 512)
|
||||||
|
|
||||||
|
|
||||||
|
def run_generic(session, tensor: np.ndarray) -> np.ndarray:
|
||||||
|
"""Generic single-input ONNX model runner."""
|
||||||
|
input_name = session.get_inputs()[0].name
|
||||||
|
outputs = session.run(None, {input_name: tensor})
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
|
|
||||||
|
_RUNNERS = {
|
||||||
|
"arcface": run_arcface,
|
||||||
|
"facenet": run_generic,
|
||||||
|
"jina_v1": run_generic,
|
||||||
|
"jina_v2": run_generic,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Model type → input shape for warmup inference (triggers CoreML JIT compilation
|
||||||
|
# before the first real request arrives, avoiding a ZMQ timeout on cold start).
|
||||||
|
_WARMUP_SHAPES = {
|
||||||
|
"arcface": (1, 3, 112, 112),
|
||||||
|
"facenet": (1, 3, 160, 160),
|
||||||
|
"jina_v1": (1, 3, 224, 224),
|
||||||
|
"jina_v2": (1, 3, 224, 224),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def warmup(session, model_type: str) -> None:
|
||||||
|
"""Run a dummy inference to trigger CoreML JIT compilation."""
|
||||||
|
shape = _WARMUP_SHAPES.get(model_type)
|
||||||
|
if shape is None:
|
||||||
|
return
|
||||||
|
logger.info(f"Warming up CoreML model ({model_type})…")
|
||||||
|
dummy = np.zeros(shape, dtype=np.float32)
|
||||||
|
try:
|
||||||
|
runner = _RUNNERS.get(model_type, run_generic)
|
||||||
|
runner(session, dummy)
|
||||||
|
logger.info("Warmup complete")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"Warmup failed (non-fatal): {exc}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ZMQ server loop
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def serve(session, port: int, model_type: str) -> None:
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REP)
|
||||||
|
socket.bind(f"tcp://0.0.0.0:{port}")
|
||||||
|
logger.info(f"Listening on tcp://0.0.0.0:{port} (model_type={model_type})")
|
||||||
|
|
||||||
|
runner = _RUNNERS.get(model_type, run_generic)
|
||||||
|
|
||||||
|
def _shutdown(sig, frame):
|
||||||
|
logger.info("Shutting down…")
|
||||||
|
socket.close(linger=0)
|
||||||
|
context.term()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, _shutdown)
|
||||||
|
signal.signal(signal.SIGTERM, _shutdown)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
frames = socket.recv_multipart()
|
||||||
|
except zmq.ZMQError as exc:
|
||||||
|
logger.error(f"recv error: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(frames) < 2:
|
||||||
|
logger.warning(f"Received unexpected frame count: {len(frames)}, ignoring")
|
||||||
|
socket.send_multipart([b"{}"])
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
header = json.loads(frames[0].decode("utf-8"))
|
||||||
|
shape = tuple(header["shape"])
|
||||||
|
dtype = np.dtype(header.get("dtype", "float32"))
|
||||||
|
tensor = np.frombuffer(frames[1], dtype=dtype).reshape(shape)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Failed to decode request: {exc}")
|
||||||
|
socket.send_multipart([b"{}"])
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
t0 = time.monotonic()
|
||||||
|
embedding = runner(session, tensor)
|
||||||
|
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||||
|
if elapsed_ms > 2000:
|
||||||
|
logger.warning(f"slow inference {elapsed_ms:.1f}ms shape={shape}")
|
||||||
|
resp_header = json.dumps(
|
||||||
|
{"shape": list(embedding.shape), "dtype": str(embedding.dtype.name)}
|
||||||
|
).encode("utf-8")
|
||||||
|
resp_payload = memoryview(np.ascontiguousarray(embedding).tobytes())
|
||||||
|
socket.send_multipart([resp_header, resp_payload])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Inference error: {exc}")
|
||||||
|
# Return a zero embedding so the client can degrade gracefully
|
||||||
|
zero = np.zeros((1, 512), dtype=np.float32)
|
||||||
|
resp_header = json.dumps(
|
||||||
|
{"shape": list(zero.shape), "dtype": "float32"}
|
||||||
|
).encode("utf-8")
|
||||||
|
socket.send_multipart([resp_header, memoryview(zero.tobytes())])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="ZMQ Embedding Server for Frigate")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
required=True,
|
||||||
|
help="Path to the ONNX model file (e.g. /config/model_cache/facedet/arcface.onnx)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-type",
|
||||||
|
default="arcface",
|
||||||
|
choices=list(_RUNNERS.keys()),
|
||||||
|
help="Model type key (default: arcface)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=5556,
|
||||||
|
help="TCP port to listen on (default: 5556)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.model):
|
||||||
|
logger.error(f"Model file not found: {args.model}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
logger.info(f"Loading model: {args.model}")
|
||||||
|
session = build_ort_session(args.model, model_type=args.model_type)
|
||||||
|
warmup(session, model_type=args.model_type)
|
||||||
|
serve(session, port=args.port, model_type=args.model_type)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue
Block a user