mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-12-06 13:34:13 +03:00
Classification improvements (#20665)
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions
* Don't classify objects that are ended * Use weighted scoring for object classification * Implement state verification
This commit is contained in:
parent
63042b9c08
commit
1fb21a4dac
@ -53,6 +53,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.tensor_output_details: dict[str, Any] | None = None
|
self.tensor_output_details: dict[str, Any] | None = None
|
||||||
self.labelmap: dict[int, str] = {}
|
self.labelmap: dict[int, str] = {}
|
||||||
self.classifications_per_second = EventsPerSecond()
|
self.classifications_per_second = EventsPerSecond()
|
||||||
|
self.state_history: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.metrics
|
self.metrics
|
||||||
@ -94,6 +95,42 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
if self.inference_speed:
|
if self.inference_speed:
|
||||||
self.inference_speed.update(duration)
|
self.inference_speed.update(duration)
|
||||||
|
|
||||||
|
def verify_state_change(self, camera: str, detected_state: str) -> str | None:
|
||||||
|
"""
|
||||||
|
Verify state change requires 3 consecutive identical states before publishing.
|
||||||
|
Returns state to publish or None if verification not complete.
|
||||||
|
"""
|
||||||
|
if camera not in self.state_history:
|
||||||
|
self.state_history[camera] = {
|
||||||
|
"current_state": None,
|
||||||
|
"pending_state": None,
|
||||||
|
"consecutive_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
verification = self.state_history[camera]
|
||||||
|
|
||||||
|
if detected_state == verification["current_state"]:
|
||||||
|
verification["pending_state"] = None
|
||||||
|
verification["consecutive_count"] = 0
|
||||||
|
return None
|
||||||
|
|
||||||
|
if detected_state == verification["pending_state"]:
|
||||||
|
verification["consecutive_count"] += 1
|
||||||
|
|
||||||
|
if verification["consecutive_count"] >= 3:
|
||||||
|
verification["current_state"] = detected_state
|
||||||
|
verification["pending_state"] = None
|
||||||
|
verification["consecutive_count"] = 0
|
||||||
|
return detected_state
|
||||||
|
else:
|
||||||
|
verification["pending_state"] = detected_state
|
||||||
|
verification["consecutive_count"] = 1
|
||||||
|
logger.debug(
|
||||||
|
f"New state '{detected_state}' detected for {camera}, need {3 - verification['consecutive_count']} more consecutive detections"
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
||||||
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
||||||
self.metrics.classification_cps[
|
self.metrics.classification_cps[
|
||||||
@ -131,6 +168,19 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.last_run = now
|
self.last_run = now
|
||||||
should_run = True
|
should_run = True
|
||||||
|
|
||||||
|
# Shortcut: always run if we have a pending state verification to complete
|
||||||
|
if (
|
||||||
|
not should_run
|
||||||
|
and camera in self.state_history
|
||||||
|
and self.state_history[camera]["pending_state"] is not None
|
||||||
|
and now > self.last_run + 0.5
|
||||||
|
):
|
||||||
|
self.last_run = now
|
||||||
|
should_run = True
|
||||||
|
logger.debug(
|
||||||
|
f"Running verification check for pending state: {self.state_history[camera]['pending_state']} ({self.state_history[camera]['consecutive_count']}/3)"
|
||||||
|
)
|
||||||
|
|
||||||
if not should_run:
|
if not should_run:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -188,10 +238,19 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
score,
|
score,
|
||||||
)
|
)
|
||||||
|
|
||||||
if score >= self.model_config.threshold:
|
if score < self.model_config.threshold:
|
||||||
|
logger.debug(
|
||||||
|
f"Score {score} below threshold {self.model_config.threshold}, skipping verification"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
detected_state = self.labelmap[best_id]
|
||||||
|
verified_state = self.verify_state_change(camera, detected_state)
|
||||||
|
|
||||||
|
if verified_state is not None:
|
||||||
self.requestor.send_data(
|
self.requestor.send_data(
|
||||||
f"{camera}/classification/{self.model_config.name}",
|
f"{camera}/classification/{self.model_config.name}",
|
||||||
self.labelmap[best_id],
|
verified_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_request(self, topic, request_data):
|
def handle_request(self, topic, request_data):
|
||||||
@ -230,7 +289,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.sub_label_publisher = sub_label_publisher
|
self.sub_label_publisher = sub_label_publisher
|
||||||
self.tensor_input_details: dict[str, Any] | None = None
|
self.tensor_input_details: dict[str, Any] | None = None
|
||||||
self.tensor_output_details: dict[str, Any] | None = None
|
self.tensor_output_details: dict[str, Any] | None = None
|
||||||
self.detected_objects: dict[str, float] = {}
|
self.classification_history: dict[str, list[tuple[str, float, float]]] = {}
|
||||||
self.labelmap: dict[int, str] = {}
|
self.labelmap: dict[int, str] = {}
|
||||||
self.classifications_per_second = EventsPerSecond()
|
self.classifications_per_second = EventsPerSecond()
|
||||||
|
|
||||||
@ -272,6 +331,56 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
if self.inference_speed:
|
if self.inference_speed:
|
||||||
self.inference_speed.update(duration)
|
self.inference_speed.update(duration)
|
||||||
|
|
||||||
|
def get_weighted_score(
|
||||||
|
self,
|
||||||
|
object_id: str,
|
||||||
|
current_label: str,
|
||||||
|
current_score: float,
|
||||||
|
current_time: float,
|
||||||
|
) -> tuple[str | None, float]:
|
||||||
|
"""
|
||||||
|
Determine weighted score based on history to prevent false positives/negatives.
|
||||||
|
Requires 60% of attempts to agree on a label before publishing.
|
||||||
|
Returns (weighted_label, weighted_score) or (None, 0.0) if no weighted score.
|
||||||
|
"""
|
||||||
|
if object_id not in self.classification_history:
|
||||||
|
self.classification_history[object_id] = []
|
||||||
|
|
||||||
|
self.classification_history[object_id].append(
|
||||||
|
(current_label, current_score, current_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
history = self.classification_history[object_id]
|
||||||
|
|
||||||
|
if len(history) < 3:
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
label_counts = {}
|
||||||
|
label_scores = {}
|
||||||
|
total_attempts = len(history)
|
||||||
|
|
||||||
|
for label, score, timestamp in history:
|
||||||
|
if label not in label_counts:
|
||||||
|
label_counts[label] = 0
|
||||||
|
label_scores[label] = []
|
||||||
|
|
||||||
|
label_counts[label] += 1
|
||||||
|
label_scores[label].append(score)
|
||||||
|
|
||||||
|
best_label = max(label_counts, key=label_counts.get)
|
||||||
|
best_count = label_counts[best_label]
|
||||||
|
|
||||||
|
consensus_threshold = total_attempts * 0.6
|
||||||
|
if best_count < consensus_threshold:
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
avg_score = sum(label_scores[best_label]) / len(label_scores[best_label])
|
||||||
|
|
||||||
|
if best_label == "none":
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
return best_label, avg_score
|
||||||
|
|
||||||
def process_frame(self, obj_data, frame):
|
def process_frame(self, obj_data, frame):
|
||||||
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
||||||
self.metrics.classification_cps[
|
self.metrics.classification_cps[
|
||||||
@ -284,6 +393,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
if obj_data["label"] not in self.model_config.object_config.objects:
|
if obj_data["label"] not in self.model_config.object_config.objects:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if obj_data.get("end_time") is not None:
|
||||||
|
return
|
||||||
|
|
||||||
now = datetime.datetime.now().timestamp()
|
now = datetime.datetime.now().timestamp()
|
||||||
x, y, x2, y2 = calculate_region(
|
x, y, x2, y2 = calculate_region(
|
||||||
frame.shape,
|
frame.shape,
|
||||||
@ -331,7 +443,6 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
probs = res / res.sum(axis=0)
|
probs = res / res.sum(axis=0)
|
||||||
best_id = np.argmax(probs)
|
best_id = np.argmax(probs)
|
||||||
score = round(probs[best_id], 2)
|
score = round(probs[best_id], 2)
|
||||||
previous_score = self.detected_objects.get(obj_data["id"], 0.0)
|
|
||||||
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
||||||
|
|
||||||
write_classification_attempt(
|
write_classification_attempt(
|
||||||
@ -347,30 +458,34 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
logger.debug(f"Score {score} is less than threshold.")
|
logger.debug(f"Score {score} is less than threshold.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if score <= previous_score:
|
|
||||||
logger.debug(f"Score {score} is worse than previous score {previous_score}")
|
|
||||||
return
|
|
||||||
|
|
||||||
sub_label = self.labelmap[best_id]
|
sub_label = self.labelmap[best_id]
|
||||||
self.detected_objects[obj_data["id"]] = score
|
|
||||||
|
|
||||||
if (
|
consensus_label, consensus_score = self.get_weighted_score(
|
||||||
self.model_config.object_config.classification_type
|
obj_data["id"], sub_label, score, now
|
||||||
== ObjectClassificationType.sub_label
|
)
|
||||||
):
|
|
||||||
if sub_label != "none":
|
if consensus_label is not None:
|
||||||
|
if (
|
||||||
|
self.model_config.object_config.classification_type
|
||||||
|
== ObjectClassificationType.sub_label
|
||||||
|
):
|
||||||
self.sub_label_publisher.publish(
|
self.sub_label_publisher.publish(
|
||||||
(obj_data["id"], sub_label, score),
|
(obj_data["id"], consensus_label, consensus_score),
|
||||||
EventMetadataTypeEnum.sub_label,
|
EventMetadataTypeEnum.sub_label,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
self.model_config.object_config.classification_type
|
self.model_config.object_config.classification_type
|
||||||
== ObjectClassificationType.attribute
|
== ObjectClassificationType.attribute
|
||||||
):
|
):
|
||||||
self.sub_label_publisher.publish(
|
self.sub_label_publisher.publish(
|
||||||
(obj_data["id"], self.model_config.name, sub_label, score),
|
(
|
||||||
EventMetadataTypeEnum.attribute.value,
|
obj_data["id"],
|
||||||
)
|
self.model_config.name,
|
||||||
|
consensus_label,
|
||||||
|
consensus_score,
|
||||||
|
),
|
||||||
|
EventMetadataTypeEnum.attribute.value,
|
||||||
|
)
|
||||||
|
|
||||||
def handle_request(self, topic, request_data):
|
def handle_request(self, topic, request_data):
|
||||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||||
@ -388,8 +503,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def expire_object(self, object_id, camera):
|
def expire_object(self, object_id, camera):
|
||||||
if object_id in self.detected_objects:
|
if object_id in self.classification_history:
|
||||||
self.detected_objects.pop(object_id)
|
self.classification_history.pop(object_id)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user