diff --git a/frigate/api/app.py b/frigate/api/app.py index 7d6b89e27..5d09ecf00 100644 --- a/frigate/api/app.py +++ b/frigate/api/app.py @@ -387,27 +387,28 @@ def config_set(request: Request, body: AppConfigSetBody): old_config: FrigateConfig = request.app.frigate_config request.app.frigate_config = config - if body.update_topic and body.update_topic.startswith("config/cameras/"): - _, _, camera, field = body.update_topic.split("/") + if body.update_topic: + if body.update_topic.startswith("config/cameras/"): + _, _, camera, field = body.update_topic.split("/") - if field == "add": - settings = config.cameras[camera] - elif field == "remove": - settings = old_config.cameras[camera] - else: - settings = config.get_nested_object(body.update_topic) + if field == "add": + settings = config.cameras[camera] + elif field == "remove": + settings = old_config.cameras[camera] + else: + settings = config.get_nested_object(body.update_topic) - request.app.config_publisher.publish_update( - CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera), - settings, - ) - elif body.update_topic and "/config/" in body.update_topic[1:]: - # Handle nested config updates (e.g., config/classification/custom/{name}) - settings = config.get_nested_object(body.update_topic) - if settings: - request.app.config_publisher.publisher.publish( - body.update_topic, settings + request.app.config_publisher.publish_update( + CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera), + settings, ) + else: + # Handle nested config updates (e.g., config/classification/custom/{name}) + settings = config.get_nested_object(body.update_topic) + if settings: + request.app.config_publisher.publisher.publish( + body.update_topic, settings + ) return JSONResponse( content=( diff --git a/frigate/api/classification.py b/frigate/api/classification.py index 7961a8e82..59d4376cb 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -3,7 +3,9 @@ import datetime import logging import os +import random import shutil +import string from typing import Any import cv2 @@ -707,7 +709,9 @@ def categorize_classification_image(request: Request, name: str, body: dict = No status_code=404, ) - new_name = f"{category}-{datetime.datetime.now().timestamp()}.png" + random_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6)) + timestamp = datetime.datetime.now().timestamp() + new_name = f"{category}-{timestamp}-{random_id}.png" new_file_folder = os.path.join( CLIPS_DIR, sanitize_filename(name), "dataset", category ) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 1fb9dfc97..0251b35cb 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -353,6 +353,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): ) def handle_request(self, topic, request_data): + logger.info(f"comparing with topic {topic} and the {request_data.get('model_name')} and the {self.model_config.name}") if topic == EmbeddingsRequestEnum.reload_classification_model.value: if request_data.get("model_name") == self.model_config.name: logger.info( diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index e536d5f11..fe04d8b17 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -286,16 +286,11 @@ class EmbeddingMaintainer(threading.Thread): topic, model_config = self.classification_config_subscriber.check_for_update() if topic and model_config: - # Extract model name from topic: config/classification/custom/{model_name} model_name = topic.split("/")[-1] - logger.info( - f"Received classification config update for model: {model_name}" - ) - self.config.classification.custom[model_name] = model_config - existing_processor_index = None - for i, processor in enumerate(self.realtime_processors): + # Check if processor already exists + for processor in self.realtime_processors: if isinstance( processor, ( @@ -304,14 +299,10 @@ class EmbeddingMaintainer(threading.Thread): ), ): if processor.model_config.name == model_name: - existing_processor_index = i - break - - if existing_processor_index is not None: - logger.info( - f"Removing existing classification processor for model: {model_name}" - ) - self.realtime_processors.pop(existing_processor_index) + logger.debug( + f"Classification processor for model {model_name} already exists, skipping" + ) + return if model_config.state_config is not None: processor = CustomStateClassificationProcessor( @@ -326,7 +317,9 @@ class EmbeddingMaintainer(threading.Thread): ) self.realtime_processors.append(processor) - logger.info(f"Added classification processor for model: {model_name}") + logger.info( + f"Added classification processor for model: {model_name} (type: {type(processor).__name__})" + ) def _process_requests(self) -> None: """Process embeddings requests""" diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index e5275b8c8..0ac7391fa 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -101,6 +101,10 @@ "title": "Generating Sample Images", "description": "We're pulling representative images from your recordings. This may take a moment..." }, + "training": { + "title": "Training Model", + "description": "Your model is being trained in the background. You can close this wizard and the training will continue." + }, "retryGenerate": "Retry Generation", "selectClass": "Select class...", "none": "None", diff --git a/web/src/components/classification/wizard/Step3ChooseExamples.tsx b/web/src/components/classification/wizard/Step3ChooseExamples.tsx index a4a96c5b3..b626a796f 100644 --- a/web/src/components/classification/wizard/Step3ChooseExamples.tsx +++ b/web/src/components/classification/wizard/Step3ChooseExamples.tsx @@ -46,6 +46,7 @@ export default function Step3ChooseExamples({ const [imageClassifications, setImageClassifications] = useState<{ [imageName: string]: string; }>(initialData?.imageClassifications || {}); + const [isTraining, setIsTraining] = useState(false); const { data: trainImages, mutate: refreshTrainImages } = useSWR( hasGenerated ? `classification/${step1Data.modelName}/train` : null, @@ -186,25 +187,25 @@ export default function Step3ChooseExamples({ }); // Step 2: Classify each image by moving it to the correct category folder - for (const [imageName, className] of Object.entries( - imageClassifications, - )) { - if (!className) continue; - - await axios.post( - `/classification/${step1Data.modelName}/dataset/categorize`, - { - training_file: imageName, - category: className === "none" ? "none" : className, - }, - ); - } + const categorizePromises = Object.entries(imageClassifications).map( + ([imageName, className]) => { + if (!className) return Promise.resolve(); + return axios.post( + `/classification/${step1Data.modelName}/dataset/categorize`, + { + training_file: imageName, + category: className === "none" ? "none" : className, + }, + ); + }, + ); + await Promise.all(categorizePromises); // Step 3: Kick off training await axios.post(`/classification/${step1Data.modelName}/train`); toast.success(t("wizard.step3.trainingStarted")); - onClose(); + setIsTraining(true); } catch (error) { const axiosError = error as { response?: { data?: { message?: string; detail?: string } }; @@ -220,7 +221,7 @@ export default function Step3ChooseExamples({ t("wizard.step3.errors.classifyFailed", { error: errorMessage }), ); } - }, [onClose, imageClassifications, step1Data, step2Data, t]); + }, [imageClassifications, step1Data, step2Data, t]); const allImagesClassified = useMemo(() => { if (!unknownImages || unknownImages.length === 0) return false; @@ -230,7 +231,22 @@ export default function Step3ChooseExamples({ return (
- {isGenerating ? ( + {isTraining ? ( +
+ +
+

+ {t("wizard.step3.training.title")} +

+

+ {t("wizard.step3.training.description")} +

+
+ +
+ ) : isGenerating ? (
@@ -321,20 +337,22 @@ export default function Step3ChooseExamples({
)} -
- - -
+ {!isTraining && ( +
+ + +
+ )}
); } diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx index 372c47429..4860d3285 100644 --- a/web/src/views/classification/ModelSelectionView.tsx +++ b/web/src/views/classification/ModelSelectionView.tsx @@ -30,9 +30,12 @@ export default function ModelSelectionView({ const { t } = useTranslation(["views/classificationModel"]); const [page, setPage] = useState("objects"); const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); - const { data: config } = useSWR("config", { - revalidateOnFocus: false, - }); + const { data: config, mutate: refreshConfig } = useSWR( + "config", + { + revalidateOnFocus: false, + }, + ); // data @@ -71,7 +74,10 @@ export default function ModelSelectionView({ <> setNewModel(false)} + onClose={() => { + setNewModel(false); + refreshConfig(); + }} /> setNewModel(true)} />; @@ -82,7 +88,10 @@ export default function ModelSelectionView({
setNewModel(false)} + onClose={() => { + setNewModel(false); + refreshConfig(); + }} />