From 51397aeb1dffd04ebedf8a64982f0973b5967ad4 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Thu, 26 Mar 2026 11:13:47 -0600 Subject: [PATCH] Cleanup mypy for custom classification --- .../real_time/custom_classification.py | 70 +++++++++++++------ 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 1a2512e43..1dcf59052 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -24,7 +24,8 @@ from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR from frigate.log import suppress_stderr_during from frigate.types import TrackedObjectUpdateTypesEnum from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels -from frigate.util.object import box_overlaps, calculate_region +from frigate.util.image import calculate_region +from frigate.util.object import box_overlaps from ..types import DataProcessorMetrics from .api import RealTimeProcessorApi @@ -49,12 +50,16 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): ): super().__init__(config, metrics) self.model_config = model_config + + if not self.model_config.name: + raise ValueError("Custom classification model name must be set.") + self.requestor = requestor self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") - self.interpreter: Interpreter = None - self.tensor_input_details: dict[str, Any] | None = None - self.tensor_output_details: dict[str, Any] | None = None + self.interpreter: Interpreter | None = None + self.tensor_input_details: list[dict[str, Any]] | None = None + self.tensor_output_details: list[dict[str, Any]] | None = None self.labelmap: dict[int, str] = {} self.classifications_per_second = EventsPerSecond() self.state_history: dict[str, dict[str, Any]] = {} @@ -63,7 +68,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): self.metrics and self.model_config.name in self.metrics.classification_speeds ): - self.inference_speed = InferenceSpeed( + self.inference_speed: InferenceSpeed | None = InferenceSpeed( self.metrics.classification_speeds[self.model_config.name] ) else: @@ -172,12 +177,20 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): return None - def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): + def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray) -> None: + if ( + not self.model_config.name + or not self.model_config.state_config + or not self.tensor_input_details + or not self.tensor_output_details + ): + return + if self.metrics and self.model_config.name in self.metrics.classification_cps: self.metrics.classification_cps[ self.model_config.name ].value = self.classifications_per_second.eps() - camera = frame_data.get("camera") + camera = str(frame_data.get("camera")) if camera not in self.model_config.state_config.cameras: return @@ -283,7 +296,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): logger.debug( f"{self.model_config.name} Ran state classification with probabilities: {probs}" ) - best_id = np.argmax(probs) + best_id = int(np.argmax(probs)) score = round(probs[best_id], 2) self.__update_metrics(datetime.datetime.now().timestamp() - now) @@ -319,7 +332,9 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): verified_state, ) - def handle_request(self, topic, request_data): + def handle_request( + self, topic: str, request_data: dict[str, Any] + ) -> dict[str, Any] | None: if topic == EmbeddingsRequestEnum.reload_classification_model.value: if request_data.get("model_name") == self.model_config.name: self.__build_detector() @@ -335,7 +350,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): else: return None - def expire_object(self, object_id, camera): + def expire_object(self, object_id: str, camera: str) -> None: pass @@ -350,13 +365,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ): super().__init__(config, metrics) self.model_config = model_config + + if not self.model_config.name: + raise ValueError("Custom classification model name must be set.") + self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) self.train_dir = os.path.join(CLIPS_DIR, self.model_config.name, "train") - self.interpreter: Interpreter = None + self.interpreter: Interpreter | None = None self.sub_label_publisher = sub_label_publisher self.requestor = requestor - self.tensor_input_details: dict[str, Any] | None = None - self.tensor_output_details: dict[str, Any] | None = None + self.tensor_input_details: list[dict[str, Any]] | None = None + self.tensor_output_details: list[dict[str, Any]] | None = None self.classification_history: dict[str, list[tuple[str, float, float]]] = {} self.labelmap: dict[int, str] = {} self.classifications_per_second = EventsPerSecond() @@ -365,7 +384,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): self.metrics and self.model_config.name in self.metrics.classification_speeds ): - self.inference_speed = InferenceSpeed( + self.inference_speed: InferenceSpeed | None = InferenceSpeed( self.metrics.classification_speeds[self.model_config.name] ) else: @@ -431,8 +450,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ) return None, 0.0 - label_counts = {} - label_scores = {} + label_counts: dict[str, int] = {} + label_scores: dict[str, list[float]] = {} total_attempts = len(history) for label, score, timestamp in history: @@ -443,7 +462,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): label_counts[label] += 1 label_scores[label].append(score) - best_label = max(label_counts, key=label_counts.get) + best_label = max(label_counts, key=lambda k: label_counts[k]) best_count = label_counts[best_label] consensus_threshold = total_attempts * 0.6 @@ -470,7 +489,15 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ) return best_label, avg_score - def process_frame(self, obj_data, frame): + def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None: + if ( + not self.model_config.name + or not self.model_config.object_config + or not self.tensor_input_details + or not self.tensor_output_details + ): + return + if self.metrics and self.model_config.name in self.metrics.classification_cps: self.metrics.classification_cps[ self.model_config.name @@ -555,7 +582,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): logger.debug( f"{self.model_config.name} Ran object classification with probabilities: {probs}" ) - best_id = np.argmax(probs) + best_id = int(np.argmax(probs)) score = round(probs[best_id], 2) self.__update_metrics(datetime.datetime.now().timestamp() - now) @@ -650,7 +677,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ), ) - def handle_request(self, topic, request_data): + def handle_request(self, topic: str, request_data: dict) -> dict | None: if topic == EmbeddingsRequestEnum.reload_classification_model.value: if request_data.get("model_name") == self.model_config.name: self.__build_detector() @@ -666,12 +693,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): else: return None - def expire_object(self, object_id, camera): + def expire_object(self, object_id: str, camera: str) -> None: if object_id in self.classification_history: self.classification_history.pop(object_id) -@staticmethod def write_classification_attempt( folder: str, frame: np.ndarray,