diff --git a/web/src/components/classification/wizard/Step3ChooseExamples.tsx b/web/src/components/classification/wizard/Step3ChooseExamples.tsx index f638c01e3..e4c157526 100644 --- a/web/src/components/classification/wizard/Step3ChooseExamples.tsx +++ b/web/src/components/classification/wizard/Step3ChooseExamples.tsx @@ -165,18 +165,15 @@ export default function Step3ChooseExamples({ const isLastClass = currentClassIndex === allClasses.length - 1; if (isLastClass) { - // Assign remaining unclassified images - unknownImages.slice(0, 24).forEach((imageName) => { - if (!newClassifications[imageName]) { - // For state models with 2 classes, assign to the last class - // For object models, assign to "none" - if (step1Data.modelType === "state" && allClasses.length === 2) { - newClassifications[imageName] = allClasses[allClasses.length - 1]; - } else { + // For object models, assign remaining unclassified images to "none" + // For state models, this should never happen since we require all images to be classified + if (step1Data.modelType !== "state") { + unknownImages.slice(0, 24).forEach((imageName) => { + if (!newClassifications[imageName]) { newClassifications[imageName] = "none"; } - } - }); + }); + } // All done, trigger training immediately setImageClassifications(newClassifications); @@ -316,8 +313,15 @@ export default function Step3ChooseExamples({ return images; } - return images.filter((img) => !imageClassifications[img]); - }, [unknownImages, imageClassifications]); + // If we're viewing a previous class (going back), show images for that class + // Otherwise show only unclassified images + const currentClassInView = allClasses[currentClassIndex]; + return images.filter((img) => { + const imgClass = imageClassifications[img]; + // Show if: unclassified OR classified with current class we're viewing + return !imgClass || imgClass === currentClassInView; + }); + }, [unknownImages, imageClassifications, allClasses, currentClassIndex]); const allImagesClassified = useMemo(() => { return unclassifiedImages.length === 0; @@ -326,15 +330,26 @@ export default function Step3ChooseExamples({ // For state models on the last class, require all images to be classified const isLastClass = currentClassIndex === allClasses.length - 1; const canProceed = useMemo(() => { - if ( - step1Data.modelType === "state" && - isLastClass && - !allImagesClassified - ) { - return false; + if (step1Data.modelType === "state" && isLastClass) { + // Check if all 24 images will be classified after current selections are applied + const totalImages = unknownImages.slice(0, 24).length; + + // Count images that will be classified (either already classified or currently selected) + const allImages = unknownImages.slice(0, 24); + const willBeClassified = allImages.filter((img) => { + return imageClassifications[img] || selectedImages.has(img); + }).length; + + return willBeClassified >= totalImages; } return true; - }, [step1Data.modelType, isLastClass, allImagesClassified]); + }, [ + step1Data.modelType, + isLastClass, + unknownImages, + imageClassifications, + selectedImages, + ]); const handleBack = useCallback(() => { if (currentClassIndex > 0) {