Improve classification (#20863)
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions

This commit is contained in:
Nicolas Mowen 2025-11-09 15:21:13 -07:00 committed by GitHub
parent a374a60756
commit 99a363c047
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 380 additions and 119 deletions

View File

@ -37,6 +37,8 @@ from frigate.models import Event
from frigate.util.classification import (
collect_object_classification_examples,
collect_state_classification_examples,
get_dataset_image_count,
read_training_metadata,
)
from frigate.util.file import get_event_snapshot
@ -564,23 +566,54 @@ def get_classification_dataset(name: str):
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
if not os.path.exists(dataset_dir):
return JSONResponse(status_code=200, content={})
return JSONResponse(
status_code=200, content={"categories": {}, "training_metadata": None}
)
for name in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, name)
for category_name in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, category_name)
if not os.path.isdir(category_dir):
continue
dataset_dict[name] = []
dataset_dict[category_name] = []
for file in filter(
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
os.listdir(category_dir),
):
dataset_dict[name].append(file)
dataset_dict[category_name].append(file)
return JSONResponse(status_code=200, content=dataset_dict)
# Get training metadata
metadata = read_training_metadata(sanitize_filename(name))
current_image_count = get_dataset_image_count(sanitize_filename(name))
if metadata is None:
training_metadata = {
"has_trained": False,
"last_training_date": None,
"last_training_image_count": 0,
"current_image_count": current_image_count,
"new_images_count": current_image_count,
}
else:
last_training_count = metadata.get("last_training_image_count", 0)
new_images_count = max(0, current_image_count - last_training_count)
training_metadata = {
"has_trained": True,
"last_training_date": metadata.get("last_training_date"),
"last_training_image_count": last_training_count,
"current_image_count": current_image_count,
"new_images_count": new_images_count,
}
return JSONResponse(
status_code=200,
content={
"categories": dataset_dict,
"training_metadata": training_metadata,
},
)
@router.get(

View File

@ -23,6 +23,7 @@ class ModelStatusTypesEnum(str, Enum):
error = "error"
training = "training"
complete = "complete"
failed = "failed"
class TrackedObjectUpdateTypesEnum(str, Enum):

View File

@ -1,5 +1,7 @@
"""Util for classification models."""
import datetime
import json
import logging
import os
import random
@ -27,10 +29,96 @@ from frigate.util.process import FrigateProcess
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 0.001
TRAINING_METADATA_FILE = ".training_metadata.json"
logger = logging.getLogger(__name__)
def write_training_metadata(model_name: str, image_count: int) -> None:
"""
Write training metadata to a hidden file in the model's clips directory.
Args:
model_name: Name of the classification model
image_count: Number of images used in training
"""
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
os.makedirs(clips_model_dir, exist_ok=True)
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
metadata = {
"last_training_date": datetime.datetime.now().isoformat(),
"last_training_image_count": image_count,
}
try:
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
logger.info(f"Wrote training metadata for {model_name}: {image_count} images")
except Exception as e:
logger.error(f"Failed to write training metadata for {model_name}: {e}")
def read_training_metadata(model_name: str) -> dict[str, any] | None:
"""
Read training metadata from the hidden file in the model's clips directory.
Args:
model_name: Name of the classification model
Returns:
Dictionary with last_training_date and last_training_image_count, or None if not found
"""
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
if not os.path.exists(metadata_path):
return None
try:
with open(metadata_path, "r") as f:
metadata = json.load(f)
return metadata
except Exception as e:
logger.error(f"Failed to read training metadata for {model_name}: {e}")
return None
def get_dataset_image_count(model_name: str) -> int:
"""
Count the total number of images in the model's dataset directory.
Args:
model_name: Name of the classification model
Returns:
Total count of images across all categories
"""
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
if not os.path.exists(dataset_dir):
return 0
total_count = 0
try:
for category in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, category)
if not os.path.isdir(category_dir):
continue
image_files = [
f
for f in os.listdir(category_dir)
if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))
]
total_count += len(image_files)
except Exception as e:
logger.error(f"Failed to count dataset images for {model_name}: {e}")
return 0
return total_count
class ClassificationTrainingProcess(FrigateProcess):
def __init__(self, model_name: str) -> None:
super().__init__(
@ -42,7 +130,8 @@ class ClassificationTrainingProcess(FrigateProcess):
def run(self) -> None:
self.pre_run_setup()
self.__train_classification_model()
success = self.__train_classification_model()
exit(0 if success else 1)
def __generate_representative_dataset_factory(self, dataset_dir: str):
def generate_representative_dataset():
@ -65,17 +154,17 @@ class ClassificationTrainingProcess(FrigateProcess):
@redirect_output_to_logger(logger, logging.DEBUG)
def __train_classification_model(self) -> bool:
"""Train a classification model."""
try:
# import in the function so that tensorflow is not initialized multiple times
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
logger.info(f"Kicking off classification training for {self.model_name}.")
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
os.makedirs(model_dir, exist_ok=True)
num_classes = len(
[
d
@ -84,6 +173,12 @@ class ClassificationTrainingProcess(FrigateProcess):
]
)
if num_classes < 2:
logger.error(
f"Training failed for {self.model_name}: Need at least 2 classes, found {num_classes}"
)
return False
# Start with imagenet base model with 35% of channels in each layer
base_model = MobileNetV2(
input_shape=(224, 224, 3),
@ -119,6 +214,11 @@ class ClassificationTrainingProcess(FrigateProcess):
subset="training",
)
total_images = train_gen.samples
logger.debug(
f"Training {self.model_name}: {total_images} images across {num_classes} classes"
)
# write labelmap
class_indices = train_gen.class_indices
index_to_class = {v: k for k, v in class_indices.items()}
@ -128,7 +228,9 @@ class ClassificationTrainingProcess(FrigateProcess):
f.write(f"{class_name}\n")
# train the model
logger.debug(f"Training {self.model_name} for {EPOCHS} epochs...")
model.fit(train_gen, epochs=EPOCHS, verbose=0)
logger.debug(f"Converting {self.model_name} to TFLite...")
# convert model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
@ -142,9 +244,28 @@ class ClassificationTrainingProcess(FrigateProcess):
tflite_model = converter.convert()
# write model
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
model_path = os.path.join(model_dir, "model.tflite")
with open(model_path, "wb") as f:
f.write(tflite_model)
# verify model file was written successfully
if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
logger.error(
f"Training failed for {self.model_name}: Model file was not created or is empty"
)
return False
# write training metadata with image count
dataset_image_count = get_dataset_image_count(self.model_name)
write_training_metadata(self.model_name, dataset_image_count)
logger.info(f"Finished training {self.model_name}")
return True
except Exception as e:
logger.error(f"Training failed for {self.model_name}: {e}", exc_info=True)
return False
def kickoff_model_training(
embeddingRequestor: EmbeddingsRequestor, model_name: str
@ -165,6 +286,10 @@ def kickoff_model_training(
training_process.start()
training_process.join()
# check if training succeeded by examining the exit code
training_success = training_process.exitcode == 0
if training_success:
# reload model and mark training as complete
embeddingRequestor.send_data(
EmbeddingsRequestEnum.reload_classification_model.value,
@ -177,6 +302,20 @@ def kickoff_model_training(
"state": ModelStatusTypesEnum.complete,
},
)
else:
logger.error(
f"Training subprocess failed for {model_name} (exit code: {training_process.exitcode})"
)
# mark training as failed so UI shows error state
# don't reload the model since it failed
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model_name,
"state": ModelStatusTypesEnum.failed,
},
)
requestor.stop()

View File

@ -13,6 +13,11 @@
"deleteModels": "Delete Models",
"editModel": "Edit Model"
},
"tooltip": {
"trainingInProgress": "Model is currently training",
"noNewImages": "No new images to train. Classify more images in the dataset first.",
"modelNotReady": "Model is not ready for training"
},
"toast": {
"success": {
"deletedCategory": "Deleted Class",
@ -30,7 +35,8 @@
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
"deleteModelFailed": "Failed to delete model: {{errorMessage}}",
"categorizeFailed": "Failed to categorize image: {{errorMessage}}",
"trainingFailed": "Failed to start model training: {{errorMessage}}",
"trainingFailed": "Model training failed. Check Frigate logs for details.",
"trainingFailedToStart": "Failed to start model training: {{errorMessage}}",
"updateModelFailed": "Failed to update model: {{errorMessage}}",
"renameCategoryFailed": "Failed to rename class: {{errorMessage}}"
}
@ -143,6 +149,8 @@
"step3": {
"selectImagesPrompt": "Select all images with: {{className}}",
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
"allImagesRequired_one": "Please classify all images. {{count}} image remaining.",
"allImagesRequired_other": "Please classify all images. {{count}} images remaining.",
"generating": {
"title": "Generating Sample Images",
"description": "Frigate is pulling representative images from your recordings. This may take a moment..."

View File

@ -10,6 +10,12 @@ import useSWR from "swr";
import { baseUrl } from "@/api/baseUrl";
import { isMobile } from "react-device-detect";
import { cn } from "@/lib/utils";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { TooltipPortal } from "@radix-ui/react-tooltip";
export type Step3FormData = {
examplesGenerated: boolean;
@ -317,6 +323,19 @@ export default function Step3ChooseExamples({
return unclassifiedImages.length === 0;
}, [unclassifiedImages]);
// For state models on the last class, require all images to be classified
const isLastClass = currentClassIndex === allClasses.length - 1;
const canProceed = useMemo(() => {
if (
step1Data.modelType === "state" &&
isLastClass &&
!allImagesClassified
) {
return false;
}
return true;
}, [step1Data.modelType, isLastClass, allImagesClassified]);
const handleBack = useCallback(() => {
if (currentClassIndex > 0) {
const previousClass = allClasses[currentClassIndex - 1];
@ -438,6 +457,8 @@ export default function Step3ChooseExamples({
<Button type="button" onClick={handleBack} className="sm:flex-1">
{t("button.back", { ns: "common" })}
</Button>
<Tooltip>
<TooltipTrigger asChild>
<Button
type="button"
onClick={
@ -447,11 +468,24 @@ export default function Step3ChooseExamples({
}
variant="select"
className="flex items-center justify-center gap-2 sm:flex-1"
disabled={!hasGenerated || isGenerating || isProcessing}
disabled={
!hasGenerated || isGenerating || isProcessing || !canProceed
}
>
{isProcessing && <ActivityIndicator className="size-4" />}
{t("button.continue", { ns: "common" })}
</Button>
</TooltipTrigger>
{!canProceed && (
<TooltipPortal>
<TooltipContent>
{t("wizard.step3.allImagesRequired", {
count: unclassifiedImages.length,
})}
</TooltipContent>
</TooltipPortal>
)}
</Tooltip>
</div>
)}
</div>

View File

@ -87,7 +87,8 @@ export type ModelState =
| "downloaded"
| "error"
| "training"
| "complete";
| "complete"
| "failed";
export type EmbeddingsReindexProgressType = {
thumbnails: number;

View File

@ -102,6 +102,12 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
position: "top-center",
});
setWasTraining(false);
refreshDataset();
} else if (modelState == "failed") {
toast.error(t("toast.error.trainingFailed"), {
position: "top-center",
});
setWasTraining(false);
}
// only refresh when modelState changes
// eslint-disable-next-line react-hooks/exhaustive-deps
@ -112,10 +118,20 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
`classification/${model.name}/train`,
);
const { data: dataset, mutate: refreshDataset } = useSWR<{
[id: string]: string[];
const { data: datasetResponse, mutate: refreshDataset } = useSWR<{
categories: { [id: string]: string[] };
training_metadata: {
has_trained: boolean;
last_training_date: string | null;
last_training_image_count: number;
current_image_count: number;
new_images_count: number;
} | null;
}>(`classification/${model.name}/dataset`);
const dataset = datasetResponse?.categories || {};
const trainingMetadata = datasetResponse?.training_metadata;
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
const refreshAll = useCallback(() => {
@ -177,7 +193,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
error.response?.data?.detail ||
"Unknown error";
toast.error(t("toast.error.trainingFailed", { errorMessage }), {
toast.error(t("toast.error.trainingFailedToStart", { errorMessage }), {
position: "top-center",
});
});
@ -421,19 +437,48 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
filterValues={{ classes: Object.keys(dataset || {}) }}
onUpdateFilter={setTrainFilter}
/>
<Tooltip>
<TooltipTrigger asChild>
<Button
className="flex justify-center gap-2"
onClick={trainModel}
variant="select"
disabled={modelState != "complete"}
variant={modelState == "failed" ? "destructive" : "select"}
disabled={
(modelState != "complete" && modelState != "failed") ||
(trainingMetadata?.new_images_count ?? 0) === 0
}
>
{modelState == "training" ? (
<ActivityIndicator size={20} />
) : (
<HiSparkles className="text-white" />
)}
{isDesktop && t("button.trainModel")}
{isDesktop && (
<>
{t("button.trainModel")}
{trainingMetadata?.new_images_count !== undefined &&
trainingMetadata.new_images_count > 0 && (
<span className="text-sm text-selected-foreground">
({trainingMetadata.new_images_count})
</span>
)}
</>
)}
</Button>
</TooltipTrigger>
{((trainingMetadata?.new_images_count ?? 0) === 0 ||
(modelState != "complete" && modelState != "failed")) && (
<TooltipPortal>
<TooltipContent>
{modelState == "training"
? t("tooltip.trainingInProgress")
: trainingMetadata?.new_images_count === 0
? t("tooltip.noNewImages")
: t("tooltip.modelNotReady")}
</TooltipContent>
</TooltipPortal>
)}
</Tooltip>
</div>
)}
</div>