Improve train state by showing number of images to classify and adding tooltip

This commit is contained in:
Nicolas Mowen 2025-11-09 13:24:30 -07:00
parent a374a60756
commit b62de79c39
4 changed files with 190 additions and 20 deletions

View File

@ -37,6 +37,8 @@ from frigate.models import Event
from frigate.util.classification import ( from frigate.util.classification import (
collect_object_classification_examples, collect_object_classification_examples,
collect_state_classification_examples, collect_state_classification_examples,
get_dataset_image_count,
read_training_metadata,
) )
from frigate.util.file import get_event_snapshot from frigate.util.file import get_event_snapshot
@ -564,23 +566,54 @@ def get_classification_dataset(name: str):
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset") dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
if not os.path.exists(dataset_dir): if not os.path.exists(dataset_dir):
return JSONResponse(status_code=200, content={}) return JSONResponse(
status_code=200, content={"categories": {}, "training_metadata": None}
)
for name in os.listdir(dataset_dir): for category_name in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, name) category_dir = os.path.join(dataset_dir, category_name)
if not os.path.isdir(category_dir): if not os.path.isdir(category_dir):
continue continue
dataset_dict[name] = [] dataset_dict[category_name] = []
for file in filter( for file in filter(
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))), lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
os.listdir(category_dir), os.listdir(category_dir),
): ):
dataset_dict[name].append(file) dataset_dict[category_name].append(file)
return JSONResponse(status_code=200, content=dataset_dict) # Get training metadata
metadata = read_training_metadata(sanitize_filename(name))
current_image_count = get_dataset_image_count(sanitize_filename(name))
if metadata is None:
training_metadata = {
"has_trained": False,
"last_training_date": None,
"last_training_image_count": 0,
"current_image_count": current_image_count,
"new_images_count": current_image_count,
}
else:
last_training_count = metadata.get("last_training_image_count", 0)
new_images_count = max(0, current_image_count - last_training_count)
training_metadata = {
"has_trained": True,
"last_training_date": metadata.get("last_training_date"),
"last_training_image_count": last_training_count,
"current_image_count": current_image_count,
"new_images_count": new_images_count,
}
return JSONResponse(
status_code=200,
content={
"categories": dataset_dict,
"training_metadata": training_metadata,
},
)
@router.get( @router.get(

View File

@ -1,5 +1,7 @@
"""Util for classification models.""" """Util for classification models."""
import datetime
import json
import logging import logging
import os import os
import random import random
@ -27,10 +29,96 @@ from frigate.util.process import FrigateProcess
BATCH_SIZE = 16 BATCH_SIZE = 16
EPOCHS = 50 EPOCHS = 50
LEARNING_RATE = 0.001 LEARNING_RATE = 0.001
TRAINING_METADATA_FILE = ".training_metadata.json"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def write_training_metadata(model_name: str, image_count: int) -> None:
"""
Write training metadata to a hidden file in the model's clips directory.
Args:
model_name: Name of the classification model
image_count: Number of images used in training
"""
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
os.makedirs(clips_model_dir, exist_ok=True)
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
metadata = {
"last_training_date": datetime.datetime.now().isoformat(),
"last_training_image_count": image_count,
}
try:
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
logger.info(f"Wrote training metadata for {model_name}: {image_count} images")
except Exception as e:
logger.error(f"Failed to write training metadata for {model_name}: {e}")
def read_training_metadata(model_name: str) -> dict[str, any] | None:
"""
Read training metadata from the hidden file in the model's clips directory.
Args:
model_name: Name of the classification model
Returns:
Dictionary with last_training_date and last_training_image_count, or None if not found
"""
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
if not os.path.exists(metadata_path):
return None
try:
with open(metadata_path, "r") as f:
metadata = json.load(f)
return metadata
except Exception as e:
logger.error(f"Failed to read training metadata for {model_name}: {e}")
return None
def get_dataset_image_count(model_name: str) -> int:
"""
Count the total number of images in the model's dataset directory.
Args:
model_name: Name of the classification model
Returns:
Total count of images across all categories
"""
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
if not os.path.exists(dataset_dir):
return 0
total_count = 0
try:
for category in os.listdir(dataset_dir):
category_dir = os.path.join(dataset_dir, category)
if not os.path.isdir(category_dir):
continue
image_files = [
f
for f in os.listdir(category_dir)
if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))
]
total_count += len(image_files)
except Exception as e:
logger.error(f"Failed to count dataset images for {model_name}: {e}")
return 0
return total_count
class ClassificationTrainingProcess(FrigateProcess): class ClassificationTrainingProcess(FrigateProcess):
def __init__(self, model_name: str) -> None: def __init__(self, model_name: str) -> None:
super().__init__( super().__init__(
@ -145,6 +233,10 @@ class ClassificationTrainingProcess(FrigateProcess):
with open(os.path.join(model_dir, "model.tflite"), "wb") as f: with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
f.write(tflite_model) f.write(tflite_model)
# write training metadata with image count
dataset_image_count = get_dataset_image_count(self.model_name)
write_training_metadata(self.model_name, dataset_image_count)
def kickoff_model_training( def kickoff_model_training(
embeddingRequestor: EmbeddingsRequestor, model_name: str embeddingRequestor: EmbeddingsRequestor, model_name: str

View File

@ -13,6 +13,11 @@
"deleteModels": "Delete Models", "deleteModels": "Delete Models",
"editModel": "Edit Model" "editModel": "Edit Model"
}, },
"tooltip": {
"trainingInProgress": "Model is currently training",
"noNewImages": "No new images to train. Classify more images in the dataset first.",
"modelNotReady": "Model is not ready for training"
},
"toast": { "toast": {
"success": { "success": {
"deletedCategory": "Deleted Class", "deletedCategory": "Deleted Class",

View File

@ -102,6 +102,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
position: "top-center", position: "top-center",
}); });
setWasTraining(false); setWasTraining(false);
refreshDataset();
} }
// only refresh when modelState changes // only refresh when modelState changes
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
@ -112,10 +113,20 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>( const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
`classification/${model.name}/train`, `classification/${model.name}/train`,
); );
const { data: dataset, mutate: refreshDataset } = useSWR<{ const { data: datasetResponse, mutate: refreshDataset } = useSWR<{
[id: string]: string[]; categories: { [id: string]: string[] };
training_metadata: {
has_trained: boolean;
last_training_date: string | null;
last_training_image_count: number;
current_image_count: number;
new_images_count: number;
} | null;
}>(`classification/${model.name}/dataset`); }>(`classification/${model.name}/dataset`);
const dataset = datasetResponse?.categories || {};
const trainingMetadata = datasetResponse?.training_metadata;
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>(); const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
const refreshAll = useCallback(() => { const refreshAll = useCallback(() => {
@ -421,19 +432,48 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
filterValues={{ classes: Object.keys(dataset || {}) }} filterValues={{ classes: Object.keys(dataset || {}) }}
onUpdateFilter={setTrainFilter} onUpdateFilter={setTrainFilter}
/> />
<Tooltip>
<TooltipTrigger asChild>
<Button <Button
className="flex justify-center gap-2" className="flex justify-center gap-2"
onClick={trainModel} onClick={trainModel}
variant="select" variant="select"
disabled={modelState != "complete"} disabled={
modelState != "complete" ||
(trainingMetadata?.new_images_count ?? 0) === 0
}
> >
{modelState == "training" ? ( {modelState == "training" ? (
<ActivityIndicator size={20} /> <ActivityIndicator size={20} />
) : ( ) : (
<HiSparkles className="text-white" /> <HiSparkles className="text-white" />
)} )}
{isDesktop && t("button.trainModel")} {isDesktop && (
<>
{t("button.trainModel")}
{trainingMetadata?.new_images_count !== undefined &&
trainingMetadata.new_images_count > 0 && (
<span className="text-sm text-selected-foreground">
({trainingMetadata.new_images_count})
</span>
)}
</>
)}
</Button> </Button>
</TooltipTrigger>
{((trainingMetadata?.new_images_count ?? 0) === 0 ||
modelState != "complete") && (
<TooltipPortal>
<TooltipContent>
{modelState == "training"
? t("tooltip.trainingInProgress")
: trainingMetadata?.new_images_count === 0
? t("tooltip.noNewImages")
: t("tooltip.modelNotReady")}
</TooltipContent>
</TooltipPortal>
)}
</Tooltip>
</div> </div>
)} )}
</div> </div>