Improve classification (#20863)
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions

This commit is contained in:
Nicolas Mowen 2025-11-09 15:21:13 -07:00 committed by GitHub
parent a374a60756
commit 99a363c047
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 380 additions and 119 deletions

View File

@ -37,6 +37,8 @@ from frigate.models import Event
from frigate.util.classification import ( from frigate.util.classification import (
collect_object_classification_examples, collect_object_classification_examples,
collect_state_classification_examples, collect_state_classification_examples,
get_dataset_image_count,
read_training_metadata,
) )
from frigate.util.file import get_event_snapshot from frigate.util.file import get_event_snapshot
@ -564,23 +566,54 @@ def get_classification_dataset(name: str):
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset") dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
if not os.path.exists(dataset_dir): if not os.path.exists(dataset_dir):
return JSONResponse(status_code=200, content={}) return JSONResponse(
status_code=200, content={"categories": {}, "training_metadata": None}
)
for name in os.listdir(dataset_dir): for category_name in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, name) category_dir = os.path.join(dataset_dir, category_name)
if not os.path.isdir(category_dir): if not os.path.isdir(category_dir):
continue continue
dataset_dict[name] = [] dataset_dict[category_name] = []
for file in filter( for file in filter(
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))), lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
os.listdir(category_dir), os.listdir(category_dir),
): ):
dataset_dict[name].append(file) dataset_dict[category_name].append(file)
return JSONResponse(status_code=200, content=dataset_dict) # Get training metadata
metadata = read_training_metadata(sanitize_filename(name))
current_image_count = get_dataset_image_count(sanitize_filename(name))
if metadata is None:
training_metadata = {
"has_trained": False,
"last_training_date": None,
"last_training_image_count": 0,
"current_image_count": current_image_count,
"new_images_count": current_image_count,
}
else:
last_training_count = metadata.get("last_training_image_count", 0)
new_images_count = max(0, current_image_count - last_training_count)
training_metadata = {
"has_trained": True,
"last_training_date": metadata.get("last_training_date"),
"last_training_image_count": last_training_count,
"current_image_count": current_image_count,
"new_images_count": new_images_count,
}
return JSONResponse(
status_code=200,
content={
"categories": dataset_dict,
"training_metadata": training_metadata,
},
)
@router.get( @router.get(

View File

@ -23,6 +23,7 @@ class ModelStatusTypesEnum(str, Enum):
error = "error" error = "error"
training = "training" training = "training"
complete = "complete" complete = "complete"
failed = "failed"
class TrackedObjectUpdateTypesEnum(str, Enum): class TrackedObjectUpdateTypesEnum(str, Enum):

View File

@ -1,5 +1,7 @@
"""Util for classification models.""" """Util for classification models."""
import datetime
import json
import logging import logging
import os import os
import random import random
@ -27,10 +29,96 @@ from frigate.util.process import FrigateProcess
BATCH_SIZE = 16 BATCH_SIZE = 16
EPOCHS = 50 EPOCHS = 50
LEARNING_RATE = 0.001 LEARNING_RATE = 0.001
TRAINING_METADATA_FILE = ".training_metadata.json"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def write_training_metadata(model_name: str, image_count: int) -> None:
"""
Write training metadata to a hidden file in the model's clips directory.
Args:
model_name: Name of the classification model
image_count: Number of images used in training
"""
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
os.makedirs(clips_model_dir, exist_ok=True)
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
metadata = {
"last_training_date": datetime.datetime.now().isoformat(),
"last_training_image_count": image_count,
}
try:
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
logger.info(f"Wrote training metadata for {model_name}: {image_count} images")
except Exception as e:
logger.error(f"Failed to write training metadata for {model_name}: {e}")
def read_training_metadata(model_name: str) -> dict[str, any] | None:
"""
Read training metadata from the hidden file in the model's clips directory.
Args:
model_name: Name of the classification model
Returns:
Dictionary with last_training_date and last_training_image_count, or None if not found
"""
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
if not os.path.exists(metadata_path):
return None
try:
with open(metadata_path, "r") as f:
metadata = json.load(f)
return metadata
except Exception as e:
logger.error(f"Failed to read training metadata for {model_name}: {e}")
return None
def get_dataset_image_count(model_name: str) -> int:
"""
Count the total number of images in the model's dataset directory.
Args:
model_name: Name of the classification model
Returns:
Total count of images across all categories
"""
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
if not os.path.exists(dataset_dir):
return 0
total_count = 0
try:
for category in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, category)
if not os.path.isdir(category_dir):
continue
image_files = [
f
for f in os.listdir(category_dir)
if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))
]
total_count += len(image_files)
except Exception as e:
logger.error(f"Failed to count dataset images for {model_name}: {e}")
return 0
return total_count
class ClassificationTrainingProcess(FrigateProcess): class ClassificationTrainingProcess(FrigateProcess):
def __init__(self, model_name: str) -> None: def __init__(self, model_name: str) -> None:
super().__init__( super().__init__(
@ -42,7 +130,8 @@ class ClassificationTrainingProcess(FrigateProcess):
def run(self) -> None: def run(self) -> None:
self.pre_run_setup() self.pre_run_setup()
self.__train_classification_model() success = self.__train_classification_model()
exit(0 if success else 1)
def __generate_representative_dataset_factory(self, dataset_dir: str): def __generate_representative_dataset_factory(self, dataset_dir: str):
def generate_representative_dataset(): def generate_representative_dataset():
@ -65,85 +154,117 @@ class ClassificationTrainingProcess(FrigateProcess):
@redirect_output_to_logger(logger, logging.DEBUG) @redirect_output_to_logger(logger, logging.DEBUG)
def __train_classification_model(self) -> bool: def __train_classification_model(self) -> bool:
"""Train a classification model.""" """Train a classification model."""
try:
# import in the function so that tensorflow is not initialized multiple times
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# import in the function so that tensorflow is not initialized multiple times dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
import tensorflow as tf model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
from tensorflow.keras import layers, models, optimizers os.makedirs(model_dir, exist_ok=True)
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
logger.info(f"Kicking off classification training for {self.model_name}.") num_classes = len(
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset") [
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name) d
os.makedirs(model_dir, exist_ok=True) for d in os.listdir(dataset_dir)
num_classes = len( if os.path.isdir(os.path.join(dataset_dir, d))
[ ]
d )
for d in os.listdir(dataset_dir)
if os.path.isdir(os.path.join(dataset_dir, d))
]
)
# Start with imagenet base model with 35% of channels in each layer if num_classes < 2:
base_model = MobileNetV2( logger.error(
input_shape=(224, 224, 3), f"Training failed for {self.model_name}: Need at least 2 classes, found {num_classes}"
include_top=False, )
weights="imagenet", return False
alpha=0.35,
)
base_model.trainable = False # Freeze pre-trained layers
model = models.Sequential( # Start with imagenet base model with 35% of channels in each layer
[ base_model = MobileNetV2(
base_model, input_shape=(224, 224, 3),
layers.GlobalAveragePooling2D(), include_top=False,
layers.Dense(128, activation="relu"), weights="imagenet",
layers.Dropout(0.3), alpha=0.35,
layers.Dense(num_classes, activation="softmax"), )
] base_model.trainable = False # Freeze pre-trained layers
)
model.compile( model = models.Sequential(
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE), [
loss="categorical_crossentropy", base_model,
metrics=["accuracy"], layers.GlobalAveragePooling2D(),
) layers.Dense(128, activation="relu"),
layers.Dropout(0.3),
layers.Dense(num_classes, activation="softmax"),
]
)
# create training set model.compile(
datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2) optimizer=optimizers.Adam(learning_rate=LEARNING_RATE),
train_gen = datagen.flow_from_directory( loss="categorical_crossentropy",
dataset_dir, metrics=["accuracy"],
target_size=(224, 224), )
batch_size=BATCH_SIZE,
class_mode="categorical",
subset="training",
)
# write labelmap # create training set
class_indices = train_gen.class_indices datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2)
index_to_class = {v: k for k, v in class_indices.items()} train_gen = datagen.flow_from_directory(
sorted_classes = [index_to_class[i] for i in range(len(index_to_class))] dataset_dir,
with open(os.path.join(model_dir, "labelmap.txt"), "w") as f: target_size=(224, 224),
for class_name in sorted_classes: batch_size=BATCH_SIZE,
f.write(f"{class_name}\n") class_mode="categorical",
subset="training",
)
# train the model total_images = train_gen.samples
model.fit(train_gen, epochs=EPOCHS, verbose=0) logger.debug(
f"Training {self.model_name}: {total_images} images across {num_classes} classes"
)
# convert model to tflite # write labelmap
converter = tf.lite.TFLiteConverter.from_keras_model(model) class_indices = train_gen.class_indices
converter.optimizations = [tf.lite.Optimize.DEFAULT] index_to_class = {v: k for k, v in class_indices.items()}
converter.representative_dataset = ( sorted_classes = [index_to_class[i] for i in range(len(index_to_class))]
self.__generate_representative_dataset_factory(dataset_dir) with open(os.path.join(model_dir, "labelmap.txt"), "w") as f:
) for class_name in sorted_classes:
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] f.write(f"{class_name}\n")
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
# write model # train the model
with open(os.path.join(model_dir, "model.tflite"), "wb") as f: logger.debug(f"Training {self.model_name} for {EPOCHS} epochs...")
f.write(tflite_model) model.fit(train_gen, epochs=EPOCHS, verbose=0)
logger.debug(f"Converting {self.model_name} to TFLite...")
# convert model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = (
self.__generate_representative_dataset_factory(dataset_dir)
)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
# write model
model_path = os.path.join(model_dir, "model.tflite")
with open(model_path, "wb") as f:
f.write(tflite_model)
# verify model file was written successfully
if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
logger.error(
f"Training failed for {self.model_name}: Model file was not created or is empty"
)
return False
# write training metadata with image count
dataset_image_count = get_dataset_image_count(self.model_name)
write_training_metadata(self.model_name, dataset_image_count)
logger.info(f"Finished training {self.model_name}")
return True
except Exception as e:
logger.error(f"Training failed for {self.model_name}: {e}", exc_info=True)
return False
def kickoff_model_training( def kickoff_model_training(
@ -165,18 +286,36 @@ def kickoff_model_training(
training_process.start() training_process.start()
training_process.join() training_process.join()
# reload model and mark training as complete # check if training succeeded by examining the exit code
embeddingRequestor.send_data( training_success = training_process.exitcode == 0
EmbeddingsRequestEnum.reload_classification_model.value,
{"model_name": model_name}, if training_success:
) # reload model and mark training as complete
requestor.send_data( embeddingRequestor.send_data(
UPDATE_MODEL_STATE, EmbeddingsRequestEnum.reload_classification_model.value,
{ {"model_name": model_name},
"model": model_name, )
"state": ModelStatusTypesEnum.complete, requestor.send_data(
}, UPDATE_MODEL_STATE,
) {
"model": model_name,
"state": ModelStatusTypesEnum.complete,
},
)
else:
logger.error(
f"Training subprocess failed for {model_name} (exit code: {training_process.exitcode})"
)
# mark training as failed so UI shows error state
# don't reload the model since it failed
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model_name,
"state": ModelStatusTypesEnum.failed,
},
)
requestor.stop() requestor.stop()

View File

@ -13,6 +13,11 @@
"deleteModels": "Delete Models", "deleteModels": "Delete Models",
"editModel": "Edit Model" "editModel": "Edit Model"
}, },
"tooltip": {
"trainingInProgress": "Model is currently training",
"noNewImages": "No new images to train. Classify more images in the dataset first.",
"modelNotReady": "Model is not ready for training"
},
"toast": { "toast": {
"success": { "success": {
"deletedCategory": "Deleted Class", "deletedCategory": "Deleted Class",
@ -30,7 +35,8 @@
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}", "deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
"deleteModelFailed": "Failed to delete model: {{errorMessage}}", "deleteModelFailed": "Failed to delete model: {{errorMessage}}",
"categorizeFailed": "Failed to categorize image: {{errorMessage}}", "categorizeFailed": "Failed to categorize image: {{errorMessage}}",
"trainingFailed": "Failed to start model training: {{errorMessage}}", "trainingFailed": "Model training failed. Check Frigate logs for details.",
"trainingFailedToStart": "Failed to start model training: {{errorMessage}}",
"updateModelFailed": "Failed to update model: {{errorMessage}}", "updateModelFailed": "Failed to update model: {{errorMessage}}",
"renameCategoryFailed": "Failed to rename class: {{errorMessage}}" "renameCategoryFailed": "Failed to rename class: {{errorMessage}}"
} }
@ -143,6 +149,8 @@
"step3": { "step3": {
"selectImagesPrompt": "Select all images with: {{className}}", "selectImagesPrompt": "Select all images with: {{className}}",
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.", "selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
"allImagesRequired_one": "Please classify all images. {{count}} image remaining.",
"allImagesRequired_other": "Please classify all images. {{count}} images remaining.",
"generating": { "generating": {
"title": "Generating Sample Images", "title": "Generating Sample Images",
"description": "Frigate is pulling representative images from your recordings. This may take a moment..." "description": "Frigate is pulling representative images from your recordings. This may take a moment..."

View File

@ -10,6 +10,12 @@ import useSWR from "swr";
import { baseUrl } from "@/api/baseUrl"; import { baseUrl } from "@/api/baseUrl";
import { isMobile } from "react-device-detect"; import { isMobile } from "react-device-detect";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { TooltipPortal } from "@radix-ui/react-tooltip";
export type Step3FormData = { export type Step3FormData = {
examplesGenerated: boolean; examplesGenerated: boolean;
@ -317,6 +323,19 @@ export default function Step3ChooseExamples({
return unclassifiedImages.length === 0; return unclassifiedImages.length === 0;
}, [unclassifiedImages]); }, [unclassifiedImages]);
// For state models on the last class, require all images to be classified
const isLastClass = currentClassIndex === allClasses.length - 1;
const canProceed = useMemo(() => {
if (
step1Data.modelType === "state" &&
isLastClass &&
!allImagesClassified
) {
return false;
}
return true;
}, [step1Data.modelType, isLastClass, allImagesClassified]);
const handleBack = useCallback(() => { const handleBack = useCallback(() => {
if (currentClassIndex > 0) { if (currentClassIndex > 0) {
const previousClass = allClasses[currentClassIndex - 1]; const previousClass = allClasses[currentClassIndex - 1];
@ -438,20 +457,35 @@ export default function Step3ChooseExamples({
<Button type="button" onClick={handleBack} className="sm:flex-1"> <Button type="button" onClick={handleBack} className="sm:flex-1">
{t("button.back", { ns: "common" })} {t("button.back", { ns: "common" })}
</Button> </Button>
<Button <Tooltip>
type="button" <TooltipTrigger asChild>
onClick={ <Button
allImagesClassified type="button"
? handleContinue onClick={
: handleContinueClassification allImagesClassified
} ? handleContinue
variant="select" : handleContinueClassification
className="flex items-center justify-center gap-2 sm:flex-1" }
disabled={!hasGenerated || isGenerating || isProcessing} variant="select"
> className="flex items-center justify-center gap-2 sm:flex-1"
{isProcessing && <ActivityIndicator className="size-4" />} disabled={
{t("button.continue", { ns: "common" })} !hasGenerated || isGenerating || isProcessing || !canProceed
</Button> }
>
{isProcessing && <ActivityIndicator className="size-4" />}
{t("button.continue", { ns: "common" })}
</Button>
</TooltipTrigger>
{!canProceed && (
<TooltipPortal>
<TooltipContent>
{t("wizard.step3.allImagesRequired", {
count: unclassifiedImages.length,
})}
</TooltipContent>
</TooltipPortal>
)}
</Tooltip>
</div> </div>
)} )}
</div> </div>

View File

@ -87,7 +87,8 @@ export type ModelState =
| "downloaded" | "downloaded"
| "error" | "error"
| "training" | "training"
| "complete"; | "complete"
| "failed";
export type EmbeddingsReindexProgressType = { export type EmbeddingsReindexProgressType = {
thumbnails: number; thumbnails: number;

View File

@ -102,6 +102,12 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
position: "top-center", position: "top-center",
}); });
setWasTraining(false); setWasTraining(false);
refreshDataset();
} else if (modelState == "failed") {
toast.error(t("toast.error.trainingFailed"), {
position: "top-center",
});
setWasTraining(false);
} }
// only refresh when modelState changes // only refresh when modelState changes
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
@ -112,10 +118,20 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>( const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
`classification/${model.name}/train`, `classification/${model.name}/train`,
); );
const { data: dataset, mutate: refreshDataset } = useSWR<{ const { data: datasetResponse, mutate: refreshDataset } = useSWR<{
[id: string]: string[]; categories: { [id: string]: string[] };
training_metadata: {
has_trained: boolean;
last_training_date: string | null;
last_training_image_count: number;
current_image_count: number;
new_images_count: number;
} | null;
}>(`classification/${model.name}/dataset`); }>(`classification/${model.name}/dataset`);
const dataset = datasetResponse?.categories || {};
const trainingMetadata = datasetResponse?.training_metadata;
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>(); const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
const refreshAll = useCallback(() => { const refreshAll = useCallback(() => {
@ -177,7 +193,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
error.response?.data?.detail || error.response?.data?.detail ||
"Unknown error"; "Unknown error";
toast.error(t("toast.error.trainingFailed", { errorMessage }), { toast.error(t("toast.error.trainingFailedToStart", { errorMessage }), {
position: "top-center", position: "top-center",
}); });
}); });
@ -421,19 +437,48 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
filterValues={{ classes: Object.keys(dataset || {}) }} filterValues={{ classes: Object.keys(dataset || {}) }}
onUpdateFilter={setTrainFilter} onUpdateFilter={setTrainFilter}
/> />
<Button <Tooltip>
className="flex justify-center gap-2" <TooltipTrigger asChild>
onClick={trainModel} <Button
variant="select" className="flex justify-center gap-2"
disabled={modelState != "complete"} onClick={trainModel}
> variant={modelState == "failed" ? "destructive" : "select"}
{modelState == "training" ? ( disabled={
<ActivityIndicator size={20} /> (modelState != "complete" && modelState != "failed") ||
) : ( (trainingMetadata?.new_images_count ?? 0) === 0
<HiSparkles className="text-white" /> }
>
{modelState == "training" ? (
<ActivityIndicator size={20} />
) : (
<HiSparkles className="text-white" />
)}
{isDesktop && (
<>
{t("button.trainModel")}
{trainingMetadata?.new_images_count !== undefined &&
trainingMetadata.new_images_count > 0 && (
<span className="text-sm text-selected-foreground">
({trainingMetadata.new_images_count})
</span>
)}
</>
)}
</Button>
</TooltipTrigger>
{((trainingMetadata?.new_images_count ?? 0) === 0 ||
(modelState != "complete" && modelState != "failed")) && (
<TooltipPortal>
<TooltipContent>
{modelState == "training"
? t("tooltip.trainingInProgress")
: trainingMetadata?.new_images_count === 0
? t("tooltip.noNewImages")
: t("tooltip.modelNotReady")}
</TooltipContent>
</TooltipPortal>
)} )}
{isDesktop && t("button.trainModel")} </Tooltip>
</Button>
</div> </div>
)} )}
</div> </div>