diff --git a/frigate/api/classification.py b/frigate/api/classification.py index 59d4376cb..e9052097a 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -167,8 +167,7 @@ def train_face(request: Request, name: str, body: dict = None): new_name = f"{sanitized_name}-{datetime.datetime.now().timestamp()}.webp" new_file_folder = os.path.join(FACE_DIR, f"{sanitized_name}") - if not os.path.exists(new_file_folder): - os.mkdir(new_file_folder) + os.makedirs(new_file_folder, exist_ok=True) if training_file_name: shutil.move(training_file, os.path.join(new_file_folder, new_name)) @@ -716,8 +715,7 @@ def categorize_classification_image(request: Request, name: str, body: dict = No CLIPS_DIR, sanitize_filename(name), "dataset", category ) - if not os.path.exists(new_file_folder): - os.mkdir(new_file_folder) + os.makedirs(new_file_folder, exist_ok=True) # use opencv because webp images can not be used to train img = cv2.imread(training_file) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 1fb9dfc97..ac6387785 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -53,9 +53,17 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): self.tensor_output_details: dict[str, Any] | None = None self.labelmap: dict[int, str] = {} self.classifications_per_second = EventsPerSecond() - self.inference_speed = InferenceSpeed( - self.metrics.classification_speeds[self.model_config.name] - ) + + if ( + self.metrics + and self.model_config.name in self.metrics.classification_speeds + ): + self.inference_speed = InferenceSpeed( + self.metrics.classification_speeds[self.model_config.name] + ) + else: + self.inference_speed = None + self.last_run = datetime.datetime.now().timestamp() self.__build_detector() @@ -83,12 +91,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): def __update_metrics(self, duration: float) -> None: self.classifications_per_second.update() - self.inference_speed.update(duration) + if self.inference_speed: + self.inference_speed.update(duration) def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): - self.metrics.classification_cps[ - self.model_config.name - ].value = self.classifications_per_second.eps() + if self.metrics and self.model_config.name in self.metrics.classification_cps: + self.metrics.classification_cps[ + self.model_config.name + ].value = self.classifications_per_second.eps() camera = frame_data.get("camera") if camera not in self.model_config.state_config.cameras: @@ -223,9 +233,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): self.detected_objects: dict[str, float] = {} self.labelmap: dict[int, str] = {} self.classifications_per_second = EventsPerSecond() - self.inference_speed = InferenceSpeed( - self.metrics.classification_speeds[self.model_config.name] - ) + + if ( + self.metrics + and self.model_config.name in self.metrics.classification_speeds + ): + self.inference_speed = InferenceSpeed( + self.metrics.classification_speeds[self.model_config.name] + ) + else: + self.inference_speed = None + self.__build_detector() @redirect_output_to_logger(logger, logging.DEBUG) @@ -251,12 +269,14 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): def __update_metrics(self, duration: float) -> None: self.classifications_per_second.update() - self.inference_speed.update(duration) + if self.inference_speed: + self.inference_speed.update(duration) def process_frame(self, obj_data, frame): - self.metrics.classification_cps[ - self.model_config.name - ].value = self.classifications_per_second.eps() + if self.metrics and self.model_config.name in self.metrics.classification_cps: + self.metrics.classification_cps[ + self.model_config.name + ].value = self.classifications_per_second.eps() if obj_data["false_positive"]: return diff --git a/frigate/data_processing/types.py b/frigate/data_processing/types.py index 5eef3a044..263a8b987 100644 --- a/frigate/data_processing/types.py +++ b/frigate/data_processing/types.py @@ -10,7 +10,6 @@ from frigate.data_processing.real_time.whisper_online import FasterWhisperASR class DataProcessorMetrics: - manager: SyncManager image_embeddings_speed: Synchronized image_embeddings_eps: Synchronized text_embeddings_speed: Synchronized @@ -29,7 +28,6 @@ class DataProcessorMetrics: classification_cps: dict[str, Synchronized] def __init__(self, manager: SyncManager, custom_classification_models: list[str]): - self.manager = manager self.image_embeddings_speed = manager.Value("d", 0.0) self.image_embeddings_eps = manager.Value("d", 0.0) self.text_embeddings_speed = manager.Value("d", 0.0) @@ -52,12 +50,6 @@ class DataProcessorMetrics: self.classification_speeds[key] = manager.Value("d", 0.0) self.classification_cps[key] = manager.Value("d", 0.0) - def add_classification_model(self, model_name: str) -> None: - """Add metrics for a new classification model dynamically.""" - if model_name not in self.classification_speeds: - self.classification_speeds[model_name] = self.manager.Value("d", 0.0) - self.classification_cps[model_name] = self.manager.Value("d", 0.0) - class DataProcessorModelRunner: def __init__(self, requestor, device: str = "CPU", model_size: str = "large"): diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index e67e14842..fe04d8b17 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -304,9 +304,6 @@ class EmbeddingMaintainer(threading.Thread): ) return - if self.metrics: - self.metrics.add_classification_model(model_name) - if model_config.state_config is not None: processor = CustomStateClassificationProcessor( self.config, model_config, self.requestor, self.metrics diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index 0ac7391fa..a49082584 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -1,4 +1,5 @@ { + "documentTitle": "Classification Models", "button": { "deleteClassificationAttempts": "Delete Classification Images", "renameCategory": "Rename Class", @@ -96,6 +97,10 @@ "selectCameraPrompt": "Select a camera from the list to define its monitoring area" }, "step3": { + "selectImagesPrompt": "Select all images with: {{className}}", + "selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.", + "allImagesClassified": "All images classified!", + "readyToContinue": "Click Continue to finish and start training.", "description": "Classify the example images below. These samples will be used to train your model.", "generating": { "title": "Generating Sample Images", diff --git a/web/src/components/classification/ClassificationModelWizardDialog.tsx b/web/src/components/classification/ClassificationModelWizardDialog.tsx index 9301ed5d4..d32cee12b 100644 --- a/web/src/components/classification/ClassificationModelWizardDialog.tsx +++ b/web/src/components/classification/ClassificationModelWizardDialog.tsx @@ -135,7 +135,7 @@ export default function ClassificationModelWizardDialog({ { e.preventDefault(); diff --git a/web/src/components/classification/wizard/Step3ChooseExamples.tsx b/web/src/components/classification/wizard/Step3ChooseExamples.tsx index f798f599b..ab269e883 100644 --- a/web/src/components/classification/wizard/Step3ChooseExamples.tsx +++ b/web/src/components/classification/wizard/Step3ChooseExamples.tsx @@ -1,11 +1,4 @@ import { Button } from "@/components/ui/button"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; import { useTranslation } from "react-i18next"; import { useState, useEffect, useCallback, useMemo } from "react"; import ActivityIndicator from "@/components/indicators/activity-indicator"; @@ -48,6 +41,8 @@ export default function Step3ChooseExamples({ }>(initialData?.imageClassifications || {}); const [isTraining, setIsTraining] = useState(false); const [isProcessing, setIsProcessing] = useState(false); + const [currentClassIndex, setCurrentClassIndex] = useState(0); + const [selectedImages, setSelectedImages] = useState>(new Set()); const { data: trainImages, mutate: refreshTrainImages } = useSWR( hasGenerated ? `classification/${step1Data.modelName}/train` : null, @@ -58,16 +53,165 @@ export default function Step3ChooseExamples({ return trainImages; }, [trainImages]); - const handleClassificationChange = useCallback( - (imageName: string, className: string) => { - setImageClassifications((prev) => ({ - ...prev, - [imageName]: className, - })); + const toggleImageSelection = useCallback((imageName: string) => { + setSelectedImages((prev) => { + const newSet = new Set(prev); + if (newSet.has(imageName)) { + newSet.delete(imageName); + } else { + newSet.add(imageName); + } + return newSet; + }); + }, []); + + // Get all classes (excluding "none" - it will be auto-assigned) + const allClasses = useMemo(() => { + return [...step1Data.classes]; + }, [step1Data.classes]); + + const currentClass = allClasses[currentClassIndex]; + + const processClassificationsAndTrain = useCallback( + async (classifications: { [imageName: string]: string }) => { + // Step 1: Create config for the new model + const modelConfig: { + enabled: boolean; + name: string; + threshold: number; + state_config?: { + cameras: Record; + motion: boolean; + }; + object_config?: { objects: string[]; classification_type: string }; + } = { + enabled: true, + name: step1Data.modelName, + threshold: 0.8, + }; + + if (step1Data.modelType === "state") { + // State model config + const cameras: Record = {}; + step2Data?.cameraAreas.forEach((area) => { + cameras[area.camera] = { + crop: area.crop, + }; + }); + + modelConfig.state_config = { + cameras, + motion: true, + }; + } else { + // Object model config + modelConfig.object_config = { + objects: step1Data.objectLabel ? [step1Data.objectLabel] : [], + classification_type: step1Data.objectType || "sub_label", + } as { objects: string[]; classification_type: string }; + } + + // Update config via config API + await axios.put("/config/set", { + requires_restart: 0, + update_topic: `config/classification/custom/${step1Data.modelName}`, + config_data: { + classification: { + custom: { + [step1Data.modelName]: modelConfig, + }, + }, + }, + }); + + // Step 2: Classify each image by moving it to the correct category folder + const categorizePromises = Object.entries(classifications).map( + ([imageName, className]) => { + if (!className) return Promise.resolve(); + return axios.post( + `/classification/${step1Data.modelName}/dataset/categorize`, + { + training_file: imageName, + category: className === "none" ? "none" : className, + }, + ); + }, + ); + await Promise.all(categorizePromises); + + // Step 3: Kick off training + await axios.post(`/classification/${step1Data.modelName}/train`); + + toast.success(t("wizard.step3.trainingStarted")); + setIsTraining(true); }, - [], + [step1Data, step2Data, t], ); + const handleContinueClassification = useCallback(async () => { + // Mark selected images with current class + const newClassifications = { ...imageClassifications }; + selectedImages.forEach((imageName) => { + newClassifications[imageName] = currentClass; + }); + + // Check if we're on the last class to select + const isLastClass = currentClassIndex === allClasses.length - 1; + + if (isLastClass) { + // Assign remaining unclassified images + unknownImages.slice(0, 24).forEach((imageName) => { + if (!newClassifications[imageName]) { + // For state models with 2 classes, assign to the last class + // For object models, assign to "none" + if (step1Data.modelType === "state" && allClasses.length === 2) { + newClassifications[imageName] = allClasses[allClasses.length - 1]; + } else { + newClassifications[imageName] = "none"; + } + } + }); + + // All done, trigger training immediately + setImageClassifications(newClassifications); + setIsProcessing(true); + + try { + await processClassificationsAndTrain(newClassifications); + } catch (error) { + const axiosError = error as { + response?: { data?: { message?: string; detail?: string } }; + message?: string; + }; + const errorMessage = + axiosError.response?.data?.message || + axiosError.response?.data?.detail || + axiosError.message || + "Failed to classify images"; + + toast.error( + t("wizard.step3.errors.classifyFailed", { error: errorMessage }), + ); + setIsProcessing(false); + } + } else { + // Move to next class + setImageClassifications(newClassifications); + setCurrentClassIndex((prev) => prev + 1); + setSelectedImages(new Set()); + } + }, [ + selectedImages, + currentClass, + currentClassIndex, + allClasses, + imageClassifications, + unknownImages, + step1Data, + processClassificationsAndTrain, + t, + ]); + const generateExamples = useCallback(async () => { setIsGenerating(true); @@ -138,76 +282,7 @@ export default function Step3ChooseExamples({ const handleContinue = useCallback(async () => { setIsProcessing(true); try { - // Step 1: Create config for the new model - const modelConfig: { - enabled: boolean; - name: string; - threshold: number; - state_config?: { - cameras: Record; - motion: boolean; - }; - object_config?: { objects: string[]; classification_type: string }; - } = { - enabled: true, - name: step1Data.modelName, - threshold: 0.8, - }; - - if (step1Data.modelType === "state") { - // State model config - const cameras: Record = {}; - step2Data?.cameraAreas.forEach((area) => { - cameras[area.camera] = { - crop: area.crop, - }; - }); - - modelConfig.state_config = { - cameras, - motion: true, - }; - } else { - // Object model config - modelConfig.object_config = { - objects: step1Data.objectLabel ? [step1Data.objectLabel] : [], - classification_type: step1Data.objectType || "sub_label", - } as { objects: string[]; classification_type: string }; - } - - // Update config via config API - await axios.put("/config/set", { - requires_restart: 0, - update_topic: `config/classification/custom/${step1Data.modelName}`, - config_data: { - classification: { - custom: { - [step1Data.modelName]: modelConfig, - }, - }, - }, - }); - - // Step 2: Classify each image by moving it to the correct category folder - const categorizePromises = Object.entries(imageClassifications).map( - ([imageName, className]) => { - if (!className) return Promise.resolve(); - return axios.post( - `/classification/${step1Data.modelName}/dataset/categorize`, - { - training_file: imageName, - category: className === "none" ? "none" : className, - }, - ); - }, - ); - await Promise.all(categorizePromises); - - // Step 3: Kick off training - await axios.post(`/classification/${step1Data.modelName}/train`); - - toast.success(t("wizard.step3.trainingStarted")); - setIsTraining(true); + await processClassificationsAndTrain(imageClassifications); } catch (error) { const axiosError = error as { response?: { data?: { message?: string; detail?: string } }; @@ -224,13 +299,23 @@ export default function Step3ChooseExamples({ ); setIsProcessing(false); } - }, [imageClassifications, step1Data, step2Data, t]); + }, [imageClassifications, processClassificationsAndTrain, t]); + + const unclassifiedImages = useMemo(() => { + if (!unknownImages) return []; + const images = unknownImages.slice(0, 24); + + // Only filter if we have any classifications + if (Object.keys(imageClassifications).length === 0) { + return images; + } + + return images.filter((img) => !imageClassifications[img]); + }, [unknownImages, imageClassifications]); const allImagesClassified = useMemo(() => { - if (!unknownImages || unknownImages.length === 0) return false; - const imagesToClassify = unknownImages.slice(0, 24); - return imagesToClassify.every((img) => imageClassifications[img]); - }, [unknownImages, imageClassifications]); + return unclassifiedImages.length === 0; + }, [unclassifiedImages]); return (
@@ -263,9 +348,18 @@ export default function Step3ChooseExamples({
) : hasGenerated ? (
-
- {t("wizard.step3.description")} -
+ {!allImagesClassified && ( +
+

+ {t("wizard.step3.selectImagesPrompt", { + className: currentClass, + })} +

+

+ {t("wizard.step3.selectImagesDescription")} +

+
+ )}
+ ) : allImagesClassified && isProcessing ? ( +
+ +

+ {t("wizard.step3.classifying")} +

+
) : ( -
- {unknownImages.slice(0, 24).map((imageName, index) => ( -
- {`Example -
- +
+ {unclassifiedImages.map((imageName, index) => { + const isSelected = selectedImages.has(imageName); + return ( +
toggleImageSelection(imageName)} + > + {`Example
-
- ))} + ); + })}
)}
@@ -347,15 +424,14 @@ export default function Step3ChooseExamples({