Add checkbox selection mode for classification and face grids

This commit is contained in:
Teagan glenn 2026-02-21 23:49:15 -07:00
parent 310d52de4e
commit 8b65ce3946
6 changed files with 354 additions and 38 deletions

View File

@ -14,7 +14,11 @@
"addClassification": "Add Classification",
"deleteModels": "Delete Models",
"editModel": "Edit Model",
"categorizeImages": "Classify Images"
"categorizeImages": "Classify Images",
"enableSelection": "Enable Selection",
"disableSelection": "Disable Selection",
"selectImage": "Select Image",
"selectGroup": "Select Group"
},
"tooltip": {
"trainingInProgress": "Model is currently training",

View File

@ -54,7 +54,11 @@
"deleteFace": "Delete Face",
"uploadImage": "Upload Image",
"reprocessFace": "Reprocess Face",
"trainFaces": "Train Faces"
"trainFaces": "Train Faces",
"enableSelection": "Enable Selection",
"disableSelection": "Disable Selection",
"selectImage": "Select Image",
"selectGroup": "Select Group"
},
"imageEntry": {
"validation": {

View File

@ -44,6 +44,7 @@ type ClassificationCardProps = {
i18nLibrary: string;
showArea?: boolean;
count?: number;
topLeftContent?: React.ReactNode;
onClick: (data: ClassificationItemData, meta: boolean) => void;
children?: React.ReactNode;
};
@ -61,6 +62,7 @@ export const ClassificationCard = forwardRef<
i18nLibrary,
showArea = true,
count,
topLeftContent,
onClick,
children,
},
@ -143,6 +145,15 @@ export const ClassificationCard = forwardRef<
onLoad={() => setImageLoaded(true)}
src={`${baseUrl}${data.filepath}`}
/>
{topLeftContent && (
<div
className="absolute left-2 top-2 z-10"
onClick={(e) => e.stopPropagation()}
onMouseDown={(e) => e.stopPropagation()}
>
{topLeftContent}
</div>
)}
<ImageShadowOverlay upperClassName="z-0" lowerClassName="h-[30%] z-0" />
{count && (
<div className="absolute right-2 top-2 flex flex-row items-center gap-1">
@ -199,6 +210,7 @@ type GroupedClassificationCardProps = {
i18nLibrary: string;
objectType: string;
noClassificationLabel?: string;
topLeftContent?: React.ReactNode;
onClick: (data: ClassificationItemData | undefined) => void;
children?: (data: ClassificationItemData) => React.ReactNode;
};
@ -209,6 +221,7 @@ export function GroupedClassificationCard({
selectedItems,
i18nLibrary,
noClassificationLabel = "details.none",
topLeftContent,
onClick,
children,
}: GroupedClassificationCardProps) {
@ -295,6 +308,7 @@ export function GroupedClassificationCard({
clickable={true}
i18nLibrary={i18nLibrary}
count={group.length}
topLeftContent={topLeftContent}
onClick={(_, meta) => {
if (meta || selectedItems.length > 0) {
onClick(undefined);

View File

@ -33,9 +33,10 @@ type ClassificationSelectionDialogProps = {
className?: string;
classes: string[];
modelName: string;
image: string;
image?: string;
images?: string[];
onRefresh: () => void;
onCategorize?: (category: string) => void; // Optional custom categorize handler
onCategorize?: (category: string, images: string[]) => void;
children: ReactNode;
};
export default function ClassificationSelectionDialog({
@ -43,6 +44,7 @@ export default function ClassificationSelectionDialog({
classes,
modelName,
image,
images,
onRefresh,
onCategorize,
children,
@ -51,37 +53,98 @@ export default function ClassificationSelectionDialog({
const onCategorizeImage = useCallback(
(category: string) => {
const targetImages = images?.length ? images : image ? [image] : [];
// If custom categorize handler is provided, use it instead
if (onCategorize) {
onCategorize(category);
onCategorize(category, targetImages);
return;
}
// Default behavior: categorize single image
axios
.post(`/classification/${modelName}/dataset/categorize`, {
category,
training_file: image,
})
.then((resp) => {
if (resp.status == 200) {
toast.success(t("toast.success.categorizedImage"), {
if (targetImages.length === 0) {
toast.error(t("toast.error.batchCategorizeFailed", { count: 0 }), {
position: "top-center",
});
return;
}
if (targetImages.length === 1) {
// Default behavior: categorize a single image.
axios
.post(`/classification/${modelName}/dataset/categorize`, {
category,
training_file: targetImages[0],
})
.then((resp) => {
if (resp.status == 200) {
toast.success(t("toast.success.categorizedImage"), {
position: "top-center",
});
onRefresh();
}
})
.catch((error) => {
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
toast.error(t("toast.error.categorizeFailed", { errorMessage }), {
position: "top-center",
});
onRefresh();
}
})
.catch((error) => {
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
toast.error(t("toast.error.categorizeFailed", { errorMessage }), {
position: "top-center",
});
});
return;
}
const requests = targetImages.map((filename) =>
axios
.post(`/classification/${modelName}/dataset/categorize`, {
category,
training_file: filename,
})
.then(() => true)
.catch(() => false),
);
Promise.allSettled(requests).then((results) => {
const successCount = results.filter(
(result) => result.status === "fulfilled" && result.value,
).length;
const totalCount = results.length;
if (successCount === totalCount) {
toast.success(
t("toast.success.batchCategorized", {
count: successCount,
}),
{
position: "top-center",
},
);
} else if (successCount > 0) {
toast.warning(
t("toast.warning.partialBatchCategorized", {
success: successCount,
total: totalCount,
}),
{
position: "top-center",
},
);
} else {
toast.error(
t("toast.error.batchCategorizeFailed", {
count: totalCount,
}),
{
position: "top-center",
},
);
}
onRefresh();
});
},
[modelName, image, onRefresh, onCategorize, t],
[modelName, image, images, onRefresh, onCategorize, t],
);
const isChildButton = useMemo(
@ -105,7 +168,7 @@ export default function ClassificationSelectionDialog({
);
return (
<div className={className ?? "flex"}>
<div className={className ?? "flex"} data-card-action="true">
<TextEntryDialog
open={newClass}
setOpen={setNewClass}

View File

@ -57,6 +57,7 @@ import { Trans, useTranslation } from "react-i18next";
import {
LuFolderCheck,
LuImagePlus,
LuListChecks,
LuPencil,
LuRefreshCw,
LuScanFace,
@ -72,6 +73,7 @@ import {
ClassificationItemData,
ClassifiedEvent,
} from "@/types/classification";
import { Checkbox } from "@/components/ui/checkbox";
export default function FaceLibrary() {
const { t } = useTranslation(["views/faceLibrary"]);
@ -152,10 +154,21 @@ export default function FaceLibrary() {
// face multiselect
const [selectedFaces, setSelectedFaces] = useState<string[]>([]);
const [selectionModeEnabled, setSelectionModeEnabled] = useState(false);
const toggleSelectionMode = useCallback(() => {
setSelectionModeEnabled((prev) => {
const next = !prev;
if (!next) {
setSelectedFaces([]);
}
return next;
});
}, []);
const onClickFaces = useCallback(
(images: string[], ctrl: boolean) => {
if (selectedFaces.length == 0 && !ctrl) {
if (!selectionModeEnabled && selectedFaces.length == 0 && !ctrl) {
return;
}
@ -181,7 +194,7 @@ export default function FaceLibrary() {
setSelectedFaces(newSelectedFaces);
},
[selectedFaces, setSelectedFaces],
[selectionModeEnabled, selectedFaces, setSelectedFaces],
);
const [deleteDialogOpen, setDeleteDialogOpen] = useState<{
@ -466,6 +479,19 @@ export default function FaceLibrary() {
</Button>
</FaceSelectionDialog>
)}
<Button
className="flex gap-2"
variant={selectionModeEnabled ? "select" : "default"}
onClick={toggleSelectionMode}
>
<LuListChecks className="size-7 rounded-md p-1 text-secondary-foreground" />
{isDesktop &&
t(
selectionModeEnabled
? "button.disableSelection"
: "button.enableSelection",
)}
</Button>
<Button
className="flex gap-2"
onClick={() =>
@ -478,6 +504,19 @@ export default function FaceLibrary() {
</div>
) : (
<div className="flex items-center justify-center gap-2">
<Button
className="flex gap-2"
variant={selectionModeEnabled ? "select" : "default"}
onClick={toggleSelectionMode}
>
<LuListChecks className="size-7 rounded-md p-1 text-secondary-foreground" />
{isDesktop &&
t(
selectionModeEnabled
? "button.disableSelection"
: "button.enableSelection",
)}
</Button>
<Button className="flex gap-2" onClick={() => setAddFace(true)}>
<LuScanFace className="size-7 rounded-md p-1 text-secondary-foreground" />
{isDesktop && t("button.addFace")}
@ -505,6 +544,7 @@ export default function FaceLibrary() {
attemptImages={trainImages}
faceNames={faces}
selectedFaces={selectedFaces}
showSelectionCheckboxes={selectionModeEnabled}
onClickFaces={onClickFaces}
onRefresh={refreshFaces}
/>
@ -514,6 +554,7 @@ export default function FaceLibrary() {
faceImages={faceImages}
pageToggle={pageToggle}
selectedFaces={selectedFaces}
showSelectionCheckboxes={selectionModeEnabled}
onClickFaces={onClickFaces}
onDelete={onDelete}
/>
@ -721,6 +762,7 @@ type TrainingGridProps = {
attemptImages: string[];
faceNames: string[];
selectedFaces: string[];
showSelectionCheckboxes: boolean;
onClickFaces: (images: string[], ctrl: boolean) => void;
onRefresh: (
data?:
@ -738,6 +780,7 @@ function TrainingGrid({
attemptImages,
faceNames,
selectedFaces,
showSelectionCheckboxes,
onClickFaces,
onRefresh,
}: TrainingGridProps) {
@ -817,6 +860,7 @@ function TrainingGrid({
event={event}
faceNames={faceNames}
selectedFaces={selectedFaces}
showSelectionCheckboxes={showSelectionCheckboxes}
onClickFaces={onClickFaces}
onRefresh={onRefresh}
/>
@ -833,6 +877,7 @@ type FaceAttemptGroupProps = {
event?: Event;
faceNames: string[];
selectedFaces: string[];
showSelectionCheckboxes: boolean;
onClickFaces: (image: string[], ctrl: boolean) => void;
onRefresh: (
data?:
@ -850,6 +895,7 @@ function FaceAttemptGroup({
event,
faceNames,
selectedFaces,
showSelectionCheckboxes,
onClickFaces,
onRefresh,
}: FaceAttemptGroupProps) {
@ -999,6 +1045,29 @@ function FaceAttemptGroup({
};
}, [event]);
const toggleGroupSelection = useCallback(() => {
const selectedCount = group.filter((face) =>
selectedFaces.includes(face.filename),
).length;
const allSelected = selectedCount === group.length;
if (allSelected) {
onClickFaces(
group
.filter((face) => selectedFaces.includes(face.filename))
.map((face) => face.filename),
false,
);
} else {
onClickFaces(
group
.filter((face) => !selectedFaces.includes(face.filename))
.map((face) => face.filename),
true,
);
}
}, [group, onClickFaces, selectedFaces]);
return (
<GroupedClassificationCard
group={group}
@ -1008,6 +1077,24 @@ function FaceAttemptGroup({
i18nLibrary="views/faceLibrary"
objectType="person"
noClassificationLabel="details.unknown"
topLeftContent={
showSelectionCheckboxes ? (
<div className="rounded bg-black/60 p-1">
<Checkbox
checked={
group.filter((face) => selectedFaces.includes(face.filename))
.length === group.length
? true
: group.some((face) => selectedFaces.includes(face.filename))
? "indeterminate"
: false
}
onCheckedChange={toggleGroupSelection}
aria-label={t("button.selectGroup")}
/>
</div>
) : undefined
}
onClick={(data) => {
if (data) {
onClickFaces([data.filename], true);
@ -1045,6 +1132,7 @@ type FaceGridProps = {
faceImages: string[];
pageToggle: string;
selectedFaces: string[];
showSelectionCheckboxes: boolean;
onClickFaces: (images: string[], ctrl: boolean) => void;
onDelete: (name: string, ids: string[]) => void;
};
@ -1053,6 +1141,7 @@ function FaceGrid({
faceImages,
pageToggle,
selectedFaces,
showSelectionCheckboxes,
onClickFaces,
onDelete,
}: FaceGridProps) {
@ -1088,9 +1177,22 @@ function FaceGrid({
filepath: `clips/faces/${pageToggle}/${image}`,
}}
selected={selectedFaces.includes(image)}
clickable={selectedFaces.length > 0}
clickable={selectedFaces.length > 0 || showSelectionCheckboxes}
i18nLibrary="views/faceLibrary"
onClick={(data, meta) => onClickFaces([data.filename], meta)}
topLeftContent={
showSelectionCheckboxes ? (
<div className="rounded bg-black/60 p-1">
<Checkbox
checked={selectedFaces.includes(image)}
onCheckedChange={() => onClickFaces([image], true)}
aria-label={t("button.selectImage")}
/>
</div>
) : undefined
}
onClick={(data, meta) =>
onClickFaces([data.filename], meta || showSelectionCheckboxes)
}
>
<Tooltip>
<TooltipTrigger>

View File

@ -46,7 +46,7 @@ import {
} from "react";
import { isDesktop, isMobileOnly } from "react-device-detect";
import { Trans, useTranslation } from "react-i18next";
import { LuPencil, LuTrash2 } from "react-icons/lu";
import { LuListChecks, LuPencil, LuTrash2 } from "react-icons/lu";
import { toast } from "sonner";
import useSWR from "swr";
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
@ -76,6 +76,7 @@ import SearchDetailDialog, {
import { SearchResult } from "@/types/search";
import { HiSparkles } from "react-icons/hi";
import { capitalizeFirstLetter } from "@/utils/stringUtil";
import { Checkbox } from "@/components/ui/checkbox";
type ModelTrainingViewProps = {
model: CustomClassificationModelConfig;
@ -150,10 +151,21 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
// image multiselect
const [selectedImages, setSelectedImages] = useState<string[]>([]);
const [selectionModeEnabled, setSelectionModeEnabled] = useState(false);
const toggleSelectionMode = useCallback(() => {
setSelectionModeEnabled((prev) => {
const next = !prev;
if (!next) {
setSelectedImages([]);
}
return next;
});
}, []);
const onClickImages = useCallback(
(images: string[], ctrl: boolean) => {
if (selectedImages.length == 0 && !ctrl) {
if (!selectionModeEnabled && selectedImages.length == 0 && !ctrl) {
return;
}
@ -179,7 +191,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
setSelectedImages(newSelectedImages);
},
[selectedImages, setSelectedImages],
[selectionModeEnabled, selectedImages, setSelectedImages],
);
// actions
@ -525,6 +537,19 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</Button>
</ClassificationSelectionDialog>
)}
<Button
className="flex gap-2"
variant={selectionModeEnabled ? "select" : "default"}
onClick={toggleSelectionMode}
>
<LuListChecks className="size-7 rounded-md p-1 text-secondary-foreground" />
{isDesktop &&
t(
selectionModeEnabled
? "button.disableSelection"
: "button.enableSelection",
)}
</Button>
<Button
className="flex gap-2"
onClick={() => setDeleteDialogOpen(selectedImages)}
@ -535,6 +560,19 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</div>
) : (
<div className="flex flex-row gap-2">
<Button
className="flex gap-2"
variant={selectionModeEnabled ? "select" : "default"}
onClick={toggleSelectionMode}
>
<LuListChecks className="size-7 rounded-md p-1 text-secondary-foreground" />
{isDesktop &&
t(
selectionModeEnabled
? "button.disableSelection"
: "button.enableSelection",
)}
</Button>
<TrainFilterDialog
filter={trainFilter}
filterValues={{ classes: Object.keys(dataset || {}) }}
@ -593,6 +631,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
trainImages={trainImages || []}
trainFilter={trainFilter}
selectedImages={selectedImages}
showSelectionCheckboxes={selectionModeEnabled}
onRefresh={refreshAll}
onClickImages={onClickImages}
onDelete={onDelete}
@ -604,6 +643,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
categoryName={pageToggle}
images={dataset?.[pageToggle] || []}
selectedImages={selectedImages}
showSelectionCheckboxes={selectionModeEnabled}
onClickImages={onClickImages}
onDelete={onDelete}
/>
@ -839,6 +879,7 @@ type DatasetGridProps = {
categoryName: string;
images: string[];
selectedImages: string[];
showSelectionCheckboxes: boolean;
onClickImages: (images: string[], ctrl: boolean) => void;
onDelete: (ids: string[]) => void;
};
@ -848,6 +889,7 @@ function DatasetGrid({
categoryName,
images,
selectedImages,
showSelectionCheckboxes,
onClickImages,
onDelete,
}: DatasetGridProps) {
@ -872,10 +914,23 @@ function DatasetGrid({
name: "",
}}
showArea={false}
clickable={selectedImages.length > 0}
clickable={selectedImages.length > 0 || showSelectionCheckboxes}
selected={selectedImages.includes(image)}
i18nLibrary="views/classificationModel"
onClick={(data, _) => onClickImages([data.filename], true)}
topLeftContent={
showSelectionCheckboxes ? (
<div className="rounded bg-black/60 p-1">
<Checkbox
checked={selectedImages.includes(image)}
onCheckedChange={() => onClickImages([image], true)}
aria-label={t("button.selectImage")}
/>
</div>
) : undefined
}
onClick={(data, meta) =>
onClickImages([data.filename], meta || showSelectionCheckboxes)
}
>
<Tooltip>
<TooltipTrigger>
@ -905,6 +960,7 @@ type TrainGridProps = {
trainImages: string[];
trainFilter?: TrainFilter;
selectedImages: string[];
showSelectionCheckboxes: boolean;
onClickImages: (images: string[], ctrl: boolean) => void;
onRefresh: () => void;
onDelete: (ids: string[]) => void;
@ -916,6 +972,7 @@ function TrainGrid({
trainImages,
trainFilter,
selectedImages,
showSelectionCheckboxes,
onClickImages,
onRefresh,
onDelete,
@ -972,6 +1029,7 @@ function TrainGrid({
classes={classes}
trainData={trainData}
selectedImages={selectedImages}
showSelectionCheckboxes={showSelectionCheckboxes}
onClickImages={onClickImages}
onRefresh={onRefresh}
onDelete={onDelete}
@ -986,6 +1044,7 @@ function TrainGrid({
classes={classes}
trainData={trainData}
selectedImages={selectedImages}
showSelectionCheckboxes={showSelectionCheckboxes}
onClickImages={onClickImages}
onRefresh={onRefresh}
/>
@ -998,6 +1057,7 @@ type StateTrainGridProps = {
classes: string[];
trainData?: ClassificationItemData[];
selectedImages: string[];
showSelectionCheckboxes: boolean;
onClickImages: (images: string[], ctrl: boolean) => void;
onRefresh: () => void;
onDelete: (ids: string[]) => void;
@ -1008,9 +1068,12 @@ function StateTrainGrid({
classes,
trainData,
selectedImages,
showSelectionCheckboxes,
onClickImages,
onRefresh,
}: StateTrainGridProps) {
const { t } = useTranslation(["views/classificationModel"]);
const threshold = useMemo(() => {
return {
recognition: model.threshold,
@ -1031,15 +1094,29 @@ function StateTrainGrid({
data={data}
threshold={threshold}
selected={selectedImages.includes(data.filename)}
clickable={selectedImages.length > 0}
clickable={selectedImages.length > 0 || showSelectionCheckboxes}
i18nLibrary="views/classificationModel"
showArea={false}
onClick={(data, meta) => onClickImages([data.filename], meta)}
topLeftContent={
showSelectionCheckboxes ? (
<div className="rounded bg-black/60 p-1">
<Checkbox
checked={selectedImages.includes(data.filename)}
onCheckedChange={() => onClickImages([data.filename], true)}
aria-label={t("button.selectImage")}
/>
</div>
) : undefined
}
onClick={(data, meta) =>
onClickImages([data.filename], meta || showSelectionCheckboxes)
}
>
<ClassificationSelectionDialog
classes={classes}
modelName={model.name}
image={data.filename}
images={selectedImages}
onRefresh={onRefresh}
>
<BlurredIconButton>
@ -1059,6 +1136,7 @@ type ObjectTrainGridProps = {
classes: string[];
trainData?: ClassificationItemData[];
selectedImages: string[];
showSelectionCheckboxes: boolean;
onClickImages: (images: string[], ctrl: boolean) => void;
onRefresh: () => void;
};
@ -1068,9 +1146,12 @@ function ObjectTrainGrid({
classes,
trainData,
selectedImages,
showSelectionCheckboxes,
onClickImages,
onRefresh,
}: ObjectTrainGridProps) {
const { t } = useTranslation(["views/classificationModel"]);
// item data
const groups = useMemo(() => {
@ -1172,6 +1253,32 @@ function ObjectTrainGrid({
[selectedImages, onClickImages],
);
const toggleGroupSelection = useCallback(
(group: ClassificationItemData[]) => {
const selectedCount = group.filter((item) =>
selectedImages.includes(item.filename),
).length;
const allSelected = selectedCount === group.length;
if (allSelected) {
onClickImages(
group
.filter((item) => selectedImages.includes(item.filename))
.map((item) => item.filename),
false,
);
} else {
onClickImages(
group
.filter((item) => !selectedImages.includes(item.filename))
.map((item) => item.filename),
true,
);
}
},
[onClickImages, selectedImages],
);
return (
<>
<SearchDetailDialog
@ -1205,6 +1312,27 @@ function ObjectTrainGrid({
i18nLibrary="views/classificationModel"
objectType={model.object_config?.objects?.at(0) ?? "Object"}
noClassificationLabel="details.none"
topLeftContent={
showSelectionCheckboxes ? (
<div className="rounded bg-black/60 p-1">
<Checkbox
checked={
group.filter((item) =>
selectedImages.includes(item.filename),
).length === group.length
? true
: group.some((item) =>
selectedImages.includes(item.filename),
)
? "indeterminate"
: false
}
onCheckedChange={() => toggleGroupSelection(group)}
aria-label={t("button.selectGroup")}
/>
</div>
) : undefined
}
onClick={(data) => {
if (data) {
onClickImages([data.filename], true);
@ -1219,6 +1347,7 @@ function ObjectTrainGrid({
classes={classes}
modelName={model.name}
image={data.filename}
images={selectedImages}
onRefresh={onRefresh}
>
<BlurredIconButton>