optim: multi batch inference

This commit is contained in:
mjq2020 2025-04-06 18:49:42 +01:00
parent 6475e042b5
commit 2bfe2341f0

View File

@ -46,17 +46,19 @@ class ResponseStore:
self.cond = threading.Condition(self.lock) self.cond = threading.Condition(self.lock)
def put(self, request_id, response): def put(self, request_id, response):
with self.cond:
self.responses[request_id] = response self.responses[request_id] = response
self.cond.notify_all() self.cond.notify_all()
def get(self, request_id, timeout=None): def get(self, request_id, timeout=None):
with self.cond: with self.cond:
if not self.cond.wait_for( if not self.cond.wait_for(
lambda: request_id in self.responses, timeout=timeout lambda: all(ri in self.responses for ri in request_id), timeout=timeout
): ):
raise TimeoutError(f"Timeout waiting for response {request_id}") raise TimeoutError(f"Timeout waiting for response {request_id}")
return self.responses.pop(request_id) results = []
for ri in request_id:
results.append(self.responses.pop(ri))
return results
# ----------------- Utility Functions ----------------- # # ----------------- Utility Functions ----------------- #
@ -167,12 +169,15 @@ class HailoAsyncInference:
if completion_info.exception: if completion_info.exception:
logger.error(f"Inference error: {completion_info.exception}") logger.error(f"Inference error: {completion_info.exception}")
else: else:
with self.output_store.cond:
for i, bindings in enumerate(bindings_list): for i, bindings in enumerate(bindings_list):
if len(bindings._output_names) == 1: if len(bindings._output_names) == 1:
result = bindings.output().get_buffer() result = bindings.output().get_buffer()
else: else:
result = { result = {
name: np.expand_dims(bindings.output(name).get_buffer(), axis=0) name: np.expand_dims(
bindings.output(name).get_buffer(), axis=0
)
for name in bindings._output_names for name in bindings._output_names
} }
self.output_store.put(request_ids[i], (input_batch[i], result)) self.output_store.put(request_ids[i], (input_batch[i], result))
@ -207,9 +212,7 @@ class HailoAsyncInference:
batch_data = self.input_queue.get() batch_data = self.input_queue.get()
if batch_data is None: if batch_data is None:
break break
request_id, frame_data = batch_data request_ids, preprocessed_batch = batch_data
preprocessed_batch = [frame_data]
request_ids = [request_id]
input_batch = preprocessed_batch # non-send_original_frame mode input_batch = preprocessed_batch # non-send_original_frame mode
bindings_list = [] bindings_list = []
@ -364,32 +367,35 @@ class HailoDetector(DetectionApi):
raise FileNotFoundError(f"Model file not found at: {self.model_path}") raise FileNotFoundError(f"Model file not found at: {self.model_path}")
return cached_model_path return cached_model_path
def _get_request_id(self) -> int: def _get_request_id(self, batch) -> int:
with self.request_counter_lock: with self.request_counter_lock:
request_ids = []
for i in range(batch):
request_id = self.request_counter request_id = self.request_counter
self.request_counter += 1 self.request_counter += 1
if self.request_counter > 1000000: if self.request_counter > 1000000:
self.request_counter = 0 self.request_counter = 0
return request_id request_ids.append(request_id)
return request_ids
def detect_raw(self, tensor_input): def detect_raw(self, tensor_input):
request_id = self._get_request_id() request_ids = self._get_request_id(tensor_input.shape[0])
tensor_input = self.preprocess(tensor_input) # tensor_input = self.preprocess(tensor_input)
if isinstance(tensor_input, np.ndarray) and len(tensor_input.shape) == 3: if isinstance(tensor_input, np.ndarray) and len(tensor_input.shape) == 3:
tensor_input = np.expand_dims(tensor_input, axis=0) tensor_input = np.expand_dims(tensor_input, axis=0)
self.input_queue.put((request_id, tensor_input)) self.input_queue.put((request_ids, tensor_input))
try: try:
original_input, infer_results = self.response_store.get( batch_infer_results = self.response_store.get(request_ids, timeout=10.0)
request_id, timeout=10.0
)
except TimeoutError: except TimeoutError:
logger.error( logger.error(
f"Timeout waiting for inference results for request {request_id}" f"Timeout waiting for inference results for request {request_ids}"
) )
return np.zeros((20, 6), dtype=np.float32) return np.zeros((0, 20, 6), dtype=np.float32)
results = []
for original_input, infer_results in batch_infer_results:
if isinstance(infer_results, list) and len(infer_results) == 1: if isinstance(infer_results, list) and len(infer_results) == 1:
infer_results = infer_results[0] infer_results = infer_results[0]
@ -404,7 +410,9 @@ class HailoDetector(DetectionApi):
score = float(det[4]) score = float(det[4])
if score < threshold: if score < threshold:
continue continue
all_detections.append([class_id, score, det[0], det[1], det[2], det[3]]) all_detections.append(
[class_id, score, det[0], det[1], det[2], det[3]]
)
if len(all_detections) == 0: if len(all_detections) == 0:
detections_array = np.zeros((20, 6), dtype=np.float32) detections_array = np.zeros((20, 6), dtype=np.float32)
@ -413,10 +421,12 @@ class HailoDetector(DetectionApi):
if detections_array.shape[0] > 20: if detections_array.shape[0] > 20:
detections_array = detections_array[:20, :] detections_array = detections_array[:20, :]
elif detections_array.shape[0] < 20: elif detections_array.shape[0] < 20:
pad = np.zeros((20 - detections_array.shape[0], 6), dtype=np.float32) pad = np.zeros(
(20 - detections_array.shape[0], 6), dtype=np.float32
)
detections_array = np.vstack((detections_array, pad)) detections_array = np.vstack((detections_array, pad))
results.append(detections_array)
return detections_array return np.stack(results, axis=0)
def preprocess(self, image): def preprocess(self, image):
if isinstance(image, np.ndarray): if isinstance(image, np.ndarray):