From 8b65ce394614517135d7a97401d8fd917bf44732 Mon Sep 17 00:00:00 2001 From: Teagan glenn Date: Sat, 21 Feb 2026 23:49:15 -0700 Subject: [PATCH] Add checkbox selection mode for classification and face grids --- .../locales/en/views/classificationModel.json | 6 +- web/public/locales/en/views/faceLibrary.json | 6 +- .../components/card/ClassificationCard.tsx | 14 ++ .../overlay/ClassificationSelectionDialog.tsx | 113 +++++++++++--- web/src/pages/FaceLibrary.tsx | 110 +++++++++++++- .../classification/ModelTrainingView.tsx | 143 +++++++++++++++++- 6 files changed, 354 insertions(+), 38 deletions(-) diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index 499b25d35..1583aeb01 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -14,7 +14,11 @@ "addClassification": "Add Classification", "deleteModels": "Delete Models", "editModel": "Edit Model", - "categorizeImages": "Classify Images" + "categorizeImages": "Classify Images", + "enableSelection": "Enable Selection", + "disableSelection": "Disable Selection", + "selectImage": "Select Image", + "selectGroup": "Select Group" }, "tooltip": { "trainingInProgress": "Model is currently training", diff --git a/web/public/locales/en/views/faceLibrary.json b/web/public/locales/en/views/faceLibrary.json index 593715261..5194be9f9 100644 --- a/web/public/locales/en/views/faceLibrary.json +++ b/web/public/locales/en/views/faceLibrary.json @@ -54,7 +54,11 @@ "deleteFace": "Delete Face", "uploadImage": "Upload Image", "reprocessFace": "Reprocess Face", - "trainFaces": "Train Faces" + "trainFaces": "Train Faces", + "enableSelection": "Enable Selection", + "disableSelection": "Disable Selection", + "selectImage": "Select Image", + "selectGroup": "Select Group" }, "imageEntry": { "validation": { diff --git a/web/src/components/card/ClassificationCard.tsx b/web/src/components/card/ClassificationCard.tsx index 6581d109a..d0dd5529d 100644 --- a/web/src/components/card/ClassificationCard.tsx +++ b/web/src/components/card/ClassificationCard.tsx @@ -44,6 +44,7 @@ type ClassificationCardProps = { i18nLibrary: string; showArea?: boolean; count?: number; + topLeftContent?: React.ReactNode; onClick: (data: ClassificationItemData, meta: boolean) => void; children?: React.ReactNode; }; @@ -61,6 +62,7 @@ export const ClassificationCard = forwardRef< i18nLibrary, showArea = true, count, + topLeftContent, onClick, children, }, @@ -143,6 +145,15 @@ export const ClassificationCard = forwardRef< onLoad={() => setImageLoaded(true)} src={`${baseUrl}${data.filepath}`} /> + {topLeftContent && ( +
e.stopPropagation()} + onMouseDown={(e) => e.stopPropagation()} + > + {topLeftContent} +
+ )} {count && (
@@ -199,6 +210,7 @@ type GroupedClassificationCardProps = { i18nLibrary: string; objectType: string; noClassificationLabel?: string; + topLeftContent?: React.ReactNode; onClick: (data: ClassificationItemData | undefined) => void; children?: (data: ClassificationItemData) => React.ReactNode; }; @@ -209,6 +221,7 @@ export function GroupedClassificationCard({ selectedItems, i18nLibrary, noClassificationLabel = "details.none", + topLeftContent, onClick, children, }: GroupedClassificationCardProps) { @@ -295,6 +308,7 @@ export function GroupedClassificationCard({ clickable={true} i18nLibrary={i18nLibrary} count={group.length} + topLeftContent={topLeftContent} onClick={(_, meta) => { if (meta || selectedItems.length > 0) { onClick(undefined); diff --git a/web/src/components/overlay/ClassificationSelectionDialog.tsx b/web/src/components/overlay/ClassificationSelectionDialog.tsx index 8e2037f18..60625dbb9 100644 --- a/web/src/components/overlay/ClassificationSelectionDialog.tsx +++ b/web/src/components/overlay/ClassificationSelectionDialog.tsx @@ -33,9 +33,10 @@ type ClassificationSelectionDialogProps = { className?: string; classes: string[]; modelName: string; - image: string; + image?: string; + images?: string[]; onRefresh: () => void; - onCategorize?: (category: string) => void; // Optional custom categorize handler + onCategorize?: (category: string, images: string[]) => void; children: ReactNode; }; export default function ClassificationSelectionDialog({ @@ -43,6 +44,7 @@ export default function ClassificationSelectionDialog({ classes, modelName, image, + images, onRefresh, onCategorize, children, @@ -51,37 +53,98 @@ export default function ClassificationSelectionDialog({ const onCategorizeImage = useCallback( (category: string) => { + const targetImages = images?.length ? images : image ? [image] : []; + // If custom categorize handler is provided, use it instead if (onCategorize) { - onCategorize(category); + onCategorize(category, targetImages); return; } - // Default behavior: categorize single image - axios - .post(`/classification/${modelName}/dataset/categorize`, { - category, - training_file: image, - }) - .then((resp) => { - if (resp.status == 200) { - toast.success(t("toast.success.categorizedImage"), { + if (targetImages.length === 0) { + toast.error(t("toast.error.batchCategorizeFailed", { count: 0 }), { + position: "top-center", + }); + return; + } + + if (targetImages.length === 1) { + // Default behavior: categorize a single image. + axios + .post(`/classification/${modelName}/dataset/categorize`, { + category, + training_file: targetImages[0], + }) + .then((resp) => { + if (resp.status == 200) { + toast.success(t("toast.success.categorizedImage"), { + position: "top-center", + }); + onRefresh(); + } + }) + .catch((error) => { + const errorMessage = + error.response?.data?.message || + error.response?.data?.detail || + "Unknown error"; + toast.error(t("toast.error.categorizeFailed", { errorMessage }), { position: "top-center", }); - onRefresh(); - } - }) - .catch((error) => { - const errorMessage = - error.response?.data?.message || - error.response?.data?.detail || - "Unknown error"; - toast.error(t("toast.error.categorizeFailed", { errorMessage }), { - position: "top-center", }); - }); + return; + } + + const requests = targetImages.map((filename) => + axios + .post(`/classification/${modelName}/dataset/categorize`, { + category, + training_file: filename, + }) + .then(() => true) + .catch(() => false), + ); + + Promise.allSettled(requests).then((results) => { + const successCount = results.filter( + (result) => result.status === "fulfilled" && result.value, + ).length; + const totalCount = results.length; + + if (successCount === totalCount) { + toast.success( + t("toast.success.batchCategorized", { + count: successCount, + }), + { + position: "top-center", + }, + ); + } else if (successCount > 0) { + toast.warning( + t("toast.warning.partialBatchCategorized", { + success: successCount, + total: totalCount, + }), + { + position: "top-center", + }, + ); + } else { + toast.error( + t("toast.error.batchCategorizeFailed", { + count: totalCount, + }), + { + position: "top-center", + }, + ); + } + + onRefresh(); + }); }, - [modelName, image, onRefresh, onCategorize, t], + [modelName, image, images, onRefresh, onCategorize, t], ); const isChildButton = useMemo( @@ -105,7 +168,7 @@ export default function ClassificationSelectionDialog({ ); return ( -
+
([]); + const [selectionModeEnabled, setSelectionModeEnabled] = useState(false); + + const toggleSelectionMode = useCallback(() => { + setSelectionModeEnabled((prev) => { + const next = !prev; + if (!next) { + setSelectedFaces([]); + } + return next; + }); + }, []); const onClickFaces = useCallback( (images: string[], ctrl: boolean) => { - if (selectedFaces.length == 0 && !ctrl) { + if (!selectionModeEnabled && selectedFaces.length == 0 && !ctrl) { return; } @@ -181,7 +194,7 @@ export default function FaceLibrary() { setSelectedFaces(newSelectedFaces); }, - [selectedFaces, setSelectedFaces], + [selectionModeEnabled, selectedFaces, setSelectedFaces], ); const [deleteDialogOpen, setDeleteDialogOpen] = useState<{ @@ -466,6 +479,19 @@ export default function FaceLibrary() { )} +
) : (
+
+ ) : undefined + } onClick={(data) => { if (data) { onClickFaces([data.filename], true); @@ -1045,6 +1132,7 @@ type FaceGridProps = { faceImages: string[]; pageToggle: string; selectedFaces: string[]; + showSelectionCheckboxes: boolean; onClickFaces: (images: string[], ctrl: boolean) => void; onDelete: (name: string, ids: string[]) => void; }; @@ -1053,6 +1141,7 @@ function FaceGrid({ faceImages, pageToggle, selectedFaces, + showSelectionCheckboxes, onClickFaces, onDelete, }: FaceGridProps) { @@ -1088,9 +1177,22 @@ function FaceGrid({ filepath: `clips/faces/${pageToggle}/${image}`, }} selected={selectedFaces.includes(image)} - clickable={selectedFaces.length > 0} + clickable={selectedFaces.length > 0 || showSelectionCheckboxes} i18nLibrary="views/faceLibrary" - onClick={(data, meta) => onClickFaces([data.filename], meta)} + topLeftContent={ + showSelectionCheckboxes ? ( +
+ onClickFaces([image], true)} + aria-label={t("button.selectImage")} + /> +
+ ) : undefined + } + onClick={(data, meta) => + onClickFaces([data.filename], meta || showSelectionCheckboxes) + } > diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 464807475..44c85fe02 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -46,7 +46,7 @@ import { } from "react"; import { isDesktop, isMobileOnly } from "react-device-detect"; import { Trans, useTranslation } from "react-i18next"; -import { LuPencil, LuTrash2 } from "react-icons/lu"; +import { LuListChecks, LuPencil, LuTrash2 } from "react-icons/lu"; import { toast } from "sonner"; import useSWR from "swr"; import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog"; @@ -76,6 +76,7 @@ import SearchDetailDialog, { import { SearchResult } from "@/types/search"; import { HiSparkles } from "react-icons/hi"; import { capitalizeFirstLetter } from "@/utils/stringUtil"; +import { Checkbox } from "@/components/ui/checkbox"; type ModelTrainingViewProps = { model: CustomClassificationModelConfig; @@ -150,10 +151,21 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { // image multiselect const [selectedImages, setSelectedImages] = useState([]); + const [selectionModeEnabled, setSelectionModeEnabled] = useState(false); + + const toggleSelectionMode = useCallback(() => { + setSelectionModeEnabled((prev) => { + const next = !prev; + if (!next) { + setSelectedImages([]); + } + return next; + }); + }, []); const onClickImages = useCallback( (images: string[], ctrl: boolean) => { - if (selectedImages.length == 0 && !ctrl) { + if (!selectionModeEnabled && selectedImages.length == 0 && !ctrl) { return; } @@ -179,7 +191,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { setSelectedImages(newSelectedImages); }, - [selectedImages, setSelectedImages], + [selectionModeEnabled, selectedImages, setSelectedImages], ); // actions @@ -525,6 +537,19 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { )} +
) : (
+ @@ -839,6 +879,7 @@ type DatasetGridProps = { categoryName: string; images: string[]; selectedImages: string[]; + showSelectionCheckboxes: boolean; onClickImages: (images: string[], ctrl: boolean) => void; onDelete: (ids: string[]) => void; }; @@ -848,6 +889,7 @@ function DatasetGrid({ categoryName, images, selectedImages, + showSelectionCheckboxes, onClickImages, onDelete, }: DatasetGridProps) { @@ -872,10 +914,23 @@ function DatasetGrid({ name: "", }} showArea={false} - clickable={selectedImages.length > 0} + clickable={selectedImages.length > 0 || showSelectionCheckboxes} selected={selectedImages.includes(image)} i18nLibrary="views/classificationModel" - onClick={(data, _) => onClickImages([data.filename], true)} + topLeftContent={ + showSelectionCheckboxes ? ( +
+ onClickImages([image], true)} + aria-label={t("button.selectImage")} + /> +
+ ) : undefined + } + onClick={(data, meta) => + onClickImages([data.filename], meta || showSelectionCheckboxes) + } > @@ -905,6 +960,7 @@ type TrainGridProps = { trainImages: string[]; trainFilter?: TrainFilter; selectedImages: string[]; + showSelectionCheckboxes: boolean; onClickImages: (images: string[], ctrl: boolean) => void; onRefresh: () => void; onDelete: (ids: string[]) => void; @@ -916,6 +972,7 @@ function TrainGrid({ trainImages, trainFilter, selectedImages, + showSelectionCheckboxes, onClickImages, onRefresh, onDelete, @@ -972,6 +1029,7 @@ function TrainGrid({ classes={classes} trainData={trainData} selectedImages={selectedImages} + showSelectionCheckboxes={showSelectionCheckboxes} onClickImages={onClickImages} onRefresh={onRefresh} onDelete={onDelete} @@ -986,6 +1044,7 @@ function TrainGrid({ classes={classes} trainData={trainData} selectedImages={selectedImages} + showSelectionCheckboxes={showSelectionCheckboxes} onClickImages={onClickImages} onRefresh={onRefresh} /> @@ -998,6 +1057,7 @@ type StateTrainGridProps = { classes: string[]; trainData?: ClassificationItemData[]; selectedImages: string[]; + showSelectionCheckboxes: boolean; onClickImages: (images: string[], ctrl: boolean) => void; onRefresh: () => void; onDelete: (ids: string[]) => void; @@ -1008,9 +1068,12 @@ function StateTrainGrid({ classes, trainData, selectedImages, + showSelectionCheckboxes, onClickImages, onRefresh, }: StateTrainGridProps) { + const { t } = useTranslation(["views/classificationModel"]); + const threshold = useMemo(() => { return { recognition: model.threshold, @@ -1031,15 +1094,29 @@ function StateTrainGrid({ data={data} threshold={threshold} selected={selectedImages.includes(data.filename)} - clickable={selectedImages.length > 0} + clickable={selectedImages.length > 0 || showSelectionCheckboxes} i18nLibrary="views/classificationModel" showArea={false} - onClick={(data, meta) => onClickImages([data.filename], meta)} + topLeftContent={ + showSelectionCheckboxes ? ( +
+ onClickImages([data.filename], true)} + aria-label={t("button.selectImage")} + /> +
+ ) : undefined + } + onClick={(data, meta) => + onClickImages([data.filename], meta || showSelectionCheckboxes) + } > @@ -1059,6 +1136,7 @@ type ObjectTrainGridProps = { classes: string[]; trainData?: ClassificationItemData[]; selectedImages: string[]; + showSelectionCheckboxes: boolean; onClickImages: (images: string[], ctrl: boolean) => void; onRefresh: () => void; }; @@ -1068,9 +1146,12 @@ function ObjectTrainGrid({ classes, trainData, selectedImages, + showSelectionCheckboxes, onClickImages, onRefresh, }: ObjectTrainGridProps) { + const { t } = useTranslation(["views/classificationModel"]); + // item data const groups = useMemo(() => { @@ -1172,6 +1253,32 @@ function ObjectTrainGrid({ [selectedImages, onClickImages], ); + const toggleGroupSelection = useCallback( + (group: ClassificationItemData[]) => { + const selectedCount = group.filter((item) => + selectedImages.includes(item.filename), + ).length; + const allSelected = selectedCount === group.length; + + if (allSelected) { + onClickImages( + group + .filter((item) => selectedImages.includes(item.filename)) + .map((item) => item.filename), + false, + ); + } else { + onClickImages( + group + .filter((item) => !selectedImages.includes(item.filename)) + .map((item) => item.filename), + true, + ); + } + }, + [onClickImages, selectedImages], + ); + return ( <> + + selectedImages.includes(item.filename), + ).length === group.length + ? true + : group.some((item) => + selectedImages.includes(item.filename), + ) + ? "indeterminate" + : false + } + onCheckedChange={() => toggleGroupSelection(group)} + aria-label={t("button.selectGroup")} + /> +
+ ) : undefined + } onClick={(data) => { if (data) { onClickImages([data.filename], true); @@ -1219,6 +1347,7 @@ function ObjectTrainGrid({ classes={classes} modelName={model.name} image={data.filename} + images={selectedImages} onRefresh={onRefresh} >