diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 3cab97805..fe8e6069c 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -130,7 +130,8 @@ class ClassificationTrainingProcess(FrigateProcess): def run(self) -> None: self.pre_run_setup() - self.__train_classification_model() + success = self.__train_classification_model() + exit(0 if success else 1) def __generate_representative_dataset_factory(self, dataset_dir: str): def generate_representative_dataset(): @@ -153,89 +154,117 @@ class ClassificationTrainingProcess(FrigateProcess): @redirect_output_to_logger(logger, logging.DEBUG) def __train_classification_model(self) -> bool: """Train a classification model.""" + try: + # import in the function so that tensorflow is not initialized multiple times + import tensorflow as tf + from tensorflow.keras import layers, models, optimizers + from tensorflow.keras.applications import MobileNetV2 + from tensorflow.keras.preprocessing.image import ImageDataGenerator - # import in the function so that tensorflow is not initialized multiple times - import tensorflow as tf - from tensorflow.keras import layers, models, optimizers - from tensorflow.keras.applications import MobileNetV2 - from tensorflow.keras.preprocessing.image import ImageDataGenerator + dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset") + model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name) + os.makedirs(model_dir, exist_ok=True) - logger.info(f"Kicking off classification training for {self.model_name}.") - dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset") - model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name) - os.makedirs(model_dir, exist_ok=True) - num_classes = len( - [ - d - for d in os.listdir(dataset_dir) - if os.path.isdir(os.path.join(dataset_dir, d)) - ] - ) + num_classes = len( + [ + d + for d in os.listdir(dataset_dir) + if os.path.isdir(os.path.join(dataset_dir, d)) + ] + ) - # Start with imagenet base model with 35% of channels in each layer - base_model = MobileNetV2( - input_shape=(224, 224, 3), - include_top=False, - weights="imagenet", - alpha=0.35, - ) - base_model.trainable = False # Freeze pre-trained layers + if num_classes < 2: + logger.error( + f"Training failed for {self.model_name}: Need at least 2 classes, found {num_classes}" + ) + return False - model = models.Sequential( - [ - base_model, - layers.GlobalAveragePooling2D(), - layers.Dense(128, activation="relu"), - layers.Dropout(0.3), - layers.Dense(num_classes, activation="softmax"), - ] - ) + # Start with imagenet base model with 35% of channels in each layer + base_model = MobileNetV2( + input_shape=(224, 224, 3), + include_top=False, + weights="imagenet", + alpha=0.35, + ) + base_model.trainable = False # Freeze pre-trained layers - model.compile( - optimizer=optimizers.Adam(learning_rate=LEARNING_RATE), - loss="categorical_crossentropy", - metrics=["accuracy"], - ) + model = models.Sequential( + [ + base_model, + layers.GlobalAveragePooling2D(), + layers.Dense(128, activation="relu"), + layers.Dropout(0.3), + layers.Dense(num_classes, activation="softmax"), + ] + ) - # create training set - datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2) - train_gen = datagen.flow_from_directory( - dataset_dir, - target_size=(224, 224), - batch_size=BATCH_SIZE, - class_mode="categorical", - subset="training", - ) + model.compile( + optimizer=optimizers.Adam(learning_rate=LEARNING_RATE), + loss="categorical_crossentropy", + metrics=["accuracy"], + ) - # write labelmap - class_indices = train_gen.class_indices - index_to_class = {v: k for k, v in class_indices.items()} - sorted_classes = [index_to_class[i] for i in range(len(index_to_class))] - with open(os.path.join(model_dir, "labelmap.txt"), "w") as f: - for class_name in sorted_classes: - f.write(f"{class_name}\n") + # create training set + datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2) + train_gen = datagen.flow_from_directory( + dataset_dir, + target_size=(224, 224), + batch_size=BATCH_SIZE, + class_mode="categorical", + subset="training", + ) - # train the model - model.fit(train_gen, epochs=EPOCHS, verbose=0) + total_images = train_gen.samples + logger.debug( + f"Training {self.model_name}: {total_images} images across {num_classes} classes" + ) - # convert model to tflite - converter = tf.lite.TFLiteConverter.from_keras_model(model) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.representative_dataset = ( - self.__generate_representative_dataset_factory(dataset_dir) - ) - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.inference_input_type = tf.uint8 - converter.inference_output_type = tf.uint8 - tflite_model = converter.convert() + # write labelmap + class_indices = train_gen.class_indices + index_to_class = {v: k for k, v in class_indices.items()} + sorted_classes = [index_to_class[i] for i in range(len(index_to_class))] + with open(os.path.join(model_dir, "labelmap.txt"), "w") as f: + for class_name in sorted_classes: + f.write(f"{class_name}\n") - # write model - with open(os.path.join(model_dir, "model.tflite"), "wb") as f: - f.write(tflite_model) + # train the model + logger.debug(f"Training {self.model_name} for {EPOCHS} epochs...") + model.fit(train_gen, epochs=EPOCHS, verbose=0) + logger.debug(f"Converting {self.model_name} to TFLite...") - # write training metadata with image count - dataset_image_count = get_dataset_image_count(self.model_name) - write_training_metadata(self.model_name, dataset_image_count) + # convert model to tflite + converter = tf.lite.TFLiteConverter.from_keras_model(model) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = ( + self.__generate_representative_dataset_factory(dataset_dir) + ) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + tflite_model = converter.convert() + + # write model + model_path = os.path.join(model_dir, "model.tflite") + with open(model_path, "wb") as f: + f.write(tflite_model) + + # verify model file was written successfully + if not os.path.exists(model_path) or os.path.getsize(model_path) == 0: + logger.error( + f"Training failed for {self.model_name}: Model file was not created or is empty" + ) + return False + + # write training metadata with image count + dataset_image_count = get_dataset_image_count(self.model_name) + write_training_metadata(self.model_name, dataset_image_count) + + logger.info(f"Finished training {self.model_name}") + return True + + except Exception as e: + logger.error(f"Training failed for {self.model_name}: {e}", exc_info=True) + return False def kickoff_model_training( @@ -257,18 +286,36 @@ def kickoff_model_training( training_process.start() training_process.join() - # reload model and mark training as complete - embeddingRequestor.send_data( - EmbeddingsRequestEnum.reload_classification_model.value, - {"model_name": model_name}, - ) - requestor.send_data( - UPDATE_MODEL_STATE, - { - "model": model_name, - "state": ModelStatusTypesEnum.complete, - }, - ) + # check if training succeeded by examining the exit code + training_success = training_process.exitcode == 0 + + if training_success: + # reload model and mark training as complete + embeddingRequestor.send_data( + EmbeddingsRequestEnum.reload_classification_model.value, + {"model_name": model_name}, + ) + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": model_name, + "state": ModelStatusTypesEnum.complete, + }, + ) + else: + logger.error( + f"Training subprocess failed for {model_name} (exit code: {training_process.exitcode})" + ) + # mark training as complete (not failed) so UI doesn't stay in training state + # but don't reload the model since it failed + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": model_name, + "state": ModelStatusTypesEnum.complete, + }, + ) + requestor.stop()