diff --git a/frigate/api/classification.py b/frigate/api/classification.py
index b3194176e..87de52884 100644
--- a/frigate/api/classification.py
+++ b/frigate/api/classification.py
@@ -37,6 +37,8 @@ from frigate.models import Event
from frigate.util.classification import (
collect_object_classification_examples,
collect_state_classification_examples,
+ get_dataset_image_count,
+ read_training_metadata,
)
from frigate.util.file import get_event_snapshot
@@ -564,23 +566,54 @@ def get_classification_dataset(name: str):
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
if not os.path.exists(dataset_dir):
- return JSONResponse(status_code=200, content={})
+ return JSONResponse(
+ status_code=200, content={"categories": {}, "training_metadata": None}
+ )
- for name in os.listdir(dataset_dir):
- category_dir = os.path.join(dataset_dir, name)
+ for category_name in os.listdir(dataset_dir):
+ category_dir = os.path.join(dataset_dir, category_name)
if not os.path.isdir(category_dir):
continue
- dataset_dict[name] = []
+ dataset_dict[category_name] = []
for file in filter(
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
os.listdir(category_dir),
):
- dataset_dict[name].append(file)
+ dataset_dict[category_name].append(file)
- return JSONResponse(status_code=200, content=dataset_dict)
+ # Get training metadata
+ metadata = read_training_metadata(sanitize_filename(name))
+ current_image_count = get_dataset_image_count(sanitize_filename(name))
+
+ if metadata is None:
+ training_metadata = {
+ "has_trained": False,
+ "last_training_date": None,
+ "last_training_image_count": 0,
+ "current_image_count": current_image_count,
+ "new_images_count": current_image_count,
+ }
+ else:
+ last_training_count = metadata.get("last_training_image_count", 0)
+ new_images_count = max(0, current_image_count - last_training_count)
+ training_metadata = {
+ "has_trained": True,
+ "last_training_date": metadata.get("last_training_date"),
+ "last_training_image_count": last_training_count,
+ "current_image_count": current_image_count,
+ "new_images_count": new_images_count,
+ }
+
+ return JSONResponse(
+ status_code=200,
+ content={
+ "categories": dataset_dict,
+ "training_metadata": training_metadata,
+ },
+ )
@router.get(
diff --git a/frigate/types.py b/frigate/types.py
index a9e27ba90..f342d27cd 100644
--- a/frigate/types.py
+++ b/frigate/types.py
@@ -23,6 +23,7 @@ class ModelStatusTypesEnum(str, Enum):
error = "error"
training = "training"
complete = "complete"
+ failed = "failed"
class TrackedObjectUpdateTypesEnum(str, Enum):
diff --git a/frigate/util/classification.py b/frigate/util/classification.py
index 43dfd7fd7..a74094c32 100644
--- a/frigate/util/classification.py
+++ b/frigate/util/classification.py
@@ -1,5 +1,7 @@
"""Util for classification models."""
+import datetime
+import json
import logging
import os
import random
@@ -27,10 +29,96 @@ from frigate.util.process import FrigateProcess
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 0.001
+TRAINING_METADATA_FILE = ".training_metadata.json"
logger = logging.getLogger(__name__)
+def write_training_metadata(model_name: str, image_count: int) -> None:
+ """
+ Write training metadata to a hidden file in the model's clips directory.
+
+ Args:
+ model_name: Name of the classification model
+ image_count: Number of images used in training
+ """
+ clips_model_dir = os.path.join(CLIPS_DIR, model_name)
+ os.makedirs(clips_model_dir, exist_ok=True)
+
+ metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
+ metadata = {
+ "last_training_date": datetime.datetime.now().isoformat(),
+ "last_training_image_count": image_count,
+ }
+
+ try:
+ with open(metadata_path, "w") as f:
+ json.dump(metadata, f, indent=2)
+ logger.info(f"Wrote training metadata for {model_name}: {image_count} images")
+ except Exception as e:
+ logger.error(f"Failed to write training metadata for {model_name}: {e}")
+
+
+def read_training_metadata(model_name: str) -> dict[str, any] | None:
+ """
+ Read training metadata from the hidden file in the model's clips directory.
+
+ Args:
+ model_name: Name of the classification model
+
+ Returns:
+ Dictionary with last_training_date and last_training_image_count, or None if not found
+ """
+ clips_model_dir = os.path.join(CLIPS_DIR, model_name)
+ metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
+
+ if not os.path.exists(metadata_path):
+ return None
+
+ try:
+ with open(metadata_path, "r") as f:
+ metadata = json.load(f)
+ return metadata
+ except Exception as e:
+ logger.error(f"Failed to read training metadata for {model_name}: {e}")
+ return None
+
+
+def get_dataset_image_count(model_name: str) -> int:
+ """
+ Count the total number of images in the model's dataset directory.
+
+ Args:
+ model_name: Name of the classification model
+
+ Returns:
+ Total count of images across all categories
+ """
+ dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
+
+ if not os.path.exists(dataset_dir):
+ return 0
+
+ total_count = 0
+ try:
+ for category in os.listdir(dataset_dir):
+ category_dir = os.path.join(dataset_dir, category)
+ if not os.path.isdir(category_dir):
+ continue
+
+ image_files = [
+ f
+ for f in os.listdir(category_dir)
+ if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))
+ ]
+ total_count += len(image_files)
+ except Exception as e:
+ logger.error(f"Failed to count dataset images for {model_name}: {e}")
+ return 0
+
+ return total_count
+
+
class ClassificationTrainingProcess(FrigateProcess):
def __init__(self, model_name: str) -> None:
super().__init__(
@@ -42,7 +130,8 @@ class ClassificationTrainingProcess(FrigateProcess):
def run(self) -> None:
self.pre_run_setup()
- self.__train_classification_model()
+ success = self.__train_classification_model()
+ exit(0 if success else 1)
def __generate_representative_dataset_factory(self, dataset_dir: str):
def generate_representative_dataset():
@@ -65,85 +154,117 @@ class ClassificationTrainingProcess(FrigateProcess):
@redirect_output_to_logger(logger, logging.DEBUG)
def __train_classification_model(self) -> bool:
"""Train a classification model."""
+ try:
+ # import in the function so that tensorflow is not initialized multiple times
+ import tensorflow as tf
+ from tensorflow.keras import layers, models, optimizers
+ from tensorflow.keras.applications import MobileNetV2
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
- # import in the function so that tensorflow is not initialized multiple times
- import tensorflow as tf
- from tensorflow.keras import layers, models, optimizers
- from tensorflow.keras.applications import MobileNetV2
- from tensorflow.keras.preprocessing.image import ImageDataGenerator
+ dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
+ model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
+ os.makedirs(model_dir, exist_ok=True)
- logger.info(f"Kicking off classification training for {self.model_name}.")
- dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
- model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
- os.makedirs(model_dir, exist_ok=True)
- num_classes = len(
- [
- d
- for d in os.listdir(dataset_dir)
- if os.path.isdir(os.path.join(dataset_dir, d))
- ]
- )
+ num_classes = len(
+ [
+ d
+ for d in os.listdir(dataset_dir)
+ if os.path.isdir(os.path.join(dataset_dir, d))
+ ]
+ )
- # Start with imagenet base model with 35% of channels in each layer
- base_model = MobileNetV2(
- input_shape=(224, 224, 3),
- include_top=False,
- weights="imagenet",
- alpha=0.35,
- )
- base_model.trainable = False # Freeze pre-trained layers
+ if num_classes < 2:
+ logger.error(
+ f"Training failed for {self.model_name}: Need at least 2 classes, found {num_classes}"
+ )
+ return False
- model = models.Sequential(
- [
- base_model,
- layers.GlobalAveragePooling2D(),
- layers.Dense(128, activation="relu"),
- layers.Dropout(0.3),
- layers.Dense(num_classes, activation="softmax"),
- ]
- )
+ # Start with imagenet base model with 35% of channels in each layer
+ base_model = MobileNetV2(
+ input_shape=(224, 224, 3),
+ include_top=False,
+ weights="imagenet",
+ alpha=0.35,
+ )
+ base_model.trainable = False # Freeze pre-trained layers
- model.compile(
- optimizer=optimizers.Adam(learning_rate=LEARNING_RATE),
- loss="categorical_crossentropy",
- metrics=["accuracy"],
- )
+ model = models.Sequential(
+ [
+ base_model,
+ layers.GlobalAveragePooling2D(),
+ layers.Dense(128, activation="relu"),
+ layers.Dropout(0.3),
+ layers.Dense(num_classes, activation="softmax"),
+ ]
+ )
- # create training set
- datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2)
- train_gen = datagen.flow_from_directory(
- dataset_dir,
- target_size=(224, 224),
- batch_size=BATCH_SIZE,
- class_mode="categorical",
- subset="training",
- )
+ model.compile(
+ optimizer=optimizers.Adam(learning_rate=LEARNING_RATE),
+ loss="categorical_crossentropy",
+ metrics=["accuracy"],
+ )
- # write labelmap
- class_indices = train_gen.class_indices
- index_to_class = {v: k for k, v in class_indices.items()}
- sorted_classes = [index_to_class[i] for i in range(len(index_to_class))]
- with open(os.path.join(model_dir, "labelmap.txt"), "w") as f:
- for class_name in sorted_classes:
- f.write(f"{class_name}\n")
+ # create training set
+ datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2)
+ train_gen = datagen.flow_from_directory(
+ dataset_dir,
+ target_size=(224, 224),
+ batch_size=BATCH_SIZE,
+ class_mode="categorical",
+ subset="training",
+ )
- # train the model
- model.fit(train_gen, epochs=EPOCHS, verbose=0)
+ total_images = train_gen.samples
+ logger.debug(
+ f"Training {self.model_name}: {total_images} images across {num_classes} classes"
+ )
- # convert model to tflite
- converter = tf.lite.TFLiteConverter.from_keras_model(model)
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = (
- self.__generate_representative_dataset_factory(dataset_dir)
- )
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.uint8
- converter.inference_output_type = tf.uint8
- tflite_model = converter.convert()
+ # write labelmap
+ class_indices = train_gen.class_indices
+ index_to_class = {v: k for k, v in class_indices.items()}
+ sorted_classes = [index_to_class[i] for i in range(len(index_to_class))]
+ with open(os.path.join(model_dir, "labelmap.txt"), "w") as f:
+ for class_name in sorted_classes:
+ f.write(f"{class_name}\n")
- # write model
- with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
- f.write(tflite_model)
+ # train the model
+ logger.debug(f"Training {self.model_name} for {EPOCHS} epochs...")
+ model.fit(train_gen, epochs=EPOCHS, verbose=0)
+ logger.debug(f"Converting {self.model_name} to TFLite...")
+
+ # convert model to tflite
+ converter = tf.lite.TFLiteConverter.from_keras_model(model)
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = (
+ self.__generate_representative_dataset_factory(dataset_dir)
+ )
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.uint8
+ converter.inference_output_type = tf.uint8
+ tflite_model = converter.convert()
+
+ # write model
+ model_path = os.path.join(model_dir, "model.tflite")
+ with open(model_path, "wb") as f:
+ f.write(tflite_model)
+
+ # verify model file was written successfully
+ if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
+ logger.error(
+ f"Training failed for {self.model_name}: Model file was not created or is empty"
+ )
+ return False
+
+ # write training metadata with image count
+ dataset_image_count = get_dataset_image_count(self.model_name)
+ write_training_metadata(self.model_name, dataset_image_count)
+
+ logger.info(f"Finished training {self.model_name}")
+ return True
+
+ except Exception as e:
+ logger.error(f"Training failed for {self.model_name}: {e}", exc_info=True)
+ return False
def kickoff_model_training(
@@ -165,18 +286,36 @@ def kickoff_model_training(
training_process.start()
training_process.join()
- # reload model and mark training as complete
- embeddingRequestor.send_data(
- EmbeddingsRequestEnum.reload_classification_model.value,
- {"model_name": model_name},
- )
- requestor.send_data(
- UPDATE_MODEL_STATE,
- {
- "model": model_name,
- "state": ModelStatusTypesEnum.complete,
- },
- )
+ # check if training succeeded by examining the exit code
+ training_success = training_process.exitcode == 0
+
+ if training_success:
+ # reload model and mark training as complete
+ embeddingRequestor.send_data(
+ EmbeddingsRequestEnum.reload_classification_model.value,
+ {"model_name": model_name},
+ )
+ requestor.send_data(
+ UPDATE_MODEL_STATE,
+ {
+ "model": model_name,
+ "state": ModelStatusTypesEnum.complete,
+ },
+ )
+ else:
+ logger.error(
+ f"Training subprocess failed for {model_name} (exit code: {training_process.exitcode})"
+ )
+ # mark training as failed so UI shows error state
+ # don't reload the model since it failed
+ requestor.send_data(
+ UPDATE_MODEL_STATE,
+ {
+ "model": model_name,
+ "state": ModelStatusTypesEnum.failed,
+ },
+ )
+
requestor.stop()
diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json
index 65118f227..2bae0c0ce 100644
--- a/web/public/locales/en/views/classificationModel.json
+++ b/web/public/locales/en/views/classificationModel.json
@@ -13,6 +13,11 @@
"deleteModels": "Delete Models",
"editModel": "Edit Model"
},
+ "tooltip": {
+ "trainingInProgress": "Model is currently training",
+ "noNewImages": "No new images to train. Classify more images in the dataset first.",
+ "modelNotReady": "Model is not ready for training"
+ },
"toast": {
"success": {
"deletedCategory": "Deleted Class",
@@ -30,7 +35,8 @@
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
"deleteModelFailed": "Failed to delete model: {{errorMessage}}",
"categorizeFailed": "Failed to categorize image: {{errorMessage}}",
- "trainingFailed": "Failed to start model training: {{errorMessage}}",
+ "trainingFailed": "Model training failed. Check Frigate logs for details.",
+ "trainingFailedToStart": "Failed to start model training: {{errorMessage}}",
"updateModelFailed": "Failed to update model: {{errorMessage}}",
"renameCategoryFailed": "Failed to rename class: {{errorMessage}}"
}
@@ -143,6 +149,8 @@
"step3": {
"selectImagesPrompt": "Select all images with: {{className}}",
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
+ "allImagesRequired_one": "Please classify all images. {{count}} image remaining.",
+ "allImagesRequired_other": "Please classify all images. {{count}} images remaining.",
"generating": {
"title": "Generating Sample Images",
"description": "Frigate is pulling representative images from your recordings. This may take a moment..."
diff --git a/web/src/components/classification/wizard/Step3ChooseExamples.tsx b/web/src/components/classification/wizard/Step3ChooseExamples.tsx
index 68da03eaf..f638c01e3 100644
--- a/web/src/components/classification/wizard/Step3ChooseExamples.tsx
+++ b/web/src/components/classification/wizard/Step3ChooseExamples.tsx
@@ -10,6 +10,12 @@ import useSWR from "swr";
import { baseUrl } from "@/api/baseUrl";
import { isMobile } from "react-device-detect";
import { cn } from "@/lib/utils";
+import {
+ Tooltip,
+ TooltipContent,
+ TooltipTrigger,
+} from "@/components/ui/tooltip";
+import { TooltipPortal } from "@radix-ui/react-tooltip";
export type Step3FormData = {
examplesGenerated: boolean;
@@ -317,6 +323,19 @@ export default function Step3ChooseExamples({
return unclassifiedImages.length === 0;
}, [unclassifiedImages]);
+ // For state models on the last class, require all images to be classified
+ const isLastClass = currentClassIndex === allClasses.length - 1;
+ const canProceed = useMemo(() => {
+ if (
+ step1Data.modelType === "state" &&
+ isLastClass &&
+ !allImagesClassified
+ ) {
+ return false;
+ }
+ return true;
+ }, [step1Data.modelType, isLastClass, allImagesClassified]);
+
const handleBack = useCallback(() => {
if (currentClassIndex > 0) {
const previousClass = allClasses[currentClassIndex - 1];
@@ -438,20 +457,35 @@ export default function Step3ChooseExamples({
-
+
+
+
+
+ {!canProceed && (
+
+
+ {t("wizard.step3.allImagesRequired", {
+ count: unclassifiedImages.length,
+ })}
+
+
+ )}
+
)}
diff --git a/web/src/types/ws.ts b/web/src/types/ws.ts
index 1120aec67..1d98b7b01 100644
--- a/web/src/types/ws.ts
+++ b/web/src/types/ws.ts
@@ -87,7 +87,8 @@ export type ModelState =
| "downloaded"
| "error"
| "training"
- | "complete";
+ | "complete"
+ | "failed";
export type EmbeddingsReindexProgressType = {
thumbnails: number;
diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx
index 9e9025691..6a3e680f9 100644
--- a/web/src/views/classification/ModelTrainingView.tsx
+++ b/web/src/views/classification/ModelTrainingView.tsx
@@ -102,6 +102,12 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
position: "top-center",
});
setWasTraining(false);
+ refreshDataset();
+ } else if (modelState == "failed") {
+ toast.error(t("toast.error.trainingFailed"), {
+ position: "top-center",
+ });
+ setWasTraining(false);
}
// only refresh when modelState changes
// eslint-disable-next-line react-hooks/exhaustive-deps
@@ -112,10 +118,20 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const { data: trainImages, mutate: refreshTrain } = useSWR(
`classification/${model.name}/train`,
);
- const { data: dataset, mutate: refreshDataset } = useSWR<{
- [id: string]: string[];
+ const { data: datasetResponse, mutate: refreshDataset } = useSWR<{
+ categories: { [id: string]: string[] };
+ training_metadata: {
+ has_trained: boolean;
+ last_training_date: string | null;
+ last_training_image_count: number;
+ current_image_count: number;
+ new_images_count: number;
+ } | null;
}>(`classification/${model.name}/dataset`);
+ const dataset = datasetResponse?.categories || {};
+ const trainingMetadata = datasetResponse?.training_metadata;
+
const [trainFilter, setTrainFilter] = useApiFilter();
const refreshAll = useCallback(() => {
@@ -177,7 +193,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
error.response?.data?.detail ||
"Unknown error";
- toast.error(t("toast.error.trainingFailed", { errorMessage }), {
+ toast.error(t("toast.error.trainingFailedToStart", { errorMessage }), {
position: "top-center",
});
});
@@ -421,19 +437,48 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
filterValues={{ classes: Object.keys(dataset || {}) }}
onUpdateFilter={setTrainFilter}
/>
-
+
)}