mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-03-27 02:28:22 +03:00
Improve image selection mechanism
This commit is contained in:
parent
855021dfc4
commit
0a569fa3c0
@ -167,8 +167,7 @@ def train_face(request: Request, name: str, body: dict = None):
|
|||||||
new_name = f"{sanitized_name}-{datetime.datetime.now().timestamp()}.webp"
|
new_name = f"{sanitized_name}-{datetime.datetime.now().timestamp()}.webp"
|
||||||
new_file_folder = os.path.join(FACE_DIR, f"{sanitized_name}")
|
new_file_folder = os.path.join(FACE_DIR, f"{sanitized_name}")
|
||||||
|
|
||||||
if not os.path.exists(new_file_folder):
|
os.makedirs(new_file_folder, exist_ok=True)
|
||||||
os.mkdir(new_file_folder)
|
|
||||||
|
|
||||||
if training_file_name:
|
if training_file_name:
|
||||||
shutil.move(training_file, os.path.join(new_file_folder, new_name))
|
shutil.move(training_file, os.path.join(new_file_folder, new_name))
|
||||||
@ -716,8 +715,7 @@ def categorize_classification_image(request: Request, name: str, body: dict = No
|
|||||||
CLIPS_DIR, sanitize_filename(name), "dataset", category
|
CLIPS_DIR, sanitize_filename(name), "dataset", category
|
||||||
)
|
)
|
||||||
|
|
||||||
if not os.path.exists(new_file_folder):
|
os.makedirs(new_file_folder, exist_ok=True)
|
||||||
os.mkdir(new_file_folder)
|
|
||||||
|
|
||||||
# use opencv because webp images can not be used to train
|
# use opencv because webp images can not be used to train
|
||||||
img = cv2.imread(training_file)
|
img = cv2.imread(training_file)
|
||||||
|
|||||||
@ -53,9 +53,17 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.tensor_output_details: dict[str, Any] | None = None
|
self.tensor_output_details: dict[str, Any] | None = None
|
||||||
self.labelmap: dict[int, str] = {}
|
self.labelmap: dict[int, str] = {}
|
||||||
self.classifications_per_second = EventsPerSecond()
|
self.classifications_per_second = EventsPerSecond()
|
||||||
self.inference_speed = InferenceSpeed(
|
|
||||||
self.metrics.classification_speeds[self.model_config.name]
|
if (
|
||||||
)
|
self.metrics
|
||||||
|
and self.model_config.name in self.metrics.classification_speeds
|
||||||
|
):
|
||||||
|
self.inference_speed = InferenceSpeed(
|
||||||
|
self.metrics.classification_speeds[self.model_config.name]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.inference_speed = None
|
||||||
|
|
||||||
self.last_run = datetime.datetime.now().timestamp()
|
self.last_run = datetime.datetime.now().timestamp()
|
||||||
self.__build_detector()
|
self.__build_detector()
|
||||||
|
|
||||||
@ -83,12 +91,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
|
|
||||||
def __update_metrics(self, duration: float) -> None:
|
def __update_metrics(self, duration: float) -> None:
|
||||||
self.classifications_per_second.update()
|
self.classifications_per_second.update()
|
||||||
self.inference_speed.update(duration)
|
if self.inference_speed:
|
||||||
|
self.inference_speed.update(duration)
|
||||||
|
|
||||||
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
||||||
self.metrics.classification_cps[
|
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
||||||
self.model_config.name
|
self.metrics.classification_cps[
|
||||||
].value = self.classifications_per_second.eps()
|
self.model_config.name
|
||||||
|
].value = self.classifications_per_second.eps()
|
||||||
camera = frame_data.get("camera")
|
camera = frame_data.get("camera")
|
||||||
|
|
||||||
if camera not in self.model_config.state_config.cameras:
|
if camera not in self.model_config.state_config.cameras:
|
||||||
@ -223,9 +233,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
self.detected_objects: dict[str, float] = {}
|
self.detected_objects: dict[str, float] = {}
|
||||||
self.labelmap: dict[int, str] = {}
|
self.labelmap: dict[int, str] = {}
|
||||||
self.classifications_per_second = EventsPerSecond()
|
self.classifications_per_second = EventsPerSecond()
|
||||||
self.inference_speed = InferenceSpeed(
|
|
||||||
self.metrics.classification_speeds[self.model_config.name]
|
if (
|
||||||
)
|
self.metrics
|
||||||
|
and self.model_config.name in self.metrics.classification_speeds
|
||||||
|
):
|
||||||
|
self.inference_speed = InferenceSpeed(
|
||||||
|
self.metrics.classification_speeds[self.model_config.name]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.inference_speed = None
|
||||||
|
|
||||||
self.__build_detector()
|
self.__build_detector()
|
||||||
|
|
||||||
@redirect_output_to_logger(logger, logging.DEBUG)
|
@redirect_output_to_logger(logger, logging.DEBUG)
|
||||||
@ -251,12 +269,14 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
|
|
||||||
def __update_metrics(self, duration: float) -> None:
|
def __update_metrics(self, duration: float) -> None:
|
||||||
self.classifications_per_second.update()
|
self.classifications_per_second.update()
|
||||||
self.inference_speed.update(duration)
|
if self.inference_speed:
|
||||||
|
self.inference_speed.update(duration)
|
||||||
|
|
||||||
def process_frame(self, obj_data, frame):
|
def process_frame(self, obj_data, frame):
|
||||||
self.metrics.classification_cps[
|
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
||||||
self.model_config.name
|
self.metrics.classification_cps[
|
||||||
].value = self.classifications_per_second.eps()
|
self.model_config.name
|
||||||
|
].value = self.classifications_per_second.eps()
|
||||||
|
|
||||||
if obj_data["false_positive"]:
|
if obj_data["false_positive"]:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from frigate.data_processing.real_time.whisper_online import FasterWhisperASR
|
|||||||
|
|
||||||
|
|
||||||
class DataProcessorMetrics:
|
class DataProcessorMetrics:
|
||||||
manager: SyncManager
|
|
||||||
image_embeddings_speed: Synchronized
|
image_embeddings_speed: Synchronized
|
||||||
image_embeddings_eps: Synchronized
|
image_embeddings_eps: Synchronized
|
||||||
text_embeddings_speed: Synchronized
|
text_embeddings_speed: Synchronized
|
||||||
@ -29,7 +28,6 @@ class DataProcessorMetrics:
|
|||||||
classification_cps: dict[str, Synchronized]
|
classification_cps: dict[str, Synchronized]
|
||||||
|
|
||||||
def __init__(self, manager: SyncManager, custom_classification_models: list[str]):
|
def __init__(self, manager: SyncManager, custom_classification_models: list[str]):
|
||||||
self.manager = manager
|
|
||||||
self.image_embeddings_speed = manager.Value("d", 0.0)
|
self.image_embeddings_speed = manager.Value("d", 0.0)
|
||||||
self.image_embeddings_eps = manager.Value("d", 0.0)
|
self.image_embeddings_eps = manager.Value("d", 0.0)
|
||||||
self.text_embeddings_speed = manager.Value("d", 0.0)
|
self.text_embeddings_speed = manager.Value("d", 0.0)
|
||||||
@ -52,12 +50,6 @@ class DataProcessorMetrics:
|
|||||||
self.classification_speeds[key] = manager.Value("d", 0.0)
|
self.classification_speeds[key] = manager.Value("d", 0.0)
|
||||||
self.classification_cps[key] = manager.Value("d", 0.0)
|
self.classification_cps[key] = manager.Value("d", 0.0)
|
||||||
|
|
||||||
def add_classification_model(self, model_name: str) -> None:
|
|
||||||
"""Add metrics for a new classification model dynamically."""
|
|
||||||
if model_name not in self.classification_speeds:
|
|
||||||
self.classification_speeds[model_name] = self.manager.Value("d", 0.0)
|
|
||||||
self.classification_cps[model_name] = self.manager.Value("d", 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
class DataProcessorModelRunner:
|
class DataProcessorModelRunner:
|
||||||
def __init__(self, requestor, device: str = "CPU", model_size: str = "large"):
|
def __init__(self, requestor, device: str = "CPU", model_size: str = "large"):
|
||||||
|
|||||||
@ -304,9 +304,6 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.metrics:
|
|
||||||
self.metrics.add_classification_model(model_name)
|
|
||||||
|
|
||||||
if model_config.state_config is not None:
|
if model_config.state_config is not None:
|
||||||
processor = CustomStateClassificationProcessor(
|
processor = CustomStateClassificationProcessor(
|
||||||
self.config, model_config, self.requestor, self.metrics
|
self.config, model_config, self.requestor, self.metrics
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
{
|
{
|
||||||
|
"documentTitle": "Classification Models",
|
||||||
"button": {
|
"button": {
|
||||||
"deleteClassificationAttempts": "Delete Classification Images",
|
"deleteClassificationAttempts": "Delete Classification Images",
|
||||||
"renameCategory": "Rename Class",
|
"renameCategory": "Rename Class",
|
||||||
@ -96,6 +97,10 @@
|
|||||||
"selectCameraPrompt": "Select a camera from the list to define its monitoring area"
|
"selectCameraPrompt": "Select a camera from the list to define its monitoring area"
|
||||||
},
|
},
|
||||||
"step3": {
|
"step3": {
|
||||||
|
"selectImagesPrompt": "Select all images with: {{className}}",
|
||||||
|
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
|
||||||
|
"allImagesClassified": "All images classified!",
|
||||||
|
"readyToContinue": "Click Continue to finish and start training.",
|
||||||
"description": "Classify the example images below. These samples will be used to train your model.",
|
"description": "Classify the example images below. These samples will be used to train your model.",
|
||||||
"generating": {
|
"generating": {
|
||||||
"title": "Generating Sample Images",
|
"title": "Generating Sample Images",
|
||||||
|
|||||||
@ -135,7 +135,7 @@ export default function ClassificationModelWizardDialog({
|
|||||||
<DialogContent
|
<DialogContent
|
||||||
className={cn(
|
className={cn(
|
||||||
"",
|
"",
|
||||||
isDesktop && "max-h-[75dvh] max-w-6xl overflow-y-auto",
|
isDesktop && "max-h-[75dvh] max-w-[40%] overflow-y-auto",
|
||||||
)}
|
)}
|
||||||
onInteractOutside={(e) => {
|
onInteractOutside={(e) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
|||||||
@ -1,11 +1,4 @@
|
|||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import {
|
|
||||||
Select,
|
|
||||||
SelectContent,
|
|
||||||
SelectItem,
|
|
||||||
SelectTrigger,
|
|
||||||
SelectValue,
|
|
||||||
} from "@/components/ui/select";
|
|
||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { useState, useEffect, useCallback, useMemo } from "react";
|
import { useState, useEffect, useCallback, useMemo } from "react";
|
||||||
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
||||||
@ -48,6 +41,8 @@ export default function Step3ChooseExamples({
|
|||||||
}>(initialData?.imageClassifications || {});
|
}>(initialData?.imageClassifications || {});
|
||||||
const [isTraining, setIsTraining] = useState(false);
|
const [isTraining, setIsTraining] = useState(false);
|
||||||
const [isProcessing, setIsProcessing] = useState(false);
|
const [isProcessing, setIsProcessing] = useState(false);
|
||||||
|
const [currentClassIndex, setCurrentClassIndex] = useState(0);
|
||||||
|
const [selectedImages, setSelectedImages] = useState<Set<string>>(new Set());
|
||||||
|
|
||||||
const { data: trainImages, mutate: refreshTrainImages } = useSWR<string[]>(
|
const { data: trainImages, mutate: refreshTrainImages } = useSWR<string[]>(
|
||||||
hasGenerated ? `classification/${step1Data.modelName}/train` : null,
|
hasGenerated ? `classification/${step1Data.modelName}/train` : null,
|
||||||
@ -58,16 +53,165 @@ export default function Step3ChooseExamples({
|
|||||||
return trainImages;
|
return trainImages;
|
||||||
}, [trainImages]);
|
}, [trainImages]);
|
||||||
|
|
||||||
const handleClassificationChange = useCallback(
|
const toggleImageSelection = useCallback((imageName: string) => {
|
||||||
(imageName: string, className: string) => {
|
setSelectedImages((prev) => {
|
||||||
setImageClassifications((prev) => ({
|
const newSet = new Set(prev);
|
||||||
...prev,
|
if (newSet.has(imageName)) {
|
||||||
[imageName]: className,
|
newSet.delete(imageName);
|
||||||
}));
|
} else {
|
||||||
|
newSet.add(imageName);
|
||||||
|
}
|
||||||
|
return newSet;
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Get all classes (excluding "none" - it will be auto-assigned)
|
||||||
|
const allClasses = useMemo(() => {
|
||||||
|
return [...step1Data.classes];
|
||||||
|
}, [step1Data.classes]);
|
||||||
|
|
||||||
|
const currentClass = allClasses[currentClassIndex];
|
||||||
|
|
||||||
|
const processClassificationsAndTrain = useCallback(
|
||||||
|
async (classifications: { [imageName: string]: string }) => {
|
||||||
|
// Step 1: Create config for the new model
|
||||||
|
const modelConfig: {
|
||||||
|
enabled: boolean;
|
||||||
|
name: string;
|
||||||
|
threshold: number;
|
||||||
|
state_config?: {
|
||||||
|
cameras: Record<string, { crop: number[] }>;
|
||||||
|
motion: boolean;
|
||||||
|
};
|
||||||
|
object_config?: { objects: string[]; classification_type: string };
|
||||||
|
} = {
|
||||||
|
enabled: true,
|
||||||
|
name: step1Data.modelName,
|
||||||
|
threshold: 0.8,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (step1Data.modelType === "state") {
|
||||||
|
// State model config
|
||||||
|
const cameras: Record<string, { crop: number[] }> = {};
|
||||||
|
step2Data?.cameraAreas.forEach((area) => {
|
||||||
|
cameras[area.camera] = {
|
||||||
|
crop: area.crop,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
modelConfig.state_config = {
|
||||||
|
cameras,
|
||||||
|
motion: true,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
// Object model config
|
||||||
|
modelConfig.object_config = {
|
||||||
|
objects: step1Data.objectLabel ? [step1Data.objectLabel] : [],
|
||||||
|
classification_type: step1Data.objectType || "sub_label",
|
||||||
|
} as { objects: string[]; classification_type: string };
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update config via config API
|
||||||
|
await axios.put("/config/set", {
|
||||||
|
requires_restart: 0,
|
||||||
|
update_topic: `config/classification/custom/${step1Data.modelName}`,
|
||||||
|
config_data: {
|
||||||
|
classification: {
|
||||||
|
custom: {
|
||||||
|
[step1Data.modelName]: modelConfig,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Step 2: Classify each image by moving it to the correct category folder
|
||||||
|
const categorizePromises = Object.entries(classifications).map(
|
||||||
|
([imageName, className]) => {
|
||||||
|
if (!className) return Promise.resolve();
|
||||||
|
return axios.post(
|
||||||
|
`/classification/${step1Data.modelName}/dataset/categorize`,
|
||||||
|
{
|
||||||
|
training_file: imageName,
|
||||||
|
category: className === "none" ? "none" : className,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
await Promise.all(categorizePromises);
|
||||||
|
|
||||||
|
// Step 3: Kick off training
|
||||||
|
await axios.post(`/classification/${step1Data.modelName}/train`);
|
||||||
|
|
||||||
|
toast.success(t("wizard.step3.trainingStarted"));
|
||||||
|
setIsTraining(true);
|
||||||
},
|
},
|
||||||
[],
|
[step1Data, step2Data, t],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleContinueClassification = useCallback(async () => {
|
||||||
|
// Mark selected images with current class
|
||||||
|
const newClassifications = { ...imageClassifications };
|
||||||
|
selectedImages.forEach((imageName) => {
|
||||||
|
newClassifications[imageName] = currentClass;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check if we're on the last class to select
|
||||||
|
const isLastClass = currentClassIndex === allClasses.length - 1;
|
||||||
|
|
||||||
|
if (isLastClass) {
|
||||||
|
// Assign remaining unclassified images
|
||||||
|
unknownImages.slice(0, 24).forEach((imageName) => {
|
||||||
|
if (!newClassifications[imageName]) {
|
||||||
|
// For state models with 2 classes, assign to the last class
|
||||||
|
// For object models, assign to "none"
|
||||||
|
if (step1Data.modelType === "state" && allClasses.length === 2) {
|
||||||
|
newClassifications[imageName] = allClasses[allClasses.length - 1];
|
||||||
|
} else {
|
||||||
|
newClassifications[imageName] = "none";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// All done, trigger training immediately
|
||||||
|
setImageClassifications(newClassifications);
|
||||||
|
setIsProcessing(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
await processClassificationsAndTrain(newClassifications);
|
||||||
|
} catch (error) {
|
||||||
|
const axiosError = error as {
|
||||||
|
response?: { data?: { message?: string; detail?: string } };
|
||||||
|
message?: string;
|
||||||
|
};
|
||||||
|
const errorMessage =
|
||||||
|
axiosError.response?.data?.message ||
|
||||||
|
axiosError.response?.data?.detail ||
|
||||||
|
axiosError.message ||
|
||||||
|
"Failed to classify images";
|
||||||
|
|
||||||
|
toast.error(
|
||||||
|
t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
|
||||||
|
);
|
||||||
|
setIsProcessing(false);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Move to next class
|
||||||
|
setImageClassifications(newClassifications);
|
||||||
|
setCurrentClassIndex((prev) => prev + 1);
|
||||||
|
setSelectedImages(new Set());
|
||||||
|
}
|
||||||
|
}, [
|
||||||
|
selectedImages,
|
||||||
|
currentClass,
|
||||||
|
currentClassIndex,
|
||||||
|
allClasses,
|
||||||
|
imageClassifications,
|
||||||
|
unknownImages,
|
||||||
|
step1Data,
|
||||||
|
processClassificationsAndTrain,
|
||||||
|
t,
|
||||||
|
]);
|
||||||
|
|
||||||
const generateExamples = useCallback(async () => {
|
const generateExamples = useCallback(async () => {
|
||||||
setIsGenerating(true);
|
setIsGenerating(true);
|
||||||
|
|
||||||
@ -138,76 +282,7 @@ export default function Step3ChooseExamples({
|
|||||||
const handleContinue = useCallback(async () => {
|
const handleContinue = useCallback(async () => {
|
||||||
setIsProcessing(true);
|
setIsProcessing(true);
|
||||||
try {
|
try {
|
||||||
// Step 1: Create config for the new model
|
await processClassificationsAndTrain(imageClassifications);
|
||||||
const modelConfig: {
|
|
||||||
enabled: boolean;
|
|
||||||
name: string;
|
|
||||||
threshold: number;
|
|
||||||
state_config?: {
|
|
||||||
cameras: Record<string, { crop: number[] }>;
|
|
||||||
motion: boolean;
|
|
||||||
};
|
|
||||||
object_config?: { objects: string[]; classification_type: string };
|
|
||||||
} = {
|
|
||||||
enabled: true,
|
|
||||||
name: step1Data.modelName,
|
|
||||||
threshold: 0.8,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (step1Data.modelType === "state") {
|
|
||||||
// State model config
|
|
||||||
const cameras: Record<string, { crop: number[] }> = {};
|
|
||||||
step2Data?.cameraAreas.forEach((area) => {
|
|
||||||
cameras[area.camera] = {
|
|
||||||
crop: area.crop,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
modelConfig.state_config = {
|
|
||||||
cameras,
|
|
||||||
motion: true,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
// Object model config
|
|
||||||
modelConfig.object_config = {
|
|
||||||
objects: step1Data.objectLabel ? [step1Data.objectLabel] : [],
|
|
||||||
classification_type: step1Data.objectType || "sub_label",
|
|
||||||
} as { objects: string[]; classification_type: string };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update config via config API
|
|
||||||
await axios.put("/config/set", {
|
|
||||||
requires_restart: 0,
|
|
||||||
update_topic: `config/classification/custom/${step1Data.modelName}`,
|
|
||||||
config_data: {
|
|
||||||
classification: {
|
|
||||||
custom: {
|
|
||||||
[step1Data.modelName]: modelConfig,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Step 2: Classify each image by moving it to the correct category folder
|
|
||||||
const categorizePromises = Object.entries(imageClassifications).map(
|
|
||||||
([imageName, className]) => {
|
|
||||||
if (!className) return Promise.resolve();
|
|
||||||
return axios.post(
|
|
||||||
`/classification/${step1Data.modelName}/dataset/categorize`,
|
|
||||||
{
|
|
||||||
training_file: imageName,
|
|
||||||
category: className === "none" ? "none" : className,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
},
|
|
||||||
);
|
|
||||||
await Promise.all(categorizePromises);
|
|
||||||
|
|
||||||
// Step 3: Kick off training
|
|
||||||
await axios.post(`/classification/${step1Data.modelName}/train`);
|
|
||||||
|
|
||||||
toast.success(t("wizard.step3.trainingStarted"));
|
|
||||||
setIsTraining(true);
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const axiosError = error as {
|
const axiosError = error as {
|
||||||
response?: { data?: { message?: string; detail?: string } };
|
response?: { data?: { message?: string; detail?: string } };
|
||||||
@ -224,13 +299,23 @@ export default function Step3ChooseExamples({
|
|||||||
);
|
);
|
||||||
setIsProcessing(false);
|
setIsProcessing(false);
|
||||||
}
|
}
|
||||||
}, [imageClassifications, step1Data, step2Data, t]);
|
}, [imageClassifications, processClassificationsAndTrain, t]);
|
||||||
|
|
||||||
|
const unclassifiedImages = useMemo(() => {
|
||||||
|
if (!unknownImages) return [];
|
||||||
|
const images = unknownImages.slice(0, 24);
|
||||||
|
|
||||||
|
// Only filter if we have any classifications
|
||||||
|
if (Object.keys(imageClassifications).length === 0) {
|
||||||
|
return images;
|
||||||
|
}
|
||||||
|
|
||||||
|
return images.filter((img) => !imageClassifications[img]);
|
||||||
|
}, [unknownImages, imageClassifications]);
|
||||||
|
|
||||||
const allImagesClassified = useMemo(() => {
|
const allImagesClassified = useMemo(() => {
|
||||||
if (!unknownImages || unknownImages.length === 0) return false;
|
return unclassifiedImages.length === 0;
|
||||||
const imagesToClassify = unknownImages.slice(0, 24);
|
}, [unclassifiedImages]);
|
||||||
return imagesToClassify.every((img) => imageClassifications[img]);
|
|
||||||
}, [unknownImages, imageClassifications]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-6">
|
<div className="flex flex-col gap-6">
|
||||||
@ -263,9 +348,18 @@ export default function Step3ChooseExamples({
|
|||||||
</div>
|
</div>
|
||||||
) : hasGenerated ? (
|
) : hasGenerated ? (
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
<div className="text-sm text-muted-foreground">
|
{!allImagesClassified && (
|
||||||
{t("wizard.step3.description")}
|
<div className="text-center">
|
||||||
</div>
|
<h3 className="text-lg font-medium">
|
||||||
|
{t("wizard.step3.selectImagesPrompt", {
|
||||||
|
className: currentClass,
|
||||||
|
})}
|
||||||
|
</h3>
|
||||||
|
<p className="text-sm text-muted-foreground">
|
||||||
|
{t("wizard.step3.selectImagesDescription")}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
"rounded-lg bg-secondary/30 p-4",
|
"rounded-lg bg-secondary/30 p-4",
|
||||||
@ -278,53 +372,36 @@ export default function Step3ChooseExamples({
|
|||||||
{t("wizard.step3.noImages")}
|
{t("wizard.step3.noImages")}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
) : allImagesClassified && isProcessing ? (
|
||||||
|
<div className="flex h-[40vh] flex-col items-center justify-center gap-4">
|
||||||
|
<ActivityIndicator className="size-12" />
|
||||||
|
<p className="text-lg font-medium">
|
||||||
|
{t("wizard.step3.classifying")}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<div className="grid grid-cols-2 gap-3 sm:grid-cols-6">
|
<div className="grid grid-cols-2 gap-4 sm:grid-cols-6">
|
||||||
{unknownImages.slice(0, 24).map((imageName, index) => (
|
{unclassifiedImages.map((imageName, index) => {
|
||||||
<div
|
const isSelected = selectedImages.has(imageName);
|
||||||
key={imageName}
|
return (
|
||||||
className="group relative aspect-square overflow-hidden rounded-lg border bg-background"
|
<div
|
||||||
>
|
key={imageName}
|
||||||
<img
|
className={cn(
|
||||||
src={`${baseUrl}clips/${step1Data.modelName}/train/${imageName}`}
|
"aspect-square cursor-pointer overflow-hidden rounded-lg border-2 bg-background transition-all",
|
||||||
alt={`Example ${index + 1}`}
|
isSelected
|
||||||
className="h-full w-full object-cover"
|
? "border-selected ring-2 ring-selected"
|
||||||
/>
|
: "border-border hover:border-primary",
|
||||||
<div className="absolute bottom-0 left-0 right-0 p-2">
|
)}
|
||||||
<Select
|
onClick={() => toggleImageSelection(imageName)}
|
||||||
value={imageClassifications[imageName] || ""}
|
>
|
||||||
onValueChange={(value) =>
|
<img
|
||||||
handleClassificationChange(imageName, value)
|
src={`${baseUrl}clips/${step1Data.modelName}/train/${imageName}`}
|
||||||
}
|
alt={`Example ${index + 1}`}
|
||||||
>
|
className="h-full w-full object-cover"
|
||||||
<SelectTrigger className="h-7 bg-background/20 text-xs">
|
/>
|
||||||
<SelectValue
|
|
||||||
placeholder={t("wizard.step3.selectClass")}
|
|
||||||
/>
|
|
||||||
</SelectTrigger>
|
|
||||||
<SelectContent>
|
|
||||||
{step1Data.modelType === "object" && (
|
|
||||||
<SelectItem
|
|
||||||
value="none"
|
|
||||||
className="cursor-pointer text-xs"
|
|
||||||
>
|
|
||||||
{t("wizard.step3.none")}
|
|
||||||
</SelectItem>
|
|
||||||
)}
|
|
||||||
{step1Data.classes.map((className) => (
|
|
||||||
<SelectItem
|
|
||||||
key={className}
|
|
||||||
value={className}
|
|
||||||
className="cursor-pointer text-xs"
|
|
||||||
>
|
|
||||||
{className}
|
|
||||||
</SelectItem>
|
|
||||||
))}
|
|
||||||
</SelectContent>
|
|
||||||
</Select>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
);
|
||||||
))}
|
})}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
@ -347,15 +424,14 @@ export default function Step3ChooseExamples({
|
|||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
onClick={handleContinue}
|
onClick={
|
||||||
|
allImagesClassified
|
||||||
|
? handleContinue
|
||||||
|
: handleContinueClassification
|
||||||
|
}
|
||||||
variant="select"
|
variant="select"
|
||||||
className="flex items-center justify-center gap-2 sm:flex-1"
|
className="flex items-center justify-center gap-2 sm:flex-1"
|
||||||
disabled={
|
disabled={!hasGenerated || isGenerating || isProcessing}
|
||||||
!hasGenerated ||
|
|
||||||
isGenerating ||
|
|
||||||
!allImagesClassified ||
|
|
||||||
isProcessing
|
|
||||||
}
|
|
||||||
>
|
>
|
||||||
{isProcessing && <ActivityIndicator className="size-4" />}
|
{isProcessing && <ActivityIndicator className="size-4" />}
|
||||||
{t("button.continue", { ns: "common" })}
|
{t("button.continue", { ns: "common" })}
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import {
|
|||||||
CustomClassificationModelConfig,
|
CustomClassificationModelConfig,
|
||||||
FrigateConfig,
|
FrigateConfig,
|
||||||
} from "@/types/frigateConfig";
|
} from "@/types/frigateConfig";
|
||||||
import { useMemo, useState } from "react";
|
import { useEffect, useMemo, useState } from "react";
|
||||||
import { isMobile } from "react-device-detect";
|
import { isMobile } from "react-device-detect";
|
||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { FaFolderPlus } from "react-icons/fa";
|
import { FaFolderPlus } from "react-icons/fa";
|
||||||
@ -37,6 +37,12 @@ export default function ModelSelectionView({
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// title
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
document.title = t("documentTitle");
|
||||||
|
}, [t]);
|
||||||
|
|
||||||
// data
|
// data
|
||||||
|
|
||||||
const classificationConfigs = useMemo(() => {
|
const classificationConfigs = useMemo(() => {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user