diff --git a/web/src/components/overlay/ClassificationSelectionDialog.tsx b/web/src/components/overlay/ClassificationSelectionDialog.tsx
index 6398348a4..8e2037f18 100644
--- a/web/src/components/overlay/ClassificationSelectionDialog.tsx
+++ b/web/src/components/overlay/ClassificationSelectionDialog.tsx
@@ -35,6 +35,7 @@ type ClassificationSelectionDialogProps = {
modelName: string;
image: string;
onRefresh: () => void;
+ onCategorize?: (category: string) => void; // Optional custom categorize handler
children: ReactNode;
};
export default function ClassificationSelectionDialog({
@@ -43,12 +44,20 @@ export default function ClassificationSelectionDialog({
modelName,
image,
onRefresh,
+ onCategorize,
children,
}: ClassificationSelectionDialogProps) {
const { t } = useTranslation(["views/classificationModel"]);
const onCategorizeImage = useCallback(
(category: string) => {
+ // If custom categorize handler is provided, use it instead
+ if (onCategorize) {
+ onCategorize(category);
+ return;
+ }
+
+ // Default behavior: categorize single image
axios
.post(`/classification/${modelName}/dataset/categorize`, {
category,
@@ -72,7 +81,7 @@ export default function ClassificationSelectionDialog({
});
});
},
- [modelName, image, onRefresh, t],
+ [modelName, image, onRefresh, onCategorize, t],
);
const isChildButton = useMemo(
diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx
index 00e811d09..87e5407dd 100644
--- a/web/src/views/classification/ModelTrainingView.tsx
+++ b/web/src/views/classification/ModelTrainingView.tsx
@@ -458,6 +458,89 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
>
)}
+ {pageToggle === "train" && (
+ {
+ // Batch categorize all selected images
+ let successCount = 0;
+ let failCount = 0;
+ const totalCount = selectedImages.length;
+
+ selectedImages.forEach((filename, index) => {
+ axios
+ .post(`/classification/${model.name}/dataset/categorize`, {
+ category,
+ training_file: filename,
+ })
+ .then((resp) => {
+ if (resp.status == 200) {
+ successCount++;
+ } else {
+ failCount++;
+ }
+
+ // Show final toast after all requests complete
+ if (index === totalCount - 1) {
+ if (successCount === totalCount) {
+ toast.success(
+ t("toast.success.batchCategorized", {
+ count: successCount,
+ }),
+ {
+ position: "top-center",
+ },
+ );
+ } else if (successCount > 0) {
+ toast.warning(
+ t("toast.warning.partialBatchCategorized", {
+ success: successCount,
+ total: totalCount,
+ }),
+ {
+ position: "top-center",
+ },
+ );
+ } else {
+ toast.error(
+ t("toast.error.batchCategorizeFailed", {
+ count: totalCount,
+ }),
+ {
+ position: "top-center",
+ },
+ );
+ }
+ setSelectedImages([]);
+ refreshAll();
+ }
+ })
+ .catch(() => {
+ failCount++;
+ if (index === totalCount - 1) {
+ toast.error(
+ t("toast.error.batchCategorizeFailed", {
+ count: totalCount,
+ }),
+ {
+ position: "top-center",
+ },
+ );
+ setSelectedImages([]);
+ refreshAll();
+ }
+ });
+ });
+ }}
+ >
+
+
+ )}