From bc9a63d9050865214962a568783ecce8f00c66eb Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Tue, 7 Oct 2025 07:16:59 -0600 Subject: [PATCH] Refactor state training grid to use classification card --- .../components/card/ClassificationCard.tsx | 6 +- .../classification/ModelTrainingView.tsx | 166 +++++++++--------- 2 files changed, 91 insertions(+), 81 deletions(-) diff --git a/web/src/components/card/ClassificationCard.tsx b/web/src/components/card/ClassificationCard.tsx index 235224db6..896ad16f4 100644 --- a/web/src/components/card/ClassificationCard.tsx +++ b/web/src/components/card/ClassificationCard.tsx @@ -16,6 +16,7 @@ type ClassificationCardProps = { threshold?: ClassificationThreshold; selected: boolean; i18nLibrary: string; + showArea?: boolean; onClick: (data: ClassificationItemData, meta: boolean) => void; children?: React.ReactNode; }; @@ -26,6 +27,7 @@ export function ClassificationCard({ threshold, selected, i18nLibrary, + showArea = true, onClick, children, }: ClassificationCardProps) { @@ -55,12 +57,12 @@ export function ClassificationCard({ }); const imageArea = useMemo(() => { - if (imgRef.current == null || !imageLoaded) { + if (!showArea || imgRef.current == null || !imageLoaded) { return undefined; } return imgRef.current.naturalWidth * imgRef.current.naturalHeight; - }, [imageLoaded]); + }, [showArea, imageLoaded]); return ( <> diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 68a2bfbbc..afe4a225e 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -60,7 +60,7 @@ import { IoMdArrowRoundBack } from "react-icons/io"; import { MdAutoFixHigh } from "react-icons/md"; import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog"; import useApiFilter from "@/hooks/use-api-filter"; -import { TrainFilter } from "@/types/classification"; +import { ClassificationItemData, TrainFilter } from "@/types/classification"; import { ClassificationCard } from "@/components/card/ClassificationCard"; type ModelTrainingViewProps = { @@ -682,20 +682,18 @@ function TrainGrid({ onRefresh, onDelete, }: TrainGridProps) { - const { t } = useTranslation(["views/classificationModel"]); - - const trainData = useMemo( + const trainData = useMemo( () => trainImages .map((raw) => { const parts = raw.replaceAll(".webp", "").split("-"); const rawScore = Number.parseFloat(parts[2]); return { - raw, - timestamp: parts[0], - label: parts[1], - score: rawScore * 100, - truePositive: rawScore >= model.threshold, + filename: raw, + filepath: `clips/${model.name}/train/${raw}`, + timestamp: Number.parseFloat(parts[0]), + name: parts[1], + score: rawScore, }; }) .filter((data) => { @@ -703,10 +701,7 @@ function TrainGrid({ return true; } - if ( - trainFilter.classes && - !trainFilter.classes.includes(data.label) - ) { + if (trainFilter.classes && !trainFilter.classes.includes(data.name)) { return false; } @@ -726,10 +721,57 @@ function TrainGrid({ return true; }) - .sort((a, b) => b.timestamp.localeCompare(a.timestamp)), + .sort((a, b) => b.timestamp - a.timestamp), [model, trainImages, trainFilter], ); + if (model.state_config) { + return ( + + ); + } + + return
; +} + +type StateTrainGridProps = { + model: CustomClassificationModelConfig; + contentRef: MutableRefObject; + classes: string[]; + trainData?: ClassificationItemData[]; + selectedImages: string[]; + onClickImages: (images: string[], ctrl: boolean) => void; + onRefresh: () => void; + onDelete: (ids: string[]) => void; +}; +function StateTrainGrid({ + model, + contentRef, + classes, + trainData, + selectedImages, + onClickImages, + onRefresh, + onDelete, +}: StateTrainGridProps) { + const { t } = useTranslation(["views/classificationModel"]); + + const threshold = useMemo(() => { + return { + recognition: model.threshold, + unknown: model.threshold, + }; + }, [model]); + return (
{trainData?.map((data) => ( -
{ - e.stopPropagation(); - onClickImages([data.raw], e.ctrlKey || e.metaKey); - }} + onClickImages([data.filename], meta)} > -
- -
-
-
-
-
- {data.label.replaceAll("_", " ")} -
-
- {data.score}% -
-
-
- - - - - - { - e.stopPropagation(); - onDelete([data.raw]); - }} - /> - - - {t("button.deleteClassificationAttempts")} - - -
-
-
-
+ + + + + { + e.stopPropagation(); + onDelete([data.filename]); + }} + /> + + + {t("button.deleteClassificationAttempts")} + + + ))}
);