From b62de79c39937f0527e4c75a361de1692f69f79c Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Sun, 9 Nov 2025 13:24:30 -0700 Subject: [PATCH] Improve train state by showing number of images to classify and adding tooltip --- frigate/api/classification.py | 45 +++++++-- frigate/util/classification.py | 92 +++++++++++++++++++ .../locales/en/views/classificationModel.json | 5 + .../classification/ModelTrainingView.tsx | 68 +++++++++++--- 4 files changed, 190 insertions(+), 20 deletions(-) diff --git a/frigate/api/classification.py b/frigate/api/classification.py index b3194176e..87de52884 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -37,6 +37,8 @@ from frigate.models import Event from frigate.util.classification import ( collect_object_classification_examples, collect_state_classification_examples, + get_dataset_image_count, + read_training_metadata, ) from frigate.util.file import get_event_snapshot @@ -564,23 +566,54 @@ def get_classification_dataset(name: str): dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset") if not os.path.exists(dataset_dir): - return JSONResponse(status_code=200, content={}) + return JSONResponse( + status_code=200, content={"categories": {}, "training_metadata": None} + ) - for name in os.listdir(dataset_dir): - category_dir = os.path.join(dataset_dir, name) + for category_name in os.listdir(dataset_dir): + category_dir = os.path.join(dataset_dir, category_name) if not os.path.isdir(category_dir): continue - dataset_dict[name] = [] + dataset_dict[category_name] = [] for file in filter( lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))), os.listdir(category_dir), ): - dataset_dict[name].append(file) + dataset_dict[category_name].append(file) - return JSONResponse(status_code=200, content=dataset_dict) + # Get training metadata + metadata = read_training_metadata(sanitize_filename(name)) + current_image_count = get_dataset_image_count(sanitize_filename(name)) + + if metadata is None: + training_metadata = { + "has_trained": False, + "last_training_date": None, + "last_training_image_count": 0, + "current_image_count": current_image_count, + "new_images_count": current_image_count, + } + else: + last_training_count = metadata.get("last_training_image_count", 0) + new_images_count = max(0, current_image_count - last_training_count) + training_metadata = { + "has_trained": True, + "last_training_date": metadata.get("last_training_date"), + "last_training_image_count": last_training_count, + "current_image_count": current_image_count, + "new_images_count": new_images_count, + } + + return JSONResponse( + status_code=200, + content={ + "categories": dataset_dict, + "training_metadata": training_metadata, + }, + ) @router.get( diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 43dfd7fd7..3cab97805 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -1,5 +1,7 @@ """Util for classification models.""" +import datetime +import json import logging import os import random @@ -27,10 +29,96 @@ from frigate.util.process import FrigateProcess BATCH_SIZE = 16 EPOCHS = 50 LEARNING_RATE = 0.001 +TRAINING_METADATA_FILE = ".training_metadata.json" logger = logging.getLogger(__name__) +def write_training_metadata(model_name: str, image_count: int) -> None: + """ + Write training metadata to a hidden file in the model's clips directory. + + Args: + model_name: Name of the classification model + image_count: Number of images used in training + """ + clips_model_dir = os.path.join(CLIPS_DIR, model_name) + os.makedirs(clips_model_dir, exist_ok=True) + + metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE) + metadata = { + "last_training_date": datetime.datetime.now().isoformat(), + "last_training_image_count": image_count, + } + + try: + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + logger.info(f"Wrote training metadata for {model_name}: {image_count} images") + except Exception as e: + logger.error(f"Failed to write training metadata for {model_name}: {e}") + + +def read_training_metadata(model_name: str) -> dict[str, any] | None: + """ + Read training metadata from the hidden file in the model's clips directory. + + Args: + model_name: Name of the classification model + + Returns: + Dictionary with last_training_date and last_training_image_count, or None if not found + """ + clips_model_dir = os.path.join(CLIPS_DIR, model_name) + metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE) + + if not os.path.exists(metadata_path): + return None + + try: + with open(metadata_path, "r") as f: + metadata = json.load(f) + return metadata + except Exception as e: + logger.error(f"Failed to read training metadata for {model_name}: {e}") + return None + + +def get_dataset_image_count(model_name: str) -> int: + """ + Count the total number of images in the model's dataset directory. + + Args: + model_name: Name of the classification model + + Returns: + Total count of images across all categories + """ + dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") + + if not os.path.exists(dataset_dir): + return 0 + + total_count = 0 + try: + for category in os.listdir(dataset_dir): + category_dir = os.path.join(dataset_dir, category) + if not os.path.isdir(category_dir): + continue + + image_files = [ + f + for f in os.listdir(category_dir) + if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg")) + ] + total_count += len(image_files) + except Exception as e: + logger.error(f"Failed to count dataset images for {model_name}: {e}") + return 0 + + return total_count + + class ClassificationTrainingProcess(FrigateProcess): def __init__(self, model_name: str) -> None: super().__init__( @@ -145,6 +233,10 @@ class ClassificationTrainingProcess(FrigateProcess): with open(os.path.join(model_dir, "model.tflite"), "wb") as f: f.write(tflite_model) + # write training metadata with image count + dataset_image_count = get_dataset_image_count(self.model_name) + write_training_metadata(self.model_name, dataset_image_count) + def kickoff_model_training( embeddingRequestor: EmbeddingsRequestor, model_name: str diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index 65118f227..4fbb95327 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -13,6 +13,11 @@ "deleteModels": "Delete Models", "editModel": "Edit Model" }, + "tooltip": { + "trainingInProgress": "Model is currently training", + "noNewImages": "No new images to train. Classify more images in the dataset first.", + "modelNotReady": "Model is not ready for training" + }, "toast": { "success": { "deletedCategory": "Deleted Class", diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 9e9025691..35b5584d7 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -102,6 +102,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { position: "top-center", }); setWasTraining(false); + refreshDataset(); } // only refresh when modelState changes // eslint-disable-next-line react-hooks/exhaustive-deps @@ -112,10 +113,20 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { const { data: trainImages, mutate: refreshTrain } = useSWR( `classification/${model.name}/train`, ); - const { data: dataset, mutate: refreshDataset } = useSWR<{ - [id: string]: string[]; + const { data: datasetResponse, mutate: refreshDataset } = useSWR<{ + categories: { [id: string]: string[] }; + training_metadata: { + has_trained: boolean; + last_training_date: string | null; + last_training_image_count: number; + current_image_count: number; + new_images_count: number; + } | null; }>(`classification/${model.name}/dataset`); + const dataset = datasetResponse?.categories || {}; + const trainingMetadata = datasetResponse?.training_metadata; + const [trainFilter, setTrainFilter] = useApiFilter(); const refreshAll = useCallback(() => { @@ -421,19 +432,48 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { filterValues={{ classes: Object.keys(dataset || {}) }} onUpdateFilter={setTrainFilter} /> - + + {((trainingMetadata?.new_images_count ?? 0) === 0 || + modelState != "complete") && ( + + + {modelState == "training" + ? t("tooltip.trainingInProgress") + : trainingMetadata?.new_images_count === 0 + ? t("tooltip.noNewImages") + : t("tooltip.modelNotReady")} + + )} - {isDesktop && t("button.trainModel")} - + )}