Add ability to transfer model via ZMQ

This commit is contained in:
Nicolas Mowen 2025-09-21 19:47:27 -06:00
parent e4d5f1f94e
commit e116f3d579

View File

@ -1,6 +1,7 @@
import json import json
import logging import logging
from typing import Any, List import os
from typing import Any, List, Optional
import numpy as np import numpy as np
import zmq import zmq
@ -46,6 +47,11 @@ class ZmqIpcDetector(DetectionApi):
b) Single frame tensor_bytes of length 20*6*4 bytes (float32). b) Single frame tensor_bytes of length 20*6*4 bytes (float32).
On any error or timeout, this detector returns a zero array of shape (20, 6). On any error or timeout, this detector returns a zero array of shape (20, 6).
Model Management:
- On initialization, sends model request to check if model is available
- If model not available, sends model data via ZMQ
- Only starts inference after model is ready
""" """
type_key = DETECTOR_KEY type_key = DETECTOR_KEY
@ -60,6 +66,13 @@ class ZmqIpcDetector(DetectionApi):
self._socket = None self._socket = None
self._create_socket() self._create_socket()
# Model management
self._model_ready = False
self._model_name = self._get_model_name()
# Initialize model if needed
self._initialize_model()
# Preallocate zero result for error paths # Preallocate zero result for error paths
self._zero_result = np.zeros((20, 6), np.float32) self._zero_result = np.zeros((20, 6), np.float32)
@ -78,6 +91,167 @@ class ZmqIpcDetector(DetectionApi):
logger.debug(f"ZMQ detector connecting to {self._endpoint}") logger.debug(f"ZMQ detector connecting to {self._endpoint}")
self._socket.connect(self._endpoint) self._socket.connect(self._endpoint)
def _get_model_name(self) -> str:
"""Get the model filename from the detector config."""
model_path = self.detector_config.model.path
return os.path.basename(model_path)
def _initialize_model(self) -> None:
"""Initialize the model by checking availability and transferring if needed."""
try:
logger.info(f"Initializing model: {self._model_name}")
# Check if model is available and transfer if needed
if self._check_and_transfer_model():
logger.info(f"Model {self._model_name} is ready")
self._model_ready = True
else:
logger.error(f"Failed to initialize model {self._model_name}")
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
def _check_and_transfer_model(self) -> bool:
"""Check if model is available and transfer if needed in one atomic operation."""
try:
# Send model availability request
header = {"model_request": True, "model_name": self._model_name}
header_bytes = json.dumps(header).encode("utf-8")
self._socket.send_multipart([header_bytes])
# Temporarily increase timeout for model operations
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
try:
response_frames = self._socket.recv_multipart()
finally:
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
if len(response_frames) == 1:
try:
response = json.loads(response_frames[0].decode("utf-8"))
model_available = response.get("model_available", False)
model_loaded = response.get("model_loaded", False)
if model_available and model_loaded:
return True
elif model_available and not model_loaded:
logger.error("Model exists but failed to load")
return False
else:
return self._send_model_data()
except json.JSONDecodeError:
logger.warning(
"Received non-JSON response for model availability check"
)
return False
else:
logger.warning(
"Received unexpected response format for model availability check"
)
return False
except Exception as e:
logger.error(f"Failed to check and transfer model: {e}")
return False
def _check_model_availability(self) -> bool:
"""Check if the model is available on the detector."""
try:
# Send model availability request
header = {"model_request": True, "model_name": self._model_name}
header_bytes = json.dumps(header).encode("utf-8")
self._socket.send_multipart([header_bytes])
# Receive response
response_frames = self._socket.recv_multipart()
# Check if this is a JSON response (model management)
if len(response_frames) == 1:
try:
response = json.loads(response_frames[0].decode("utf-8"))
model_available = response.get("model_available", False)
model_loaded = response.get("model_loaded", False)
logger.debug(
f"Model availability check: available={model_available}, loaded={model_loaded}"
)
return model_available and model_loaded
except json.JSONDecodeError:
logger.warning(
"Received non-JSON response for model availability check"
)
return False
else:
logger.warning(
"Received unexpected response format for model availability check"
)
return False
except Exception as e:
logger.error(f"Failed to check model availability: {e}")
return False
def _send_model_data(self) -> bool:
"""Send model data to the detector."""
try:
model_path = self.detector_config.model.path
if not os.path.exists(model_path):
logger.error(f"Model file not found: {model_path}")
return False
logger.info(f"Transferring model to detector: {self._model_name}")
with open(model_path, "rb") as f:
model_data = f.read()
header = {"model_data": True, "model_name": self._model_name}
header_bytes = json.dumps(header).encode("utf-8")
self._socket.send_multipart([header_bytes, model_data])
# Temporarily increase timeout for model loading (can take several seconds)
original_timeout = self._socket.getsockopt(zmq.RCVTIMEO)
self._socket.setsockopt(zmq.RCVTIMEO, 30000)
try:
# Receive response
response_frames = self._socket.recv_multipart()
finally:
# Restore original timeout
self._socket.setsockopt(zmq.RCVTIMEO, original_timeout)
# Check if this is a JSON response (model management)
if len(response_frames) == 1:
try:
response = json.loads(response_frames[0].decode("utf-8"))
model_saved = response.get("model_saved", False)
model_loaded = response.get("model_loaded", False)
if model_saved and model_loaded:
logger.info(
f"Model {self._model_name} transferred and loaded successfully"
)
else:
logger.error(
f"Model transfer failed: saved={model_saved}, loaded={model_loaded}"
)
return model_saved and model_loaded
except json.JSONDecodeError:
logger.warning("Received non-JSON response for model data transfer")
return False
else:
logger.warning(
"Received unexpected response format for model data transfer"
)
return False
except Exception as e:
logger.error(f"Failed to send model data: {e}")
return False
def _build_header(self, tensor_input: np.ndarray) -> bytes: def _build_header(self, tensor_input: np.ndarray) -> bytes:
header: dict[str, Any] = { header: dict[str, Any] = {
"shape": list(tensor_input.shape), "shape": list(tensor_input.shape),
@ -111,6 +285,10 @@ class ZmqIpcDetector(DetectionApi):
return self._zero_result return self._zero_result
def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray: def detect_raw(self, tensor_input: np.ndarray) -> np.ndarray:
if not self._model_ready:
logger.warning("Model not ready, returning zero detections")
return self._zero_result
try: try:
header_bytes = self._build_header(tensor_input) header_bytes = self._build_header(tensor_input)
payload_bytes = memoryview(tensor_input.tobytes(order="C")) payload_bytes = memoryview(tensor_input.tobytes(order="C"))
@ -123,13 +301,13 @@ class ZmqIpcDetector(DetectionApi):
detections = self._decode_response(reply_frames) detections = self._decode_response(reply_frames)
# Ensure output shape and dtype are exactly as expected # Ensure output shape and dtype are exactly as expected
return detections return detections
except zmq.Again: except zmq.Again:
# Timeout # Timeout
logger.debug("ZMQ detector request timed out; resetting socket") logger.debug("ZMQ detector request timed out; resetting socket")
try: try:
self._create_socket() self._create_socket()
self._initialize_model()
except Exception: except Exception:
pass pass
return self._zero_result return self._zero_result
@ -137,6 +315,7 @@ class ZmqIpcDetector(DetectionApi):
logger.error(f"ZMQ detector ZMQError: {exc}; resetting socket") logger.error(f"ZMQ detector ZMQError: {exc}; resetting socket")
try: try:
self._create_socket() self._create_socket()
self._initialize_model()
except Exception: except Exception:
pass pass
return self._zero_result return self._zero_result