From 9970e429a1fcc3eed63fe2aac4d469f8e75f69a1 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Sat, 25 Oct 2025 07:52:20 -0600 Subject: [PATCH] Use weighted scoring for object classification --- .../real_time/custom_classification.py | 99 ++++++++++++++----- 1 file changed, 76 insertions(+), 23 deletions(-) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 76422053d..6d7a449e8 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -230,7 +230,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): self.sub_label_publisher = sub_label_publisher self.tensor_input_details: dict[str, Any] | None = None self.tensor_output_details: dict[str, Any] | None = None - self.detected_objects: dict[str, float] = {} + self.classification_history: dict[str, list[tuple[str, float, float]]] = {} self.labelmap: dict[int, str] = {} self.classifications_per_second = EventsPerSecond() @@ -272,6 +272,56 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): if self.inference_speed: self.inference_speed.update(duration) + def get_weighted_score( + self, + object_id: str, + current_label: str, + current_score: float, + current_time: float, + ) -> tuple[str | None, float]: + """ + Determine weighted score based on history to prevent false positives/negatives. + Requires 60% of attempts to agree on a label before publishing. + Returns (weighted_label, weighted_score) or (None, 0.0) if no weighted score. + """ + if object_id not in self.classification_history: + self.classification_history[object_id] = [] + + self.classification_history[object_id].append( + (current_label, current_score, current_time) + ) + + history = self.classification_history[object_id] + + if len(history) < 3: + return None, 0.0 + + label_counts = {} + label_scores = {} + total_attempts = len(history) + + for label, score, timestamp in history: + if label not in label_counts: + label_counts[label] = 0 + label_scores[label] = [] + + label_counts[label] += 1 + label_scores[label].append(score) + + best_label = max(label_counts, key=label_counts.get) + best_count = label_counts[best_label] + + consensus_threshold = total_attempts * 0.6 + if best_count < consensus_threshold: + return None, 0.0 + + avg_score = sum(label_scores[best_label]) / len(label_scores[best_label]) + + if best_label == "none": + return None, 0.0 + + return best_label, avg_score + def process_frame(self, obj_data, frame): if self.metrics and self.model_config.name in self.metrics.classification_cps: self.metrics.classification_cps[ @@ -334,7 +384,6 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): probs = res / res.sum(axis=0) best_id = np.argmax(probs) score = round(probs[best_id], 2) - previous_score = self.detected_objects.get(obj_data["id"], 0.0) self.__update_metrics(datetime.datetime.now().timestamp() - now) write_classification_attempt( @@ -350,30 +399,34 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): logger.debug(f"Score {score} is less than threshold.") return - if score <= previous_score: - logger.debug(f"Score {score} is worse than previous score {previous_score}") - return - sub_label = self.labelmap[best_id] - self.detected_objects[obj_data["id"]] = score - if ( - self.model_config.object_config.classification_type - == ObjectClassificationType.sub_label - ): - if sub_label != "none": + consensus_label, consensus_score = self.get_weighted_score( + obj_data["id"], sub_label, score, now + ) + + if consensus_label is not None: + if ( + self.model_config.object_config.classification_type + == ObjectClassificationType.sub_label + ): self.sub_label_publisher.publish( - (obj_data["id"], sub_label, score), + (obj_data["id"], consensus_label, consensus_score), EventMetadataTypeEnum.sub_label, ) - elif ( - self.model_config.object_config.classification_type - == ObjectClassificationType.attribute - ): - self.sub_label_publisher.publish( - (obj_data["id"], self.model_config.name, sub_label, score), - EventMetadataTypeEnum.attribute.value, - ) + elif ( + self.model_config.object_config.classification_type + == ObjectClassificationType.attribute + ): + self.sub_label_publisher.publish( + ( + obj_data["id"], + self.model_config.name, + consensus_label, + consensus_score, + ), + EventMetadataTypeEnum.attribute.value, + ) def handle_request(self, topic, request_data): if topic == EmbeddingsRequestEnum.reload_classification_model.value: @@ -391,8 +444,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): return None def expire_object(self, object_id, camera): - if object_id in self.detected_objects: - self.detected_objects.pop(object_id) + if object_id in self.classification_history: + self.classification_history.pop(object_id) @staticmethod