Compare commits

...

3 Commits

Author SHA1 Message Date
Abinila Siva
b3f3d06f96 remove unused time import 2025-11-24 12:20:40 -05:00
Abinila Siva
1d93f6853e ruff format 2025-11-24 11:51:17 -05:00
Abinila Siva
2701c03bd5 update code for clean exit 2025-11-24 11:33:50 -05:00
2 changed files with 77 additions and 21 deletions

View File

@ -2,7 +2,6 @@ import glob
import logging import logging
import os import os
import shutil import shutil
import time
import urllib.request import urllib.request
import zipfile import zipfile
from queue import Queue from queue import Queue
@ -44,7 +43,7 @@ class MemryXDetector(DetectionApi):
ModelTypeEnum.yolox, ModelTypeEnum.yolox,
] ]
def __init__(self, detector_config): def __init__(self, detector_config, stop_event=None):
"""Initialize MemryX detector with the provided configuration.""" """Initialize MemryX detector with the provided configuration."""
try: try:
# Import MemryX SDK # Import MemryX SDK
@ -55,6 +54,9 @@ class MemryXDetector(DetectionApi):
) )
return return
# Get stop_event from detector_config
self.stop_event = getattr(detector_config, "_stop_event", stop_event)
model_cfg = getattr(detector_config, "model", None) model_cfg = getattr(detector_config, "model", None)
# Check if model_type was explicitly set by the user # Check if model_type was explicitly set by the user
@ -363,26 +365,43 @@ class MemryXDetector(DetectionApi):
def process_input(self): def process_input(self):
"""Input callback function: wait for frames in the input queue, preprocess, and send to MX3 (return)""" """Input callback function: wait for frames in the input queue, preprocess, and send to MX3 (return)"""
while True: while True:
# Check if shutdown is requested
if self.stop_event.is_set():
logger.debug("[process_input] Stop event detected, returning None")
return None
try: try:
# Wait for a frame from the queue (blocking call) # Wait for a frame from the queue with timeout to check stop_event periodically
frame = self.capture_queue.get( frame = self.capture_queue.get(block=True, timeout=0.5)
block=True
) # Blocks until data is available
return frame return frame
except Exception as e: except Exception as e:
logger.info(f"[process_input] Error processing input: {e}") # Silently handle queue.Empty timeouts (expected during normal operation)
time.sleep(0.1) # Prevent busy waiting in case of error # Log any other unexpected exceptions
if "Empty" not in str(type(e).__name__):
logger.warning(f"[process_input] Unexpected error: {e}")
# Loop continues and will check stop_event at the top
def receive_output(self): def receive_output(self):
"""Retrieve processed results from MemryX output queue + a copy of the original frame""" """Retrieve processed results from MemryX output queue + a copy of the original frame"""
connection_id = ( try:
self.capture_id_queue.get() # Get connection ID with timeout
) # Get the corresponding connection ID connection_id = self.capture_id_queue.get(
detections = self.output_queue.get() # Get detections from MemryX block=True, timeout=1.0
) # Get the corresponding connection ID
detections = self.output_queue.get() # Get detections from MemryX
return connection_id, detections return connection_id, detections
except Exception as e:
# On timeout or stop event, return None
if self.stop_event.is_set():
logger.debug("[receive_output] Stop event detected, exiting")
# Silently handle queue.Empty timeouts, they're expected during normal operation
elif "Empty" not in str(type(e).__name__):
logger.warning(f"[receive_output] Error receiving output: {e}")
return None, None
def post_process_yolonas(self, output): def post_process_yolonas(self, output):
predictions = output[0] predictions = output[0]
@ -831,6 +850,15 @@ class MemryXDetector(DetectionApi):
f"{self.memx_model_type} is currently not supported for memryx. See the docs for more info on supported models." f"{self.memx_model_type} is currently not supported for memryx. See the docs for more info on supported models."
) )
def shutdown(self):
"""Gracefully shutdown the MemryX accelerator"""
try:
if hasattr(self, "accl") and self.accl is not None:
self.accl.shutdown()
logger.info("MemryX accelerator shutdown complete")
except Exception as e:
logger.error(f"Error during MemryX shutdown: {e}")
def detect_raw(self, tensor_input: np.ndarray): def detect_raw(self, tensor_input: np.ndarray):
"""Removed synchronous detect_raw() function so that we only use async""" """Removed synchronous detect_raw() function so that we only use async"""
return 0 return 0

View File

@ -43,6 +43,7 @@ class BaseLocalDetector(ObjectDetector):
self, self,
detector_config: BaseDetectorConfig = None, detector_config: BaseDetectorConfig = None,
labels: str = None, labels: str = None,
stop_event: MpEvent = None,
): ):
self.fps = EventsPerSecond() self.fps = EventsPerSecond()
if labels is None: if labels is None:
@ -58,6 +59,10 @@ class BaseLocalDetector(ObjectDetector):
self.input_transform = None self.input_transform = None
self.dtype = InputDTypeEnum.int self.dtype = InputDTypeEnum.int
# Attach stop_event to detector_config so detectors can access it
if detector_config and stop_event:
detector_config._stop_event = stop_event
self.detect_api = create_detector(detector_config) self.detect_api = create_detector(detector_config)
def _transform_input(self, tensor_input: np.ndarray) -> np.ndarray: def _transform_input(self, tensor_input: np.ndarray) -> np.ndarray:
@ -240,6 +245,10 @@ class AsyncDetectorRunner(FrigateProcess):
while not self.stop_event.is_set(): while not self.stop_event.is_set():
connection_id, detections = self._detector.async_receive_output() connection_id, detections = self._detector.async_receive_output()
# Handle timeout case (queue.Empty) - just continue
if connection_id is None:
continue
if not self.send_times: if not self.send_times:
# guard; shouldn't happen if send/recv are balanced # guard; shouldn't happen if send/recv are balanced
continue continue
@ -266,21 +275,40 @@ class AsyncDetectorRunner(FrigateProcess):
self._frame_manager = SharedMemoryFrameManager() self._frame_manager = SharedMemoryFrameManager()
self._publisher = ObjectDetectorPublisher() self._publisher = ObjectDetectorPublisher()
self._detector = AsyncLocalObjectDetector(detector_config=self.detector_config) self._detector = AsyncLocalObjectDetector(
detector_config=self.detector_config, stop_event=self.stop_event
)
for name in self.cameras: for name in self.cameras:
self.create_output_shm(name) self.create_output_shm(name)
t_detect = threading.Thread(target=self._detect_worker, daemon=True) t_detect = threading.Thread(target=self._detect_worker, daemon=False)
t_result = threading.Thread(target=self._result_worker, daemon=True) t_result = threading.Thread(target=self._result_worker, daemon=False)
t_detect.start() t_detect.start()
t_result.start() t_result.start()
while not self.stop_event.is_set(): try:
time.sleep(0.5) while not self.stop_event.is_set():
time.sleep(0.5)
self._publisher.stop() logger.info(
logger.info("Exited async detection process...") "Stop event detected, waiting for detector threads to finish..."
)
# Wait for threads to finish processing
t_detect.join(timeout=5)
t_result.join(timeout=5)
# Explicitly shutdown MemryX accelerator
if hasattr(self._detector.detect_api, "shutdown"):
logger.info("Calling MemryX shutdown method...")
self._detector.detect_api.shutdown()
self._publisher.stop()
except Exception as e:
logger.error(f"Error during async detector shutdown: {e}")
finally:
logger.info("Exited Async detection process...")
class ObjectDetectProcess: class ObjectDetectProcess:
@ -308,7 +336,7 @@ class ObjectDetectProcess:
# if the process has already exited on its own, just return # if the process has already exited on its own, just return
if self.detect_process and self.detect_process.exitcode: if self.detect_process and self.detect_process.exitcode:
return return
self.detect_process.terminate()
logging.info("Waiting for detection process to exit gracefully...") logging.info("Waiting for detection process to exit gracefully...")
self.detect_process.join(timeout=30) self.detect_process.join(timeout=30)
if self.detect_process.exitcode is None: if self.detect_process.exitcode is None: