mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-05-04 20:47:42 +03:00
Add input store type
This commit is contained in:
parent
3ee2d7086d
commit
6645b158a3
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import subprocess
|
||||
import threading
|
||||
import urllib.request
|
||||
@ -28,37 +27,11 @@ from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detector_config import (
|
||||
BaseDetectorConfig,
|
||||
)
|
||||
from frigate.object_detection.util import RequestStore, ResponseStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ----------------- ResponseStore Class ----------------- #
|
||||
class ResponseStore:
|
||||
"""
|
||||
A thread-safe hash-based response store that maps request IDs
|
||||
to their results. Threads can wait on the condition variable until
|
||||
their request's result appears.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.responses = {} # Maps request_id -> (original_input, infer_results)
|
||||
self.lock = threading.Lock()
|
||||
self.cond = threading.Condition(self.lock)
|
||||
|
||||
def put(self, request_id, response):
|
||||
with self.cond:
|
||||
self.responses[request_id] = response
|
||||
self.cond.notify_all()
|
||||
|
||||
def get(self, request_id, timeout=None):
|
||||
with self.cond:
|
||||
if not self.cond.wait_for(
|
||||
lambda: request_id in self.responses, timeout=timeout
|
||||
):
|
||||
raise TimeoutError(f"Timeout waiting for response {request_id}")
|
||||
return self.responses.pop(request_id)
|
||||
|
||||
|
||||
# ----------------- Utility Functions ----------------- #
|
||||
|
||||
|
||||
@ -122,14 +95,14 @@ class HailoAsyncInference:
|
||||
def __init__(
|
||||
self,
|
||||
hef_path: str,
|
||||
input_queue: queue.Queue,
|
||||
input_store: RequestStore,
|
||||
output_store: ResponseStore,
|
||||
batch_size: int = 1,
|
||||
input_type: Optional[str] = None,
|
||||
output_type: Optional[Dict[str, str]] = None,
|
||||
send_original_frame: bool = False,
|
||||
) -> None:
|
||||
self.input_queue = input_queue
|
||||
self.input_store = input_store
|
||||
self.output_store = output_store
|
||||
|
||||
params = VDevice.create_params()
|
||||
@ -204,9 +177,11 @@ class HailoAsyncInference:
|
||||
def run(self) -> None:
|
||||
with self.infer_model.configure() as configured_infer_model:
|
||||
while True:
|
||||
batch_data = self.input_queue.get()
|
||||
batch_data = self.input_store.get()
|
||||
|
||||
if batch_data is None:
|
||||
break
|
||||
|
||||
request_id, frame_data = batch_data
|
||||
preprocessed_batch = [frame_data]
|
||||
request_ids = [request_id]
|
||||
@ -274,16 +249,14 @@ class HailoDetector(DetectionApi):
|
||||
self.working_model_path = self.check_and_prepare()
|
||||
|
||||
self.batch_size = 1
|
||||
self.input_queue = queue.Queue()
|
||||
self.input_store = RequestStore()
|
||||
self.response_store = ResponseStore()
|
||||
self.request_counter = 0
|
||||
self.request_counter_lock = threading.Lock()
|
||||
|
||||
try:
|
||||
logger.debug(f"[INIT] Loading HEF model from {self.working_model_path}")
|
||||
self.inference_engine = HailoAsyncInference(
|
||||
self.working_model_path,
|
||||
self.input_queue,
|
||||
self.input_store,
|
||||
self.response_store,
|
||||
self.batch_size,
|
||||
)
|
||||
@ -364,26 +337,16 @@ class HailoDetector(DetectionApi):
|
||||
raise FileNotFoundError(f"Model file not found at: {self.model_path}")
|
||||
return cached_model_path
|
||||
|
||||
def _get_request_id(self) -> int:
|
||||
with self.request_counter_lock:
|
||||
request_id = self.request_counter
|
||||
self.request_counter += 1
|
||||
if self.request_counter > 1000000:
|
||||
self.request_counter = 0
|
||||
return request_id
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
request_id = self._get_request_id()
|
||||
|
||||
tensor_input = self.preprocess(tensor_input)
|
||||
|
||||
if isinstance(tensor_input, np.ndarray) and len(tensor_input.shape) == 3:
|
||||
tensor_input = np.expand_dims(tensor_input, axis=0)
|
||||
|
||||
self.input_queue.put((request_id, tensor_input))
|
||||
request_id = self.input_store.put(tensor_input)
|
||||
|
||||
try:
|
||||
original_input, infer_results = self.response_store.get(
|
||||
request_id, timeout=10.0
|
||||
)
|
||||
_, infer_results = self.response_store.get(request_id, timeout=10.0)
|
||||
except TimeoutError:
|
||||
logger.error(
|
||||
f"Timeout waiting for inference results for request {request_id}"
|
||||
|
||||
@ -15,12 +15,13 @@ from frigate.detectors import create_detector
|
||||
from frigate.detectors.detector_config import (
|
||||
BaseDetectorConfig,
|
||||
InputDTypeEnum,
|
||||
InputTensorEnum,
|
||||
)
|
||||
from frigate.util.builtin import EventsPerSecond, load_labels
|
||||
from frigate.util.image import SharedMemoryFrameManager, UntrackedSharedMemory
|
||||
from frigate.util.services import listen
|
||||
|
||||
from .util import tensor_transform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -30,14 +31,6 @@ class ObjectDetector(ABC):
|
||||
pass
|
||||
|
||||
|
||||
def tensor_transform(desired_shape: InputTensorEnum):
|
||||
# Currently this function only supports BHWC permutations
|
||||
if desired_shape == InputTensorEnum.nhwc:
|
||||
return None
|
||||
elif desired_shape == InputTensorEnum.nchw:
|
||||
return (0, 3, 1, 2)
|
||||
|
||||
|
||||
class LocalObjectDetector(ObjectDetector):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -0,0 +1,73 @@
|
||||
"""Object detection utilities."""
|
||||
|
||||
import threading
|
||||
import queue
|
||||
|
||||
from numpy import ndarray
|
||||
|
||||
from frigate.detectors.detector_config import InputTensorEnum
|
||||
|
||||
|
||||
class RequestStore:
|
||||
"""
|
||||
A thread-safe hash-based response store that handles creating requests.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.request_counter = 0
|
||||
self.request_counter_lock = threading.Lock()
|
||||
self.input_queue = queue.Queue()
|
||||
|
||||
def __get_request_id(self) -> int:
|
||||
with self.request_counter_lock:
|
||||
request_id = self.request_counter
|
||||
self.request_counter += 1
|
||||
if self.request_counter > 1000000:
|
||||
self.request_counter = 0
|
||||
return request_id
|
||||
|
||||
def put(self, tensor_input: ndarray) -> int:
|
||||
request_id = self.__get_request_id()
|
||||
self.input_queue.get((request_id, tensor_input))
|
||||
return request_id
|
||||
|
||||
def get(self) -> tuple[int, ndarray] | None:
|
||||
try:
|
||||
return self.input_queue.get_nowait()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class ResponseStore:
|
||||
"""
|
||||
A thread-safe hash-based response store that maps request IDs
|
||||
to their results. Threads can wait on the condition variable until
|
||||
their request's result appears.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.responses = {} # Maps request_id -> (original_input, infer_results)
|
||||
self.lock = threading.Lock()
|
||||
self.cond = threading.Condition(self.lock)
|
||||
|
||||
def put(self, request_id: int, response: ndarray):
|
||||
with self.cond:
|
||||
self.responses[request_id] = response
|
||||
self.cond.notify_all()
|
||||
|
||||
def get(self, request_id: int, timeout=None) -> ndarray:
|
||||
with self.cond:
|
||||
if not self.cond.wait_for(
|
||||
lambda: request_id in self.responses, timeout=timeout
|
||||
):
|
||||
raise TimeoutError(f"Timeout waiting for response {request_id}")
|
||||
|
||||
return self.responses.pop(request_id)
|
||||
|
||||
|
||||
def tensor_transform(desired_shape: InputTensorEnum):
|
||||
# Currently this function only supports BHWC permutations
|
||||
if desired_shape == InputTensorEnum.nhwc:
|
||||
return None
|
||||
elif desired_shape == InputTensorEnum.nchw:
|
||||
return (0, 3, 1, 2)
|
||||
Loading…
Reference in New Issue
Block a user