update code for clean exit

This commit is contained in:
Abinila Siva 2025-11-24 11:33:50 -05:00
parent 2a9c028f55
commit 2701c03bd5
2 changed files with 83 additions and 19 deletions

View File

@ -44,7 +44,7 @@ class MemryXDetector(DetectionApi):
ModelTypeEnum.yolox, ModelTypeEnum.yolox,
] ]
def __init__(self, detector_config): def __init__(self, detector_config, stop_event=None):
"""Initialize MemryX detector with the provided configuration.""" """Initialize MemryX detector with the provided configuration."""
try: try:
# Import MemryX SDK # Import MemryX SDK
@ -55,6 +55,9 @@ class MemryXDetector(DetectionApi):
) )
return return
# Get stop_event from detector_config
self.stop_event = getattr(detector_config, "_stop_event", stop_event)
model_cfg = getattr(detector_config, "model", None) model_cfg = getattr(detector_config, "model", None)
# Check if model_type was explicitly set by the user # Check if model_type was explicitly set by the user
@ -363,26 +366,50 @@ class MemryXDetector(DetectionApi):
def process_input(self): def process_input(self):
"""Input callback function: wait for frames in the input queue, preprocess, and send to MX3 (return)""" """Input callback function: wait for frames in the input queue, preprocess, and send to MX3 (return)"""
while True: while True:
# Check if shutdown is requested
if self.stop_event.is_set():
logger.debug("[process_input] Stop event detected, returning None")
return None
try: try:
# Wait for a frame from the queue (blocking call) # Wait for a frame from the queue with timeout to check stop_event periodically
frame = self.capture_queue.get( frame = self.capture_queue.get(
block=True block=True,
) # Blocks until data is available timeout=0.5
)
return frame return frame
except Exception as e: except Exception as e:
logger.info(f"[process_input] Error processing input: {e}") # Silently handle queue.Empty timeouts (expected during normal operation)
time.sleep(0.1) # Prevent busy waiting in case of error # Log any other unexpected exceptions
if "Empty" not in str(type(e).__name__):
logger.warning(f"[process_input] Unexpected error: {e}")
# Loop continues and will check stop_event at the top
def receive_output(self): def receive_output(self):
"""Retrieve processed results from MemryX output queue + a copy of the original frame""" """Retrieve processed results from MemryX output queue + a copy of the original frame"""
connection_id = ( try:
self.capture_id_queue.get() # Get connection ID with timeout
) # Get the corresponding connection ID connection_id = (
detections = self.output_queue.get() # Get detections from MemryX self.capture_id_queue.get(
block=True,
timeout=1.0
)
) # Get the corresponding connection ID
detections = self.output_queue.get() # Get detections from MemryX
return connection_id, detections
except Exception as e:
# On timeout or stop event, return None
if self.stop_event.is_set():
logger.debug("[receive_output] Stop event detected, exiting")
# Silently handle queue.Empty timeouts, they're expected during normal operation
elif "Empty" not in str(type(e).__name__):
logger.warning(f"[receive_output] Error receiving output: {e}")
return None, None
return connection_id, detections
def post_process_yolonas(self, output): def post_process_yolonas(self, output):
predictions = output[0] predictions = output[0]
@ -831,6 +858,15 @@ class MemryXDetector(DetectionApi):
f"{self.memx_model_type} is currently not supported for memryx. See the docs for more info on supported models." f"{self.memx_model_type} is currently not supported for memryx. See the docs for more info on supported models."
) )
def shutdown(self):
"""Gracefully shutdown the MemryX accelerator"""
try:
if hasattr(self, "accl") and self.accl is not None:
self.accl.shutdown()
logger.info("MemryX accelerator shutdown complete")
except Exception as e:
logger.error(f"Error during MemryX shutdown: {e}")
def detect_raw(self, tensor_input: np.ndarray): def detect_raw(self, tensor_input: np.ndarray):
"""Removed synchronous detect_raw() function so that we only use async""" """Removed synchronous detect_raw() function so that we only use async"""
return 0 return 0

View File

@ -43,6 +43,7 @@ class BaseLocalDetector(ObjectDetector):
self, self,
detector_config: BaseDetectorConfig = None, detector_config: BaseDetectorConfig = None,
labels: str = None, labels: str = None,
stop_event: MpEvent = None,
): ):
self.fps = EventsPerSecond() self.fps = EventsPerSecond()
if labels is None: if labels is None:
@ -58,6 +59,10 @@ class BaseLocalDetector(ObjectDetector):
self.input_transform = None self.input_transform = None
self.dtype = InputDTypeEnum.int self.dtype = InputDTypeEnum.int
# Attach stop_event to detector_config so detectors can access it
if detector_config and stop_event:
detector_config._stop_event = stop_event
self.detect_api = create_detector(detector_config) self.detect_api = create_detector(detector_config)
def _transform_input(self, tensor_input: np.ndarray) -> np.ndarray: def _transform_input(self, tensor_input: np.ndarray) -> np.ndarray:
@ -240,6 +245,10 @@ class AsyncDetectorRunner(FrigateProcess):
while not self.stop_event.is_set(): while not self.stop_event.is_set():
connection_id, detections = self._detector.async_receive_output() connection_id, detections = self._detector.async_receive_output()
# Handle timeout case (queue.Empty) - just continue
if connection_id is None:
continue
if not self.send_times: if not self.send_times:
# guard; shouldn't happen if send/recv are balanced # guard; shouldn't happen if send/recv are balanced
continue continue
@ -266,21 +275,40 @@ class AsyncDetectorRunner(FrigateProcess):
self._frame_manager = SharedMemoryFrameManager() self._frame_manager = SharedMemoryFrameManager()
self._publisher = ObjectDetectorPublisher() self._publisher = ObjectDetectorPublisher()
self._detector = AsyncLocalObjectDetector(detector_config=self.detector_config) self._detector = AsyncLocalObjectDetector(
detector_config=self.detector_config, stop_event=self.stop_event
)
for name in self.cameras: for name in self.cameras:
self.create_output_shm(name) self.create_output_shm(name)
t_detect = threading.Thread(target=self._detect_worker, daemon=True) t_detect = threading.Thread(target=self._detect_worker, daemon=False)
t_result = threading.Thread(target=self._result_worker, daemon=True) t_result = threading.Thread(target=self._result_worker, daemon=False)
t_detect.start() t_detect.start()
t_result.start() t_result.start()
while not self.stop_event.is_set(): try:
time.sleep(0.5) while not self.stop_event.is_set():
time.sleep(0.5)
self._publisher.stop() logger.info(
logger.info("Exited async detection process...") "Stop event detected, waiting for detector threads to finish..."
)
# Wait for threads to finish processing
t_detect.join(timeout=5)
t_result.join(timeout=5)
# Explicitly shutdown MemryX accelerator
if hasattr(self._detector.detect_api, "shutdown"):
logger.info("Calling MemryX shutdown method...")
self._detector.detect_api.shutdown()
self._publisher.stop()
except Exception as e:
logger.error(f"Error during async detector shutdown: {e}")
finally:
logger.info("Exited Async detection process...")
class ObjectDetectProcess: class ObjectDetectProcess:
@ -308,7 +336,7 @@ class ObjectDetectProcess:
# if the process has already exited on its own, just return # if the process has already exited on its own, just return
if self.detect_process and self.detect_process.exitcode: if self.detect_process and self.detect_process.exitcode:
return return
self.detect_process.terminate()
logging.info("Waiting for detection process to exit gracefully...") logging.info("Waiting for detection process to exit gracefully...")
self.detect_process.join(timeout=30) self.detect_process.join(timeout=30)
if self.detect_process.exitcode is None: if self.detect_process.exitcode is None: