Improve flow handling for classification state

This commit is contained in:
Nicolas Mowen 2025-11-10 06:13:14 -07:00
parent 55dcbc6371
commit 8aad89a83a

View File

@ -165,18 +165,15 @@ export default function Step3ChooseExamples({
const isLastClass = currentClassIndex === allClasses.length - 1;
if (isLastClass) {
// Assign remaining unclassified images
unknownImages.slice(0, 24).forEach((imageName) => {
if (!newClassifications[imageName]) {
// For state models with 2 classes, assign to the last class
// For object models, assign to "none"
if (step1Data.modelType === "state" && allClasses.length === 2) {
newClassifications[imageName] = allClasses[allClasses.length - 1];
} else {
// For object models, assign remaining unclassified images to "none"
// For state models, this should never happen since we require all images to be classified
if (step1Data.modelType !== "state") {
unknownImages.slice(0, 24).forEach((imageName) => {
if (!newClassifications[imageName]) {
newClassifications[imageName] = "none";
}
}
});
});
}
// All done, trigger training immediately
setImageClassifications(newClassifications);
@ -316,8 +313,15 @@ export default function Step3ChooseExamples({
return images;
}
return images.filter((img) => !imageClassifications[img]);
}, [unknownImages, imageClassifications]);
// If we're viewing a previous class (going back), show images for that class
// Otherwise show only unclassified images
const currentClassInView = allClasses[currentClassIndex];
return images.filter((img) => {
const imgClass = imageClassifications[img];
// Show if: unclassified OR classified with current class we're viewing
return !imgClass || imgClass === currentClassInView;
});
}, [unknownImages, imageClassifications, allClasses, currentClassIndex]);
const allImagesClassified = useMemo(() => {
return unclassifiedImages.length === 0;
@ -326,15 +330,26 @@ export default function Step3ChooseExamples({
// For state models on the last class, require all images to be classified
const isLastClass = currentClassIndex === allClasses.length - 1;
const canProceed = useMemo(() => {
if (
step1Data.modelType === "state" &&
isLastClass &&
!allImagesClassified
) {
return false;
if (step1Data.modelType === "state" && isLastClass) {
// Check if all 24 images will be classified after current selections are applied
const totalImages = unknownImages.slice(0, 24).length;
// Count images that will be classified (either already classified or currently selected)
const allImages = unknownImages.slice(0, 24);
const willBeClassified = allImages.filter((img) => {
return imageClassifications[img] || selectedImages.has(img);
}).length;
return willBeClassified >= totalImages;
}
return true;
}, [step1Data.modelType, isLastClass, allImagesClassified]);
}, [
step1Data.modelType,
isLastClass,
unknownImages,
imageClassifications,
selectedImages,
]);
const handleBack = useCallback(() => {
if (currentClassIndex > 0) {