mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-12-06 05:24:11 +03:00
Improve classification (#20863)
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions
This commit is contained in:
parent
a374a60756
commit
99a363c047
@ -37,6 +37,8 @@ from frigate.models import Event
|
||||
from frigate.util.classification import (
|
||||
collect_object_classification_examples,
|
||||
collect_state_classification_examples,
|
||||
get_dataset_image_count,
|
||||
read_training_metadata,
|
||||
)
|
||||
from frigate.util.file import get_event_snapshot
|
||||
|
||||
@ -564,23 +566,54 @@ def get_classification_dataset(name: str):
|
||||
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
|
||||
|
||||
if not os.path.exists(dataset_dir):
|
||||
return JSONResponse(status_code=200, content={})
|
||||
return JSONResponse(
|
||||
status_code=200, content={"categories": {}, "training_metadata": None}
|
||||
)
|
||||
|
||||
for name in os.listdir(dataset_dir):
|
||||
category_dir = os.path.join(dataset_dir, name)
|
||||
for category_name in os.listdir(dataset_dir):
|
||||
category_dir = os.path.join(dataset_dir, category_name)
|
||||
|
||||
if not os.path.isdir(category_dir):
|
||||
continue
|
||||
|
||||
dataset_dict[name] = []
|
||||
dataset_dict[category_name] = []
|
||||
|
||||
for file in filter(
|
||||
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
|
||||
os.listdir(category_dir),
|
||||
):
|
||||
dataset_dict[name].append(file)
|
||||
dataset_dict[category_name].append(file)
|
||||
|
||||
return JSONResponse(status_code=200, content=dataset_dict)
|
||||
# Get training metadata
|
||||
metadata = read_training_metadata(sanitize_filename(name))
|
||||
current_image_count = get_dataset_image_count(sanitize_filename(name))
|
||||
|
||||
if metadata is None:
|
||||
training_metadata = {
|
||||
"has_trained": False,
|
||||
"last_training_date": None,
|
||||
"last_training_image_count": 0,
|
||||
"current_image_count": current_image_count,
|
||||
"new_images_count": current_image_count,
|
||||
}
|
||||
else:
|
||||
last_training_count = metadata.get("last_training_image_count", 0)
|
||||
new_images_count = max(0, current_image_count - last_training_count)
|
||||
training_metadata = {
|
||||
"has_trained": True,
|
||||
"last_training_date": metadata.get("last_training_date"),
|
||||
"last_training_image_count": last_training_count,
|
||||
"current_image_count": current_image_count,
|
||||
"new_images_count": new_images_count,
|
||||
}
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"categories": dataset_dict,
|
||||
"training_metadata": training_metadata,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@ -23,6 +23,7 @@ class ModelStatusTypesEnum(str, Enum):
|
||||
error = "error"
|
||||
training = "training"
|
||||
complete = "complete"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class TrackedObjectUpdateTypesEnum(str, Enum):
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Util for classification models."""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@ -27,10 +29,96 @@ from frigate.util.process import FrigateProcess
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 50
|
||||
LEARNING_RATE = 0.001
|
||||
TRAINING_METADATA_FILE = ".training_metadata.json"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def write_training_metadata(model_name: str, image_count: int) -> None:
|
||||
"""
|
||||
Write training metadata to a hidden file in the model's clips directory.
|
||||
|
||||
Args:
|
||||
model_name: Name of the classification model
|
||||
image_count: Number of images used in training
|
||||
"""
|
||||
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
|
||||
os.makedirs(clips_model_dir, exist_ok=True)
|
||||
|
||||
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
|
||||
metadata = {
|
||||
"last_training_date": datetime.datetime.now().isoformat(),
|
||||
"last_training_image_count": image_count,
|
||||
}
|
||||
|
||||
try:
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
logger.info(f"Wrote training metadata for {model_name}: {image_count} images")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write training metadata for {model_name}: {e}")
|
||||
|
||||
|
||||
def read_training_metadata(model_name: str) -> dict[str, any] | None:
|
||||
"""
|
||||
Read training metadata from the hidden file in the model's clips directory.
|
||||
|
||||
Args:
|
||||
model_name: Name of the classification model
|
||||
|
||||
Returns:
|
||||
Dictionary with last_training_date and last_training_image_count, or None if not found
|
||||
"""
|
||||
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
|
||||
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
|
||||
|
||||
if not os.path.exists(metadata_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
return metadata
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read training metadata for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_dataset_image_count(model_name: str) -> int:
|
||||
"""
|
||||
Count the total number of images in the model's dataset directory.
|
||||
|
||||
Args:
|
||||
model_name: Name of the classification model
|
||||
|
||||
Returns:
|
||||
Total count of images across all categories
|
||||
"""
|
||||
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||
|
||||
if not os.path.exists(dataset_dir):
|
||||
return 0
|
||||
|
||||
total_count = 0
|
||||
try:
|
||||
for category in os.listdir(dataset_dir):
|
||||
category_dir = os.path.join(dataset_dir, category)
|
||||
if not os.path.isdir(category_dir):
|
||||
continue
|
||||
|
||||
image_files = [
|
||||
f
|
||||
for f in os.listdir(category_dir)
|
||||
if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))
|
||||
]
|
||||
total_count += len(image_files)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to count dataset images for {model_name}: {e}")
|
||||
return 0
|
||||
|
||||
return total_count
|
||||
|
||||
|
||||
class ClassificationTrainingProcess(FrigateProcess):
|
||||
def __init__(self, model_name: str) -> None:
|
||||
super().__init__(
|
||||
@ -42,7 +130,8 @@ class ClassificationTrainingProcess(FrigateProcess):
|
||||
|
||||
def run(self) -> None:
|
||||
self.pre_run_setup()
|
||||
self.__train_classification_model()
|
||||
success = self.__train_classification_model()
|
||||
exit(0 if success else 1)
|
||||
|
||||
def __generate_representative_dataset_factory(self, dataset_dir: str):
|
||||
def generate_representative_dataset():
|
||||
@ -65,17 +154,17 @@ class ClassificationTrainingProcess(FrigateProcess):
|
||||
@redirect_output_to_logger(logger, logging.DEBUG)
|
||||
def __train_classification_model(self) -> bool:
|
||||
"""Train a classification model."""
|
||||
|
||||
try:
|
||||
# import in the function so that tensorflow is not initialized multiple times
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers, models, optimizers
|
||||
from tensorflow.keras.applications import MobileNetV2
|
||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||
|
||||
logger.info(f"Kicking off classification training for {self.model_name}.")
|
||||
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
|
||||
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
num_classes = len(
|
||||
[
|
||||
d
|
||||
@ -84,6 +173,12 @@ class ClassificationTrainingProcess(FrigateProcess):
|
||||
]
|
||||
)
|
||||
|
||||
if num_classes < 2:
|
||||
logger.error(
|
||||
f"Training failed for {self.model_name}: Need at least 2 classes, found {num_classes}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Start with imagenet base model with 35% of channels in each layer
|
||||
base_model = MobileNetV2(
|
||||
input_shape=(224, 224, 3),
|
||||
@ -119,6 +214,11 @@ class ClassificationTrainingProcess(FrigateProcess):
|
||||
subset="training",
|
||||
)
|
||||
|
||||
total_images = train_gen.samples
|
||||
logger.debug(
|
||||
f"Training {self.model_name}: {total_images} images across {num_classes} classes"
|
||||
)
|
||||
|
||||
# write labelmap
|
||||
class_indices = train_gen.class_indices
|
||||
index_to_class = {v: k for k, v in class_indices.items()}
|
||||
@ -128,7 +228,9 @@ class ClassificationTrainingProcess(FrigateProcess):
|
||||
f.write(f"{class_name}\n")
|
||||
|
||||
# train the model
|
||||
logger.debug(f"Training {self.model_name} for {EPOCHS} epochs...")
|
||||
model.fit(train_gen, epochs=EPOCHS, verbose=0)
|
||||
logger.debug(f"Converting {self.model_name} to TFLite...")
|
||||
|
||||
# convert model to tflite
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||
@ -142,9 +244,28 @@ class ClassificationTrainingProcess(FrigateProcess):
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# write model
|
||||
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
|
||||
model_path = os.path.join(model_dir, "model.tflite")
|
||||
with open(model_path, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
|
||||
# verify model file was written successfully
|
||||
if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
|
||||
logger.error(
|
||||
f"Training failed for {self.model_name}: Model file was not created or is empty"
|
||||
)
|
||||
return False
|
||||
|
||||
# write training metadata with image count
|
||||
dataset_image_count = get_dataset_image_count(self.model_name)
|
||||
write_training_metadata(self.model_name, dataset_image_count)
|
||||
|
||||
logger.info(f"Finished training {self.model_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed for {self.model_name}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def kickoff_model_training(
|
||||
embeddingRequestor: EmbeddingsRequestor, model_name: str
|
||||
@ -165,6 +286,10 @@ def kickoff_model_training(
|
||||
training_process.start()
|
||||
training_process.join()
|
||||
|
||||
# check if training succeeded by examining the exit code
|
||||
training_success = training_process.exitcode == 0
|
||||
|
||||
if training_success:
|
||||
# reload model and mark training as complete
|
||||
embeddingRequestor.send_data(
|
||||
EmbeddingsRequestEnum.reload_classification_model.value,
|
||||
@ -177,6 +302,20 @@ def kickoff_model_training(
|
||||
"state": ModelStatusTypesEnum.complete,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Training subprocess failed for {model_name} (exit code: {training_process.exitcode})"
|
||||
)
|
||||
# mark training as failed so UI shows error state
|
||||
# don't reload the model since it failed
|
||||
requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": model_name,
|
||||
"state": ModelStatusTypesEnum.failed,
|
||||
},
|
||||
)
|
||||
|
||||
requestor.stop()
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,11 @@
|
||||
"deleteModels": "Delete Models",
|
||||
"editModel": "Edit Model"
|
||||
},
|
||||
"tooltip": {
|
||||
"trainingInProgress": "Model is currently training",
|
||||
"noNewImages": "No new images to train. Classify more images in the dataset first.",
|
||||
"modelNotReady": "Model is not ready for training"
|
||||
},
|
||||
"toast": {
|
||||
"success": {
|
||||
"deletedCategory": "Deleted Class",
|
||||
@ -30,7 +35,8 @@
|
||||
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
|
||||
"deleteModelFailed": "Failed to delete model: {{errorMessage}}",
|
||||
"categorizeFailed": "Failed to categorize image: {{errorMessage}}",
|
||||
"trainingFailed": "Failed to start model training: {{errorMessage}}",
|
||||
"trainingFailed": "Model training failed. Check Frigate logs for details.",
|
||||
"trainingFailedToStart": "Failed to start model training: {{errorMessage}}",
|
||||
"updateModelFailed": "Failed to update model: {{errorMessage}}",
|
||||
"renameCategoryFailed": "Failed to rename class: {{errorMessage}}"
|
||||
}
|
||||
@ -143,6 +149,8 @@
|
||||
"step3": {
|
||||
"selectImagesPrompt": "Select all images with: {{className}}",
|
||||
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
|
||||
"allImagesRequired_one": "Please classify all images. {{count}} image remaining.",
|
||||
"allImagesRequired_other": "Please classify all images. {{count}} images remaining.",
|
||||
"generating": {
|
||||
"title": "Generating Sample Images",
|
||||
"description": "Frigate is pulling representative images from your recordings. This may take a moment..."
|
||||
|
||||
@ -10,6 +10,12 @@ import useSWR from "swr";
|
||||
import { baseUrl } from "@/api/baseUrl";
|
||||
import { isMobile } from "react-device-detect";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { TooltipPortal } from "@radix-ui/react-tooltip";
|
||||
|
||||
export type Step3FormData = {
|
||||
examplesGenerated: boolean;
|
||||
@ -317,6 +323,19 @@ export default function Step3ChooseExamples({
|
||||
return unclassifiedImages.length === 0;
|
||||
}, [unclassifiedImages]);
|
||||
|
||||
// For state models on the last class, require all images to be classified
|
||||
const isLastClass = currentClassIndex === allClasses.length - 1;
|
||||
const canProceed = useMemo(() => {
|
||||
if (
|
||||
step1Data.modelType === "state" &&
|
||||
isLastClass &&
|
||||
!allImagesClassified
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}, [step1Data.modelType, isLastClass, allImagesClassified]);
|
||||
|
||||
const handleBack = useCallback(() => {
|
||||
if (currentClassIndex > 0) {
|
||||
const previousClass = allClasses[currentClassIndex - 1];
|
||||
@ -438,6 +457,8 @@ export default function Step3ChooseExamples({
|
||||
<Button type="button" onClick={handleBack} className="sm:flex-1">
|
||||
{t("button.back", { ns: "common" })}
|
||||
</Button>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={
|
||||
@ -447,11 +468,24 @@ export default function Step3ChooseExamples({
|
||||
}
|
||||
variant="select"
|
||||
className="flex items-center justify-center gap-2 sm:flex-1"
|
||||
disabled={!hasGenerated || isGenerating || isProcessing}
|
||||
disabled={
|
||||
!hasGenerated || isGenerating || isProcessing || !canProceed
|
||||
}
|
||||
>
|
||||
{isProcessing && <ActivityIndicator className="size-4" />}
|
||||
{t("button.continue", { ns: "common" })}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
{!canProceed && (
|
||||
<TooltipPortal>
|
||||
<TooltipContent>
|
||||
{t("wizard.step3.allImagesRequired", {
|
||||
count: unclassifiedImages.length,
|
||||
})}
|
||||
</TooltipContent>
|
||||
</TooltipPortal>
|
||||
)}
|
||||
</Tooltip>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@ -87,7 +87,8 @@ export type ModelState =
|
||||
| "downloaded"
|
||||
| "error"
|
||||
| "training"
|
||||
| "complete";
|
||||
| "complete"
|
||||
| "failed";
|
||||
|
||||
export type EmbeddingsReindexProgressType = {
|
||||
thumbnails: number;
|
||||
|
||||
@ -102,6 +102,12 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
position: "top-center",
|
||||
});
|
||||
setWasTraining(false);
|
||||
refreshDataset();
|
||||
} else if (modelState == "failed") {
|
||||
toast.error(t("toast.error.trainingFailed"), {
|
||||
position: "top-center",
|
||||
});
|
||||
setWasTraining(false);
|
||||
}
|
||||
// only refresh when modelState changes
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
@ -112,10 +118,20 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
||||
`classification/${model.name}/train`,
|
||||
);
|
||||
const { data: dataset, mutate: refreshDataset } = useSWR<{
|
||||
[id: string]: string[];
|
||||
const { data: datasetResponse, mutate: refreshDataset } = useSWR<{
|
||||
categories: { [id: string]: string[] };
|
||||
training_metadata: {
|
||||
has_trained: boolean;
|
||||
last_training_date: string | null;
|
||||
last_training_image_count: number;
|
||||
current_image_count: number;
|
||||
new_images_count: number;
|
||||
} | null;
|
||||
}>(`classification/${model.name}/dataset`);
|
||||
|
||||
const dataset = datasetResponse?.categories || {};
|
||||
const trainingMetadata = datasetResponse?.training_metadata;
|
||||
|
||||
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
|
||||
|
||||
const refreshAll = useCallback(() => {
|
||||
@ -177,7 +193,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
error.response?.data?.detail ||
|
||||
"Unknown error";
|
||||
|
||||
toast.error(t("toast.error.trainingFailed", { errorMessage }), {
|
||||
toast.error(t("toast.error.trainingFailedToStart", { errorMessage }), {
|
||||
position: "top-center",
|
||||
});
|
||||
});
|
||||
@ -421,19 +437,48 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
filterValues={{ classes: Object.keys(dataset || {}) }}
|
||||
onUpdateFilter={setTrainFilter}
|
||||
/>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
className="flex justify-center gap-2"
|
||||
onClick={trainModel}
|
||||
variant="select"
|
||||
disabled={modelState != "complete"}
|
||||
variant={modelState == "failed" ? "destructive" : "select"}
|
||||
disabled={
|
||||
(modelState != "complete" && modelState != "failed") ||
|
||||
(trainingMetadata?.new_images_count ?? 0) === 0
|
||||
}
|
||||
>
|
||||
{modelState == "training" ? (
|
||||
<ActivityIndicator size={20} />
|
||||
) : (
|
||||
<HiSparkles className="text-white" />
|
||||
)}
|
||||
{isDesktop && t("button.trainModel")}
|
||||
{isDesktop && (
|
||||
<>
|
||||
{t("button.trainModel")}
|
||||
{trainingMetadata?.new_images_count !== undefined &&
|
||||
trainingMetadata.new_images_count > 0 && (
|
||||
<span className="text-sm text-selected-foreground">
|
||||
({trainingMetadata.new_images_count})
|
||||
</span>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
{((trainingMetadata?.new_images_count ?? 0) === 0 ||
|
||||
(modelState != "complete" && modelState != "failed")) && (
|
||||
<TooltipPortal>
|
||||
<TooltipContent>
|
||||
{modelState == "training"
|
||||
? t("tooltip.trainingInProgress")
|
||||
: trainingMetadata?.new_images_count === 0
|
||||
? t("tooltip.noNewImages")
|
||||
: t("tooltip.modelNotReady")}
|
||||
</TooltipContent>
|
||||
</TooltipPortal>
|
||||
)}
|
||||
</Tooltip>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
Loading…
Reference in New Issue
Block a user