Add error handling for training

This commit is contained in:
Nicolas Mowen 2025-11-09 13:34:39 -07:00
parent 292d024aac
commit 6c47a131e4

View File

@ -130,7 +130,8 @@ class ClassificationTrainingProcess(FrigateProcess):
def run(self) -> None:
self.pre_run_setup()
self.__train_classification_model()
success = self.__train_classification_model()
exit(0 if success else 1)
def __generate_representative_dataset_factory(self, dataset_dir: str):
def generate_representative_dataset():
@ -153,17 +154,17 @@ class ClassificationTrainingProcess(FrigateProcess):
@redirect_output_to_logger(logger, logging.DEBUG)
def __train_classification_model(self) -> bool:
"""Train a classification model."""
try:
# import in the function so that tensorflow is not initialized multiple times
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
logger.info(f"Kicking off classification training for {self.model_name}.")
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
os.makedirs(model_dir, exist_ok=True)
num_classes = len(
[
d
@ -172,6 +173,12 @@ class ClassificationTrainingProcess(FrigateProcess):
]
)
if num_classes < 2:
logger.error(
f"Training failed for {self.model_name}: Need at least 2 classes, found {num_classes}"
)
return False
# Start with imagenet base model with 35% of channels in each layer
base_model = MobileNetV2(
input_shape=(224, 224, 3),
@ -207,6 +214,11 @@ class ClassificationTrainingProcess(FrigateProcess):
subset="training",
)
total_images = train_gen.samples
logger.debug(
f"Training {self.model_name}: {total_images} images across {num_classes} classes"
)
# write labelmap
class_indices = train_gen.class_indices
index_to_class = {v: k for k, v in class_indices.items()}
@ -216,7 +228,9 @@ class ClassificationTrainingProcess(FrigateProcess):
f.write(f"{class_name}\n")
# train the model
logger.debug(f"Training {self.model_name} for {EPOCHS} epochs...")
model.fit(train_gen, epochs=EPOCHS, verbose=0)
logger.debug(f"Converting {self.model_name} to TFLite...")
# convert model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
@ -230,13 +244,28 @@ class ClassificationTrainingProcess(FrigateProcess):
tflite_model = converter.convert()
# write model
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
model_path = os.path.join(model_dir, "model.tflite")
with open(model_path, "wb") as f:
f.write(tflite_model)
# verify model file was written successfully
if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
logger.error(
f"Training failed for {self.model_name}: Model file was not created or is empty"
)
return False
# write training metadata with image count
dataset_image_count = get_dataset_image_count(self.model_name)
write_training_metadata(self.model_name, dataset_image_count)
logger.info(f"Finished training {self.model_name}")
return True
except Exception as e:
logger.error(f"Training failed for {self.model_name}: {e}", exc_info=True)
return False
def kickoff_model_training(
embeddingRequestor: EmbeddingsRequestor, model_name: str
@ -257,6 +286,10 @@ def kickoff_model_training(
training_process.start()
training_process.join()
# check if training succeeded by examining the exit code
training_success = training_process.exitcode == 0
if training_success:
# reload model and mark training as complete
embeddingRequestor.send_data(
EmbeddingsRequestEnum.reload_classification_model.value,
@ -269,6 +302,20 @@ def kickoff_model_training(
"state": ModelStatusTypesEnum.complete,
},
)
else:
logger.error(
f"Training subprocess failed for {model_name} (exit code: {training_process.exitcode})"
)
# mark training as complete (not failed) so UI doesn't stay in training state
# but don't reload the model since it failed
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model_name,
"state": ModelStatusTypesEnum.complete,
},
)
requestor.stop()