[MemryX] Clean shutdown of detector process (#21035)
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions

* update code for clean exit

* ruff format

* remove unused time import

* update stop_event handling

* remove hasattr check
This commit is contained in:
Abinila Siva 2025-11-25 12:25:07 -05:00 committed by GitHub
parent 8520ade5c4
commit fe47620153
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 78 additions and 20 deletions

View File

@ -2,7 +2,6 @@ import glob
import logging import logging
import os import os
import shutil import shutil
import time
import urllib.request import urllib.request
import zipfile import zipfile
from queue import Queue from queue import Queue
@ -55,6 +54,9 @@ class MemryXDetector(DetectionApi):
) )
return return
# Initialize stop_event as None, will be set later by set_stop_event()
self.stop_event = None
model_cfg = getattr(detector_config, "model", None) model_cfg = getattr(detector_config, "model", None)
# Check if model_type was explicitly set by the user # Check if model_type was explicitly set by the user
@ -363,27 +365,44 @@ class MemryXDetector(DetectionApi):
def process_input(self): def process_input(self):
"""Input callback function: wait for frames in the input queue, preprocess, and send to MX3 (return)""" """Input callback function: wait for frames in the input queue, preprocess, and send to MX3 (return)"""
while True: while True:
# Check if shutdown is requested
if self.stop_event and self.stop_event.is_set():
logger.debug("[process_input] Stop event detected, returning None")
return None
try: try:
# Wait for a frame from the queue (blocking call) # Wait for a frame from the queue with timeout to check stop_event periodically
frame = self.capture_queue.get( frame = self.capture_queue.get(block=True, timeout=0.5)
block=True
) # Blocks until data is available
return frame return frame
except Exception as e: except Exception as e:
logger.info(f"[process_input] Error processing input: {e}") # Silently handle queue.Empty timeouts (expected during normal operation)
time.sleep(0.1) # Prevent busy waiting in case of error # Log any other unexpected exceptions
if "Empty" not in str(type(e).__name__):
logger.warning(f"[process_input] Unexpected error: {e}")
# Loop continues and will check stop_event at the top
def receive_output(self): def receive_output(self):
"""Retrieve processed results from MemryX output queue + a copy of the original frame""" """Retrieve processed results from MemryX output queue + a copy of the original frame"""
connection_id = ( try:
self.capture_id_queue.get() # Get connection ID with timeout
connection_id = self.capture_id_queue.get(
block=True, timeout=1.0
) # Get the corresponding connection ID ) # Get the corresponding connection ID
detections = self.output_queue.get() # Get detections from MemryX detections = self.output_queue.get() # Get detections from MemryX
return connection_id, detections return connection_id, detections
except Exception as e:
# On timeout or stop event, return None
if self.stop_event and self.stop_event.is_set():
logger.debug("[receive_output] Stop event detected, exiting")
# Silently handle queue.Empty timeouts, they're expected during normal operation
elif "Empty" not in str(type(e).__name__):
logger.warning(f"[receive_output] Error receiving output: {e}")
return None, None
def post_process_yolonas(self, output): def post_process_yolonas(self, output):
predictions = output[0] predictions = output[0]
@ -831,6 +850,19 @@ class MemryXDetector(DetectionApi):
f"{self.memx_model_type} is currently not supported for memryx. See the docs for more info on supported models." f"{self.memx_model_type} is currently not supported for memryx. See the docs for more info on supported models."
) )
def set_stop_event(self, stop_event):
"""Set the stop event for graceful shutdown."""
self.stop_event = stop_event
def shutdown(self):
"""Gracefully shutdown the MemryX accelerator"""
try:
if hasattr(self, "accl") and self.accl is not None:
self.accl.shutdown()
logger.info("MemryX accelerator shutdown complete")
except Exception as e:
logger.error(f"Error during MemryX shutdown: {e}")
def detect_raw(self, tensor_input: np.ndarray): def detect_raw(self, tensor_input: np.ndarray):
"""Removed synchronous detect_raw() function so that we only use async""" """Removed synchronous detect_raw() function so that we only use async"""
return 0 return 0

View File

@ -43,6 +43,7 @@ class BaseLocalDetector(ObjectDetector):
self, self,
detector_config: BaseDetectorConfig = None, detector_config: BaseDetectorConfig = None,
labels: str = None, labels: str = None,
stop_event: MpEvent = None,
): ):
self.fps = EventsPerSecond() self.fps = EventsPerSecond()
if labels is None: if labels is None:
@ -60,6 +61,10 @@ class BaseLocalDetector(ObjectDetector):
self.detect_api = create_detector(detector_config) self.detect_api = create_detector(detector_config)
# If the detector supports stop_event, pass it
if hasattr(self.detect_api, "set_stop_event") and stop_event:
self.detect_api.set_stop_event(stop_event)
def _transform_input(self, tensor_input: np.ndarray) -> np.ndarray: def _transform_input(self, tensor_input: np.ndarray) -> np.ndarray:
if self.input_transform: if self.input_transform:
tensor_input = np.transpose(tensor_input, self.input_transform) tensor_input = np.transpose(tensor_input, self.input_transform)
@ -240,6 +245,10 @@ class AsyncDetectorRunner(FrigateProcess):
while not self.stop_event.is_set(): while not self.stop_event.is_set():
connection_id, detections = self._detector.async_receive_output() connection_id, detections = self._detector.async_receive_output()
# Handle timeout case (queue.Empty) - just continue
if connection_id is None:
continue
if not self.send_times: if not self.send_times:
# guard; shouldn't happen if send/recv are balanced # guard; shouldn't happen if send/recv are balanced
continue continue
@ -266,21 +275,38 @@ class AsyncDetectorRunner(FrigateProcess):
self._frame_manager = SharedMemoryFrameManager() self._frame_manager = SharedMemoryFrameManager()
self._publisher = ObjectDetectorPublisher() self._publisher = ObjectDetectorPublisher()
self._detector = AsyncLocalObjectDetector(detector_config=self.detector_config) self._detector = AsyncLocalObjectDetector(
detector_config=self.detector_config, stop_event=self.stop_event
)
for name in self.cameras: for name in self.cameras:
self.create_output_shm(name) self.create_output_shm(name)
t_detect = threading.Thread(target=self._detect_worker, daemon=True) t_detect = threading.Thread(target=self._detect_worker, daemon=False)
t_result = threading.Thread(target=self._result_worker, daemon=True) t_result = threading.Thread(target=self._result_worker, daemon=False)
t_detect.start() t_detect.start()
t_result.start() t_result.start()
try:
while not self.stop_event.is_set(): while not self.stop_event.is_set():
time.sleep(0.5) time.sleep(0.5)
logger.info(
"Stop event detected, waiting for detector threads to finish..."
)
# Wait for threads to finish processing
t_detect.join(timeout=5)
t_result.join(timeout=5)
# Shutdown the AsyncDetector
self._detector.detect_api.shutdown()
self._publisher.stop() self._publisher.stop()
logger.info("Exited async detection process...") except Exception as e:
logger.error(f"Error during async detector shutdown: {e}")
finally:
logger.info("Exited Async detection process...")
class ObjectDetectProcess: class ObjectDetectProcess:
@ -308,7 +334,7 @@ class ObjectDetectProcess:
# if the process has already exited on its own, just return # if the process has already exited on its own, just return
if self.detect_process and self.detect_process.exitcode: if self.detect_process and self.detect_process.exitcode:
return return
self.detect_process.terminate()
logging.info("Waiting for detection process to exit gracefully...") logging.info("Waiting for detection process to exit gracefully...")
self.detect_process.join(timeout=30) self.detect_process.join(timeout=30)
if self.detect_process.exitcode is None: if self.detect_process.exitcode is None: