From 7f093d81a9003458e9a8f0f34c9a51b8f3cb5ef8 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Wed, 22 Oct 2025 14:30:33 -0600 Subject: [PATCH] Finalize training and image selection step --- frigate/api/app.py | 7 ++ frigate/embeddings/maintainer.py | 57 +++++++++ frigate/util/classification.py | 18 +-- .../locales/en/views/classificationModel.json | 6 +- .../wizard/Step3ChooseExamples.tsx | 117 ++++++++++++++++-- 5 files changed, 183 insertions(+), 22 deletions(-) diff --git a/frigate/api/app.py b/frigate/api/app.py index f84190407..7d6b89e27 100644 --- a/frigate/api/app.py +++ b/frigate/api/app.py @@ -401,6 +401,13 @@ def config_set(request: Request, body: AppConfigSetBody): CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera), settings, ) + elif body.update_topic and "/config/" in body.update_topic[1:]: + # Handle nested config updates (e.g., config/classification/custom/{name}) + settings = config.get_nested_object(body.update_topic) + if settings: + request.app.config_publisher.publisher.publish( + body.update_topic, settings + ) return JSONResponse( content=( diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 55e3d57ba..78e0ab067 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -9,6 +9,7 @@ from typing import Any from peewee import DoesNotExist +from frigate.comms.config_updater import ConfigSubscriber from frigate.comms.detections_updater import DetectionSubscriber, DetectionTypeEnum from frigate.comms.embeddings_updater import ( EmbeddingsRequestEnum, @@ -95,6 +96,9 @@ class EmbeddingMaintainer(threading.Thread): CameraConfigUpdateEnum.semantic_search, ], ) + self.classification_config_subscriber = ConfigSubscriber( + "config/classification/custom/" + ) # Configure Frigate DB db = SqliteVecQueueDatabase( @@ -255,6 +259,7 @@ class EmbeddingMaintainer(threading.Thread): """Maintain a SQLite-vec database for semantic search.""" while not self.stop_event.is_set(): self.config_updater.check_for_updates() + self._check_classification_config_updates() self._process_requests() self._process_updates() self._process_recordings_updates() @@ -265,6 +270,7 @@ class EmbeddingMaintainer(threading.Thread): self._process_event_metadata() self.config_updater.stop() + self.classification_config_subscriber.stop() self.event_subscriber.stop() self.event_end_subscriber.stop() self.recordings_subscriber.stop() @@ -275,6 +281,57 @@ class EmbeddingMaintainer(threading.Thread): self.requestor.stop() logger.info("Exiting embeddings maintenance...") + def _check_classification_config_updates(self) -> None: + """Check for classification config updates and add new processors.""" + topic, model_config = self.classification_config_subscriber.check_for_update() + + if topic and model_config: + # Extract model name from topic: config/classification/custom/{model_name} + model_name = topic.split("/")[-1] + logger.info( + f"Received classification config update for model: {model_name}" + ) + + # Update config + self.config.classification.custom[model_name] = model_config + + # Check if processor already exists for this model + existing_processor_index = None + for i, processor in enumerate(self.realtime_processors): + if isinstance( + processor, + ( + CustomStateClassificationProcessor, + CustomObjectClassificationProcessor, + ), + ): + if processor.model_config.name == model_name: + existing_processor_index = i + break + + # Remove existing processor if found + if existing_processor_index is not None: + logger.info( + f"Removing existing classification processor for model: {model_name}" + ) + self.realtime_processors.pop(existing_processor_index) + + # Add new processor + if model_config.state_config is not None: + processor = CustomStateClassificationProcessor( + self.config, model_config, self.requestor, self.metrics + ) + else: + processor = CustomObjectClassificationProcessor( + self.config, + model_config, + self.event_metadata_publisher, + self.metrics, + ) + + self.realtime_processors.append(processor) + logger.info(f"Added classification processor for model: {model_name}") + def _process_requests(self) -> None: """Process embeddings requests""" diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 4b6e51dce..3cd9a2b70 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -225,13 +225,13 @@ def collect_state_classification_examples( # Step 4: Select 24 most visually distinct images (they're already cropped) distinct_images = _select_distinct_images(keyframes, target_count=24) - # Step 5: Save to dataset directory (in "unknown" subfolder for unlabeled data) - unknown_dir = os.path.join(dataset_dir, "unknown") - os.makedirs(unknown_dir, exist_ok=True) + # Step 5: Save to train directory for later classification + train_dir = os.path.join(CLIPS_DIR, model_name, "train") + os.makedirs(train_dir, exist_ok=True) saved_count = 0 for idx, image_path in enumerate(distinct_images): - dest_path = os.path.join(unknown_dir, f"example_{idx:03d}.jpg") + dest_path = os.path.join(train_dir, f"example_{idx:03d}.jpg") try: img = cv2.imread(image_path) @@ -549,13 +549,13 @@ def collect_object_classification_examples( distinct_images = _select_distinct_images(thumbnails, target_count=24) logger.debug(f"Selected {len(distinct_images)} distinct images") - # Step 5: Save to dataset directory - unknown_dir = os.path.join(dataset_dir, "unknown") - os.makedirs(unknown_dir, exist_ok=True) + # Step 5: Save to train directory for later classification + train_dir = os.path.join(CLIPS_DIR, model_name, "train") + os.makedirs(train_dir, exist_ok=True) saved_count = 0 for idx, image_path in enumerate(distinct_images): - dest_path = os.path.join(unknown_dir, f"example_{idx:03d}.jpg") + dest_path = os.path.join(train_dir, f"example_{idx:03d}.jpg") try: img = cv2.imread(image_path) @@ -573,7 +573,7 @@ def collect_object_classification_examples( logger.warning(f"Failed to clean up temp directory: {e}") logger.debug( - f"Successfully collected {saved_count} classification examples in {unknown_dir}" + f"Successfully collected {saved_count} classification examples in {train_dir}" ) diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index 8f6614f37..a9d9381de 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -98,12 +98,16 @@ }, "retryGenerate": "Retry Generation", "selectClass": "Select class...", + "none": "None", "noImages": "No sample images generated", + "classifying": "Classifying & Training...", + "trainingStarted": "Training started successfully", "errors": { "noCameras": "No cameras configured", "noObjectLabel": "No object label selected", "generateFailed": "Failed to generate examples: {{error}}", - "generationFailed": "Generation failed. Please try again." + "generationFailed": "Generation failed. Please try again.", + "classifyFailed": "Failed to classify images: {{error}}" }, "generateSuccess": "Successfully generated sample images" } diff --git a/web/src/components/classification/wizard/Step3ChooseExamples.tsx b/web/src/components/classification/wizard/Step3ChooseExamples.tsx index 5172cae8a..e373fde21 100644 --- a/web/src/components/classification/wizard/Step3ChooseExamples.tsx +++ b/web/src/components/classification/wizard/Step3ChooseExamples.tsx @@ -45,14 +45,14 @@ export default function Step3ChooseExamples({ [imageName: string]: string; }>(initialData?.imageClassifications || {}); - const { data: dataset, mutate: refreshDataset } = useSWR<{ - [id: string]: string[]; - }>(hasGenerated ? `classification/${step1Data.modelName}/dataset` : null); + const { data: trainImages, mutate: refreshTrainImages } = useSWR( + hasGenerated ? `classification/${step1Data.modelName}/train` : null, + ); const unknownImages = useMemo(() => { - if (!dataset || !dataset.unknown) return []; - return dataset.unknown; - }, [dataset]); + if (!trainImages) return []; + return trainImages; + }, [trainImages]); const handleClassificationChange = useCallback( (imageName: string, className: string) => { @@ -104,7 +104,7 @@ export default function Step3ChooseExamples({ setHasGenerated(true); toast.success(t("wizard.step3.generateSuccess")); - await refreshDataset(); + await refreshTrainImages(); } catch (error) { const axiosError = error as { response?: { data?: { message?: string; detail?: string } }; @@ -122,7 +122,7 @@ export default function Step3ChooseExamples({ } finally { setIsGenerating(false); } - }, [step1Data, step2Data, t, refreshDataset]); + }, [step1Data, step2Data, t, refreshTrainImages]); useEffect(() => { if (!hasGenerated && !isGenerating) { @@ -131,9 +131,94 @@ export default function Step3ChooseExamples({ // eslint-disable-next-line react-hooks/exhaustive-deps }, []); - const handleContinue = useCallback(() => { - onNext({ examplesGenerated: true, imageClassifications }); - }, [onNext, imageClassifications]); + const handleContinue = useCallback(async () => { + try { + // Step 1: Create config for the new model + const modelConfig: { + enabled: boolean; + name: string; + threshold: number; + state_config?: { + cameras: Record; + motion: boolean; + }; + object_config?: { objects: string[]; classification_type: string }; + } = { + enabled: true, + name: step1Data.modelName, + threshold: 0.8, + }; + + if (step1Data.modelType === "state") { + // State model config + const cameras: Record = {}; + step2Data?.cameraAreas.forEach((area) => { + cameras[area.camera] = { + crop: area.crop, + }; + }); + + modelConfig.state_config = { + cameras, + motion: true, + }; + } else { + // Object model config + modelConfig.object_config = { + objects: step1Data.objectLabel ? [step1Data.objectLabel] : [], + classification_type: step1Data.objectType || "sub_label", + } as { objects: string[]; classification_type: string }; + } + + // Update config via config API + await axios.put("/config/set", { + requires_restart: 0, + update_topic: `config/classification/custom/${step1Data.modelName}`, + config_data: { + classification: { + custom: { + [step1Data.modelName]: modelConfig, + }, + }, + }, + }); + + // Step 2: Classify each image by moving it to the correct category folder + for (const [imageName, className] of Object.entries( + imageClassifications, + )) { + if (!className) continue; + + await axios.post( + `/classification/${step1Data.modelName}/dataset/categorize`, + { + training_file: imageName, + category: className === "none" ? "none" : className, + }, + ); + } + + // Step 3: Kick off training + await axios.post(`/classification/${step1Data.modelName}/train`); + + toast.success(t("wizard.step3.trainingStarted")); + onNext({ examplesGenerated: true, imageClassifications }); + } catch (error) { + const axiosError = error as { + response?: { data?: { message?: string; detail?: string } }; + message?: string; + }; + const errorMessage = + axiosError.response?.data?.message || + axiosError.response?.data?.detail || + axiosError.message || + "Failed to classify images"; + + toast.error( + t("wizard.step3.errors.classifyFailed", { error: errorMessage }), + ); + } + }, [onNext, imageClassifications, step1Data, step2Data, t]); const allImagesClassified = useMemo(() => { if (!unknownImages || unknownImages.length === 0) return false; @@ -175,7 +260,7 @@ export default function Step3ChooseExamples({ className="group relative aspect-square cursor-pointer overflow-hidden rounded-lg border bg-background transition-all hover:ring-2 hover:ring-primary" > {`Example @@ -192,6 +277,14 @@ export default function Step3ChooseExamples({ /> + {step1Data.modelType === "object" && ( + + {t("wizard.step3.none")} + + )} {step1Data.classes.map((className) => (