From 99a363c0478136f1ec1a41e46285eeef554b62d8 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Sun, 9 Nov 2025 15:21:13 -0700 Subject: [PATCH] Improve classification (#20863) --- frigate/api/classification.py | 45 ++- frigate/types.py | 1 + frigate/util/classification.py | 303 +++++++++++++----- .../locales/en/views/classificationModel.json | 10 +- .../wizard/Step3ChooseExamples.tsx | 62 +++- web/src/types/ws.ts | 3 +- .../classification/ModelTrainingView.tsx | 75 ++++- 7 files changed, 380 insertions(+), 119 deletions(-) diff --git a/frigate/api/classification.py b/frigate/api/classification.py index b3194176e..87de52884 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -37,6 +37,8 @@ from frigate.models import Event from frigate.util.classification import ( collect_object_classification_examples, collect_state_classification_examples, + get_dataset_image_count, + read_training_metadata, ) from frigate.util.file import get_event_snapshot @@ -564,23 +566,54 @@ def get_classification_dataset(name: str): dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset") if not os.path.exists(dataset_dir): - return JSONResponse(status_code=200, content={}) + return JSONResponse( + status_code=200, content={"categories": {}, "training_metadata": None} + ) - for name in os.listdir(dataset_dir): - category_dir = os.path.join(dataset_dir, name) + for category_name in os.listdir(dataset_dir): + category_dir = os.path.join(dataset_dir, category_name) if not os.path.isdir(category_dir): continue - dataset_dict[name] = [] + dataset_dict[category_name] = [] for file in filter( lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))), os.listdir(category_dir), ): - dataset_dict[name].append(file) + dataset_dict[category_name].append(file) - return JSONResponse(status_code=200, content=dataset_dict) + # Get training metadata + metadata = read_training_metadata(sanitize_filename(name)) + current_image_count = get_dataset_image_count(sanitize_filename(name)) + + if metadata is None: + training_metadata = { + "has_trained": False, + "last_training_date": None, + "last_training_image_count": 0, + "current_image_count": current_image_count, + "new_images_count": current_image_count, + } + else: + last_training_count = metadata.get("last_training_image_count", 0) + new_images_count = max(0, current_image_count - last_training_count) + training_metadata = { + "has_trained": True, + "last_training_date": metadata.get("last_training_date"), + "last_training_image_count": last_training_count, + "current_image_count": current_image_count, + "new_images_count": new_images_count, + } + + return JSONResponse( + status_code=200, + content={ + "categories": dataset_dict, + "training_metadata": training_metadata, + }, + ) @router.get( diff --git a/frigate/types.py b/frigate/types.py index a9e27ba90..f342d27cd 100644 --- a/frigate/types.py +++ b/frigate/types.py @@ -23,6 +23,7 @@ class ModelStatusTypesEnum(str, Enum): error = "error" training = "training" complete = "complete" + failed = "failed" class TrackedObjectUpdateTypesEnum(str, Enum): diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 43dfd7fd7..a74094c32 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -1,5 +1,7 @@ """Util for classification models.""" +import datetime +import json import logging import os import random @@ -27,10 +29,96 @@ from frigate.util.process import FrigateProcess BATCH_SIZE = 16 EPOCHS = 50 LEARNING_RATE = 0.001 +TRAINING_METADATA_FILE = ".training_metadata.json" logger = logging.getLogger(__name__) +def write_training_metadata(model_name: str, image_count: int) -> None: + """ + Write training metadata to a hidden file in the model's clips directory. + + Args: + model_name: Name of the classification model + image_count: Number of images used in training + """ + clips_model_dir = os.path.join(CLIPS_DIR, model_name) + os.makedirs(clips_model_dir, exist_ok=True) + + metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE) + metadata = { + "last_training_date": datetime.datetime.now().isoformat(), + "last_training_image_count": image_count, + } + + try: + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + logger.info(f"Wrote training metadata for {model_name}: {image_count} images") + except Exception as e: + logger.error(f"Failed to write training metadata for {model_name}: {e}") + + +def read_training_metadata(model_name: str) -> dict[str, any] | None: + """ + Read training metadata from the hidden file in the model's clips directory. + + Args: + model_name: Name of the classification model + + Returns: + Dictionary with last_training_date and last_training_image_count, or None if not found + """ + clips_model_dir = os.path.join(CLIPS_DIR, model_name) + metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE) + + if not os.path.exists(metadata_path): + return None + + try: + with open(metadata_path, "r") as f: + metadata = json.load(f) + return metadata + except Exception as e: + logger.error(f"Failed to read training metadata for {model_name}: {e}") + return None + + +def get_dataset_image_count(model_name: str) -> int: + """ + Count the total number of images in the model's dataset directory. + + Args: + model_name: Name of the classification model + + Returns: + Total count of images across all categories + """ + dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") + + if not os.path.exists(dataset_dir): + return 0 + + total_count = 0 + try: + for category in os.listdir(dataset_dir): + category_dir = os.path.join(dataset_dir, category) + if not os.path.isdir(category_dir): + continue + + image_files = [ + f + for f in os.listdir(category_dir) + if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg")) + ] + total_count += len(image_files) + except Exception as e: + logger.error(f"Failed to count dataset images for {model_name}: {e}") + return 0 + + return total_count + + class ClassificationTrainingProcess(FrigateProcess): def __init__(self, model_name: str) -> None: super().__init__( @@ -42,7 +130,8 @@ class ClassificationTrainingProcess(FrigateProcess): def run(self) -> None: self.pre_run_setup() - self.__train_classification_model() + success = self.__train_classification_model() + exit(0 if success else 1) def __generate_representative_dataset_factory(self, dataset_dir: str): def generate_representative_dataset(): @@ -65,85 +154,117 @@ class ClassificationTrainingProcess(FrigateProcess): @redirect_output_to_logger(logger, logging.DEBUG) def __train_classification_model(self) -> bool: """Train a classification model.""" + try: + # import in the function so that tensorflow is not initialized multiple times + import tensorflow as tf + from tensorflow.keras import layers, models, optimizers + from tensorflow.keras.applications import MobileNetV2 + from tensorflow.keras.preprocessing.image import ImageDataGenerator - # import in the function so that tensorflow is not initialized multiple times - import tensorflow as tf - from tensorflow.keras import layers, models, optimizers - from tensorflow.keras.applications import MobileNetV2 - from tensorflow.keras.preprocessing.image import ImageDataGenerator + dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset") + model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name) + os.makedirs(model_dir, exist_ok=True) - logger.info(f"Kicking off classification training for {self.model_name}.") - dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset") - model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name) - os.makedirs(model_dir, exist_ok=True) - num_classes = len( - [ - d - for d in os.listdir(dataset_dir) - if os.path.isdir(os.path.join(dataset_dir, d)) - ] - ) + num_classes = len( + [ + d + for d in os.listdir(dataset_dir) + if os.path.isdir(os.path.join(dataset_dir, d)) + ] + ) - # Start with imagenet base model with 35% of channels in each layer - base_model = MobileNetV2( - input_shape=(224, 224, 3), - include_top=False, - weights="imagenet", - alpha=0.35, - ) - base_model.trainable = False # Freeze pre-trained layers + if num_classes < 2: + logger.error( + f"Training failed for {self.model_name}: Need at least 2 classes, found {num_classes}" + ) + return False - model = models.Sequential( - [ - base_model, - layers.GlobalAveragePooling2D(), - layers.Dense(128, activation="relu"), - layers.Dropout(0.3), - layers.Dense(num_classes, activation="softmax"), - ] - ) + # Start with imagenet base model with 35% of channels in each layer + base_model = MobileNetV2( + input_shape=(224, 224, 3), + include_top=False, + weights="imagenet", + alpha=0.35, + ) + base_model.trainable = False # Freeze pre-trained layers - model.compile( - optimizer=optimizers.Adam(learning_rate=LEARNING_RATE), - loss="categorical_crossentropy", - metrics=["accuracy"], - ) + model = models.Sequential( + [ + base_model, + layers.GlobalAveragePooling2D(), + layers.Dense(128, activation="relu"), + layers.Dropout(0.3), + layers.Dense(num_classes, activation="softmax"), + ] + ) - # create training set - datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2) - train_gen = datagen.flow_from_directory( - dataset_dir, - target_size=(224, 224), - batch_size=BATCH_SIZE, - class_mode="categorical", - subset="training", - ) + model.compile( + optimizer=optimizers.Adam(learning_rate=LEARNING_RATE), + loss="categorical_crossentropy", + metrics=["accuracy"], + ) - # write labelmap - class_indices = train_gen.class_indices - index_to_class = {v: k for k, v in class_indices.items()} - sorted_classes = [index_to_class[i] for i in range(len(index_to_class))] - with open(os.path.join(model_dir, "labelmap.txt"), "w") as f: - for class_name in sorted_classes: - f.write(f"{class_name}\n") + # create training set + datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2) + train_gen = datagen.flow_from_directory( + dataset_dir, + target_size=(224, 224), + batch_size=BATCH_SIZE, + class_mode="categorical", + subset="training", + ) - # train the model - model.fit(train_gen, epochs=EPOCHS, verbose=0) + total_images = train_gen.samples + logger.debug( + f"Training {self.model_name}: {total_images} images across {num_classes} classes" + ) - # convert model to tflite - converter = tf.lite.TFLiteConverter.from_keras_model(model) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = ( - self.__generate_representative_dataset_factory(dataset_dir) - ) - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 - tflite_model = converter.convert() + # write labelmap + class_indices = train_gen.class_indices + index_to_class = {v: k for k, v in class_indices.items()} + sorted_classes = [index_to_class[i] for i in range(len(index_to_class))] + with open(os.path.join(model_dir, "labelmap.txt"), "w") as f: + for class_name in sorted_classes: + f.write(f"{class_name}\n") - # write model - with open(os.path.join(model_dir, "model.tflite"), "wb") as f: - f.write(tflite_model) + # train the model + logger.debug(f"Training {self.model_name} for {EPOCHS} epochs...") + model.fit(train_gen, epochs=EPOCHS, verbose=0) + logger.debug(f"Converting {self.model_name} to TFLite...") + + # convert model to tflite + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = ( + self.__generate_representative_dataset_factory(dataset_dir) + ) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + tflite_model = converter.convert() + + # write model + model_path = os.path.join(model_dir, "model.tflite") + with open(model_path, "wb") as f: + f.write(tflite_model) + + # verify model file was written successfully + if not os.path.exists(model_path) or os.path.getsize(model_path) == 0: + logger.error( + f"Training failed for {self.model_name}: Model file was not created or is empty" + ) + return False + + # write training metadata with image count + dataset_image_count = get_dataset_image_count(self.model_name) + write_training_metadata(self.model_name, dataset_image_count) + + logger.info(f"Finished training {self.model_name}") + return True + + except Exception as e: + logger.error(f"Training failed for {self.model_name}: {e}", exc_info=True) + return False def kickoff_model_training( @@ -165,18 +286,36 @@ def kickoff_model_training( training_process.start() training_process.join() - # reload model and mark training as complete - embeddingRequestor.send_data( - EmbeddingsRequestEnum.reload_classification_model.value, - {"model_name": model_name}, - ) - requestor.send_data( - UPDATE_MODEL_STATE, - { - "model": model_name, - "state": ModelStatusTypesEnum.complete, - }, - ) + # check if training succeeded by examining the exit code + training_success = training_process.exitcode == 0 + + if training_success: + # reload model and mark training as complete + embeddingRequestor.send_data( + EmbeddingsRequestEnum.reload_classification_model.value, + {"model_name": model_name}, + ) + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": model_name, + "state": ModelStatusTypesEnum.complete, + }, + ) + else: + logger.error( + f"Training subprocess failed for {model_name} (exit code: {training_process.exitcode})" + ) + # mark training as failed so UI shows error state + # don't reload the model since it failed + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": model_name, + "state": ModelStatusTypesEnum.failed, + }, + ) + requestor.stop() diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index 65118f227..2bae0c0ce 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -13,6 +13,11 @@ "deleteModels": "Delete Models", "editModel": "Edit Model" }, + "tooltip": { + "trainingInProgress": "Model is currently training", + "noNewImages": "No new images to train. Classify more images in the dataset first.", + "modelNotReady": "Model is not ready for training" + }, "toast": { "success": { "deletedCategory": "Deleted Class", @@ -30,7 +35,8 @@ "deleteCategoryFailed": "Failed to delete class: {{errorMessage}}", "deleteModelFailed": "Failed to delete model: {{errorMessage}}", "categorizeFailed": "Failed to categorize image: {{errorMessage}}", - "trainingFailed": "Failed to start model training: {{errorMessage}}", + "trainingFailed": "Model training failed. Check Frigate logs for details.", + "trainingFailedToStart": "Failed to start model training: {{errorMessage}}", "updateModelFailed": "Failed to update model: {{errorMessage}}", "renameCategoryFailed": "Failed to rename class: {{errorMessage}}" } @@ -143,6 +149,8 @@ "step3": { "selectImagesPrompt": "Select all images with: {{className}}", "selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.", + "allImagesRequired_one": "Please classify all images. {{count}} image remaining.", + "allImagesRequired_other": "Please classify all images. {{count}} images remaining.", "generating": { "title": "Generating Sample Images", "description": "Frigate is pulling representative images from your recordings. This may take a moment..." diff --git a/web/src/components/classification/wizard/Step3ChooseExamples.tsx b/web/src/components/classification/wizard/Step3ChooseExamples.tsx index 68da03eaf..f638c01e3 100644 --- a/web/src/components/classification/wizard/Step3ChooseExamples.tsx +++ b/web/src/components/classification/wizard/Step3ChooseExamples.tsx @@ -10,6 +10,12 @@ import useSWR from "swr"; import { baseUrl } from "@/api/baseUrl"; import { isMobile } from "react-device-detect"; import { cn } from "@/lib/utils"; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip"; +import { TooltipPortal } from "@radix-ui/react-tooltip"; export type Step3FormData = { examplesGenerated: boolean; @@ -317,6 +323,19 @@ export default function Step3ChooseExamples({ return unclassifiedImages.length === 0; }, [unclassifiedImages]); + // For state models on the last class, require all images to be classified + const isLastClass = currentClassIndex === allClasses.length - 1; + const canProceed = useMemo(() => { + if ( + step1Data.modelType === "state" && + isLastClass && + !allImagesClassified + ) { + return false; + } + return true; + }, [step1Data.modelType, isLastClass, allImagesClassified]); + const handleBack = useCallback(() => { if (currentClassIndex > 0) { const previousClass = allClasses[currentClassIndex - 1]; @@ -438,20 +457,35 @@ export default function Step3ChooseExamples({ - + + + + + {!canProceed && ( + + + {t("wizard.step3.allImagesRequired", { + count: unclassifiedImages.length, + })} + + + )} + )} diff --git a/web/src/types/ws.ts b/web/src/types/ws.ts index 1120aec67..1d98b7b01 100644 --- a/web/src/types/ws.ts +++ b/web/src/types/ws.ts @@ -87,7 +87,8 @@ export type ModelState = | "downloaded" | "error" | "training" - | "complete"; + | "complete" + | "failed"; export type EmbeddingsReindexProgressType = { thumbnails: number; diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 9e9025691..6a3e680f9 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -102,6 +102,12 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { position: "top-center", }); setWasTraining(false); + refreshDataset(); + } else if (modelState == "failed") { + toast.error(t("toast.error.trainingFailed"), { + position: "top-center", + }); + setWasTraining(false); } // only refresh when modelState changes // eslint-disable-next-line react-hooks/exhaustive-deps @@ -112,10 +118,20 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { const { data: trainImages, mutate: refreshTrain } = useSWR( `classification/${model.name}/train`, ); - const { data: dataset, mutate: refreshDataset } = useSWR<{ - [id: string]: string[]; + const { data: datasetResponse, mutate: refreshDataset } = useSWR<{ + categories: { [id: string]: string[] }; + training_metadata: { + has_trained: boolean; + last_training_date: string | null; + last_training_image_count: number; + current_image_count: number; + new_images_count: number; + } | null; }>(`classification/${model.name}/dataset`); + const dataset = datasetResponse?.categories || {}; + const trainingMetadata = datasetResponse?.training_metadata; + const [trainFilter, setTrainFilter] = useApiFilter(); const refreshAll = useCallback(() => { @@ -177,7 +193,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { error.response?.data?.detail || "Unknown error"; - toast.error(t("toast.error.trainingFailed", { errorMessage }), { + toast.error(t("toast.error.trainingFailedToStart", { errorMessage }), { position: "top-center", }); }); @@ -421,19 +437,48 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { filterValues={{ classes: Object.keys(dataset || {}) }} onUpdateFilter={setTrainFilter} /> - + + {((trainingMetadata?.new_images_count ?? 0) === 0 || + (modelState != "complete" && modelState != "failed")) && ( + + + {modelState == "training" + ? t("tooltip.trainingInProgress") + : trainingMetadata?.new_images_count === 0 + ? t("tooltip.noNewImages") + : t("tooltip.modelNotReady")} + + )} - {isDesktop && t("button.trainModel")} - + )}