diff --git a/web/src/components/overlay/ClassificationSelectionDialog.tsx b/web/src/components/overlay/ClassificationSelectionDialog.tsx index 6398348a4..8e2037f18 100644 --- a/web/src/components/overlay/ClassificationSelectionDialog.tsx +++ b/web/src/components/overlay/ClassificationSelectionDialog.tsx @@ -35,6 +35,7 @@ type ClassificationSelectionDialogProps = { modelName: string; image: string; onRefresh: () => void; + onCategorize?: (category: string) => void; // Optional custom categorize handler children: ReactNode; }; export default function ClassificationSelectionDialog({ @@ -43,12 +44,20 @@ export default function ClassificationSelectionDialog({ modelName, image, onRefresh, + onCategorize, children, }: ClassificationSelectionDialogProps) { const { t } = useTranslation(["views/classificationModel"]); const onCategorizeImage = useCallback( (category: string) => { + // If custom categorize handler is provided, use it instead + if (onCategorize) { + onCategorize(category); + return; + } + + // Default behavior: categorize single image axios .post(`/classification/${modelName}/dataset/categorize`, { category, @@ -72,7 +81,7 @@ export default function ClassificationSelectionDialog({ }); }); }, - [modelName, image, onRefresh, t], + [modelName, image, onRefresh, onCategorize, t], ); const isChildButton = useMemo( diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 00e811d09..87e5407dd 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -458,6 +458,89 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { )} + {pageToggle === "train" && ( + { + // Batch categorize all selected images + let successCount = 0; + let failCount = 0; + const totalCount = selectedImages.length; + + selectedImages.forEach((filename, index) => { + axios + .post(`/classification/${model.name}/dataset/categorize`, { + category, + training_file: filename, + }) + .then((resp) => { + if (resp.status == 200) { + successCount++; + } else { + failCount++; + } + + // Show final toast after all requests complete + if (index === totalCount - 1) { + if (successCount === totalCount) { + toast.success( + t("toast.success.batchCategorized", { + count: successCount, + }), + { + position: "top-center", + }, + ); + } else if (successCount > 0) { + toast.warning( + t("toast.warning.partialBatchCategorized", { + success: successCount, + total: totalCount, + }), + { + position: "top-center", + }, + ); + } else { + toast.error( + t("toast.error.batchCategorizeFailed", { + count: totalCount, + }), + { + position: "top-center", + }, + ); + } + setSelectedImages([]); + refreshAll(); + } + }) + .catch(() => { + failCount++; + if (index === totalCount - 1) { + toast.error( + t("toast.error.batchCategorizeFailed", { + count: totalCount, + }), + { + position: "top-center", + }, + ); + setSelectedImages([]); + refreshAll(); + } + }); + }); + }} + > + + + )}