diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json
index 499b25d35..1583aeb01 100644
--- a/web/public/locales/en/views/classificationModel.json
+++ b/web/public/locales/en/views/classificationModel.json
@@ -14,7 +14,11 @@
"addClassification": "Add Classification",
"deleteModels": "Delete Models",
"editModel": "Edit Model",
- "categorizeImages": "Classify Images"
+ "categorizeImages": "Classify Images",
+ "enableSelection": "Enable Selection",
+ "disableSelection": "Disable Selection",
+ "selectImage": "Select Image",
+ "selectGroup": "Select Group"
},
"tooltip": {
"trainingInProgress": "Model is currently training",
diff --git a/web/public/locales/en/views/faceLibrary.json b/web/public/locales/en/views/faceLibrary.json
index 593715261..5194be9f9 100644
--- a/web/public/locales/en/views/faceLibrary.json
+++ b/web/public/locales/en/views/faceLibrary.json
@@ -54,7 +54,11 @@
"deleteFace": "Delete Face",
"uploadImage": "Upload Image",
"reprocessFace": "Reprocess Face",
- "trainFaces": "Train Faces"
+ "trainFaces": "Train Faces",
+ "enableSelection": "Enable Selection",
+ "disableSelection": "Disable Selection",
+ "selectImage": "Select Image",
+ "selectGroup": "Select Group"
},
"imageEntry": {
"validation": {
diff --git a/web/src/components/card/ClassificationCard.tsx b/web/src/components/card/ClassificationCard.tsx
index 6581d109a..d0dd5529d 100644
--- a/web/src/components/card/ClassificationCard.tsx
+++ b/web/src/components/card/ClassificationCard.tsx
@@ -44,6 +44,7 @@ type ClassificationCardProps = {
i18nLibrary: string;
showArea?: boolean;
count?: number;
+ topLeftContent?: React.ReactNode;
onClick: (data: ClassificationItemData, meta: boolean) => void;
children?: React.ReactNode;
};
@@ -61,6 +62,7 @@ export const ClassificationCard = forwardRef<
i18nLibrary,
showArea = true,
count,
+ topLeftContent,
onClick,
children,
},
@@ -143,6 +145,15 @@ export const ClassificationCard = forwardRef<
onLoad={() => setImageLoaded(true)}
src={`${baseUrl}${data.filepath}`}
/>
+ {topLeftContent && (
+
e.stopPropagation()}
+ onMouseDown={(e) => e.stopPropagation()}
+ >
+ {topLeftContent}
+
+ )}
@@ -199,6 +210,7 @@ type GroupedClassificationCardProps = {
i18nLibrary: string;
objectType: string;
noClassificationLabel?: string;
+ topLeftContent?: React.ReactNode;
onClick: (data: ClassificationItemData | undefined) => void;
children?: (data: ClassificationItemData) => React.ReactNode;
};
@@ -209,6 +221,7 @@ export function GroupedClassificationCard({
selectedItems,
i18nLibrary,
noClassificationLabel = "details.none",
+ topLeftContent,
onClick,
children,
}: GroupedClassificationCardProps) {
@@ -295,6 +308,7 @@ export function GroupedClassificationCard({
clickable={true}
i18nLibrary={i18nLibrary}
count={group.length}
+ topLeftContent={topLeftContent}
onClick={(_, meta) => {
if (meta || selectedItems.length > 0) {
onClick(undefined);
diff --git a/web/src/components/overlay/ClassificationSelectionDialog.tsx b/web/src/components/overlay/ClassificationSelectionDialog.tsx
index 8e2037f18..60625dbb9 100644
--- a/web/src/components/overlay/ClassificationSelectionDialog.tsx
+++ b/web/src/components/overlay/ClassificationSelectionDialog.tsx
@@ -33,9 +33,10 @@ type ClassificationSelectionDialogProps = {
className?: string;
classes: string[];
modelName: string;
- image: string;
+ image?: string;
+ images?: string[];
onRefresh: () => void;
- onCategorize?: (category: string) => void; // Optional custom categorize handler
+ onCategorize?: (category: string, images: string[]) => void;
children: ReactNode;
};
export default function ClassificationSelectionDialog({
@@ -43,6 +44,7 @@ export default function ClassificationSelectionDialog({
classes,
modelName,
image,
+ images,
onRefresh,
onCategorize,
children,
@@ -51,37 +53,98 @@ export default function ClassificationSelectionDialog({
const onCategorizeImage = useCallback(
(category: string) => {
+ const targetImages = images?.length ? images : image ? [image] : [];
+
// If custom categorize handler is provided, use it instead
if (onCategorize) {
- onCategorize(category);
+ onCategorize(category, targetImages);
return;
}
- // Default behavior: categorize single image
- axios
- .post(`/classification/${modelName}/dataset/categorize`, {
- category,
- training_file: image,
- })
- .then((resp) => {
- if (resp.status == 200) {
- toast.success(t("toast.success.categorizedImage"), {
+ if (targetImages.length === 0) {
+ toast.error(t("toast.error.batchCategorizeFailed", { count: 0 }), {
+ position: "top-center",
+ });
+ return;
+ }
+
+ if (targetImages.length === 1) {
+ // Default behavior: categorize a single image.
+ axios
+ .post(`/classification/${modelName}/dataset/categorize`, {
+ category,
+ training_file: targetImages[0],
+ })
+ .then((resp) => {
+ if (resp.status == 200) {
+ toast.success(t("toast.success.categorizedImage"), {
+ position: "top-center",
+ });
+ onRefresh();
+ }
+ })
+ .catch((error) => {
+ const errorMessage =
+ error.response?.data?.message ||
+ error.response?.data?.detail ||
+ "Unknown error";
+ toast.error(t("toast.error.categorizeFailed", { errorMessage }), {
position: "top-center",
});
- onRefresh();
- }
- })
- .catch((error) => {
- const errorMessage =
- error.response?.data?.message ||
- error.response?.data?.detail ||
- "Unknown error";
- toast.error(t("toast.error.categorizeFailed", { errorMessage }), {
- position: "top-center",
});
- });
+ return;
+ }
+
+ const requests = targetImages.map((filename) =>
+ axios
+ .post(`/classification/${modelName}/dataset/categorize`, {
+ category,
+ training_file: filename,
+ })
+ .then(() => true)
+ .catch(() => false),
+ );
+
+ Promise.allSettled(requests).then((results) => {
+ const successCount = results.filter(
+ (result) => result.status === "fulfilled" && result.value,
+ ).length;
+ const totalCount = results.length;
+
+ if (successCount === totalCount) {
+ toast.success(
+ t("toast.success.batchCategorized", {
+ count: successCount,
+ }),
+ {
+ position: "top-center",
+ },
+ );
+ } else if (successCount > 0) {
+ toast.warning(
+ t("toast.warning.partialBatchCategorized", {
+ success: successCount,
+ total: totalCount,
+ }),
+ {
+ position: "top-center",
+ },
+ );
+ } else {
+ toast.error(
+ t("toast.error.batchCategorizeFailed", {
+ count: totalCount,
+ }),
+ {
+ position: "top-center",
+ },
+ );
+ }
+
+ onRefresh();
+ });
},
- [modelName, image, onRefresh, onCategorize, t],
+ [modelName, image, images, onRefresh, onCategorize, t],
);
const isChildButton = useMemo(
@@ -105,7 +168,7 @@ export default function ClassificationSelectionDialog({
);
return (
-
+
([]);
+ const [selectionModeEnabled, setSelectionModeEnabled] = useState(false);
+
+ const toggleSelectionMode = useCallback(() => {
+ setSelectionModeEnabled((prev) => {
+ const next = !prev;
+ if (!next) {
+ setSelectedFaces([]);
+ }
+ return next;
+ });
+ }, []);
const onClickFaces = useCallback(
(images: string[], ctrl: boolean) => {
- if (selectedFaces.length == 0 && !ctrl) {
+ if (!selectionModeEnabled && selectedFaces.length == 0 && !ctrl) {
return;
}
@@ -181,7 +194,7 @@ export default function FaceLibrary() {
setSelectedFaces(newSelectedFaces);
},
- [selectedFaces, setSelectedFaces],
+ [selectionModeEnabled, selectedFaces, setSelectedFaces],
);
const [deleteDialogOpen, setDeleteDialogOpen] = useState<{
@@ -466,6 +479,19 @@ export default function FaceLibrary() {
)}
+
) : (
+
+ ) : undefined
+ }
onClick={(data) => {
if (data) {
onClickFaces([data.filename], true);
@@ -1045,6 +1132,7 @@ type FaceGridProps = {
faceImages: string[];
pageToggle: string;
selectedFaces: string[];
+ showSelectionCheckboxes: boolean;
onClickFaces: (images: string[], ctrl: boolean) => void;
onDelete: (name: string, ids: string[]) => void;
};
@@ -1053,6 +1141,7 @@ function FaceGrid({
faceImages,
pageToggle,
selectedFaces,
+ showSelectionCheckboxes,
onClickFaces,
onDelete,
}: FaceGridProps) {
@@ -1088,9 +1177,22 @@ function FaceGrid({
filepath: `clips/faces/${pageToggle}/${image}`,
}}
selected={selectedFaces.includes(image)}
- clickable={selectedFaces.length > 0}
+ clickable={selectedFaces.length > 0 || showSelectionCheckboxes}
i18nLibrary="views/faceLibrary"
- onClick={(data, meta) => onClickFaces([data.filename], meta)}
+ topLeftContent={
+ showSelectionCheckboxes ? (
+
+ onClickFaces([image], true)}
+ aria-label={t("button.selectImage")}
+ />
+
+ ) : undefined
+ }
+ onClick={(data, meta) =>
+ onClickFaces([data.filename], meta || showSelectionCheckboxes)
+ }
>
diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx
index 464807475..44c85fe02 100644
--- a/web/src/views/classification/ModelTrainingView.tsx
+++ b/web/src/views/classification/ModelTrainingView.tsx
@@ -46,7 +46,7 @@ import {
} from "react";
import { isDesktop, isMobileOnly } from "react-device-detect";
import { Trans, useTranslation } from "react-i18next";
-import { LuPencil, LuTrash2 } from "react-icons/lu";
+import { LuListChecks, LuPencil, LuTrash2 } from "react-icons/lu";
import { toast } from "sonner";
import useSWR from "swr";
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
@@ -76,6 +76,7 @@ import SearchDetailDialog, {
import { SearchResult } from "@/types/search";
import { HiSparkles } from "react-icons/hi";
import { capitalizeFirstLetter } from "@/utils/stringUtil";
+import { Checkbox } from "@/components/ui/checkbox";
type ModelTrainingViewProps = {
model: CustomClassificationModelConfig;
@@ -150,10 +151,21 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
// image multiselect
const [selectedImages, setSelectedImages] = useState([]);
+ const [selectionModeEnabled, setSelectionModeEnabled] = useState(false);
+
+ const toggleSelectionMode = useCallback(() => {
+ setSelectionModeEnabled((prev) => {
+ const next = !prev;
+ if (!next) {
+ setSelectedImages([]);
+ }
+ return next;
+ });
+ }, []);
const onClickImages = useCallback(
(images: string[], ctrl: boolean) => {
- if (selectedImages.length == 0 && !ctrl) {
+ if (!selectionModeEnabled && selectedImages.length == 0 && !ctrl) {
return;
}
@@ -179,7 +191,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
setSelectedImages(newSelectedImages);
},
- [selectedImages, setSelectedImages],
+ [selectionModeEnabled, selectedImages, setSelectedImages],
);
// actions
@@ -525,6 +537,19 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
)}
+
) : (
+
@@ -839,6 +879,7 @@ type DatasetGridProps = {
categoryName: string;
images: string[];
selectedImages: string[];
+ showSelectionCheckboxes: boolean;
onClickImages: (images: string[], ctrl: boolean) => void;
onDelete: (ids: string[]) => void;
};
@@ -848,6 +889,7 @@ function DatasetGrid({
categoryName,
images,
selectedImages,
+ showSelectionCheckboxes,
onClickImages,
onDelete,
}: DatasetGridProps) {
@@ -872,10 +914,23 @@ function DatasetGrid({
name: "",
}}
showArea={false}
- clickable={selectedImages.length > 0}
+ clickable={selectedImages.length > 0 || showSelectionCheckboxes}
selected={selectedImages.includes(image)}
i18nLibrary="views/classificationModel"
- onClick={(data, _) => onClickImages([data.filename], true)}
+ topLeftContent={
+ showSelectionCheckboxes ? (
+
+ onClickImages([image], true)}
+ aria-label={t("button.selectImage")}
+ />
+
+ ) : undefined
+ }
+ onClick={(data, meta) =>
+ onClickImages([data.filename], meta || showSelectionCheckboxes)
+ }
>
@@ -905,6 +960,7 @@ type TrainGridProps = {
trainImages: string[];
trainFilter?: TrainFilter;
selectedImages: string[];
+ showSelectionCheckboxes: boolean;
onClickImages: (images: string[], ctrl: boolean) => void;
onRefresh: () => void;
onDelete: (ids: string[]) => void;
@@ -916,6 +972,7 @@ function TrainGrid({
trainImages,
trainFilter,
selectedImages,
+ showSelectionCheckboxes,
onClickImages,
onRefresh,
onDelete,
@@ -972,6 +1029,7 @@ function TrainGrid({
classes={classes}
trainData={trainData}
selectedImages={selectedImages}
+ showSelectionCheckboxes={showSelectionCheckboxes}
onClickImages={onClickImages}
onRefresh={onRefresh}
onDelete={onDelete}
@@ -986,6 +1044,7 @@ function TrainGrid({
classes={classes}
trainData={trainData}
selectedImages={selectedImages}
+ showSelectionCheckboxes={showSelectionCheckboxes}
onClickImages={onClickImages}
onRefresh={onRefresh}
/>
@@ -998,6 +1057,7 @@ type StateTrainGridProps = {
classes: string[];
trainData?: ClassificationItemData[];
selectedImages: string[];
+ showSelectionCheckboxes: boolean;
onClickImages: (images: string[], ctrl: boolean) => void;
onRefresh: () => void;
onDelete: (ids: string[]) => void;
@@ -1008,9 +1068,12 @@ function StateTrainGrid({
classes,
trainData,
selectedImages,
+ showSelectionCheckboxes,
onClickImages,
onRefresh,
}: StateTrainGridProps) {
+ const { t } = useTranslation(["views/classificationModel"]);
+
const threshold = useMemo(() => {
return {
recognition: model.threshold,
@@ -1031,15 +1094,29 @@ function StateTrainGrid({
data={data}
threshold={threshold}
selected={selectedImages.includes(data.filename)}
- clickable={selectedImages.length > 0}
+ clickable={selectedImages.length > 0 || showSelectionCheckboxes}
i18nLibrary="views/classificationModel"
showArea={false}
- onClick={(data, meta) => onClickImages([data.filename], meta)}
+ topLeftContent={
+ showSelectionCheckboxes ? (
+
+ onClickImages([data.filename], true)}
+ aria-label={t("button.selectImage")}
+ />
+
+ ) : undefined
+ }
+ onClick={(data, meta) =>
+ onClickImages([data.filename], meta || showSelectionCheckboxes)
+ }
>
@@ -1059,6 +1136,7 @@ type ObjectTrainGridProps = {
classes: string[];
trainData?: ClassificationItemData[];
selectedImages: string[];
+ showSelectionCheckboxes: boolean;
onClickImages: (images: string[], ctrl: boolean) => void;
onRefresh: () => void;
};
@@ -1068,9 +1146,12 @@ function ObjectTrainGrid({
classes,
trainData,
selectedImages,
+ showSelectionCheckboxes,
onClickImages,
onRefresh,
}: ObjectTrainGridProps) {
+ const { t } = useTranslation(["views/classificationModel"]);
+
// item data
const groups = useMemo(() => {
@@ -1172,6 +1253,32 @@ function ObjectTrainGrid({
[selectedImages, onClickImages],
);
+ const toggleGroupSelection = useCallback(
+ (group: ClassificationItemData[]) => {
+ const selectedCount = group.filter((item) =>
+ selectedImages.includes(item.filename),
+ ).length;
+ const allSelected = selectedCount === group.length;
+
+ if (allSelected) {
+ onClickImages(
+ group
+ .filter((item) => selectedImages.includes(item.filename))
+ .map((item) => item.filename),
+ false,
+ );
+ } else {
+ onClickImages(
+ group
+ .filter((item) => !selectedImages.includes(item.filename))
+ .map((item) => item.filename),
+ true,
+ );
+ }
+ },
+ [onClickImages, selectedImages],
+ );
+
return (
<>
+
+ selectedImages.includes(item.filename),
+ ).length === group.length
+ ? true
+ : group.some((item) =>
+ selectedImages.includes(item.filename),
+ )
+ ? "indeterminate"
+ : false
+ }
+ onCheckedChange={() => toggleGroupSelection(group)}
+ aria-label={t("button.selectGroup")}
+ />
+
+ ) : undefined
+ }
onClick={(data) => {
if (data) {
onClickImages([data.filename], true);
@@ -1219,6 +1347,7 @@ function ObjectTrainGrid({
classes={classes}
modelName={model.name}
image={data.filename}
+ images={selectedImages}
onRefresh={onRefresh}
>