[MemryX] Clean shutdown of detector process (#21035)
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions

* update code for clean exit

* ruff format

* remove unused time import

* update stop_event handling

* remove hasattr check
This commit is contained in:
Abinila Siva 2025-11-25 12:25:07 -05:00 committed by GitHub
parent 8520ade5c4
commit fe47620153
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 20 deletions

View File

@ -2,7 +2,6 @@ import glob
import logging
import os
import shutil
import time
import urllib.request
import zipfile
from queue import Queue
@ -55,6 +54,9 @@ class MemryXDetector(DetectionApi):
)
return
# Initialize stop_event as None, will be set later by set_stop_event()
self.stop_event = None
model_cfg = getattr(detector_config, "model", None)
# Check if model_type was explicitly set by the user
@ -363,27 +365,44 @@ class MemryXDetector(DetectionApi):
def process_input(self):
"""Input callback function: wait for frames in the input queue, preprocess, and send to MX3 (return)"""
while True:
# Check if shutdown is requested
if self.stop_event and self.stop_event.is_set():
logger.debug("[process_input] Stop event detected, returning None")
return None
try:
# Wait for a frame from the queue (blocking call)
frame = self.capture_queue.get(
block=True
) # Blocks until data is available
# Wait for a frame from the queue with timeout to check stop_event periodically
frame = self.capture_queue.get(block=True, timeout=0.5)
return frame
except Exception as e:
logger.info(f"[process_input] Error processing input: {e}")
time.sleep(0.1) # Prevent busy waiting in case of error
# Silently handle queue.Empty timeouts (expected during normal operation)
# Log any other unexpected exceptions
if "Empty" not in str(type(e).__name__):
logger.warning(f"[process_input] Unexpected error: {e}")
# Loop continues and will check stop_event at the top
def receive_output(self):
"""Retrieve processed results from MemryX output queue + a copy of the original frame"""
connection_id = (
self.capture_id_queue.get()
try:
# Get connection ID with timeout
connection_id = self.capture_id_queue.get(
block=True, timeout=1.0
) # Get the corresponding connection ID
detections = self.output_queue.get() # Get detections from MemryX
return connection_id, detections
except Exception as e:
# On timeout or stop event, return None
if self.stop_event and self.stop_event.is_set():
logger.debug("[receive_output] Stop event detected, exiting")
# Silently handle queue.Empty timeouts, they're expected during normal operation
elif "Empty" not in str(type(e).__name__):
logger.warning(f"[receive_output] Error receiving output: {e}")
return None, None
def post_process_yolonas(self, output):
predictions = output[0]
@ -831,6 +850,19 @@ class MemryXDetector(DetectionApi):
f"{self.memx_model_type} is currently not supported for memryx. See the docs for more info on supported models."
)
def set_stop_event(self, stop_event):
"""Set the stop event for graceful shutdown."""
self.stop_event = stop_event
def shutdown(self):
"""Gracefully shutdown the MemryX accelerator"""
try:
if hasattr(self, "accl") and self.accl is not None:
self.accl.shutdown()
logger.info("MemryX accelerator shutdown complete")
except Exception as e:
logger.error(f"Error during MemryX shutdown: {e}")
def detect_raw(self, tensor_input: np.ndarray):
"""Removed synchronous detect_raw() function so that we only use async"""
return 0

View File

@ -43,6 +43,7 @@ class BaseLocalDetector(ObjectDetector):
self,
detector_config: BaseDetectorConfig = None,
labels: str = None,
stop_event: MpEvent = None,
):
self.fps = EventsPerSecond()
if labels is None:
@ -60,6 +61,10 @@ class BaseLocalDetector(ObjectDetector):
self.detect_api = create_detector(detector_config)
# If the detector supports stop_event, pass it
if hasattr(self.detect_api, "set_stop_event") and stop_event:
self.detect_api.set_stop_event(stop_event)
def _transform_input(self, tensor_input: np.ndarray) -> np.ndarray:
if self.input_transform:
tensor_input = np.transpose(tensor_input, self.input_transform)
@ -240,6 +245,10 @@ class AsyncDetectorRunner(FrigateProcess):
while not self.stop_event.is_set():
connection_id, detections = self._detector.async_receive_output()
# Handle timeout case (queue.Empty) - just continue
if connection_id is None:
continue
if not self.send_times:
# guard; shouldn't happen if send/recv are balanced
continue
@ -266,21 +275,38 @@ class AsyncDetectorRunner(FrigateProcess):
self._frame_manager = SharedMemoryFrameManager()
self._publisher = ObjectDetectorPublisher()
self._detector = AsyncLocalObjectDetector(detector_config=self.detector_config)
self._detector = AsyncLocalObjectDetector(
detector_config=self.detector_config, stop_event=self.stop_event
)
for name in self.cameras:
self.create_output_shm(name)
t_detect = threading.Thread(target=self._detect_worker, daemon=True)
t_result = threading.Thread(target=self._result_worker, daemon=True)
t_detect = threading.Thread(target=self._detect_worker, daemon=False)
t_result = threading.Thread(target=self._result_worker, daemon=False)
t_detect.start()
t_result.start()
try:
while not self.stop_event.is_set():
time.sleep(0.5)
logger.info(
"Stop event detected, waiting for detector threads to finish..."
)
# Wait for threads to finish processing
t_detect.join(timeout=5)
t_result.join(timeout=5)
# Shutdown the AsyncDetector
self._detector.detect_api.shutdown()
self._publisher.stop()
logger.info("Exited async detection process...")
except Exception as e:
logger.error(f"Error during async detector shutdown: {e}")
finally:
logger.info("Exited Async detection process...")
class ObjectDetectProcess:
@ -308,7 +334,7 @@ class ObjectDetectProcess:
# if the process has already exited on its own, just return
if self.detect_process and self.detect_process.exitcode:
return
self.detect_process.terminate()
logging.info("Waiting for detection process to exit gracefully...")
self.detect_process.join(timeout=30)
if self.detect_process.exitcode is None: