diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index ac6387785..46929041f 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -53,6 +53,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): self.tensor_output_details: dict[str, Any] | None = None self.labelmap: dict[int, str] = {} self.classifications_per_second = EventsPerSecond() + self.state_history: dict[str, dict[str, Any]] = {} if ( self.metrics @@ -94,6 +95,42 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): if self.inference_speed: self.inference_speed.update(duration) + def verify_state_change(self, camera: str, detected_state: str) -> str | None: + """ + Verify state change requires 3 consecutive identical states before publishing. + Returns state to publish or None if verification not complete. + """ + if camera not in self.state_history: + self.state_history[camera] = { + "current_state": None, + "pending_state": None, + "consecutive_count": 0, + } + + verification = self.state_history[camera] + + if detected_state == verification["current_state"]: + verification["pending_state"] = None + verification["consecutive_count"] = 0 + return None + + if detected_state == verification["pending_state"]: + verification["consecutive_count"] += 1 + + if verification["consecutive_count"] >= 3: + verification["current_state"] = detected_state + verification["pending_state"] = None + verification["consecutive_count"] = 0 + return detected_state + else: + verification["pending_state"] = detected_state + verification["consecutive_count"] = 1 + logger.debug( + f"New state '{detected_state}' detected for {camera}, need {3 - verification['consecutive_count']} more consecutive detections" + ) + + return None + def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): if self.metrics and self.model_config.name in self.metrics.classification_cps: self.metrics.classification_cps[ @@ -131,6 +168,19 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): self.last_run = now should_run = True + # Shortcut: always run if we have a pending state verification to complete + if ( + not should_run + and camera in self.state_history + and self.state_history[camera]["pending_state"] is not None + and now > self.last_run + 0.5 + ): + self.last_run = now + should_run = True + logger.debug( + f"Running verification check for pending state: {self.state_history[camera]['pending_state']} ({self.state_history[camera]['consecutive_count']}/3)" + ) + if not should_run: return @@ -188,10 +238,19 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): score, ) - if score >= self.model_config.threshold: + if score < self.model_config.threshold: + logger.debug( + f"Score {score} below threshold {self.model_config.threshold}, skipping verification" + ) + return + + detected_state = self.labelmap[best_id] + verified_state = self.verify_state_change(camera, detected_state) + + if verified_state is not None: self.requestor.send_data( f"{camera}/classification/{self.model_config.name}", - self.labelmap[best_id], + verified_state, ) def handle_request(self, topic, request_data): @@ -230,7 +289,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): self.sub_label_publisher = sub_label_publisher self.tensor_input_details: dict[str, Any] | None = None self.tensor_output_details: dict[str, Any] | None = None - self.detected_objects: dict[str, float] = {} + self.classification_history: dict[str, list[tuple[str, float, float]]] = {} self.labelmap: dict[int, str] = {} self.classifications_per_second = EventsPerSecond() @@ -272,6 +331,56 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): if self.inference_speed: self.inference_speed.update(duration) + def get_weighted_score( + self, + object_id: str, + current_label: str, + current_score: float, + current_time: float, + ) -> tuple[str | None, float]: + """ + Determine weighted score based on history to prevent false positives/negatives. + Requires 60% of attempts to agree on a label before publishing. + Returns (weighted_label, weighted_score) or (None, 0.0) if no weighted score. + """ + if object_id not in self.classification_history: + self.classification_history[object_id] = [] + + self.classification_history[object_id].append( + (current_label, current_score, current_time) + ) + + history = self.classification_history[object_id] + + if len(history) < 3: + return None, 0.0 + + label_counts = {} + label_scores = {} + total_attempts = len(history) + + for label, score, timestamp in history: + if label not in label_counts: + label_counts[label] = 0 + label_scores[label] = [] + + label_counts[label] += 1 + label_scores[label].append(score) + + best_label = max(label_counts, key=label_counts.get) + best_count = label_counts[best_label] + + consensus_threshold = total_attempts * 0.6 + if best_count < consensus_threshold: + return None, 0.0 + + avg_score = sum(label_scores[best_label]) / len(label_scores[best_label]) + + if best_label == "none": + return None, 0.0 + + return best_label, avg_score + def process_frame(self, obj_data, frame): if self.metrics and self.model_config.name in self.metrics.classification_cps: self.metrics.classification_cps[ @@ -284,6 +393,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): if obj_data["label"] not in self.model_config.object_config.objects: return + if obj_data.get("end_time") is not None: + return + now = datetime.datetime.now().timestamp() x, y, x2, y2 = calculate_region( frame.shape, @@ -331,7 +443,6 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): probs = res / res.sum(axis=0) best_id = np.argmax(probs) score = round(probs[best_id], 2) - previous_score = self.detected_objects.get(obj_data["id"], 0.0) self.__update_metrics(datetime.datetime.now().timestamp() - now) write_classification_attempt( @@ -347,30 +458,34 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): logger.debug(f"Score {score} is less than threshold.") return - if score <= previous_score: - logger.debug(f"Score {score} is worse than previous score {previous_score}") - return - sub_label = self.labelmap[best_id] - self.detected_objects[obj_data["id"]] = score - if ( - self.model_config.object_config.classification_type - == ObjectClassificationType.sub_label - ): - if sub_label != "none": + consensus_label, consensus_score = self.get_weighted_score( + obj_data["id"], sub_label, score, now + ) + + if consensus_label is not None: + if ( + self.model_config.object_config.classification_type + == ObjectClassificationType.sub_label + ): self.sub_label_publisher.publish( - (obj_data["id"], sub_label, score), + (obj_data["id"], consensus_label, consensus_score), EventMetadataTypeEnum.sub_label, ) - elif ( - self.model_config.object_config.classification_type - == ObjectClassificationType.attribute - ): - self.sub_label_publisher.publish( - (obj_data["id"], self.model_config.name, sub_label, score), - EventMetadataTypeEnum.attribute.value, - ) + elif ( + self.model_config.object_config.classification_type + == ObjectClassificationType.attribute + ): + self.sub_label_publisher.publish( + ( + obj_data["id"], + self.model_config.name, + consensus_label, + consensus_score, + ), + EventMetadataTypeEnum.attribute.value, + ) def handle_request(self, topic, request_data): if topic == EmbeddingsRequestEnum.reload_classification_model.value: @@ -388,8 +503,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): return None def expire_object(self, object_id, camera): - if object_id in self.detected_objects: - self.detected_objects.pop(object_id) + if object_id in self.classification_history: + self.classification_history.pop(object_id) @staticmethod