mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-05-09 15:05:26 +03:00
fix: add multi-threading to classification
This commit is contained in:
parent
335229d0d4
commit
1577657a6d
@ -4,6 +4,7 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
@ -38,6 +39,7 @@ except ModuleNotFoundError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_OBJECT_CLASSIFICATIONS = 16
|
||||
MAX_CONCURRENT_CLASSIFICATIONS_PER_OBJECT = 3
|
||||
|
||||
|
||||
class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
@ -379,6 +381,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
self.classification_history: dict[str, list[tuple[str, float, float]]] = {}
|
||||
self.labelmap: dict[int, str] = {}
|
||||
self.classifications_per_second = EventsPerSecond()
|
||||
self.active_tasks: dict[str, int] = {}
|
||||
self.interpreter_lock = threading.Lock()
|
||||
self.history_lock = threading.Lock()
|
||||
|
||||
if (
|
||||
self.metrics
|
||||
@ -419,6 +424,137 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
if self.inference_speed:
|
||||
self.inference_speed.update(duration)
|
||||
|
||||
def _run_inference_thread(
|
||||
self, crop: np.ndarray, object_id: str, camera: str, timestamp: float
|
||||
) -> None:
|
||||
"""Run inference in a separate thread to avoid blocking."""
|
||||
try:
|
||||
with self.interpreter_lock:
|
||||
if self.interpreter is None:
|
||||
return
|
||||
|
||||
input = np.expand_dims(crop, axis=0)
|
||||
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
||||
self.interpreter.invoke()
|
||||
res = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0]
|
||||
probs = res / res.sum(axis=0)
|
||||
|
||||
best_id = int(np.argmax(probs))
|
||||
score = round(probs[best_id], 2)
|
||||
self.__update_metrics(datetime.datetime.now().timestamp() - timestamp)
|
||||
|
||||
label = self.labelmap[best_id]
|
||||
|
||||
save_attempts = (
|
||||
self.model_config.save_attempts
|
||||
if self.model_config.save_attempts is not None
|
||||
else 200
|
||||
)
|
||||
|
||||
threading.Thread(
|
||||
target=write_classification_attempt,
|
||||
name=f"_save_classification_{object_id}",
|
||||
daemon=True,
|
||||
args=(
|
||||
self.train_dir,
|
||||
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
|
||||
object_id,
|
||||
timestamp,
|
||||
label,
|
||||
score,
|
||||
save_attempts,
|
||||
),
|
||||
).start()
|
||||
|
||||
if score >= self.model_config.threshold:
|
||||
# Add result to history before publishing (will be included in consensus calculation)
|
||||
with self.history_lock:
|
||||
if object_id not in self.classification_history:
|
||||
self.classification_history[object_id] = []
|
||||
|
||||
self.classification_history[object_id].append((label, score, timestamp))
|
||||
logger.debug(
|
||||
f"Added classification result for {object_id}: {label} ({score}) at {timestamp}"
|
||||
)
|
||||
|
||||
self._publish_result(object_id, camera, label, score, timestamp)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in inference thread for {object_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
with self.interpreter_lock:
|
||||
self.active_tasks[object_id] = self.active_tasks.get(object_id, 1) - 1
|
||||
if self.active_tasks[object_id] <= 0:
|
||||
del self.active_tasks[object_id]
|
||||
|
||||
def _publish_result(
|
||||
self, object_id: str, camera: str, label: str, score: float, timestamp: float
|
||||
) -> None:
|
||||
"""Publish classification result after weighted scoring."""
|
||||
consensus_label, consensus_score = self.get_weighted_score(
|
||||
object_id, label, score, timestamp
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"{self.model_config.name}: get_weighted_score returned consensus_label={consensus_label}, consensus_score={consensus_score} for {object_id}"
|
||||
)
|
||||
|
||||
if consensus_label is None:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"{self.model_config.name}: Publishing sub_label={consensus_label} for object {object_id} on {camera}"
|
||||
)
|
||||
|
||||
if (
|
||||
self.model_config.object_config.classification_type
|
||||
== ObjectClassificationType.sub_label
|
||||
):
|
||||
self.sub_label_publisher.publish(
|
||||
(object_id, consensus_label, consensus_score),
|
||||
EventMetadataTypeEnum.sub_label,
|
||||
)
|
||||
self.requestor.send_data(
|
||||
"tracked_object_update",
|
||||
json.dumps(
|
||||
{
|
||||
"type": TrackedObjectUpdateTypesEnum.classification,
|
||||
"id": object_id,
|
||||
"camera": camera,
|
||||
"timestamp": timestamp,
|
||||
"model": self.model_config.name,
|
||||
"sub_label": consensus_label,
|
||||
"score": consensus_score,
|
||||
}
|
||||
),
|
||||
)
|
||||
elif (
|
||||
self.model_config.object_config.classification_type
|
||||
== ObjectClassificationType.attribute
|
||||
):
|
||||
self.sub_label_publisher.publish(
|
||||
(
|
||||
object_id,
|
||||
self.model_config.name,
|
||||
consensus_label,
|
||||
consensus_score,
|
||||
),
|
||||
EventMetadataTypeEnum.attribute.value,
|
||||
)
|
||||
self.requestor.send_data(
|
||||
"tracked_object_update",
|
||||
json.dumps(
|
||||
{
|
||||
"type": TrackedObjectUpdateTypesEnum.classification,
|
||||
"id": object_id,
|
||||
"camera": camera,
|
||||
"timestamp": timestamp,
|
||||
"model": self.model_config.name,
|
||||
"attribute": consensus_label,
|
||||
"score": consensus_score,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def get_weighted_score(
|
||||
self,
|
||||
object_id: str,
|
||||
@ -431,63 +567,64 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
Requires 60% of attempts to agree on a label before publishing.
|
||||
Returns (weighted_label, weighted_score) or (None, 0.0) if no weighted score.
|
||||
"""
|
||||
if object_id not in self.classification_history:
|
||||
self.classification_history[object_id] = []
|
||||
logger.debug(f"Created new classification history for {object_id}")
|
||||
with self.history_lock:
|
||||
if object_id not in self.classification_history:
|
||||
self.classification_history[object_id] = []
|
||||
logger.debug(f"Created new classification history for {object_id}")
|
||||
|
||||
self.classification_history[object_id].append(
|
||||
(current_label, current_score, current_time)
|
||||
)
|
||||
|
||||
history = self.classification_history[object_id]
|
||||
logger.debug(
|
||||
f"History for {object_id}: {len(history)} entries, latest=({current_label}, {current_score})"
|
||||
)
|
||||
|
||||
if len(history) < 3:
|
||||
logger.debug(
|
||||
f"History for {object_id} has {len(history)} entries, need at least 3"
|
||||
self.classification_history[object_id].append(
|
||||
(current_label, current_score, current_time)
|
||||
)
|
||||
return None, 0.0
|
||||
|
||||
label_counts: dict[str, int] = {}
|
||||
label_scores: dict[str, list[float]] = {}
|
||||
total_attempts = len(history)
|
||||
|
||||
for label, score, timestamp in history:
|
||||
if label not in label_counts:
|
||||
label_counts[label] = 0
|
||||
label_scores[label] = []
|
||||
|
||||
label_counts[label] += 1
|
||||
label_scores[label].append(score)
|
||||
|
||||
best_label = max(label_counts, key=lambda k: label_counts[k])
|
||||
best_count = label_counts[best_label]
|
||||
|
||||
consensus_threshold = total_attempts * 0.6
|
||||
logger.debug(
|
||||
f"Consensus calc for {object_id}: label_counts={label_counts}, "
|
||||
f"best_label={best_label}, best_count={best_count}, "
|
||||
f"total={total_attempts}, threshold={consensus_threshold}"
|
||||
)
|
||||
|
||||
if best_count < consensus_threshold:
|
||||
history = self.classification_history[object_id]
|
||||
logger.debug(
|
||||
f"No consensus for {object_id}: {best_count} < {consensus_threshold}"
|
||||
f"History for {object_id}: {len(history)} entries, latest=({current_label}, {current_score})"
|
||||
)
|
||||
return None, 0.0
|
||||
|
||||
avg_score = sum(label_scores[best_label]) / len(label_scores[best_label])
|
||||
if len(history) < 3:
|
||||
logger.debug(
|
||||
f"History for {object_id} has {len(history)} entries, need at least 3"
|
||||
)
|
||||
return None, 0.0
|
||||
|
||||
if best_label == "none":
|
||||
logger.debug(f"Filtering 'none' label for {object_id}")
|
||||
return None, 0.0
|
||||
label_counts: dict[str, int] = {}
|
||||
label_scores: dict[str, list[float]] = {}
|
||||
total_attempts = len(history)
|
||||
|
||||
logger.debug(
|
||||
f"Consensus reached for {object_id}: {best_label} with avg_score={avg_score}"
|
||||
)
|
||||
return best_label, avg_score
|
||||
for label, score, timestamp in history:
|
||||
if label not in label_counts:
|
||||
label_counts[label] = 0
|
||||
label_scores[label] = []
|
||||
|
||||
label_counts[label] += 1
|
||||
label_scores[label].append(score)
|
||||
|
||||
best_label = max(label_counts, key=lambda k: label_counts[k])
|
||||
best_count = label_counts[best_label]
|
||||
|
||||
consensus_threshold = total_attempts * 0.6
|
||||
logger.debug(
|
||||
f"Consensus calc for {object_id}: label_counts={label_counts}, "
|
||||
f"best_label={best_label}, best_count={best_count}, "
|
||||
f"total={total_attempts}, threshold={consensus_threshold}"
|
||||
)
|
||||
|
||||
if best_count < consensus_threshold:
|
||||
logger.debug(
|
||||
f"No consensus for {object_id}: {best_count} < {consensus_threshold}"
|
||||
)
|
||||
return None, 0.0
|
||||
|
||||
avg_score = sum(label_scores[best_label]) / len(label_scores[best_label])
|
||||
|
||||
if best_label == "none":
|
||||
logger.debug(f"Filtering 'none' label for {object_id}")
|
||||
return None, 0.0
|
||||
|
||||
logger.debug(
|
||||
f"Consensus reached for {object_id}: {best_label} with avg_score={avg_score}"
|
||||
)
|
||||
return best_label, avg_score
|
||||
|
||||
def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None:
|
||||
if (
|
||||
@ -521,6 +658,12 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
):
|
||||
return
|
||||
|
||||
with self.interpreter_lock:
|
||||
current_tasks = self.active_tasks.get(object_id, 0)
|
||||
if current_tasks >= MAX_CONCURRENT_CLASSIFICATIONS_PER_OBJECT:
|
||||
logger.debug(f"Object {object_id} has {current_tasks}/{MAX_CONCURRENT_CLASSIFICATIONS_PER_OBJECT} active tasks")
|
||||
return
|
||||
|
||||
now = datetime.datetime.now().timestamp()
|
||||
x, y, x2, y2 = calculate_region(
|
||||
frame.shape,
|
||||
@ -566,116 +709,22 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
|
||||
# Still track history even when model doesn't exist to respect MAX_OBJECT_CLASSIFICATIONS
|
||||
# Add an entry with "unknown" label so the history limit is enforced
|
||||
if object_id not in self.classification_history:
|
||||
self.classification_history[object_id] = []
|
||||
with self.history_lock:
|
||||
if object_id not in self.classification_history:
|
||||
self.classification_history[object_id] = []
|
||||
|
||||
self.classification_history[object_id].append(("unknown", 0.0, now))
|
||||
self.classification_history[object_id].append(("unknown", 0.0, now))
|
||||
return
|
||||
|
||||
input = np.expand_dims(resized_crop, axis=0)
|
||||
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
||||
self.interpreter.invoke()
|
||||
res: np.ndarray = self.interpreter.get_tensor(
|
||||
self.tensor_output_details[0]["index"]
|
||||
)[0]
|
||||
probs = res / res.sum(axis=0)
|
||||
logger.debug(
|
||||
f"{self.model_config.name} Ran object classification with probabilities: {probs}"
|
||||
)
|
||||
best_id = int(np.argmax(probs))
|
||||
score = round(probs[best_id], 2)
|
||||
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
||||
with self.interpreter_lock:
|
||||
self.active_tasks[object_id] = self.active_tasks.get(object_id, 0) + 1
|
||||
|
||||
save_attempts = (
|
||||
self.model_config.save_attempts
|
||||
if self.model_config.save_attempts is not None
|
||||
else 200
|
||||
)
|
||||
write_classification_attempt(
|
||||
self.train_dir,
|
||||
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
|
||||
object_id,
|
||||
now,
|
||||
self.labelmap[best_id],
|
||||
score,
|
||||
max_files=save_attempts,
|
||||
)
|
||||
|
||||
if score < self.model_config.threshold:
|
||||
logger.debug(
|
||||
f"{self.model_config.name}: Score {score} < threshold {self.model_config.threshold} for {object_id}, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
sub_label = self.labelmap[best_id]
|
||||
|
||||
logger.debug(
|
||||
f"{self.model_config.name}: Object {object_id} (label={obj_data['label']}) passed threshold with sub_label={sub_label}, score={score}"
|
||||
)
|
||||
|
||||
consensus_label, consensus_score = self.get_weighted_score(
|
||||
object_id, sub_label, score, now
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"{self.model_config.name}: get_weighted_score returned consensus_label={consensus_label}, consensus_score={consensus_score} for {object_id}"
|
||||
)
|
||||
|
||||
if consensus_label is not None:
|
||||
camera = obj_data["camera"]
|
||||
logger.debug(
|
||||
f"{self.model_config.name}: Publishing sub_label={consensus_label} for {obj_data['label']} object {object_id} on {camera}"
|
||||
)
|
||||
|
||||
if (
|
||||
self.model_config.object_config.classification_type
|
||||
== ObjectClassificationType.sub_label
|
||||
):
|
||||
self.sub_label_publisher.publish(
|
||||
(object_id, consensus_label, consensus_score),
|
||||
EventMetadataTypeEnum.sub_label,
|
||||
)
|
||||
self.requestor.send_data(
|
||||
"tracked_object_update",
|
||||
json.dumps(
|
||||
{
|
||||
"type": TrackedObjectUpdateTypesEnum.classification,
|
||||
"id": object_id,
|
||||
"camera": camera,
|
||||
"timestamp": now,
|
||||
"model": self.model_config.name,
|
||||
"sub_label": consensus_label,
|
||||
"score": consensus_score,
|
||||
}
|
||||
),
|
||||
)
|
||||
elif (
|
||||
self.model_config.object_config.classification_type
|
||||
== ObjectClassificationType.attribute
|
||||
):
|
||||
self.sub_label_publisher.publish(
|
||||
(
|
||||
object_id,
|
||||
self.model_config.name,
|
||||
consensus_label,
|
||||
consensus_score,
|
||||
),
|
||||
EventMetadataTypeEnum.attribute.value,
|
||||
)
|
||||
self.requestor.send_data(
|
||||
"tracked_object_update",
|
||||
json.dumps(
|
||||
{
|
||||
"type": TrackedObjectUpdateTypesEnum.classification,
|
||||
"id": object_id,
|
||||
"camera": camera,
|
||||
"timestamp": now,
|
||||
"model": self.model_config.name,
|
||||
"attribute": consensus_label,
|
||||
"score": consensus_score,
|
||||
}
|
||||
),
|
||||
)
|
||||
threading.Thread(
|
||||
target=self._run_inference_thread,
|
||||
name=f"_classification_{object_id}",
|
||||
daemon=True,
|
||||
args=(resized_crop, object_id, obj_data["camera"], now),
|
||||
).start()
|
||||
|
||||
def handle_request(self, topic: str, request_data: dict) -> dict | None:
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
@ -694,8 +743,13 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
return None
|
||||
|
||||
def expire_object(self, object_id: str, camera: str) -> None:
|
||||
if object_id in self.classification_history:
|
||||
self.classification_history.pop(object_id)
|
||||
"""Clean up tracking data when an object expires."""
|
||||
with self.history_lock:
|
||||
if object_id in self.classification_history:
|
||||
self.classification_history.pop(object_id)
|
||||
with self.interpreter_lock:
|
||||
if object_id in self.active_tasks:
|
||||
del self.active_tasks[object_id]
|
||||
|
||||
|
||||
def write_classification_attempt(
|
||||
@ -714,7 +768,9 @@ def write_classification_attempt(
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
cv2.imwrite(file, frame)
|
||||
|
||||
# delete oldest face image if maximum is reached
|
||||
# Delete oldest files if maximum is reached
|
||||
# In multi-threaded environment, we need to delete multiple files
|
||||
# to ensure we stay under the limit, as multiple threads may be running concurrently
|
||||
try:
|
||||
files = sorted(
|
||||
filter(lambda f: f.endswith(".webp"), os.listdir(folder)),
|
||||
@ -723,6 +779,12 @@ def write_classification_attempt(
|
||||
)
|
||||
|
||||
if len(files) > max_files:
|
||||
os.unlink(os.path.join(folder, files[-1]))
|
||||
# Delete all files beyond the limit
|
||||
files_to_delete = files[max_files:]
|
||||
for f in files_to_delete:
|
||||
try:
|
||||
os.unlink(os.path.join(folder, f))
|
||||
except FileNotFoundError:
|
||||
pass # File might have been deleted by another thread
|
||||
except (FileNotFoundError, OSError):
|
||||
pass
|
||||
|
||||
Loading…
Reference in New Issue
Block a user