fix: add multi-threading to classification

This commit is contained in:
ZhaiSoul 2026-04-14 16:17:55 +08:00
parent 335229d0d4
commit 1577657a6d

View File

@ -4,6 +4,7 @@ import datetime
import json import json
import logging import logging
import os import os
import threading
from typing import Any from typing import Any
import cv2 import cv2
@ -38,6 +39,7 @@ except ModuleNotFoundError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_OBJECT_CLASSIFICATIONS = 16 MAX_OBJECT_CLASSIFICATIONS = 16
MAX_CONCURRENT_CLASSIFICATIONS_PER_OBJECT = 3
class CustomStateClassificationProcessor(RealTimeProcessorApi): class CustomStateClassificationProcessor(RealTimeProcessorApi):
@ -379,6 +381,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.classification_history: dict[str, list[tuple[str, float, float]]] = {} self.classification_history: dict[str, list[tuple[str, float, float]]] = {}
self.labelmap: dict[int, str] = {} self.labelmap: dict[int, str] = {}
self.classifications_per_second = EventsPerSecond() self.classifications_per_second = EventsPerSecond()
self.active_tasks: dict[str, int] = {}
self.interpreter_lock = threading.Lock()
self.history_lock = threading.Lock()
if ( if (
self.metrics self.metrics
@ -419,6 +424,137 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
if self.inference_speed: if self.inference_speed:
self.inference_speed.update(duration) self.inference_speed.update(duration)
def _run_inference_thread(
self, crop: np.ndarray, object_id: str, camera: str, timestamp: float
) -> None:
"""Run inference in a separate thread to avoid blocking."""
try:
with self.interpreter_lock:
if self.interpreter is None:
return
input = np.expand_dims(crop, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke()
res = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0]
probs = res / res.sum(axis=0)
best_id = int(np.argmax(probs))
score = round(probs[best_id], 2)
self.__update_metrics(datetime.datetime.now().timestamp() - timestamp)
label = self.labelmap[best_id]
save_attempts = (
self.model_config.save_attempts
if self.model_config.save_attempts is not None
else 200
)
threading.Thread(
target=write_classification_attempt,
name=f"_save_classification_{object_id}",
daemon=True,
args=(
self.train_dir,
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
object_id,
timestamp,
label,
score,
save_attempts,
),
).start()
if score >= self.model_config.threshold:
# Add result to history before publishing (will be included in consensus calculation)
with self.history_lock:
if object_id not in self.classification_history:
self.classification_history[object_id] = []
self.classification_history[object_id].append((label, score, timestamp))
logger.debug(
f"Added classification result for {object_id}: {label} ({score}) at {timestamp}"
)
self._publish_result(object_id, camera, label, score, timestamp)
except Exception as e:
logger.error(f"Error in inference thread for {object_id}: {e}", exc_info=True)
finally:
with self.interpreter_lock:
self.active_tasks[object_id] = self.active_tasks.get(object_id, 1) - 1
if self.active_tasks[object_id] <= 0:
del self.active_tasks[object_id]
def _publish_result(
self, object_id: str, camera: str, label: str, score: float, timestamp: float
) -> None:
"""Publish classification result after weighted scoring."""
consensus_label, consensus_score = self.get_weighted_score(
object_id, label, score, timestamp
)
logger.debug(
f"{self.model_config.name}: get_weighted_score returned consensus_label={consensus_label}, consensus_score={consensus_score} for {object_id}"
)
if consensus_label is None:
return
logger.debug(
f"{self.model_config.name}: Publishing sub_label={consensus_label} for object {object_id} on {camera}"
)
if (
self.model_config.object_config.classification_type
== ObjectClassificationType.sub_label
):
self.sub_label_publisher.publish(
(object_id, consensus_label, consensus_score),
EventMetadataTypeEnum.sub_label,
)
self.requestor.send_data(
"tracked_object_update",
json.dumps(
{
"type": TrackedObjectUpdateTypesEnum.classification,
"id": object_id,
"camera": camera,
"timestamp": timestamp,
"model": self.model_config.name,
"sub_label": consensus_label,
"score": consensus_score,
}
),
)
elif (
self.model_config.object_config.classification_type
== ObjectClassificationType.attribute
):
self.sub_label_publisher.publish(
(
object_id,
self.model_config.name,
consensus_label,
consensus_score,
),
EventMetadataTypeEnum.attribute.value,
)
self.requestor.send_data(
"tracked_object_update",
json.dumps(
{
"type": TrackedObjectUpdateTypesEnum.classification,
"id": object_id,
"camera": camera,
"timestamp": timestamp,
"model": self.model_config.name,
"attribute": consensus_label,
"score": consensus_score,
}
),
)
def get_weighted_score( def get_weighted_score(
self, self,
object_id: str, object_id: str,
@ -431,63 +567,64 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
Requires 60% of attempts to agree on a label before publishing. Requires 60% of attempts to agree on a label before publishing.
Returns (weighted_label, weighted_score) or (None, 0.0) if no weighted score. Returns (weighted_label, weighted_score) or (None, 0.0) if no weighted score.
""" """
if object_id not in self.classification_history: with self.history_lock:
self.classification_history[object_id] = [] if object_id not in self.classification_history:
logger.debug(f"Created new classification history for {object_id}") self.classification_history[object_id] = []
logger.debug(f"Created new classification history for {object_id}")
self.classification_history[object_id].append( self.classification_history[object_id].append(
(current_label, current_score, current_time) (current_label, current_score, current_time)
)
history = self.classification_history[object_id]
logger.debug(
f"History for {object_id}: {len(history)} entries, latest=({current_label}, {current_score})"
)
if len(history) < 3:
logger.debug(
f"History for {object_id} has {len(history)} entries, need at least 3"
) )
return None, 0.0
label_counts: dict[str, int] = {} history = self.classification_history[object_id]
label_scores: dict[str, list[float]] = {}
total_attempts = len(history)
for label, score, timestamp in history:
if label not in label_counts:
label_counts[label] = 0
label_scores[label] = []
label_counts[label] += 1
label_scores[label].append(score)
best_label = max(label_counts, key=lambda k: label_counts[k])
best_count = label_counts[best_label]
consensus_threshold = total_attempts * 0.6
logger.debug(
f"Consensus calc for {object_id}: label_counts={label_counts}, "
f"best_label={best_label}, best_count={best_count}, "
f"total={total_attempts}, threshold={consensus_threshold}"
)
if best_count < consensus_threshold:
logger.debug( logger.debug(
f"No consensus for {object_id}: {best_count} < {consensus_threshold}" f"History for {object_id}: {len(history)} entries, latest=({current_label}, {current_score})"
) )
return None, 0.0
avg_score = sum(label_scores[best_label]) / len(label_scores[best_label]) if len(history) < 3:
logger.debug(
f"History for {object_id} has {len(history)} entries, need at least 3"
)
return None, 0.0
if best_label == "none": label_counts: dict[str, int] = {}
logger.debug(f"Filtering 'none' label for {object_id}") label_scores: dict[str, list[float]] = {}
return None, 0.0 total_attempts = len(history)
logger.debug( for label, score, timestamp in history:
f"Consensus reached for {object_id}: {best_label} with avg_score={avg_score}" if label not in label_counts:
) label_counts[label] = 0
return best_label, avg_score label_scores[label] = []
label_counts[label] += 1
label_scores[label].append(score)
best_label = max(label_counts, key=lambda k: label_counts[k])
best_count = label_counts[best_label]
consensus_threshold = total_attempts * 0.6
logger.debug(
f"Consensus calc for {object_id}: label_counts={label_counts}, "
f"best_label={best_label}, best_count={best_count}, "
f"total={total_attempts}, threshold={consensus_threshold}"
)
if best_count < consensus_threshold:
logger.debug(
f"No consensus for {object_id}: {best_count} < {consensus_threshold}"
)
return None, 0.0
avg_score = sum(label_scores[best_label]) / len(label_scores[best_label])
if best_label == "none":
logger.debug(f"Filtering 'none' label for {object_id}")
return None, 0.0
logger.debug(
f"Consensus reached for {object_id}: {best_label} with avg_score={avg_score}"
)
return best_label, avg_score
def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None: def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None:
if ( if (
@ -521,6 +658,12 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
): ):
return return
with self.interpreter_lock:
current_tasks = self.active_tasks.get(object_id, 0)
if current_tasks >= MAX_CONCURRENT_CLASSIFICATIONS_PER_OBJECT:
logger.debug(f"Object {object_id} has {current_tasks}/{MAX_CONCURRENT_CLASSIFICATIONS_PER_OBJECT} active tasks")
return
now = datetime.datetime.now().timestamp() now = datetime.datetime.now().timestamp()
x, y, x2, y2 = calculate_region( x, y, x2, y2 = calculate_region(
frame.shape, frame.shape,
@ -566,116 +709,22 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
# Still track history even when model doesn't exist to respect MAX_OBJECT_CLASSIFICATIONS # Still track history even when model doesn't exist to respect MAX_OBJECT_CLASSIFICATIONS
# Add an entry with "unknown" label so the history limit is enforced # Add an entry with "unknown" label so the history limit is enforced
if object_id not in self.classification_history: with self.history_lock:
self.classification_history[object_id] = [] if object_id not in self.classification_history:
self.classification_history[object_id] = []
self.classification_history[object_id].append(("unknown", 0.0, now)) self.classification_history[object_id].append(("unknown", 0.0, now))
return return
input = np.expand_dims(resized_crop, axis=0) with self.interpreter_lock:
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.active_tasks[object_id] = self.active_tasks.get(object_id, 0) + 1
self.interpreter.invoke()
res: np.ndarray = self.interpreter.get_tensor(
self.tensor_output_details[0]["index"]
)[0]
probs = res / res.sum(axis=0)
logger.debug(
f"{self.model_config.name} Ran object classification with probabilities: {probs}"
)
best_id = int(np.argmax(probs))
score = round(probs[best_id], 2)
self.__update_metrics(datetime.datetime.now().timestamp() - now)
save_attempts = ( threading.Thread(
self.model_config.save_attempts target=self._run_inference_thread,
if self.model_config.save_attempts is not None name=f"_classification_{object_id}",
else 200 daemon=True,
) args=(resized_crop, object_id, obj_data["camera"], now),
write_classification_attempt( ).start()
self.train_dir,
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
object_id,
now,
self.labelmap[best_id],
score,
max_files=save_attempts,
)
if score < self.model_config.threshold:
logger.debug(
f"{self.model_config.name}: Score {score} < threshold {self.model_config.threshold} for {object_id}, skipping"
)
return
sub_label = self.labelmap[best_id]
logger.debug(
f"{self.model_config.name}: Object {object_id} (label={obj_data['label']}) passed threshold with sub_label={sub_label}, score={score}"
)
consensus_label, consensus_score = self.get_weighted_score(
object_id, sub_label, score, now
)
logger.debug(
f"{self.model_config.name}: get_weighted_score returned consensus_label={consensus_label}, consensus_score={consensus_score} for {object_id}"
)
if consensus_label is not None:
camera = obj_data["camera"]
logger.debug(
f"{self.model_config.name}: Publishing sub_label={consensus_label} for {obj_data['label']} object {object_id} on {camera}"
)
if (
self.model_config.object_config.classification_type
== ObjectClassificationType.sub_label
):
self.sub_label_publisher.publish(
(object_id, consensus_label, consensus_score),
EventMetadataTypeEnum.sub_label,
)
self.requestor.send_data(
"tracked_object_update",
json.dumps(
{
"type": TrackedObjectUpdateTypesEnum.classification,
"id": object_id,
"camera": camera,
"timestamp": now,
"model": self.model_config.name,
"sub_label": consensus_label,
"score": consensus_score,
}
),
)
elif (
self.model_config.object_config.classification_type
== ObjectClassificationType.attribute
):
self.sub_label_publisher.publish(
(
object_id,
self.model_config.name,
consensus_label,
consensus_score,
),
EventMetadataTypeEnum.attribute.value,
)
self.requestor.send_data(
"tracked_object_update",
json.dumps(
{
"type": TrackedObjectUpdateTypesEnum.classification,
"id": object_id,
"camera": camera,
"timestamp": now,
"model": self.model_config.name,
"attribute": consensus_label,
"score": consensus_score,
}
),
)
def handle_request(self, topic: str, request_data: dict) -> dict | None: def handle_request(self, topic: str, request_data: dict) -> dict | None:
if topic == EmbeddingsRequestEnum.reload_classification_model.value: if topic == EmbeddingsRequestEnum.reload_classification_model.value:
@ -694,8 +743,13 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
return None return None
def expire_object(self, object_id: str, camera: str) -> None: def expire_object(self, object_id: str, camera: str) -> None:
if object_id in self.classification_history: """Clean up tracking data when an object expires."""
self.classification_history.pop(object_id) with self.history_lock:
if object_id in self.classification_history:
self.classification_history.pop(object_id)
with self.interpreter_lock:
if object_id in self.active_tasks:
del self.active_tasks[object_id]
def write_classification_attempt( def write_classification_attempt(
@ -714,7 +768,9 @@ def write_classification_attempt(
os.makedirs(folder, exist_ok=True) os.makedirs(folder, exist_ok=True)
cv2.imwrite(file, frame) cv2.imwrite(file, frame)
# delete oldest face image if maximum is reached # Delete oldest files if maximum is reached
# In multi-threaded environment, we need to delete multiple files
# to ensure we stay under the limit, as multiple threads may be running concurrently
try: try:
files = sorted( files = sorted(
filter(lambda f: f.endswith(".webp"), os.listdir(folder)), filter(lambda f: f.endswith(".webp"), os.listdir(folder)),
@ -723,6 +779,12 @@ def write_classification_attempt(
) )
if len(files) > max_files: if len(files) > max_files:
os.unlink(os.path.join(folder, files[-1])) # Delete all files beyond the limit
files_to_delete = files[max_files:]
for f in files_to_delete:
try:
os.unlink(os.path.join(folder, f))
except FileNotFoundError:
pass # File might have been deleted by another thread
except (FileNotFoundError, OSError): except (FileNotFoundError, OSError):
pass pass