mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-05-09 23:15:28 +03:00
fix: add multi-threading to classification
This commit is contained in:
parent
335229d0d4
commit
1577657a6d
@ -4,6 +4,7 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -38,6 +39,7 @@ except ModuleNotFoundError:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MAX_OBJECT_CLASSIFICATIONS = 16
|
MAX_OBJECT_CLASSIFICATIONS = 16
|
||||||
|
MAX_CONCURRENT_CLASSIFICATIONS_PER_OBJECT = 3
|
||||||
|
|
||||||
|
|
||||||
class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||||
@ -379,6 +381,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.classification_history: dict[str, list[tuple[str, float, float]]] = {}
|
self.classification_history: dict[str, list[tuple[str, float, float]]] = {}
|
||||||
self.labelmap: dict[int, str] = {}
|
self.labelmap: dict[int, str] = {}
|
||||||
self.classifications_per_second = EventsPerSecond()
|
self.classifications_per_second = EventsPerSecond()
|
||||||
|
self.active_tasks: dict[str, int] = {}
|
||||||
|
self.interpreter_lock = threading.Lock()
|
||||||
|
self.history_lock = threading.Lock()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.metrics
|
self.metrics
|
||||||
@ -419,6 +424,137 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
if self.inference_speed:
|
if self.inference_speed:
|
||||||
self.inference_speed.update(duration)
|
self.inference_speed.update(duration)
|
||||||
|
|
||||||
|
def _run_inference_thread(
|
||||||
|
self, crop: np.ndarray, object_id: str, camera: str, timestamp: float
|
||||||
|
) -> None:
|
||||||
|
"""Run inference in a separate thread to avoid blocking."""
|
||||||
|
try:
|
||||||
|
with self.interpreter_lock:
|
||||||
|
if self.interpreter is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
input = np.expand_dims(crop, axis=0)
|
||||||
|
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
||||||
|
self.interpreter.invoke()
|
||||||
|
res = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0]
|
||||||
|
probs = res / res.sum(axis=0)
|
||||||
|
|
||||||
|
best_id = int(np.argmax(probs))
|
||||||
|
score = round(probs[best_id], 2)
|
||||||
|
self.__update_metrics(datetime.datetime.now().timestamp() - timestamp)
|
||||||
|
|
||||||
|
label = self.labelmap[best_id]
|
||||||
|
|
||||||
|
save_attempts = (
|
||||||
|
self.model_config.save_attempts
|
||||||
|
if self.model_config.save_attempts is not None
|
||||||
|
else 200
|
||||||
|
)
|
||||||
|
|
||||||
|
threading.Thread(
|
||||||
|
target=write_classification_attempt,
|
||||||
|
name=f"_save_classification_{object_id}",
|
||||||
|
daemon=True,
|
||||||
|
args=(
|
||||||
|
self.train_dir,
|
||||||
|
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
|
||||||
|
object_id,
|
||||||
|
timestamp,
|
||||||
|
label,
|
||||||
|
score,
|
||||||
|
save_attempts,
|
||||||
|
),
|
||||||
|
).start()
|
||||||
|
|
||||||
|
if score >= self.model_config.threshold:
|
||||||
|
# Add result to history before publishing (will be included in consensus calculation)
|
||||||
|
with self.history_lock:
|
||||||
|
if object_id not in self.classification_history:
|
||||||
|
self.classification_history[object_id] = []
|
||||||
|
|
||||||
|
self.classification_history[object_id].append((label, score, timestamp))
|
||||||
|
logger.debug(
|
||||||
|
f"Added classification result for {object_id}: {label} ({score}) at {timestamp}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._publish_result(object_id, camera, label, score, timestamp)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in inference thread for {object_id}: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
with self.interpreter_lock:
|
||||||
|
self.active_tasks[object_id] = self.active_tasks.get(object_id, 1) - 1
|
||||||
|
if self.active_tasks[object_id] <= 0:
|
||||||
|
del self.active_tasks[object_id]
|
||||||
|
|
||||||
|
def _publish_result(
|
||||||
|
self, object_id: str, camera: str, label: str, score: float, timestamp: float
|
||||||
|
) -> None:
|
||||||
|
"""Publish classification result after weighted scoring."""
|
||||||
|
consensus_label, consensus_score = self.get_weighted_score(
|
||||||
|
object_id, label, score, timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_config.name}: get_weighted_score returned consensus_label={consensus_label}, consensus_score={consensus_score} for {object_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if consensus_label is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{self.model_config.name}: Publishing sub_label={consensus_label} for object {object_id} on {camera}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.model_config.object_config.classification_type
|
||||||
|
== ObjectClassificationType.sub_label
|
||||||
|
):
|
||||||
|
self.sub_label_publisher.publish(
|
||||||
|
(object_id, consensus_label, consensus_score),
|
||||||
|
EventMetadataTypeEnum.sub_label,
|
||||||
|
)
|
||||||
|
self.requestor.send_data(
|
||||||
|
"tracked_object_update",
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": TrackedObjectUpdateTypesEnum.classification,
|
||||||
|
"id": object_id,
|
||||||
|
"camera": camera,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"model": self.model_config.name,
|
||||||
|
"sub_label": consensus_label,
|
||||||
|
"score": consensus_score,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
self.model_config.object_config.classification_type
|
||||||
|
== ObjectClassificationType.attribute
|
||||||
|
):
|
||||||
|
self.sub_label_publisher.publish(
|
||||||
|
(
|
||||||
|
object_id,
|
||||||
|
self.model_config.name,
|
||||||
|
consensus_label,
|
||||||
|
consensus_score,
|
||||||
|
),
|
||||||
|
EventMetadataTypeEnum.attribute.value,
|
||||||
|
)
|
||||||
|
self.requestor.send_data(
|
||||||
|
"tracked_object_update",
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": TrackedObjectUpdateTypesEnum.classification,
|
||||||
|
"id": object_id,
|
||||||
|
"camera": camera,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"model": self.model_config.name,
|
||||||
|
"attribute": consensus_label,
|
||||||
|
"score": consensus_score,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def get_weighted_score(
|
def get_weighted_score(
|
||||||
self,
|
self,
|
||||||
object_id: str,
|
object_id: str,
|
||||||
@ -431,6 +567,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
Requires 60% of attempts to agree on a label before publishing.
|
Requires 60% of attempts to agree on a label before publishing.
|
||||||
Returns (weighted_label, weighted_score) or (None, 0.0) if no weighted score.
|
Returns (weighted_label, weighted_score) or (None, 0.0) if no weighted score.
|
||||||
"""
|
"""
|
||||||
|
with self.history_lock:
|
||||||
if object_id not in self.classification_history:
|
if object_id not in self.classification_history:
|
||||||
self.classification_history[object_id] = []
|
self.classification_history[object_id] = []
|
||||||
logger.debug(f"Created new classification history for {object_id}")
|
logger.debug(f"Created new classification history for {object_id}")
|
||||||
@ -521,6 +658,12 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
with self.interpreter_lock:
|
||||||
|
current_tasks = self.active_tasks.get(object_id, 0)
|
||||||
|
if current_tasks >= MAX_CONCURRENT_CLASSIFICATIONS_PER_OBJECT:
|
||||||
|
logger.debug(f"Object {object_id} has {current_tasks}/{MAX_CONCURRENT_CLASSIFICATIONS_PER_OBJECT} active tasks")
|
||||||
|
return
|
||||||
|
|
||||||
now = datetime.datetime.now().timestamp()
|
now = datetime.datetime.now().timestamp()
|
||||||
x, y, x2, y2 = calculate_region(
|
x, y, x2, y2 = calculate_region(
|
||||||
frame.shape,
|
frame.shape,
|
||||||
@ -566,116 +709,22 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
|
|
||||||
# Still track history even when model doesn't exist to respect MAX_OBJECT_CLASSIFICATIONS
|
# Still track history even when model doesn't exist to respect MAX_OBJECT_CLASSIFICATIONS
|
||||||
# Add an entry with "unknown" label so the history limit is enforced
|
# Add an entry with "unknown" label so the history limit is enforced
|
||||||
|
with self.history_lock:
|
||||||
if object_id not in self.classification_history:
|
if object_id not in self.classification_history:
|
||||||
self.classification_history[object_id] = []
|
self.classification_history[object_id] = []
|
||||||
|
|
||||||
self.classification_history[object_id].append(("unknown", 0.0, now))
|
self.classification_history[object_id].append(("unknown", 0.0, now))
|
||||||
return
|
return
|
||||||
|
|
||||||
input = np.expand_dims(resized_crop, axis=0)
|
with self.interpreter_lock:
|
||||||
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
self.active_tasks[object_id] = self.active_tasks.get(object_id, 0) + 1
|
||||||
self.interpreter.invoke()
|
|
||||||
res: np.ndarray = self.interpreter.get_tensor(
|
|
||||||
self.tensor_output_details[0]["index"]
|
|
||||||
)[0]
|
|
||||||
probs = res / res.sum(axis=0)
|
|
||||||
logger.debug(
|
|
||||||
f"{self.model_config.name} Ran object classification with probabilities: {probs}"
|
|
||||||
)
|
|
||||||
best_id = int(np.argmax(probs))
|
|
||||||
score = round(probs[best_id], 2)
|
|
||||||
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
|
||||||
|
|
||||||
save_attempts = (
|
threading.Thread(
|
||||||
self.model_config.save_attempts
|
target=self._run_inference_thread,
|
||||||
if self.model_config.save_attempts is not None
|
name=f"_classification_{object_id}",
|
||||||
else 200
|
daemon=True,
|
||||||
)
|
args=(resized_crop, object_id, obj_data["camera"], now),
|
||||||
write_classification_attempt(
|
).start()
|
||||||
self.train_dir,
|
|
||||||
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
|
|
||||||
object_id,
|
|
||||||
now,
|
|
||||||
self.labelmap[best_id],
|
|
||||||
score,
|
|
||||||
max_files=save_attempts,
|
|
||||||
)
|
|
||||||
|
|
||||||
if score < self.model_config.threshold:
|
|
||||||
logger.debug(
|
|
||||||
f"{self.model_config.name}: Score {score} < threshold {self.model_config.threshold} for {object_id}, skipping"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
sub_label = self.labelmap[best_id]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"{self.model_config.name}: Object {object_id} (label={obj_data['label']}) passed threshold with sub_label={sub_label}, score={score}"
|
|
||||||
)
|
|
||||||
|
|
||||||
consensus_label, consensus_score = self.get_weighted_score(
|
|
||||||
object_id, sub_label, score, now
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"{self.model_config.name}: get_weighted_score returned consensus_label={consensus_label}, consensus_score={consensus_score} for {object_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if consensus_label is not None:
|
|
||||||
camera = obj_data["camera"]
|
|
||||||
logger.debug(
|
|
||||||
f"{self.model_config.name}: Publishing sub_label={consensus_label} for {obj_data['label']} object {object_id} on {camera}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.model_config.object_config.classification_type
|
|
||||||
== ObjectClassificationType.sub_label
|
|
||||||
):
|
|
||||||
self.sub_label_publisher.publish(
|
|
||||||
(object_id, consensus_label, consensus_score),
|
|
||||||
EventMetadataTypeEnum.sub_label,
|
|
||||||
)
|
|
||||||
self.requestor.send_data(
|
|
||||||
"tracked_object_update",
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"type": TrackedObjectUpdateTypesEnum.classification,
|
|
||||||
"id": object_id,
|
|
||||||
"camera": camera,
|
|
||||||
"timestamp": now,
|
|
||||||
"model": self.model_config.name,
|
|
||||||
"sub_label": consensus_label,
|
|
||||||
"score": consensus_score,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
self.model_config.object_config.classification_type
|
|
||||||
== ObjectClassificationType.attribute
|
|
||||||
):
|
|
||||||
self.sub_label_publisher.publish(
|
|
||||||
(
|
|
||||||
object_id,
|
|
||||||
self.model_config.name,
|
|
||||||
consensus_label,
|
|
||||||
consensus_score,
|
|
||||||
),
|
|
||||||
EventMetadataTypeEnum.attribute.value,
|
|
||||||
)
|
|
||||||
self.requestor.send_data(
|
|
||||||
"tracked_object_update",
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"type": TrackedObjectUpdateTypesEnum.classification,
|
|
||||||
"id": object_id,
|
|
||||||
"camera": camera,
|
|
||||||
"timestamp": now,
|
|
||||||
"model": self.model_config.name,
|
|
||||||
"attribute": consensus_label,
|
|
||||||
"score": consensus_score,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def handle_request(self, topic: str, request_data: dict) -> dict | None:
|
def handle_request(self, topic: str, request_data: dict) -> dict | None:
|
||||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||||
@ -694,8 +743,13 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def expire_object(self, object_id: str, camera: str) -> None:
|
def expire_object(self, object_id: str, camera: str) -> None:
|
||||||
|
"""Clean up tracking data when an object expires."""
|
||||||
|
with self.history_lock:
|
||||||
if object_id in self.classification_history:
|
if object_id in self.classification_history:
|
||||||
self.classification_history.pop(object_id)
|
self.classification_history.pop(object_id)
|
||||||
|
with self.interpreter_lock:
|
||||||
|
if object_id in self.active_tasks:
|
||||||
|
del self.active_tasks[object_id]
|
||||||
|
|
||||||
|
|
||||||
def write_classification_attempt(
|
def write_classification_attempt(
|
||||||
@ -714,7 +768,9 @@ def write_classification_attempt(
|
|||||||
os.makedirs(folder, exist_ok=True)
|
os.makedirs(folder, exist_ok=True)
|
||||||
cv2.imwrite(file, frame)
|
cv2.imwrite(file, frame)
|
||||||
|
|
||||||
# delete oldest face image if maximum is reached
|
# Delete oldest files if maximum is reached
|
||||||
|
# In multi-threaded environment, we need to delete multiple files
|
||||||
|
# to ensure we stay under the limit, as multiple threads may be running concurrently
|
||||||
try:
|
try:
|
||||||
files = sorted(
|
files = sorted(
|
||||||
filter(lambda f: f.endswith(".webp"), os.listdir(folder)),
|
filter(lambda f: f.endswith(".webp"), os.listdir(folder)),
|
||||||
@ -723,6 +779,12 @@ def write_classification_attempt(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(files) > max_files:
|
if len(files) > max_files:
|
||||||
os.unlink(os.path.join(folder, files[-1]))
|
# Delete all files beyond the limit
|
||||||
|
files_to_delete = files[max_files:]
|
||||||
|
for f in files_to_delete:
|
||||||
|
try:
|
||||||
|
os.unlink(os.path.join(folder, f))
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass # File might have been deleted by another thread
|
||||||
except (FileNotFoundError, OSError):
|
except (FileNotFoundError, OSError):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user