From a0a422df36011faef8019c8af37a57182e174eee Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Thu, 29 May 2025 10:55:52 -0600 Subject: [PATCH] Handle model structure and write attempt images --- .../real_time/custom_classification.py | 70 +++++++++++++++---- 1 file changed, 57 insertions(+), 13 deletions(-) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index cd99508c9..8bdf64033 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -2,6 +2,7 @@ import datetime import logging +import os from typing import Any import cv2 @@ -14,6 +15,7 @@ from frigate.comms.event_metadata_updater import ( from frigate.comms.inter_process import InterProcessRequestor from frigate.config import FrigateConfig from frigate.config.classification import CustomClassificationConfig +from frigate.const import MODEL_CACHE_DIR from frigate.util.builtin import load_labels from frigate.util.object import box_overlaps, calculate_region @@ -33,14 +35,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): self, config: FrigateConfig, model_config: CustomClassificationConfig, - name: str, requestor: InterProcessRequestor, metrics: DataProcessorMetrics, ): super().__init__(config, metrics) self.model_config = model_config - self.name = name self.requestor = requestor + self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) + self.train_dir = os.path.join(self.model_dir, "train") self.interpreter: Interpreter = None self.tensor_input_details: dict[str, Any] = None self.tensor_output_details: dict[str, Any] = None @@ -50,13 +52,16 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): def __build_detector(self) -> None: self.interpreter = Interpreter( - model_path=self.model_config.model_path, + model_path=os.path.join(self.model_dir, "model.tflite"), num_threads=2, ) self.interpreter.allocate_tensors() self.tensor_input_details = self.interpreter.get_input_details() self.tensor_output_details = self.interpreter.get_output_details() - self.labelmap = load_labels(self.model_config.labelmap_path, prefill=0) + self.labelmap = load_labels( + os.path.join(self.model_dir, "labelmap.txt"), + prefill=0, + ) def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): camera = frame_data.get("camera") @@ -105,15 +110,15 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): ) rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) - input = rgb[ + frame = rgb[ y:y2, x:x2, ] - if input.shape != (224, 224): - input = cv2.resize(input, (224, 224)) + if frame.shape != (224, 224): + frame = cv2.resize(frame, (224, 224)) - input = np.expand_dims(input, axis=0) + input = np.expand_dims(frame, axis=0) self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input) self.interpreter.invoke() res: np.ndarray = self.interpreter.get_tensor( @@ -123,9 +128,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): best_id = np.argmax(probs) score = round(probs[best_id], 2) + write_classification_attempt( + self.train_dir, frame, now, self.labelmap[best_id], score + ) + if score >= camera_config.threshold: self.requestor.send_data( - f"{camera}/classification/{self.name}", self.labelmap[best_id] + f"{camera}/classification/{self.model_config.name}", + self.labelmap[best_id], ) def handle_request(self, topic, request_data): @@ -145,6 +155,8 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ): super().__init__(config, metrics) self.model_config = model_config + self.model_dir = os.path.join(MODEL_CACHE_DIR, self.model_config.name) + self.train_dir = os.path.join(self.model_dir, "train") self.interpreter: Interpreter = None self.sub_label_publisher = sub_label_publisher self.tensor_input_details: dict[str, Any] = None @@ -155,18 +167,22 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): def __build_detector(self) -> None: self.interpreter = Interpreter( - model_path=self.model_config.model_path, + model_path=os.path.join(self.model_dir, "model.tflite"), num_threads=2, ) self.interpreter.allocate_tensors() self.tensor_input_details = self.interpreter.get_input_details() self.tensor_output_details = self.interpreter.get_output_details() - self.labelmap = load_labels(self.model_config.labelmap_path, prefill=0) + self.labelmap = load_labels( + os.path.join(self.model_dir, "labelmap.txt"), + prefill=0, + ) def process_frame(self, obj_data, frame): if obj_data["label"] not in self.model_config.object_config.objects: return + now = datetime.datetime.now().timestamp() x, y, x2, y2 = calculate_region( frame.shape, obj_data["box"][0], @@ -194,11 +210,13 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): )[0] probs = res / res.sum(axis=0) best_id = np.argmax(probs) - score = round(probs[best_id], 2) - previous_score = self.detected_objects.get(obj_data["id"], 0.0) + write_classification_attempt( + self.train_dir, frame, now, self.labelmap[best_id], score + ) + if score <= previous_score: logger.debug(f"Score {score} is worse than previous score {previous_score}") return @@ -215,3 +233,29 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): def expire_object(self, object_id, camera): if object_id in self.detected_objects: self.detected_objects.pop(object_id) + + +@staticmethod +def write_classification_attempt( + folder: str, + frame: np.ndarray, + timestamp: float, + label: str, + score: float, +) -> None: + if "-" in label: + label = label.replace("-", "_") + + file = os.path.join(folder, f"{timestamp}-{label}-{score}.webp") + os.makedirs(folder, exist_ok=True) + cv2.imwrite(file, frame) + + files = sorted( + filter(lambda f: (f.endswith(".webp")), os.listdir(folder)), + key=lambda f: os.path.getctime(os.path.join(folder, f)), + reverse=True, + ) + + # delete oldest face image if maximum is reached + if len(files) > 100: + os.unlink(os.path.join(folder, files[-1]))