From 6645b158a327365b1f855a0eb9ab416f912a3080 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Tue, 15 Apr 2025 06:49:16 -0600 Subject: [PATCH] Add input store type --- frigate/detectors/plugins/hailo8l.py | 61 +++++------------------ frigate/object_detection/base.py | 11 +---- frigate/object_detection/util.py | 73 ++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 58 deletions(-) diff --git a/frigate/detectors/plugins/hailo8l.py b/frigate/detectors/plugins/hailo8l.py index ad86ca03d1..ffadf0fdbc 100755 --- a/frigate/detectors/plugins/hailo8l.py +++ b/frigate/detectors/plugins/hailo8l.py @@ -1,6 +1,5 @@ import logging import os -import queue import subprocess import threading import urllib.request @@ -28,37 +27,11 @@ from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detector_config import ( BaseDetectorConfig, ) +from frigate.object_detection.util import RequestStore, ResponseStore logger = logging.getLogger(__name__) -# ----------------- ResponseStore Class ----------------- # -class ResponseStore: - """ - A thread-safe hash-based response store that maps request IDs - to their results. Threads can wait on the condition variable until - their request's result appears. - """ - - def __init__(self): - self.responses = {} # Maps request_id -> (original_input, infer_results) - self.lock = threading.Lock() - self.cond = threading.Condition(self.lock) - - def put(self, request_id, response): - with self.cond: - self.responses[request_id] = response - self.cond.notify_all() - - def get(self, request_id, timeout=None): - with self.cond: - if not self.cond.wait_for( - lambda: request_id in self.responses, timeout=timeout - ): - raise TimeoutError(f"Timeout waiting for response {request_id}") - return self.responses.pop(request_id) - - # ----------------- Utility Functions ----------------- # @@ -122,14 +95,14 @@ class HailoAsyncInference: def __init__( self, hef_path: str, - input_queue: queue.Queue, + input_store: RequestStore, output_store: ResponseStore, batch_size: int = 1, input_type: Optional[str] = None, output_type: Optional[Dict[str, str]] = None, send_original_frame: bool = False, ) -> None: - self.input_queue = input_queue + self.input_store = input_store self.output_store = output_store params = VDevice.create_params() @@ -204,9 +177,11 @@ class HailoAsyncInference: def run(self) -> None: with self.infer_model.configure() as configured_infer_model: while True: - batch_data = self.input_queue.get() + batch_data = self.input_store.get() + if batch_data is None: break + request_id, frame_data = batch_data preprocessed_batch = [frame_data] request_ids = [request_id] @@ -274,16 +249,14 @@ class HailoDetector(DetectionApi): self.working_model_path = self.check_and_prepare() self.batch_size = 1 - self.input_queue = queue.Queue() + self.input_store = RequestStore() self.response_store = ResponseStore() - self.request_counter = 0 - self.request_counter_lock = threading.Lock() try: logger.debug(f"[INIT] Loading HEF model from {self.working_model_path}") self.inference_engine = HailoAsyncInference( self.working_model_path, - self.input_queue, + self.input_store, self.response_store, self.batch_size, ) @@ -364,26 +337,16 @@ class HailoDetector(DetectionApi): raise FileNotFoundError(f"Model file not found at: {self.model_path}") return cached_model_path - def _get_request_id(self) -> int: - with self.request_counter_lock: - request_id = self.request_counter - self.request_counter += 1 - if self.request_counter > 1000000: - self.request_counter = 0 - return request_id - def detect_raw(self, tensor_input): - request_id = self._get_request_id() - tensor_input = self.preprocess(tensor_input) + if isinstance(tensor_input, np.ndarray) and len(tensor_input.shape) == 3: tensor_input = np.expand_dims(tensor_input, axis=0) - self.input_queue.put((request_id, tensor_input)) + request_id = self.input_store.put(tensor_input) + try: - original_input, infer_results = self.response_store.get( - request_id, timeout=10.0 - ) + _, infer_results = self.response_store.get(request_id, timeout=10.0) except TimeoutError: logger.error( f"Timeout waiting for inference results for request {request_id}" diff --git a/frigate/object_detection/base.py b/frigate/object_detection/base.py index 8e88ae578a..dfc39ac2de 100644 --- a/frigate/object_detection/base.py +++ b/frigate/object_detection/base.py @@ -15,12 +15,13 @@ from frigate.detectors import create_detector from frigate.detectors.detector_config import ( BaseDetectorConfig, InputDTypeEnum, - InputTensorEnum, ) from frigate.util.builtin import EventsPerSecond, load_labels from frigate.util.image import SharedMemoryFrameManager, UntrackedSharedMemory from frigate.util.services import listen +from .util import tensor_transform + logger = logging.getLogger(__name__) @@ -30,14 +31,6 @@ class ObjectDetector(ABC): pass -def tensor_transform(desired_shape: InputTensorEnum): - # Currently this function only supports BHWC permutations - if desired_shape == InputTensorEnum.nhwc: - return None - elif desired_shape == InputTensorEnum.nchw: - return (0, 3, 1, 2) - - class LocalObjectDetector(ObjectDetector): def __init__( self, diff --git a/frigate/object_detection/util.py b/frigate/object_detection/util.py index e69de29bb2..ffdb1eca45 100644 --- a/frigate/object_detection/util.py +++ b/frigate/object_detection/util.py @@ -0,0 +1,73 @@ +"""Object detection utilities.""" + +import threading +import queue + +from numpy import ndarray + +from frigate.detectors.detector_config import InputTensorEnum + + +class RequestStore: + """ + A thread-safe hash-based response store that handles creating requests. + """ + + def __init__(self): + self.request_counter = 0 + self.request_counter_lock = threading.Lock() + self.input_queue = queue.Queue() + + def __get_request_id(self) -> int: + with self.request_counter_lock: + request_id = self.request_counter + self.request_counter += 1 + if self.request_counter > 1000000: + self.request_counter = 0 + return request_id + + def put(self, tensor_input: ndarray) -> int: + request_id = self.__get_request_id() + self.input_queue.get((request_id, tensor_input)) + return request_id + + def get(self) -> tuple[int, ndarray] | None: + try: + return self.input_queue.get_nowait() + except Exception: + return None + + +class ResponseStore: + """ + A thread-safe hash-based response store that maps request IDs + to their results. Threads can wait on the condition variable until + their request's result appears. + """ + + def __init__(self): + self.responses = {} # Maps request_id -> (original_input, infer_results) + self.lock = threading.Lock() + self.cond = threading.Condition(self.lock) + + def put(self, request_id: int, response: ndarray): + with self.cond: + self.responses[request_id] = response + self.cond.notify_all() + + def get(self, request_id: int, timeout=None) -> ndarray: + with self.cond: + if not self.cond.wait_for( + lambda: request_id in self.responses, timeout=timeout + ): + raise TimeoutError(f"Timeout waiting for response {request_id}") + + return self.responses.pop(request_id) + + +def tensor_transform(desired_shape: InputTensorEnum): + # Currently this function only supports BHWC permutations + if desired_shape == InputTensorEnum.nhwc: + return None + elif desired_shape == InputTensorEnum.nchw: + return (0, 3, 1, 2)