Add error handling for training

This commit is contained in:
Nicolas Mowen 2025-11-09 13:34:39 -07:00
parent 292d024aac
commit 6c47a131e4

View File

@ -130,7 +130,8 @@ class ClassificationTrainingProcess(FrigateProcess):
def run(self) -> None:
self.pre_run_setup()
self.__train_classification_model()
success = self.__train_classification_model()
exit(0 if success else 1)
def __generate_representative_dataset_factory(self, dataset_dir: str):
def generate_representative_dataset():
@ -153,89 +154,117 @@ class ClassificationTrainingProcess(FrigateProcess):
@redirect_output_to_logger(logger, logging.DEBUG)
def __train_classification_model(self) -> bool:
"""Train a classification model."""
try:
# import in the function so that tensorflow is not initialized multiple times
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# import in the function so that tensorflow is not initialized multiple times
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
os.makedirs(model_dir, exist_ok=True)
logger.info(f"Kicking off classification training for {self.model_name}.")
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
os.makedirs(model_dir, exist_ok=True)
num_classes = len(
[
d
for d in os.listdir(dataset_dir)
if os.path.isdir(os.path.join(dataset_dir, d))
]
)
num_classes = len(
[
d
for d in os.listdir(dataset_dir)
if os.path.isdir(os.path.join(dataset_dir, d))
]
)
# Start with imagenet base model with 35% of channels in each layer
base_model = MobileNetV2(
input_shape=(224, 224, 3),
include_top=False,
weights="imagenet",
alpha=0.35,
)
base_model.trainable = False # Freeze pre-trained layers
if num_classes < 2:
logger.error(
f"Training failed for {self.model_name}: Need at least 2 classes, found {num_classes}"
)
return False
model = models.Sequential(
[
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(128, activation="relu"),
layers.Dropout(0.3),
layers.Dense(num_classes, activation="softmax"),
]
)
# Start with imagenet base model with 35% of channels in each layer
base_model = MobileNetV2(
input_shape=(224, 224, 3),
include_top=False,
weights="imagenet",
alpha=0.35,
)
base_model.trainable = False # Freeze pre-trained layers
model.compile(
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
model = models.Sequential(
[
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(128, activation="relu"),
layers.Dropout(0.3),
layers.Dense(num_classes, activation="softmax"),
]
)
# create training set
datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2)
train_gen = datagen.flow_from_directory(
dataset_dir,
target_size=(224, 224),
batch_size=BATCH_SIZE,
class_mode="categorical",
subset="training",
)
model.compile(
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
# write labelmap
class_indices = train_gen.class_indices
index_to_class = {v: k for k, v in class_indices.items()}
sorted_classes = [index_to_class[i] for i in range(len(index_to_class))]
with open(os.path.join(model_dir, "labelmap.txt"), "w") as f:
for class_name in sorted_classes:
f.write(f"{class_name}\n")
# create training set
datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2)
train_gen = datagen.flow_from_directory(
dataset_dir,
target_size=(224, 224),
batch_size=BATCH_SIZE,
class_mode="categorical",
subset="training",
)
# train the model
model.fit(train_gen, epochs=EPOCHS, verbose=0)
total_images = train_gen.samples
logger.debug(
f"Training {self.model_name}: {total_images} images across {num_classes} classes"
)
# convert model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = (
self.__generate_representative_dataset_factory(dataset_dir)
)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
# write labelmap
class_indices = train_gen.class_indices
index_to_class = {v: k for k, v in class_indices.items()}
sorted_classes = [index_to_class[i] for i in range(len(index_to_class))]
with open(os.path.join(model_dir, "labelmap.txt"), "w") as f:
for class_name in sorted_classes:
f.write(f"{class_name}\n")
# write model
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
f.write(tflite_model)
# train the model
logger.debug(f"Training {self.model_name} for {EPOCHS} epochs...")
model.fit(train_gen, epochs=EPOCHS, verbose=0)
logger.debug(f"Converting {self.model_name} to TFLite...")
# write training metadata with image count
dataset_image_count = get_dataset_image_count(self.model_name)
write_training_metadata(self.model_name, dataset_image_count)
# convert model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = (
self.__generate_representative_dataset_factory(dataset_dir)
)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
# write model
model_path = os.path.join(model_dir, "model.tflite")
with open(model_path, "wb") as f:
f.write(tflite_model)
# verify model file was written successfully
if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
logger.error(
f"Training failed for {self.model_name}: Model file was not created or is empty"
)
return False
# write training metadata with image count
dataset_image_count = get_dataset_image_count(self.model_name)
write_training_metadata(self.model_name, dataset_image_count)
logger.info(f"Finished training {self.model_name}")
return True
except Exception as e:
logger.error(f"Training failed for {self.model_name}: {e}", exc_info=True)
return False
def kickoff_model_training(
@ -257,18 +286,36 @@ def kickoff_model_training(
training_process.start()
training_process.join()
# reload model and mark training as complete
embeddingRequestor.send_data(
EmbeddingsRequestEnum.reload_classification_model.value,
{"model_name": model_name},
)
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model_name,
"state": ModelStatusTypesEnum.complete,
},
)
# check if training succeeded by examining the exit code
training_success = training_process.exitcode == 0
if training_success:
# reload model and mark training as complete
embeddingRequestor.send_data(
EmbeddingsRequestEnum.reload_classification_model.value,
{"model_name": model_name},
)
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model_name,
"state": ModelStatusTypesEnum.complete,
},
)
else:
logger.error(
f"Training subprocess failed for {model_name} (exit code: {training_process.exitcode})"
)
# mark training as complete (not failed) so UI doesn't stay in training state
# but don't reload the model since it failed
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model_name,
"state": ModelStatusTypesEnum.complete,
},
)
requestor.stop()