"""Real time processor that works with classification tflite models.""" import datetime import json import logging import os from typing import Any import cv2 import numpy as np from frigate.comms.embeddings_updater import EmbeddingsRequestEnum from frigate.comms.event_metadata_updater import ( EventMetadataPublisher, EventMetadataTypeEnum, ) from frigate.comms.inter_process import InterProcessRequestor from frigate.config import FrigateConfig from frigate.config.classification import ( CustomClassificationConfig, ObjectClassificationType, ) from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR from frigate.log import redirect_output_to_logger from frigate.types import TrackedObjectUpdateTypesEnum from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels from frigate.util.object import box_overlaps, calculate_region from ..types import DataProcessorMetrics from .api import RealTimeProcessorApi try: from tflite_runtime.interpreter import Interpreter except ModuleNotFoundError: from tensorflow.lite.python.interpreter import Interpreter logger = logging.getLogger(__name__) MAX_OBJECT_CLASSIFICATIONS = 16 class CustomStateClassificationProcessor(RealTimeProcessorApi): def __init__( self, config: FrigateConfig, model_config: CustomClassificationConfig, requestor: InterProcessRequestor, metrics: DataProcessorMetrics, ): super().__init__(config, metrics) self.model_config = model_config self.requestor = requestor self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") self.interpreter: Interpreter | None = None self.tensor_input_details: dict[str, Any] | None = None self.tensor_output_details: dict[str, Any] | None = None self.labelmap: dict[int, str] = {} self.classifications_per_second = EventsPerSecond() self.state_history: dict[str, dict[str, Any]] = {} if ( self.metrics and self.model_config.name in self.metrics.classification_speeds ): self.inference_speed = InferenceSpeed( self.metrics.classification_speeds[self.model_config.name] ) else: self.inference_speed = None self.last_run = datetime.datetime.now().timestamp() self.__build_detector() @redirect_output_to_logger(logger, logging.DEBUG) def __build_detector(self) -> None: model_path = os.path.join(self.model_dir, "model.tflite") labelmap_path = os.path.join(self.model_dir, "labelmap.txt") if not os.path.exists(model_path) or not os.path.exists(labelmap_path): self.interpreter = None self.tensor_input_details = None self.tensor_output_details = None self.labelmap = {} return self.interpreter = Interpreter( model_path=model_path, num_threads=2, ) self.interpreter.allocate_tensors() self.tensor_input_details = self.interpreter.get_input_details() self.tensor_output_details = self.interpreter.get_output_details() self.labelmap = load_labels(labelmap_path, prefill=0) self.classifications_per_second.start() def __update_metrics(self, duration: float) -> None: self.classifications_per_second.update() if self.inference_speed: self.inference_speed.update(duration) def _should_save_image( self, camera: str, detected_state: str, score: float = 1.0 ) -> bool: """ Determine if we should save the image for training. Save when: - State is changing or being verified (regardless of score) - Score is less than 100% (even if state matches, useful for training) Don't save when: - State is stable (matches current_state) AND score is 100% """ if camera not in self.state_history: # First detection for this camera, save it return True verification = self.state_history[camera] current_state = verification.get("current_state") pending_state = verification.get("pending_state") # Save if there's a pending state change being verified if pending_state is not None: return True # Save if the detected state differs from the current verified state # (state is changing) if current_state is not None and detected_state != current_state: return True # If score is less than 100%, save even if state matches # (useful for training to improve confidence) if score < 1.0: return True # Don't save if state is stable (detected_state == current_state) AND score is 100% return False def verify_state_change(self, camera: str, detected_state: str) -> str | None: """ Verify state change requires 3 consecutive identical states before publishing. Returns state to publish or None if verification not complete. """ if camera not in self.state_history: self.state_history[camera] = { "current_state": None, "pending_state": None, "consecutive_count": 0, } verification = self.state_history[camera] if detected_state == verification["current_state"]: verification["pending_state"] = None verification["consecutive_count"] = 0 return None if detected_state == verification["pending_state"]: verification["consecutive_count"] += 1 if verification["consecutive_count"] >= 3: verification["current_state"] = detected_state verification["pending_state"] = None verification["consecutive_count"] = 0 return detected_state else: verification["pending_state"] = detected_state verification["consecutive_count"] = 1 logger.debug( f"New state '{detected_state}' detected for {camera}, need {3 - verification['consecutive_count']} more consecutive detections" ) return None def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): if self.metrics and self.model_config.name in self.metrics.classification_cps: self.metrics.classification_cps[ self.model_config.name ].value = self.classifications_per_second.eps() camera = frame_data.get("camera") if camera not in self.model_config.state_config.cameras: return camera_config = self.model_config.state_config.cameras[camera] crop = [ camera_config.crop[0] * self.config.cameras[camera].detect.width, camera_config.crop[1] * self.config.cameras[camera].detect.height, camera_config.crop[2] * self.config.cameras[camera].detect.width, camera_config.crop[3] * self.config.cameras[camera].detect.height, ] should_run = False now = datetime.datetime.now().timestamp() if ( self.model_config.state_config.interval and now > self.last_run + self.model_config.state_config.interval ): self.last_run = now should_run = True if ( not should_run and self.model_config.state_config.motion and any([box_overlaps(crop, mb) for mb in frame_data.get("motion", [])]) ): # classification should run at most once per second if now > self.last_run + 1: self.last_run = now should_run = True # Shortcut: always run if we have a pending state verification to complete if ( not should_run and camera in self.state_history and self.state_history[camera]["pending_state"] is not None and now > self.last_run + 0.5 ): self.last_run = now should_run = True logger.debug( f"Running verification check for pending state: {self.state_history[camera]['pending_state']} ({self.state_history[camera]['consecutive_count']}/3)" ) if not should_run: return x, y, x2, y2 = calculate_region( frame.shape, crop[0], crop[1], crop[2], crop[3], 224, 1.0, ) rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) frame = rgb[ y:y2, x:x2, ] if frame.shape != (224, 224): try: resized_frame = cv2.resize(frame, (224, 224)) except Exception: logger.warning("Failed to resize image for state classification") return if self.interpreter is None: # When interpreter is None, always save (score is 0.0, which is < 1.0) if self._should_save_image(camera, "unknown", 0.0): write_classification_attempt( self.train_dir, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), "none-none", now, "unknown", 0.0, ) return input = np.expand_dims(resized_frame, axis=0) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.invoke() res: np.ndarray = self.interpreter.get_tensor( self.tensor_output_details[0]["index"] )[0] probs = res / res.sum(axis=0) logger.debug( f"{self.model_config.name} Ran state classification with probabilities: {probs}" ) best_id = np.argmax(probs) score = round(probs[best_id], 2) self.__update_metrics(datetime.datetime.now().timestamp() - now) detected_state = self.labelmap[best_id] if self._should_save_image(camera, detected_state, score): write_classification_attempt( self.train_dir, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), "none-none", now, detected_state, score, ) if score < self.model_config.threshold: logger.debug( f"Score {score} below threshold {self.model_config.threshold}, skipping verification" ) return verified_state = self.verify_state_change(camera, detected_state) if verified_state is not None: self.requestor.send_data( f"{camera}/classification/{self.model_config.name}", verified_state, ) def handle_request(self, topic, request_data): if topic == EmbeddingsRequestEnum.reload_classification_model.value: if request_data.get("model_name") == self.model_config.name: self.__build_detector() logger.info( f"Successfully loaded updated model for {self.model_config.name}" ) return { "success": True, "message": f"Loaded {self.model_config.name} model.", } else: return None else: return None def expire_object(self, object_id, camera): pass class CustomObjectClassificationProcessor(RealTimeProcessorApi): def __init__( self, config: FrigateConfig, model_config: CustomClassificationConfig, sub_label_publisher: EventMetadataPublisher, requestor: InterProcessRequestor, metrics: DataProcessorMetrics, ): super().__init__(config, metrics) self.model_config = model_config self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") self.interpreter: Interpreter | None = None self.sub_label_publisher = sub_label_publisher self.requestor = requestor self.tensor_input_details: dict[str, Any] | None = None self.tensor_output_details: dict[str, Any] | None = None self.classification_history: dict[str, list[tuple[str, float, float]]] = {} self.labelmap: dict[int, str] = {} self.classifications_per_second = EventsPerSecond() if ( self.metrics and self.model_config.name in self.metrics.classification_speeds ): self.inference_speed = InferenceSpeed( self.metrics.classification_speeds[self.model_config.name] ) else: self.inference_speed = None self.__build_detector() @redirect_output_to_logger(logger, logging.DEBUG) def __build_detector(self) -> None: model_path = os.path.join(self.model_dir, "model.tflite") labelmap_path = os.path.join(self.model_dir, "labelmap.txt") if not os.path.exists(model_path) or not os.path.exists(labelmap_path): self.interpreter = None self.tensor_input_details = None self.tensor_output_details = None self.labelmap = {} return self.interpreter = Interpreter( model_path=model_path, num_threads=2, ) self.interpreter.allocate_tensors() self.tensor_input_details = self.interpreter.get_input_details() self.tensor_output_details = self.interpreter.get_output_details() self.labelmap = load_labels(labelmap_path, prefill=0) def __update_metrics(self, duration: float) -> None: self.classifications_per_second.update() if self.inference_speed: self.inference_speed.update(duration) def get_weighted_score( self, object_id: str, current_label: str, current_score: float, current_time: float, ) -> tuple[str | None, float]: """ Determine weighted score based on history to prevent false positives/negatives. Requires 60% of attempts to agree on a label before publishing. Returns (weighted_label, weighted_score) or (None, 0.0) if no weighted score. """ if object_id not in self.classification_history: self.classification_history[object_id] = [] self.classification_history[object_id].append( (current_label, current_score, current_time) ) history = self.classification_history[object_id] if len(history) < 3: return None, 0.0 label_counts = {} label_scores = {} total_attempts = len(history) for label, score, timestamp in history: if label not in label_counts: label_counts[label] = 0 label_scores[label] = [] label_counts[label] += 1 label_scores[label].append(score) best_label = max(label_counts, key=label_counts.get) best_count = label_counts[best_label] consensus_threshold = total_attempts * 0.6 if best_count < consensus_threshold: return None, 0.0 avg_score = sum(label_scores[best_label]) / len(label_scores[best_label]) if best_label == "none": return None, 0.0 return best_label, avg_score def process_frame(self, obj_data, frame): if self.metrics and self.model_config.name in self.metrics.classification_cps: self.metrics.classification_cps[ self.model_config.name ].value = self.classifications_per_second.eps() if obj_data["false_positive"]: return if obj_data["label"] not in self.model_config.object_config.objects: return if obj_data.get("end_time") is not None: return object_id = obj_data["id"] if ( object_id in self.classification_history and len(self.classification_history[object_id]) >= MAX_OBJECT_CLASSIFICATIONS ): return now = datetime.datetime.now().timestamp() x, y, x2, y2 = calculate_region( frame.shape, obj_data["box"][0], obj_data["box"][1], obj_data["box"][2], obj_data["box"][3], max( obj_data["box"][2] - obj_data["box"][0], obj_data["box"][3] - obj_data["box"][1], ), 1.0, ) rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) crop = rgb[ y:y2, x:x2, ] if crop.shape != (224, 224): try: resized_crop = cv2.resize(crop, (224, 224)) except Exception: logger.warning("Failed to resize image for state classification") return if self.interpreter is None: write_classification_attempt( self.train_dir, cv2.cvtColor(crop, cv2.COLOR_RGB2BGR), object_id, now, "unknown", 0.0, ) return input = np.expand_dims(resized_crop, axis=0) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.invoke() res: np.ndarray = self.interpreter.get_tensor( self.tensor_output_details[0]["index"] )[0] probs = res / res.sum(axis=0) logger.debug( f"{self.model_config.name} Ran object classification with probabilities: {probs}" ) best_id = np.argmax(probs) score = round(probs[best_id], 2) self.__update_metrics(datetime.datetime.now().timestamp() - now) write_classification_attempt( self.train_dir, cv2.cvtColor(crop, cv2.COLOR_RGB2BGR), object_id, now, self.labelmap[best_id], score, max_files=200, ) if score < self.model_config.threshold: logger.debug(f"Score {score} is less than threshold.") return sub_label = self.labelmap[best_id] consensus_label, consensus_score = self.get_weighted_score( object_id, sub_label, score, now ) if consensus_label is not None: camera = obj_data["camera"] if ( self.model_config.object_config.classification_type == ObjectClassificationType.sub_label ): self.sub_label_publisher.publish( (object_id, consensus_label, consensus_score), EventMetadataTypeEnum.sub_label, ) self.requestor.send_data( "tracked_object_update", json.dumps( { "type": TrackedObjectUpdateTypesEnum.classification, "id": object_id, "camera": camera, "timestamp": now, "model": self.model_config.name, "sub_label": consensus_label, "score": consensus_score, } ), ) elif ( self.model_config.object_config.classification_type == ObjectClassificationType.attribute ): self.sub_label_publisher.publish( ( object_id, self.model_config.name, consensus_label, consensus_score, ), EventMetadataTypeEnum.attribute.value, ) self.requestor.send_data( "tracked_object_update", json.dumps( { "type": TrackedObjectUpdateTypesEnum.classification, "id": object_id, "camera": camera, "timestamp": now, "model": self.model_config.name, "attribute": consensus_label, "score": consensus_score, } ), ) def handle_request(self, topic, request_data): if topic == EmbeddingsRequestEnum.reload_classification_model.value: if request_data.get("model_name") == self.model_config.name: logger.info( f"Successfully loaded updated model for {self.model_config.name}" ) return { "success": True, "message": f"Loaded {self.model_config.name} model.", } else: return None else: return None def expire_object(self, object_id, camera): if object_id in self.classification_history: self.classification_history.pop(object_id) @staticmethod def write_classification_attempt( folder: str, frame: np.ndarray, event_id: str, timestamp: float, label: str, score: float, max_files: int = 100, ) -> None: if "-" in label: label = label.replace("-", "_") file = os.path.join(folder, f"{event_id}-{timestamp}-{label}-{score}.webp") os.makedirs(folder, exist_ok=True) cv2.imwrite(file, frame) files = sorted( filter(lambda f: (f.endswith(".webp")), os.listdir(folder)), key=lambda f: os.path.getctime(os.path.join(folder, f)), reverse=True, ) # delete oldest face image if maximum is reached try: if len(files) > max_files: os.unlink(os.path.join(folder, files[-1])) except FileNotFoundError: pass