Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
78da105a3b
Merge 08230cd096 into a374a60756 2025-11-09 14:41:19 +00:00
7 changed files with 119 additions and 380 deletions

View File

@ -37,8 +37,6 @@ from frigate.models import Event
from frigate.util.classification import ( from frigate.util.classification import (
collect_object_classification_examples, collect_object_classification_examples,
collect_state_classification_examples, collect_state_classification_examples,
get_dataset_image_count,
read_training_metadata,
) )
from frigate.util.file import get_event_snapshot from frigate.util.file import get_event_snapshot
@ -566,54 +564,23 @@ def get_classification_dataset(name: str):
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset") dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
if not os.path.exists(dataset_dir): if not os.path.exists(dataset_dir):
return JSONResponse( return JSONResponse(status_code=200, content={})
status_code=200, content={"categories": {}, "training_metadata": None}
)
for category_name in os.listdir(dataset_dir): for name in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, category_name) category_dir = os.path.join(dataset_dir, name)
if not os.path.isdir(category_dir): if not os.path.isdir(category_dir):
continue continue
dataset_dict[category_name] = [] dataset_dict[name] = []
for file in filter( for file in filter(
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))), lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
os.listdir(category_dir), os.listdir(category_dir),
): ):
dataset_dict[category_name].append(file) dataset_dict[name].append(file)
# Get training metadata return JSONResponse(status_code=200, content=dataset_dict)
metadata = read_training_metadata(sanitize_filename(name))
current_image_count = get_dataset_image_count(sanitize_filename(name))
if metadata is None:
training_metadata = {
"has_trained": False,
"last_training_date": None,
"last_training_image_count": 0,
"current_image_count": current_image_count,
"new_images_count": current_image_count,
}
else:
last_training_count = metadata.get("last_training_image_count", 0)
new_images_count = max(0, current_image_count - last_training_count)
training_metadata = {
"has_trained": True,
"last_training_date": metadata.get("last_training_date"),
"last_training_image_count": last_training_count,
"current_image_count": current_image_count,
"new_images_count": new_images_count,
}
return JSONResponse(
status_code=200,
content={
"categories": dataset_dict,
"training_metadata": training_metadata,
},
)
@router.get( @router.get(

View File

@ -23,7 +23,6 @@ class ModelStatusTypesEnum(str, Enum):
error = "error" error = "error"
training = "training" training = "training"
complete = "complete" complete = "complete"
failed = "failed"
class TrackedObjectUpdateTypesEnum(str, Enum): class TrackedObjectUpdateTypesEnum(str, Enum):

View File

@ -1,7 +1,5 @@
"""Util for classification models.""" """Util for classification models."""
import datetime
import json
import logging import logging
import os import os
import random import random
@ -29,96 +27,10 @@ from frigate.util.process import FrigateProcess
BATCH_SIZE = 16 BATCH_SIZE = 16
EPOCHS = 50 EPOCHS = 50
LEARNING_RATE = 0.001 LEARNING_RATE = 0.001
TRAINING_METADATA_FILE = ".training_metadata.json"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def write_training_metadata(model_name: str, image_count: int) -> None:
"""
Write training metadata to a hidden file in the model's clips directory.
Args:
model_name: Name of the classification model
image_count: Number of images used in training
"""
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
os.makedirs(clips_model_dir, exist_ok=True)
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
metadata = {
"last_training_date": datetime.datetime.now().isoformat(),
"last_training_image_count": image_count,
}
try:
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
logger.info(f"Wrote training metadata for {model_name}: {image_count} images")
except Exception as e:
logger.error(f"Failed to write training metadata for {model_name}: {e}")
def read_training_metadata(model_name: str) -> dict[str, any] | None:
"""
Read training metadata from the hidden file in the model's clips directory.
Args:
model_name: Name of the classification model
Returns:
Dictionary with last_training_date and last_training_image_count, or None if not found
"""
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
if not os.path.exists(metadata_path):
return None
try:
with open(metadata_path, "r") as f:
metadata = json.load(f)
return metadata
except Exception as e:
logger.error(f"Failed to read training metadata for {model_name}: {e}")
return None
def get_dataset_image_count(model_name: str) -> int:
"""
Count the total number of images in the model's dataset directory.
Args:
model_name: Name of the classification model
Returns:
Total count of images across all categories
"""
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
if not os.path.exists(dataset_dir):
return 0
total_count = 0
try:
for category in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, category)
if not os.path.isdir(category_dir):
continue
image_files = [
f
for f in os.listdir(category_dir)
if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))
]
total_count += len(image_files)
except Exception as e:
logger.error(f"Failed to count dataset images for {model_name}: {e}")
return 0
return total_count
class ClassificationTrainingProcess(FrigateProcess): class ClassificationTrainingProcess(FrigateProcess):
def __init__(self, model_name: str) -> None: def __init__(self, model_name: str) -> None:
super().__init__( super().__init__(
@ -130,8 +42,7 @@ class ClassificationTrainingProcess(FrigateProcess):
def run(self) -> None: def run(self) -> None:
self.pre_run_setup() self.pre_run_setup()
success = self.__train_classification_model() self.__train_classification_model()
exit(0 if success else 1)
def __generate_representative_dataset_factory(self, dataset_dir: str): def __generate_representative_dataset_factory(self, dataset_dir: str):
def generate_representative_dataset(): def generate_representative_dataset():
@ -154,117 +65,85 @@ class ClassificationTrainingProcess(FrigateProcess):
@redirect_output_to_logger(logger, logging.DEBUG) @redirect_output_to_logger(logger, logging.DEBUG)
def __train_classification_model(self) -> bool: def __train_classification_model(self) -> bool:
"""Train a classification model.""" """Train a classification model."""
try:
# import in the function so that tensorflow is not initialized multiple times
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset") # import in the function so that tensorflow is not initialized multiple times
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name) import tensorflow as tf
os.makedirs(model_dir, exist_ok=True) from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
num_classes = len( logger.info(f"Kicking off classification training for {self.model_name}.")
[ dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
d model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
for d in os.listdir(dataset_dir) os.makedirs(model_dir, exist_ok=True)
if os.path.isdir(os.path.join(dataset_dir, d)) num_classes = len(
] [
) d
for d in os.listdir(dataset_dir)
if os.path.isdir(os.path.join(dataset_dir, d))
]
)
if num_classes < 2: # Start with imagenet base model with 35% of channels in each layer
logger.error( base_model = MobileNetV2(
f"Training failed for {self.model_name}: Need at least 2 classes, found {num_classes}" input_shape=(224, 224, 3),
) include_top=False,
return False weights="imagenet",
alpha=0.35,
)
base_model.trainable = False # Freeze pre-trained layers
# Start with imagenet base model with 35% of channels in each layer model = models.Sequential(
base_model = MobileNetV2( [
input_shape=(224, 224, 3), base_model,
include_top=False, layers.GlobalAveragePooling2D(),
weights="imagenet", layers.Dense(128, activation="relu"),
alpha=0.35, layers.Dropout(0.3),
) layers.Dense(num_classes, activation="softmax"),
base_model.trainable = False # Freeze pre-trained layers ]
)
model = models.Sequential( model.compile(
[ optimizer=optimizers.Adam(learning_rate=LEARNING_RATE),
base_model, loss="categorical_crossentropy",
layers.GlobalAveragePooling2D(), metrics=["accuracy"],
layers.Dense(128, activation="relu"), )
layers.Dropout(0.3),
layers.Dense(num_classes, activation="softmax"),
]
)
model.compile( # create training set
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE), datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2)
loss="categorical_crossentropy", train_gen = datagen.flow_from_directory(
metrics=["accuracy"], dataset_dir,
) target_size=(224, 224),
batch_size=BATCH_SIZE,
class_mode="categorical",
subset="training",
)
# create training set # write labelmap
datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2) class_indices = train_gen.class_indices
train_gen = datagen.flow_from_directory( index_to_class = {v: k for k, v in class_indices.items()}
dataset_dir, sorted_classes = [index_to_class[i] for i in range(len(index_to_class))]
target_size=(224, 224), with open(os.path.join(model_dir, "labelmap.txt"), "w") as f:
batch_size=BATCH_SIZE, for class_name in sorted_classes:
class_mode="categorical", f.write(f"{class_name}\n")
subset="training",
)
total_images = train_gen.samples # train the model
logger.debug( model.fit(train_gen, epochs=EPOCHS, verbose=0)
f"Training {self.model_name}: {total_images} images across {num_classes} classes"
)
# write labelmap # convert model to tflite
class_indices = train_gen.class_indices converter = tf.lite.TFLiteConverter.from_keras_model(model)
index_to_class = {v: k for k, v in class_indices.items()} converter.optimizations = [tf.lite.Optimize.DEFAULT]
sorted_classes = [index_to_class[i] for i in range(len(index_to_class))] converter.representative_dataset = (
with open(os.path.join(model_dir, "labelmap.txt"), "w") as f: self.__generate_representative_dataset_factory(dataset_dir)
for class_name in sorted_classes: )
f.write(f"{class_name}\n") converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
# train the model # write model
logger.debug(f"Training {self.model_name} for {EPOCHS} epochs...") with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
model.fit(train_gen, epochs=EPOCHS, verbose=0) f.write(tflite_model)
logger.debug(f"Converting {self.model_name} to TFLite...")
# convert model to tflite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = (
self.__generate_representative_dataset_factory(dataset_dir)
)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
# write model
model_path = os.path.join(model_dir, "model.tflite")
with open(model_path, "wb") as f:
f.write(tflite_model)
# verify model file was written successfully
if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
logger.error(
f"Training failed for {self.model_name}: Model file was not created or is empty"
)
return False
# write training metadata with image count
dataset_image_count = get_dataset_image_count(self.model_name)
write_training_metadata(self.model_name, dataset_image_count)
logger.info(f"Finished training {self.model_name}")
return True
except Exception as e:
logger.error(f"Training failed for {self.model_name}: {e}", exc_info=True)
return False
def kickoff_model_training( def kickoff_model_training(
@ -286,36 +165,18 @@ def kickoff_model_training(
training_process.start() training_process.start()
training_process.join() training_process.join()
# check if training succeeded by examining the exit code # reload model and mark training as complete
training_success = training_process.exitcode == 0 embeddingRequestor.send_data(
EmbeddingsRequestEnum.reload_classification_model.value,
if training_success: {"model_name": model_name},
# reload model and mark training as complete )
embeddingRequestor.send_data( requestor.send_data(
EmbeddingsRequestEnum.reload_classification_model.value, UPDATE_MODEL_STATE,
{"model_name": model_name}, {
) "model": model_name,
requestor.send_data( "state": ModelStatusTypesEnum.complete,
UPDATE_MODEL_STATE, },
{ )
"model": model_name,
"state": ModelStatusTypesEnum.complete,
},
)
else:
logger.error(
f"Training subprocess failed for {model_name} (exit code: {training_process.exitcode})"
)
# mark training as failed so UI shows error state
# don't reload the model since it failed
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": model_name,
"state": ModelStatusTypesEnum.failed,
},
)
requestor.stop() requestor.stop()

View File

@ -13,11 +13,6 @@
"deleteModels": "Delete Models", "deleteModels": "Delete Models",
"editModel": "Edit Model" "editModel": "Edit Model"
}, },
"tooltip": {
"trainingInProgress": "Model is currently training",
"noNewImages": "No new images to train. Classify more images in the dataset first.",
"modelNotReady": "Model is not ready for training"
},
"toast": { "toast": {
"success": { "success": {
"deletedCategory": "Deleted Class", "deletedCategory": "Deleted Class",
@ -35,8 +30,7 @@
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}", "deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
"deleteModelFailed": "Failed to delete model: {{errorMessage}}", "deleteModelFailed": "Failed to delete model: {{errorMessage}}",
"categorizeFailed": "Failed to categorize image: {{errorMessage}}", "categorizeFailed": "Failed to categorize image: {{errorMessage}}",
"trainingFailed": "Model training failed. Check Frigate logs for details.", "trainingFailed": "Failed to start model training: {{errorMessage}}",
"trainingFailedToStart": "Failed to start model training: {{errorMessage}}",
"updateModelFailed": "Failed to update model: {{errorMessage}}", "updateModelFailed": "Failed to update model: {{errorMessage}}",
"renameCategoryFailed": "Failed to rename class: {{errorMessage}}" "renameCategoryFailed": "Failed to rename class: {{errorMessage}}"
} }
@ -149,8 +143,6 @@
"step3": { "step3": {
"selectImagesPrompt": "Select all images with: {{className}}", "selectImagesPrompt": "Select all images with: {{className}}",
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.", "selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
"allImagesRequired_one": "Please classify all images. {{count}} image remaining.",
"allImagesRequired_other": "Please classify all images. {{count}} images remaining.",
"generating": { "generating": {
"title": "Generating Sample Images", "title": "Generating Sample Images",
"description": "Frigate is pulling representative images from your recordings. This may take a moment..." "description": "Frigate is pulling representative images from your recordings. This may take a moment..."

View File

@ -10,12 +10,6 @@ import useSWR from "swr";
import { baseUrl } from "@/api/baseUrl"; import { baseUrl } from "@/api/baseUrl";
import { isMobile } from "react-device-detect"; import { isMobile } from "react-device-detect";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { TooltipPortal } from "@radix-ui/react-tooltip";
export type Step3FormData = { export type Step3FormData = {
examplesGenerated: boolean; examplesGenerated: boolean;
@ -323,19 +317,6 @@ export default function Step3ChooseExamples({
return unclassifiedImages.length === 0; return unclassifiedImages.length === 0;
}, [unclassifiedImages]); }, [unclassifiedImages]);
// For state models on the last class, require all images to be classified
const isLastClass = currentClassIndex === allClasses.length - 1;
const canProceed = useMemo(() => {
if (
step1Data.modelType === "state" &&
isLastClass &&
!allImagesClassified
) {
return false;
}
return true;
}, [step1Data.modelType, isLastClass, allImagesClassified]);
const handleBack = useCallback(() => { const handleBack = useCallback(() => {
if (currentClassIndex > 0) { if (currentClassIndex > 0) {
const previousClass = allClasses[currentClassIndex - 1]; const previousClass = allClasses[currentClassIndex - 1];
@ -457,35 +438,20 @@ export default function Step3ChooseExamples({
<Button type="button" onClick={handleBack} className="sm:flex-1"> <Button type="button" onClick={handleBack} className="sm:flex-1">
{t("button.back", { ns: "common" })} {t("button.back", { ns: "common" })}
</Button> </Button>
<Tooltip> <Button
<TooltipTrigger asChild> type="button"
<Button onClick={
type="button" allImagesClassified
onClick={ ? handleContinue
allImagesClassified : handleContinueClassification
? handleContinue }
: handleContinueClassification variant="select"
} className="flex items-center justify-center gap-2 sm:flex-1"
variant="select" disabled={!hasGenerated || isGenerating || isProcessing}
className="flex items-center justify-center gap-2 sm:flex-1" >
disabled={ {isProcessing && <ActivityIndicator className="size-4" />}
!hasGenerated || isGenerating || isProcessing || !canProceed {t("button.continue", { ns: "common" })}
} </Button>
>
{isProcessing && <ActivityIndicator className="size-4" />}
{t("button.continue", { ns: "common" })}
</Button>
</TooltipTrigger>
{!canProceed && (
<TooltipPortal>
<TooltipContent>
{t("wizard.step3.allImagesRequired", {
count: unclassifiedImages.length,
})}
</TooltipContent>
</TooltipPortal>
)}
</Tooltip>
</div> </div>
)} )}
</div> </div>

View File

@ -87,8 +87,7 @@ export type ModelState =
| "downloaded" | "downloaded"
| "error" | "error"
| "training" | "training"
| "complete" | "complete";
| "failed";
export type EmbeddingsReindexProgressType = { export type EmbeddingsReindexProgressType = {
thumbnails: number; thumbnails: number;

View File

@ -102,12 +102,6 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
position: "top-center", position: "top-center",
}); });
setWasTraining(false); setWasTraining(false);
refreshDataset();
} else if (modelState == "failed") {
toast.error(t("toast.error.trainingFailed"), {
position: "top-center",
});
setWasTraining(false);
} }
// only refresh when modelState changes // only refresh when modelState changes
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
@ -118,20 +112,10 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>( const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
`classification/${model.name}/train`, `classification/${model.name}/train`,
); );
const { data: datasetResponse, mutate: refreshDataset } = useSWR<{ const { data: dataset, mutate: refreshDataset } = useSWR<{
categories: { [id: string]: string[] }; [id: string]: string[];
training_metadata: {
has_trained: boolean;
last_training_date: string | null;
last_training_image_count: number;
current_image_count: number;
new_images_count: number;
} | null;
}>(`classification/${model.name}/dataset`); }>(`classification/${model.name}/dataset`);
const dataset = datasetResponse?.categories || {};
const trainingMetadata = datasetResponse?.training_metadata;
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>(); const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
const refreshAll = useCallback(() => { const refreshAll = useCallback(() => {
@ -193,7 +177,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
error.response?.data?.detail || error.response?.data?.detail ||
"Unknown error"; "Unknown error";
toast.error(t("toast.error.trainingFailedToStart", { errorMessage }), { toast.error(t("toast.error.trainingFailed", { errorMessage }), {
position: "top-center", position: "top-center",
}); });
}); });
@ -437,48 +421,19 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
filterValues={{ classes: Object.keys(dataset || {}) }} filterValues={{ classes: Object.keys(dataset || {}) }}
onUpdateFilter={setTrainFilter} onUpdateFilter={setTrainFilter}
/> />
<Tooltip> <Button
<TooltipTrigger asChild> className="flex justify-center gap-2"
<Button onClick={trainModel}
className="flex justify-center gap-2" variant="select"
onClick={trainModel} disabled={modelState != "complete"}
variant={modelState == "failed" ? "destructive" : "select"} >
disabled={ {modelState == "training" ? (
(modelState != "complete" && modelState != "failed") || <ActivityIndicator size={20} />
(trainingMetadata?.new_images_count ?? 0) === 0 ) : (
} <HiSparkles className="text-white" />
>
{modelState == "training" ? (
<ActivityIndicator size={20} />
) : (
<HiSparkles className="text-white" />
)}
{isDesktop && (
<>
{t("button.trainModel")}
{trainingMetadata?.new_images_count !== undefined &&
trainingMetadata.new_images_count > 0 && (
<span className="text-sm text-selected-foreground">
({trainingMetadata.new_images_count})
</span>
)}
</>
)}
</Button>
</TooltipTrigger>
{((trainingMetadata?.new_images_count ?? 0) === 0 ||
(modelState != "complete" && modelState != "failed")) && (
<TooltipPortal>
<TooltipContent>
{modelState == "training"
? t("tooltip.trainingInProgress")
: trainingMetadata?.new_images_count === 0
? t("tooltip.noNewImages")
: t("tooltip.modelNotReady")}
</TooltipContent>
</TooltipPortal>
)} )}
</Tooltip> {isDesktop && t("button.trainModel")}
</Button>
</div> </div>
)} )}
</div> </div>