diff --git a/frigate/util/classification.py b/frigate/util/classification.py index 643f77d3b..ada3ee1f7 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -5,6 +5,7 @@ import json import logging import os import random +import shutil from collections import defaultdict import cv2 @@ -397,6 +398,8 @@ def collect_state_classification_examples( # Step 5: Save to train directory for later classification train_dir = os.path.join(CLIPS_DIR, model_name, "train") + if os.path.exists(train_dir): + shutil.rmtree(train_dir) os.makedirs(train_dir, exist_ok=True) saved_count = 0 @@ -411,8 +414,6 @@ def collect_state_classification_examples( except Exception as e: logger.error(f"Failed to save image {image_path}: {e}") - import shutil - try: shutil.rmtree(temp_dir) except Exception as e: @@ -750,6 +751,8 @@ def collect_object_classification_examples( # Step 5: Save to train directory for later classification train_dir = os.path.join(CLIPS_DIR, model_name, "train") + if os.path.exists(train_dir): + shutil.rmtree(train_dir) os.makedirs(train_dir, exist_ok=True) saved_count = 0 @@ -764,8 +767,6 @@ def collect_object_classification_examples( except Exception as e: logger.error(f"Failed to save image {image_path}: {e}") - import shutil - try: shutil.rmtree(temp_dir) except Exception as e: @@ -806,24 +807,25 @@ def _select_balanced_events( selected = [] for group_events in grouped.values(): + # Take top events by score, then randomly sample from them sorted_events = sorted( group_events, key=lambda e: e.data.get("score", 0) if e.data else 0, reverse=True, ) - sample_size = min(samples_per_group, len(sorted_events)) - selected.extend(sorted_events[:sample_size]) + # Consider top 3x candidates to allow randomness while preferring higher scores + candidate_pool = sorted_events[: samples_per_group * 3] + sample_size = min(samples_per_group, len(candidate_pool)) + selected.extend(random.sample(candidate_pool, sample_size)) if len(selected) < target_count: remaining = [e for e in events if e not in selected] - remaining_sorted = sorted( - remaining, - key=lambda e: e.data.get("score", 0) if e.data else 0, - reverse=True, - ) needed = target_count - len(selected) - selected.extend(remaining_sorted[:needed]) + if len(remaining) > needed: + selected.extend(random.sample(remaining, needed)) + else: + selected.extend(remaining) return selected[:target_count] diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index 03704fb50..17f881a3c 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -180,9 +180,14 @@ "classifyFailed": "Failed to classify images: {{error}}" }, "generateSuccess": "Successfully generated sample images", + "refreshExamples": "Generate new examples", + "refreshConfirm": { + "title": "Generate New Examples?", + "description": "This will generate a new set of images and clear all selections, including any previous classes. You will need to re-select examples for all classes." + }, "missingStatesWarning": { - "title": "Missing State Examples", - "description": "It's recommended to select examples for all states for best results. You can continue without selecting all states, but the model will not be trained until all states have images. After continuing, use the Recent Classifications view to classify images for the missing states, then train the model." + "title": "Missing Class Examples", + "description": "Not all classes have examples. Try generating new examples to find the missing class, or continue and use the Recent Classifications view to add images later." } } } diff --git a/web/src/components/classification/wizard/Step3ChooseExamples.tsx b/web/src/components/classification/wizard/Step3ChooseExamples.tsx index e3dd04afc..c6693d029 100644 --- a/web/src/components/classification/wizard/Step3ChooseExamples.tsx +++ b/web/src/components/classification/wizard/Step3ChooseExamples.tsx @@ -11,7 +11,24 @@ import { baseUrl } from "@/api/baseUrl"; import { isMobile } from "react-device-detect"; import { cn } from "@/lib/utils"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip"; +import { TooltipPortal } from "@radix-ui/react-tooltip"; import { IoIosWarning } from "react-icons/io"; +import { LuRefreshCw } from "react-icons/lu"; export type Step3FormData = { examplesGenerated: boolean; @@ -47,6 +64,7 @@ export default function Step3ChooseExamples({ const [selectedImages, setSelectedImages] = useState>(new Set()); const [cacheKey, setCacheKey] = useState(Date.now()); const [loadedImages, setLoadedImages] = useState>(new Set()); + const [showRefreshConfirm, setShowRefreshConfirm] = useState(false); const handleImageLoad = useCallback((imageName: string) => { setLoadedImages((prev) => new Set(prev).add(imageName)); @@ -484,8 +502,52 @@ export default function Step3ChooseExamples({ } }, [currentClassIndex, allClasses, imageClassifications, onBack]); + const doRefresh = useCallback(() => { + setCurrentClassIndex(0); + setSelectedImages(new Set()); + setImageClassifications({}); + setLoadedImages(new Set()); + setShowRefreshConfirm(false); + generateExamples(); + }, [generateExamples]); + + const handleRefresh = useCallback(() => { + if (Object.keys(imageClassifications).length > 0) { + setShowRefreshConfirm(true); + } else { + doRefresh(); + } + }, [imageClassifications, doRefresh]); + return (
+ + + + + {t("wizard.step3.refreshConfirm.title")} + + + {t("wizard.step3.refreshConfirm.description")} + + + + + {t("button.cancel", { ns: "common" })} + + + {t("button.continue", { ns: "common" })} + + + + + {isTraining ? (
@@ -514,15 +576,43 @@ export default function Step3ChooseExamples({
) : hasGenerated ? ( -
+
+ + + + + + + {t("wizard.step3.refreshExamples")} + + + {showMissingStatesWarning && ( {t("wizard.step3.missingStatesWarning.title")} - + {t("wizard.step3.missingStatesWarning.description")} + )} diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx index 3cd450bba..4b4ef492d 100644 --- a/web/src/views/classification/ModelSelectionView.tsx +++ b/web/src/views/classification/ModelSelectionView.tsx @@ -342,7 +342,7 @@ function ModelCard({ config, onClick, onUpdate, onDelete }: ModelCardProps) { {config.name}
- + e.stopPropagation()}>