mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-17 13:48:21 +03:00
Compare commits
5 Commits
3f183b0b8f
...
e171beeb80
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e171beeb80 | ||
|
|
352d271fe4 | ||
|
|
a6e11a59d6 | ||
|
|
a7d8d13d9a | ||
|
|
a2c43ad8bb |
@ -9,4 +9,25 @@ Snapshots are accessible in the UI in the Explore pane. This allows for quick su
|
||||
|
||||
To only save snapshots for objects that enter a specific zone, [see the zone docs](./zones.md#restricting-snapshots-to-specific-zones)
|
||||
|
||||
Snapshots sent via MQTT are configured in the [config file](https://docs.frigate.video/configuration/) under `cameras -> your_camera -> mqtt`
|
||||
Snapshots sent via MQTT are configured in the [config file](/configuration) under `cameras -> your_camera -> mqtt`
|
||||
|
||||
## Frame Selection
|
||||
|
||||
Frigate does not save every frame — it picks a single "best" frame for each tracked object and uses it for both the snapshot and clean copy. As the object is tracked across frames, Frigate continuously evaluates whether the current frame is better than the previous best based on detection confidence, object size, and the presence of key attributes like faces or license plates. Frames where the object touches the edge of the frame are deprioritized. The snapshot is written to disk once tracking ends using whichever frame was determined to be the best.
|
||||
|
||||
MQTT snapshots are published more frequently — each time a better thumbnail frame is found during tracking, or when the current best image is older than `best_image_timeout` (default: 60s). These use their own annotation settings configured under `cameras -> your_camera -> mqtt`.
|
||||
|
||||
## Clean Copy
|
||||
|
||||
Frigate can produce up to two snapshot files per event, each used in different places:
|
||||
|
||||
| Version | File | Annotations | Used by |
|
||||
| --- | --- | --- | --- |
|
||||
| **Regular snapshot** | `<camera>-<id>.jpg` | Respects your `timestamp`, `bounding_box`, `crop`, and `height` settings | API (`/api/events/<id>/snapshot.jpg`), MQTT (`<camera>/<label>/snapshot`), Explore pane in the UI |
|
||||
| **Clean copy** | `<camera>-<id>-clean.webp` | Always unannotated — no bounding box, no timestamp, no crop, full resolution | API (`/api/events/<id>/snapshot-clean.webp`), [Frigate+](/plus/first_model) submissions, "Download Clean Snapshot" in the UI |
|
||||
|
||||
MQTT snapshots are configured separately under `cameras -> your_camera -> mqtt` and are unrelated to the clean copy.
|
||||
|
||||
The clean copy is required for submitting events to [Frigate+](/plus/first_model) — if you plan to use Frigate+, keep `clean_copy` enabled regardless of your other snapshot settings.
|
||||
|
||||
If you are not using Frigate+ and `timestamp`, `bounding_box`, and `crop` are all disabled, the regular snapshot is already effectively clean, so `clean_copy` provides no benefit and only uses additional disk space. You can safely set `clean_copy: False` in this case.
|
||||
|
||||
@ -16,7 +16,15 @@ See the [MQTT integration
|
||||
documentation](https://www.home-assistant.io/integrations/mqtt/) for more
|
||||
details.
|
||||
|
||||
In addition, MQTT must be enabled in your Frigate configuration file and Frigate must be connected to the same MQTT server as Home Assistant for many of the entities created by the integration to function.
|
||||
In addition, MQTT must be enabled in your Frigate configuration file and Frigate must be connected to the same MQTT server as Home Assistant for many of the entities created by the integration to function, e.g.:
|
||||
|
||||
```yaml
|
||||
mqtt:
|
||||
enabled: True
|
||||
host: mqtt.server.com # the address of your HA server that's running the MQTT integration
|
||||
user: your_mqtt_broker_username
|
||||
password: your_mqtt_broker_password
|
||||
```
|
||||
|
||||
### Integration installation
|
||||
|
||||
@ -95,12 +103,12 @@ services:
|
||||
|
||||
If you are using Home Assistant Add-on, the URL should be one of the following depending on which Add-on variant you are using. Note that if you are using the Proxy Add-on, you should NOT point the integration at the proxy URL. Just enter the same URL used to access Frigate directly from your network.
|
||||
|
||||
| Add-on Variant | URL |
|
||||
| -------------------------- | ----------------------------------------- |
|
||||
| Frigate | `http://ccab4aaf-frigate:5000` |
|
||||
| Frigate (Full Access) | `http://ccab4aaf-frigate-fa:5000` |
|
||||
| Frigate Beta | `http://ccab4aaf-frigate-beta:5000` |
|
||||
| Frigate Beta (Full Access) | `http://ccab4aaf-frigate-fa-beta:5000` |
|
||||
| Add-on Variant | URL |
|
||||
| -------------------------- | -------------------------------------- |
|
||||
| Frigate | `http://ccab4aaf-frigate:5000` |
|
||||
| Frigate (Full Access) | `http://ccab4aaf-frigate-fa:5000` |
|
||||
| Frigate Beta | `http://ccab4aaf-frigate-beta:5000` |
|
||||
| Frigate Beta (Full Access) | `http://ccab4aaf-frigate-fa-beta:5000` |
|
||||
|
||||
### Frigate running on a separate machine
|
||||
|
||||
|
||||
@ -120,7 +120,7 @@ Message published for each changed tracked object. The first message is publishe
|
||||
|
||||
### `frigate/tracked_object_update`
|
||||
|
||||
Message published for updates to tracked object metadata, for example:
|
||||
Message published for updates to tracked object metadata. All messages include an `id` field which is the tracked object's event ID, and can be used to look up the event via the API or match it to items in the UI.
|
||||
|
||||
#### Generative AI Description Update
|
||||
|
||||
@ -134,12 +134,14 @@ Message published for updates to tracked object metadata, for example:
|
||||
|
||||
#### Face Recognition Update
|
||||
|
||||
Published after each recognition attempt, regardless of whether the score meets `recognition_threshold`. See the [Face Recognition](/configuration/face_recognition) documentation for details on how scoring works.
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "face",
|
||||
"id": "1607123955.475377-mxklsc",
|
||||
"name": "John",
|
||||
"score": 0.95,
|
||||
"name": "John", // best matching person, or null if no match
|
||||
"score": 0.95, // running weighted average across all recognition attempts
|
||||
"camera": "front_door_cam",
|
||||
"timestamp": 1607123958.748393
|
||||
}
|
||||
@ -147,11 +149,13 @@ Message published for updates to tracked object metadata, for example:
|
||||
|
||||
#### License Plate Recognition Update
|
||||
|
||||
Published when a license plate is recognized on a car object. See the [License Plate Recognition](/configuration/license_plate_recognition) documentation for details.
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "lpr",
|
||||
"id": "1607123955.475377-mxklsc",
|
||||
"name": "John's Car",
|
||||
"name": "John's Car", // known name for the plate, or null
|
||||
"plate": "123ABC",
|
||||
"score": 0.95,
|
||||
"camera": "driveway_cam",
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Base runner implementation for ONNX models."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
@ -10,6 +11,11 @@ from typing import Any
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
try:
|
||||
import zmq as _zmq
|
||||
except ImportError:
|
||||
_zmq = None
|
||||
|
||||
from frigate.util.model import get_ort_providers
|
||||
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
|
||||
|
||||
@ -548,12 +554,213 @@ class RKNNModelRunner(BaseModelRunner):
|
||||
pass
|
||||
|
||||
|
||||
class ZmqEmbeddingRunner(BaseModelRunner):
|
||||
"""Send preprocessed embedding tensors over ZMQ to an external inference service.
|
||||
|
||||
This enables offloading ONNX embedding inference (e.g. ArcFace face recognition,
|
||||
Jina semantic search) to a native host process that has access to hardware
|
||||
acceleration unavailable inside Docker, such as CoreML/ANE on Apple Silicon.
|
||||
|
||||
Protocol:
|
||||
- Request is a multipart message: [ header_json_bytes, tensor_bytes ]
|
||||
where header is:
|
||||
{
|
||||
"shape": List[int], # e.g. [1, 3, 112, 112]
|
||||
"dtype": str, # numpy dtype, e.g. "float32"
|
||||
"model_type": str, # e.g. "arcface"
|
||||
}
|
||||
tensor_bytes are the raw C-order bytes of the input tensor.
|
||||
|
||||
- Response is either:
|
||||
a) Multipart [ header_json_bytes, embedding_bytes ] with header specifying
|
||||
shape and dtype of the returned embedding; or
|
||||
b) Single frame of raw float32 bytes (embedding vector, batch-first).
|
||||
|
||||
On timeout or error, a zero embedding is returned so the caller can degrade
|
||||
gracefully (the face will simply not be recognized for that frame).
|
||||
|
||||
Configuration example (face_recognition.device):
|
||||
face_recognition:
|
||||
enabled: true
|
||||
model_size: large
|
||||
device: "zmq://host.docker.internal:5556"
|
||||
"""
|
||||
|
||||
# Model type → primary input name (used to answer get_input_names())
|
||||
_INPUT_NAMES: dict[str, list[str]] = {}
|
||||
|
||||
# Model type → model input spatial width
|
||||
_INPUT_WIDTHS: dict[str, int] = {}
|
||||
|
||||
# Model type → embedding output dimensionality (used for zero-fallback shape)
|
||||
_OUTPUT_DIMS: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def _init_model_maps(cls) -> None:
|
||||
"""Populate the model maps lazily to avoid circular imports at module load."""
|
||||
if cls._INPUT_NAMES:
|
||||
return
|
||||
from frigate.embeddings.types import EnrichmentModelTypeEnum
|
||||
|
||||
cls._INPUT_NAMES = {
|
||||
EnrichmentModelTypeEnum.arcface.value: ["data"],
|
||||
EnrichmentModelTypeEnum.facenet.value: ["data"],
|
||||
EnrichmentModelTypeEnum.jina_v1.value: ["pixel_values"],
|
||||
EnrichmentModelTypeEnum.jina_v2.value: ["pixel_values"],
|
||||
}
|
||||
cls._INPUT_WIDTHS = {
|
||||
EnrichmentModelTypeEnum.arcface.value: 112,
|
||||
EnrichmentModelTypeEnum.facenet.value: 160,
|
||||
EnrichmentModelTypeEnum.jina_v1.value: 224,
|
||||
EnrichmentModelTypeEnum.jina_v2.value: 224,
|
||||
}
|
||||
cls._OUTPUT_DIMS = {
|
||||
EnrichmentModelTypeEnum.arcface.value: 512,
|
||||
EnrichmentModelTypeEnum.facenet.value: 128,
|
||||
EnrichmentModelTypeEnum.jina_v1.value: 768,
|
||||
EnrichmentModelTypeEnum.jina_v2.value: 768,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
model_type: str,
|
||||
request_timeout_ms: int = 60000,
|
||||
linger_ms: int = 0,
|
||||
):
|
||||
if _zmq is None:
|
||||
raise ImportError(
|
||||
"pyzmq is required for ZmqEmbeddingRunner. Install it with: pip install pyzmq"
|
||||
)
|
||||
self._init_model_maps()
|
||||
# "zmq://host:port" is the Frigate config sentinel; ZMQ sockets need "tcp://host:port"
|
||||
self._endpoint = endpoint.replace("zmq://", "tcp://", 1)
|
||||
self._model_type = model_type
|
||||
self._request_timeout_ms = request_timeout_ms
|
||||
self._linger_ms = linger_ms
|
||||
self._context = _zmq.Context()
|
||||
self._socket = None
|
||||
self._needs_reset = False
|
||||
self._lock = threading.Lock()
|
||||
self._create_socket()
|
||||
logger.info(
|
||||
f"ZmqEmbeddingRunner({model_type}): connected to {endpoint}"
|
||||
)
|
||||
|
||||
def _create_socket(self) -> None:
|
||||
if self._socket is not None:
|
||||
try:
|
||||
self._socket.close(linger=self._linger_ms)
|
||||
except Exception:
|
||||
pass
|
||||
self._socket = self._context.socket(_zmq.REQ)
|
||||
self._socket.setsockopt(_zmq.RCVTIMEO, self._request_timeout_ms)
|
||||
self._socket.setsockopt(_zmq.SNDTIMEO, self._request_timeout_ms)
|
||||
self._socket.setsockopt(_zmq.LINGER, self._linger_ms)
|
||||
self._socket.connect(self._endpoint)
|
||||
|
||||
def get_input_names(self) -> list[str]:
|
||||
return self._INPUT_NAMES.get(self._model_type, ["data"])
|
||||
|
||||
def get_input_width(self) -> int:
|
||||
return self._INPUT_WIDTHS.get(self._model_type, -1)
|
||||
|
||||
def run(self, inputs: dict[str, Any]) -> list[np.ndarray]:
|
||||
"""Send the primary input tensor over ZMQ and return the embedding.
|
||||
|
||||
For single-input models (ArcFace, FaceNet) the entire inputs dict maps to
|
||||
one tensor. For multi-input models only the first tensor is sent; those
|
||||
models are not yet supported for ZMQ offload.
|
||||
"""
|
||||
tensor_input = np.ascontiguousarray(next(iter(inputs.values())))
|
||||
batch_size = tensor_input.shape[0]
|
||||
|
||||
with self._lock:
|
||||
# Lazy reset: if a previous call errored, reset the socket now — before any
|
||||
# ZMQ operations — so we don't manipulate sockets inside an error handler where
|
||||
# Frigate's own ZMQ threads may be polling and could hit a libzmq assertion.
|
||||
# The lock ensures only one thread touches the socket at a time (ZMQ REQ
|
||||
# sockets are not thread-safe; concurrent calls from the reindex thread and
|
||||
# the normal embedding maintainer thread would corrupt the socket state).
|
||||
if self._needs_reset:
|
||||
self._reset_socket()
|
||||
self._needs_reset = False
|
||||
|
||||
try:
|
||||
header = {
|
||||
"shape": list(tensor_input.shape),
|
||||
"dtype": str(tensor_input.dtype.name),
|
||||
"model_type": self._model_type,
|
||||
}
|
||||
header_bytes = json.dumps(header).encode("utf-8")
|
||||
payload_bytes = memoryview(tensor_input.tobytes(order="C"))
|
||||
|
||||
self._socket.send_multipart([header_bytes, payload_bytes])
|
||||
reply_frames = self._socket.recv_multipart()
|
||||
return self._decode_response(reply_frames)
|
||||
|
||||
except _zmq.Again:
|
||||
logger.warning(
|
||||
f"ZmqEmbeddingRunner({self._model_type}): request timed out, will reset socket before next call"
|
||||
)
|
||||
self._needs_reset = True
|
||||
return [np.zeros((batch_size, self._get_output_dim()), dtype=np.float32)]
|
||||
except _zmq.ZMQError as exc:
|
||||
logger.error(f"ZmqEmbeddingRunner({self._model_type}) ZMQError: {exc}, will reset socket before next call")
|
||||
self._needs_reset = True
|
||||
return [np.zeros((batch_size, self._get_output_dim()), dtype=np.float32)]
|
||||
except Exception as exc:
|
||||
logger.error(f"ZmqEmbeddingRunner({self._model_type}) unexpected error: {exc}")
|
||||
return [np.zeros((batch_size, self._get_output_dim()), dtype=np.float32)]
|
||||
|
||||
def _reset_socket(self) -> None:
|
||||
try:
|
||||
self._create_socket()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _decode_response(self, frames: list[bytes]) -> list[np.ndarray]:
|
||||
try:
|
||||
if len(frames) >= 2:
|
||||
header = json.loads(frames[0].decode("utf-8"))
|
||||
shape = tuple(header.get("shape", []))
|
||||
dtype = np.dtype(header.get("dtype", "float32"))
|
||||
return [np.frombuffer(frames[1], dtype=dtype).reshape(shape)]
|
||||
elif len(frames) == 1:
|
||||
# Raw float32 bytes — reshape to (1, embedding_dim)
|
||||
arr = np.frombuffer(frames[0], dtype=np.float32)
|
||||
return [arr.reshape((1, -1))]
|
||||
else:
|
||||
logger.warning(f"ZmqEmbeddingRunner({self._model_type}): empty reply")
|
||||
return [np.zeros((1, self._get_output_dim()), dtype=np.float32)]
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"ZmqEmbeddingRunner({self._model_type}): failed to decode response: {exc}"
|
||||
)
|
||||
return [np.zeros((1, self._get_output_dim()), dtype=np.float32)]
|
||||
|
||||
def _get_output_dim(self) -> int:
|
||||
return self._OUTPUT_DIMS.get(self._model_type, 512)
|
||||
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
if self._socket is not None:
|
||||
self._socket.close(linger=self._linger_ms)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_optimized_runner(
|
||||
model_path: str, device: str | None, model_type: str, **kwargs
|
||||
) -> BaseModelRunner:
|
||||
"""Get an optimized runner for the hardware."""
|
||||
device = device or "AUTO"
|
||||
|
||||
# ZMQ embedding runner — offloads ONNX inference to a native host process.
|
||||
# Triggered when device is a ZMQ endpoint, e.g. "zmq://host.docker.internal:5556".
|
||||
if device.startswith("zmq://"):
|
||||
return ZmqEmbeddingRunner(endpoint=device, model_type=model_type)
|
||||
|
||||
if device != "CPU" and is_rknn_compatible(model_path):
|
||||
rknn_path = auto_convert_model(model_path)
|
||||
|
||||
|
||||
275
tools/zmq_embedding_server.py
Normal file
275
tools/zmq_embedding_server.py
Normal file
@ -0,0 +1,275 @@
|
||||
"""ZMQ Embedding Server — native Mac (Apple Silicon) inference service.
|
||||
|
||||
Runs ONNX models using hardware acceleration unavailable inside Docker on macOS,
|
||||
specifically CoreML and the Apple Neural Engine. Frigate's Docker container
|
||||
connects to this server over ZMQ TCP, sends preprocessed tensors, and receives
|
||||
embedding vectors back.
|
||||
|
||||
Supported models:
|
||||
- ArcFace (face recognition, 512-dim output)
|
||||
- FaceNet (face recognition, 128-dim output)
|
||||
- Jina V1/V2 vision (semantic search, 768-dim output)
|
||||
|
||||
Requirements (install outside Docker, on the Mac host):
|
||||
pip install onnxruntime pyzmq numpy
|
||||
|
||||
Usage:
|
||||
# ArcFace face recognition (port 5556):
|
||||
python tools/zmq_embedding_server.py \\
|
||||
--model /config/model_cache/facedet/arcface.onnx \\
|
||||
--model-type arcface \\
|
||||
--port 5556
|
||||
|
||||
# Jina V1 vision semantic search (port 5557):
|
||||
python tools/zmq_embedding_server.py \\
|
||||
--model /config/model_cache/jinaai/jina-clip-v1/vision_model_quantized.onnx \\
|
||||
--model-type jina_v1 \\
|
||||
--port 5557
|
||||
|
||||
Frigate config (docker-compose / config.yaml):
|
||||
face_recognition:
|
||||
enabled: true
|
||||
model_size: large
|
||||
device: "zmq://host.docker.internal:5556"
|
||||
|
||||
semantic_search:
|
||||
enabled: true
|
||||
model_size: small
|
||||
device: "zmq://host.docker.internal:5557"
|
||||
|
||||
Protocol (REQ/REP):
|
||||
Request: multipart [ header_json_bytes, tensor_bytes ]
|
||||
header = {
|
||||
"shape": [batch, channels, height, width], # e.g. [1, 3, 112, 112]
|
||||
"dtype": "float32",
|
||||
"model_type": "arcface",
|
||||
}
|
||||
tensor_bytes = raw C-order numpy bytes
|
||||
|
||||
Response: multipart [ header_json_bytes, embedding_bytes ]
|
||||
header = {
|
||||
"shape": [batch, embedding_dim], # e.g. [1, 512]
|
||||
"dtype": "float32",
|
||||
}
|
||||
embedding_bytes = raw C-order numpy bytes
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("zmq_embedding_server")
|
||||
|
||||
|
||||
# Models that require ORT_ENABLE_BASIC optimization to avoid graph fusion issues
|
||||
# (e.g. SimplifiedLayerNormFusion creates nodes that some providers can't handle).
|
||||
_COMPLEX_MODELS = {"jina_v1", "jina_v2"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ONNX Runtime session (CoreML preferred on Apple Silicon)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_ort_session(model_path: str, model_type: str = ""):
|
||||
"""Create an ONNX Runtime InferenceSession, preferring CoreML on macOS.
|
||||
|
||||
Jina V1/V2 models use ORT_ENABLE_BASIC graph optimization to avoid
|
||||
fusion passes (e.g. SimplifiedLayerNormFusion) that produce unsupported
|
||||
nodes. All other models use the default ORT_ENABLE_ALL.
|
||||
"""
|
||||
import onnxruntime as ort
|
||||
|
||||
available = ort.get_available_providers()
|
||||
logger.info(f"Available ORT providers: {available}")
|
||||
|
||||
# Prefer CoreMLExecutionProvider on Apple Silicon for ANE/GPU acceleration.
|
||||
# Falls back automatically to CPUExecutionProvider if CoreML is unavailable.
|
||||
preferred = []
|
||||
if "CoreMLExecutionProvider" in available:
|
||||
preferred.append("CoreMLExecutionProvider")
|
||||
logger.info("Using CoreMLExecutionProvider (Apple Neural Engine / GPU)")
|
||||
else:
|
||||
logger.warning(
|
||||
"CoreMLExecutionProvider not available — falling back to CPU. "
|
||||
"Install onnxruntime-silicon or a CoreML-enabled onnxruntime build."
|
||||
)
|
||||
|
||||
preferred.append("CPUExecutionProvider")
|
||||
|
||||
sess_options = None
|
||||
if model_type in _COMPLEX_MODELS:
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = (
|
||||
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
)
|
||||
logger.info(f"Using ORT_ENABLE_BASIC optimization for {model_type}")
|
||||
|
||||
session = ort.InferenceSession(model_path, sess_options=sess_options, providers=preferred)
|
||||
|
||||
input_names = [inp.name for inp in session.get_inputs()]
|
||||
output_names = [out.name for out in session.get_outputs()]
|
||||
logger.info(f"Model loaded: inputs={input_names}, outputs={output_names}")
|
||||
return session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_arcface(session, tensor: np.ndarray) -> np.ndarray:
|
||||
"""Run ArcFace — input (1, 3, 112, 112) float32, output (1, 512) float32."""
|
||||
outputs = session.run(None, {"data": tensor})
|
||||
return outputs[0] # shape (1, 512)
|
||||
|
||||
|
||||
def run_generic(session, tensor: np.ndarray) -> np.ndarray:
|
||||
"""Generic single-input ONNX model runner."""
|
||||
input_name = session.get_inputs()[0].name
|
||||
outputs = session.run(None, {input_name: tensor})
|
||||
return outputs[0]
|
||||
|
||||
|
||||
_RUNNERS = {
|
||||
"arcface": run_arcface,
|
||||
"facenet": run_generic,
|
||||
"jina_v1": run_generic,
|
||||
"jina_v2": run_generic,
|
||||
}
|
||||
|
||||
# Model type → input shape for warmup inference (triggers CoreML JIT compilation
|
||||
# before the first real request arrives, avoiding a ZMQ timeout on cold start).
|
||||
_WARMUP_SHAPES = {
|
||||
"arcface": (1, 3, 112, 112),
|
||||
"facenet": (1, 3, 160, 160),
|
||||
"jina_v1": (1, 3, 224, 224),
|
||||
"jina_v2": (1, 3, 224, 224),
|
||||
}
|
||||
|
||||
|
||||
def warmup(session, model_type: str) -> None:
|
||||
"""Run a dummy inference to trigger CoreML JIT compilation."""
|
||||
shape = _WARMUP_SHAPES.get(model_type)
|
||||
if shape is None:
|
||||
return
|
||||
logger.info(f"Warming up CoreML model ({model_type})…")
|
||||
dummy = np.zeros(shape, dtype=np.float32)
|
||||
try:
|
||||
runner = _RUNNERS.get(model_type, run_generic)
|
||||
runner(session, dummy)
|
||||
logger.info("Warmup complete")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Warmup failed (non-fatal): {exc}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ZMQ server loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def serve(session, port: int, model_type: str) -> None:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://0.0.0.0:{port}")
|
||||
logger.info(f"Listening on tcp://0.0.0.0:{port} (model_type={model_type})")
|
||||
|
||||
runner = _RUNNERS.get(model_type, run_generic)
|
||||
|
||||
def _shutdown(sig, frame):
|
||||
logger.info("Shutting down…")
|
||||
socket.close(linger=0)
|
||||
context.term()
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, _shutdown)
|
||||
signal.signal(signal.SIGTERM, _shutdown)
|
||||
|
||||
while True:
|
||||
try:
|
||||
frames = socket.recv_multipart()
|
||||
except zmq.ZMQError as exc:
|
||||
logger.error(f"recv error: {exc}")
|
||||
continue
|
||||
|
||||
if len(frames) < 2:
|
||||
logger.warning(f"Received unexpected frame count: {len(frames)}, ignoring")
|
||||
socket.send_multipart([b"{}"])
|
||||
continue
|
||||
|
||||
try:
|
||||
header = json.loads(frames[0].decode("utf-8"))
|
||||
shape = tuple(header["shape"])
|
||||
dtype = np.dtype(header.get("dtype", "float32"))
|
||||
tensor = np.frombuffer(frames[1], dtype=dtype).reshape(shape)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to decode request: {exc}")
|
||||
socket.send_multipart([b"{}"])
|
||||
continue
|
||||
|
||||
try:
|
||||
t0 = time.monotonic()
|
||||
embedding = runner(session, tensor)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
if elapsed_ms > 2000:
|
||||
logger.warning(f"slow inference {elapsed_ms:.1f}ms shape={shape}")
|
||||
resp_header = json.dumps(
|
||||
{"shape": list(embedding.shape), "dtype": str(embedding.dtype.name)}
|
||||
).encode("utf-8")
|
||||
resp_payload = memoryview(np.ascontiguousarray(embedding).tobytes())
|
||||
socket.send_multipart([resp_header, resp_payload])
|
||||
except Exception as exc:
|
||||
logger.error(f"Inference error: {exc}")
|
||||
# Return a zero embedding so the client can degrade gracefully
|
||||
zero = np.zeros((1, 512), dtype=np.float32)
|
||||
resp_header = json.dumps(
|
||||
{"shape": list(zero.shape), "dtype": "float32"}
|
||||
).encode("utf-8")
|
||||
socket.send_multipart([resp_header, memoryview(zero.tobytes())])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ZMQ Embedding Server for Frigate")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
help="Path to the ONNX model file (e.g. /config/model_cache/facedet/arcface.onnx)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
default="arcface",
|
||||
choices=list(_RUNNERS.keys()),
|
||||
help="Model type key (default: arcface)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=5556,
|
||||
help="TCP port to listen on (default: 5556)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.model):
|
||||
logger.error(f"Model file not found: {args.model}")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info(f"Loading model: {args.model}")
|
||||
session = build_ort_session(args.model, model_type=args.model_type)
|
||||
warmup(session, model_type=args.model_type)
|
||||
serve(session, port=args.port, model_type=args.model_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user