mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-25 09:38:22 +03:00
Add ability to regenerate examples in classification wizard (#22604)
* add randomness to object classification also ensure train_dir is fresh if user has regenerated examples * frontend refresh button * fix radix dropdown issue * i18n
This commit is contained in:
parent
6c5801ac83
commit
91ef3b2ceb
@ -5,6 +5,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -397,6 +398,8 @@ def collect_state_classification_examples(
|
|||||||
|
|
||||||
# Step 5: Save to train directory for later classification
|
# Step 5: Save to train directory for later classification
|
||||||
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
||||||
|
if os.path.exists(train_dir):
|
||||||
|
shutil.rmtree(train_dir)
|
||||||
os.makedirs(train_dir, exist_ok=True)
|
os.makedirs(train_dir, exist_ok=True)
|
||||||
|
|
||||||
saved_count = 0
|
saved_count = 0
|
||||||
@ -411,8 +414,6 @@ def collect_state_classification_examples(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save image {image_path}: {e}")
|
logger.error(f"Failed to save image {image_path}: {e}")
|
||||||
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(temp_dir)
|
shutil.rmtree(temp_dir)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -750,6 +751,8 @@ def collect_object_classification_examples(
|
|||||||
|
|
||||||
# Step 5: Save to train directory for later classification
|
# Step 5: Save to train directory for later classification
|
||||||
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
||||||
|
if os.path.exists(train_dir):
|
||||||
|
shutil.rmtree(train_dir)
|
||||||
os.makedirs(train_dir, exist_ok=True)
|
os.makedirs(train_dir, exist_ok=True)
|
||||||
|
|
||||||
saved_count = 0
|
saved_count = 0
|
||||||
@ -764,8 +767,6 @@ def collect_object_classification_examples(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save image {image_path}: {e}")
|
logger.error(f"Failed to save image {image_path}: {e}")
|
||||||
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(temp_dir)
|
shutil.rmtree(temp_dir)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -806,24 +807,25 @@ def _select_balanced_events(
|
|||||||
selected = []
|
selected = []
|
||||||
|
|
||||||
for group_events in grouped.values():
|
for group_events in grouped.values():
|
||||||
|
# Take top events by score, then randomly sample from them
|
||||||
sorted_events = sorted(
|
sorted_events = sorted(
|
||||||
group_events,
|
group_events,
|
||||||
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
sample_size = min(samples_per_group, len(sorted_events))
|
# Consider top 3x candidates to allow randomness while preferring higher scores
|
||||||
selected.extend(sorted_events[:sample_size])
|
candidate_pool = sorted_events[: samples_per_group * 3]
|
||||||
|
sample_size = min(samples_per_group, len(candidate_pool))
|
||||||
|
selected.extend(random.sample(candidate_pool, sample_size))
|
||||||
|
|
||||||
if len(selected) < target_count:
|
if len(selected) < target_count:
|
||||||
remaining = [e for e in events if e not in selected]
|
remaining = [e for e in events if e not in selected]
|
||||||
remaining_sorted = sorted(
|
|
||||||
remaining,
|
|
||||||
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
needed = target_count - len(selected)
|
needed = target_count - len(selected)
|
||||||
selected.extend(remaining_sorted[:needed])
|
if len(remaining) > needed:
|
||||||
|
selected.extend(random.sample(remaining, needed))
|
||||||
|
else:
|
||||||
|
selected.extend(remaining)
|
||||||
|
|
||||||
return selected[:target_count]
|
return selected[:target_count]
|
||||||
|
|
||||||
|
|||||||
@ -180,9 +180,14 @@
|
|||||||
"classifyFailed": "Failed to classify images: {{error}}"
|
"classifyFailed": "Failed to classify images: {{error}}"
|
||||||
},
|
},
|
||||||
"generateSuccess": "Successfully generated sample images",
|
"generateSuccess": "Successfully generated sample images",
|
||||||
|
"refreshExamples": "Generate new examples",
|
||||||
|
"refreshConfirm": {
|
||||||
|
"title": "Generate New Examples?",
|
||||||
|
"description": "This will generate a new set of images and clear all selections, including any previous classes. You will need to re-select examples for all classes."
|
||||||
|
},
|
||||||
"missingStatesWarning": {
|
"missingStatesWarning": {
|
||||||
"title": "Missing State Examples",
|
"title": "Missing Class Examples",
|
||||||
"description": "It's recommended to select examples for all states for best results. You can continue without selecting all states, but the model will not be trained until all states have images. After continuing, use the Recent Classifications view to classify images for the missing states, then train the model."
|
"description": "Not all classes have examples. Try generating new examples to find the missing class, or continue and use the Recent Classifications view to add images later."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,7 +11,24 @@ import { baseUrl } from "@/api/baseUrl";
|
|||||||
import { isMobile } from "react-device-detect";
|
import { isMobile } from "react-device-detect";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||||
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogAction,
|
||||||
|
AlertDialogCancel,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogDescription,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogTitle,
|
||||||
|
} from "@/components/ui/alert-dialog";
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
TooltipContent,
|
||||||
|
TooltipTrigger,
|
||||||
|
} from "@/components/ui/tooltip";
|
||||||
|
import { TooltipPortal } from "@radix-ui/react-tooltip";
|
||||||
import { IoIosWarning } from "react-icons/io";
|
import { IoIosWarning } from "react-icons/io";
|
||||||
|
import { LuRefreshCw } from "react-icons/lu";
|
||||||
|
|
||||||
export type Step3FormData = {
|
export type Step3FormData = {
|
||||||
examplesGenerated: boolean;
|
examplesGenerated: boolean;
|
||||||
@ -47,6 +64,7 @@ export default function Step3ChooseExamples({
|
|||||||
const [selectedImages, setSelectedImages] = useState<Set<string>>(new Set());
|
const [selectedImages, setSelectedImages] = useState<Set<string>>(new Set());
|
||||||
const [cacheKey, setCacheKey] = useState<number>(Date.now());
|
const [cacheKey, setCacheKey] = useState<number>(Date.now());
|
||||||
const [loadedImages, setLoadedImages] = useState<Set<string>>(new Set());
|
const [loadedImages, setLoadedImages] = useState<Set<string>>(new Set());
|
||||||
|
const [showRefreshConfirm, setShowRefreshConfirm] = useState(false);
|
||||||
|
|
||||||
const handleImageLoad = useCallback((imageName: string) => {
|
const handleImageLoad = useCallback((imageName: string) => {
|
||||||
setLoadedImages((prev) => new Set(prev).add(imageName));
|
setLoadedImages((prev) => new Set(prev).add(imageName));
|
||||||
@ -484,8 +502,52 @@ export default function Step3ChooseExamples({
|
|||||||
}
|
}
|
||||||
}, [currentClassIndex, allClasses, imageClassifications, onBack]);
|
}, [currentClassIndex, allClasses, imageClassifications, onBack]);
|
||||||
|
|
||||||
|
const doRefresh = useCallback(() => {
|
||||||
|
setCurrentClassIndex(0);
|
||||||
|
setSelectedImages(new Set());
|
||||||
|
setImageClassifications({});
|
||||||
|
setLoadedImages(new Set());
|
||||||
|
setShowRefreshConfirm(false);
|
||||||
|
generateExamples();
|
||||||
|
}, [generateExamples]);
|
||||||
|
|
||||||
|
const handleRefresh = useCallback(() => {
|
||||||
|
if (Object.keys(imageClassifications).length > 0) {
|
||||||
|
setShowRefreshConfirm(true);
|
||||||
|
} else {
|
||||||
|
doRefresh();
|
||||||
|
}
|
||||||
|
}, [imageClassifications, doRefresh]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-6">
|
<div className="flex flex-col gap-6">
|
||||||
|
<AlertDialog
|
||||||
|
open={showRefreshConfirm}
|
||||||
|
onOpenChange={setShowRefreshConfirm}
|
||||||
|
>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader>
|
||||||
|
<AlertDialogTitle>
|
||||||
|
{t("wizard.step3.refreshConfirm.title")}
|
||||||
|
</AlertDialogTitle>
|
||||||
|
<AlertDialogDescription>
|
||||||
|
{t("wizard.step3.refreshConfirm.description")}
|
||||||
|
</AlertDialogDescription>
|
||||||
|
</AlertDialogHeader>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<AlertDialogCancel>
|
||||||
|
{t("button.cancel", { ns: "common" })}
|
||||||
|
</AlertDialogCancel>
|
||||||
|
<AlertDialogAction
|
||||||
|
onClick={doRefresh}
|
||||||
|
className="bg-destructive text-white hover:bg-destructive/90"
|
||||||
|
>
|
||||||
|
{t("button.continue", { ns: "common" })}
|
||||||
|
</AlertDialogAction>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialog>
|
||||||
|
|
||||||
{isTraining ? (
|
{isTraining ? (
|
||||||
<div className="flex flex-col items-center gap-6 py-12">
|
<div className="flex flex-col items-center gap-6 py-12">
|
||||||
<ActivityIndicator className="size-12" />
|
<ActivityIndicator className="size-12" />
|
||||||
@ -514,15 +576,43 @@ export default function Step3ChooseExamples({
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
) : hasGenerated ? (
|
) : hasGenerated ? (
|
||||||
<div className="flex flex-col gap-4">
|
<div className="relative flex flex-col gap-4">
|
||||||
|
<Tooltip open={showRefreshConfirm ? false : undefined}>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
className="absolute right-0 top-0 size-8"
|
||||||
|
onClick={handleRefresh}
|
||||||
|
disabled={isGenerating || isProcessing}
|
||||||
|
>
|
||||||
|
<LuRefreshCw className="size-4" />
|
||||||
|
</Button>
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipPortal>
|
||||||
|
<TooltipContent>
|
||||||
|
{t("wizard.step3.refreshExamples")}
|
||||||
|
</TooltipContent>
|
||||||
|
</TooltipPortal>
|
||||||
|
</Tooltip>
|
||||||
{showMissingStatesWarning && (
|
{showMissingStatesWarning && (
|
||||||
<Alert variant="destructive">
|
<Alert variant="destructive">
|
||||||
<IoIosWarning className="size-5" />
|
<IoIosWarning className="size-5" />
|
||||||
<AlertTitle>
|
<AlertTitle>
|
||||||
{t("wizard.step3.missingStatesWarning.title")}
|
{t("wizard.step3.missingStatesWarning.title")}
|
||||||
</AlertTitle>
|
</AlertTitle>
|
||||||
<AlertDescription>
|
<AlertDescription className="flex flex-col gap-2">
|
||||||
{t("wizard.step3.missingStatesWarning.description")}
|
{t("wizard.step3.missingStatesWarning.description")}
|
||||||
|
<Button
|
||||||
|
variant="secondary"
|
||||||
|
size="sm"
|
||||||
|
className="w-fit"
|
||||||
|
onClick={handleRefresh}
|
||||||
|
disabled={isGenerating || isProcessing}
|
||||||
|
>
|
||||||
|
<LuRefreshCw className="mr-1.5 size-3.5" />
|
||||||
|
{t("wizard.step3.refreshExamples")}
|
||||||
|
</Button>
|
||||||
</AlertDescription>
|
</AlertDescription>
|
||||||
</Alert>
|
</Alert>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@ -342,7 +342,7 @@ function ModelCard({ config, onClick, onUpdate, onDelete }: ModelCardProps) {
|
|||||||
{config.name}
|
{config.name}
|
||||||
</div>
|
</div>
|
||||||
<div className="absolute bottom-2 right-2 z-40">
|
<div className="absolute bottom-2 right-2 z-40">
|
||||||
<DropdownMenu>
|
<DropdownMenu modal={false}>
|
||||||
<DropdownMenuTrigger asChild onClick={(e) => e.stopPropagation()}>
|
<DropdownMenuTrigger asChild onClick={(e) => e.stopPropagation()}>
|
||||||
<BlurredIconButton>
|
<BlurredIconButton>
|
||||||
<FiMoreVertical className="size-5 text-white" />
|
<FiMoreVertical className="size-5 text-white" />
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user