From 699a626cf070d0c2df8edb103c2c9c6d978e61c3 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Tue, 7 Oct 2025 08:28:17 -0600 Subject: [PATCH] Combine classification objects by event --- .../real_time/custom_classification.py | 7 +- .../classification/ModelTrainingView.tsx | 131 ++++++++++++------ 2 files changed, 94 insertions(+), 44 deletions(-) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 45a4b2223..8d9ea9f57 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -151,6 +151,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): write_classification_attempt( self.train_dir, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), + "none-none", now, "unknown", 0.0, @@ -171,6 +172,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): write_classification_attempt( self.train_dir, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), + "none-none", now, self.labelmap[best_id], score, @@ -293,6 +295,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): write_classification_attempt( self.train_dir, cv2.cvtColor(crop, cv2.COLOR_RGB2BGR), + obj_data["id"], now, "unknown", 0.0, @@ -314,6 +317,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): write_classification_attempt( self.train_dir, cv2.cvtColor(crop, cv2.COLOR_RGB2BGR), + obj_data["id"], now, self.labelmap[best_id], score, @@ -372,6 +376,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): def write_classification_attempt( folder: str, frame: np.ndarray, + event_id: str, timestamp: float, label: str, score: float, @@ -379,7 +384,7 @@ def write_classification_attempt( if "-" in label: label = label.replace("-", "_") - file = os.path.join(folder, f"{timestamp}-{label}-{score}.webp") + file = os.path.join(folder, f"{event_id}-{timestamp}-{label}-{score}.webp") os.makedirs(folder, exist_ok=True) cv2.imwrite(file, frame) diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 8519ba82c..b82ee7136 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -1,4 +1,3 @@ -import { baseUrl } from "@/api/baseUrl"; import TextEntryDialog from "@/components/overlay/dialog/TextEntryDialog"; import { Button, buttonVariants } from "@/components/ui/button"; import { @@ -61,7 +60,11 @@ import { MdAutoFixHigh } from "react-icons/md"; import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog"; import useApiFilter from "@/hooks/use-api-filter"; import { ClassificationItemData, TrainFilter } from "@/types/classification"; -import { ClassificationCard } from "@/components/card/ClassificationCard"; +import { + ClassificationCard, + GroupedClassificationCard, +} from "@/components/card/ClassificationCard"; +import { Event } from "@/types/event"; type ModelTrainingViewProps = { model: CustomClassificationModelConfig; @@ -687,11 +690,12 @@ function TrainGrid({ trainImages .map((raw) => { const parts = raw.replaceAll(".webp", "").split("-"); - const rawScore = Number.parseFloat(parts[2]); + const rawScore = Number.parseFloat(parts[4]); return { filename: raw, filepath: `clips/${model.name}/train/${raw}`, - timestamp: Number.parseFloat(parts[0]), + timestamp: Number.parseFloat(parts[2]), + eventId: `${parts[0]}-${parts[1]}`, name: parts[1], score: rawScore, }; @@ -740,7 +744,18 @@ function TrainGrid({ ); } - return
; + return ( + + ); } type StateTrainGridProps = { @@ -841,6 +856,32 @@ function ObjectTrainGrid({ }: ObjectTrainGridProps) { const { t } = useTranslation(["views/classificationModel"]); + // item data + + const groups = useMemo(() => { + const groups: { [eventId: string]: ClassificationItemData[] } = {}; + + trainData + ?.sort((a, b) => a.eventId!.localeCompare(b.eventId!)) + .reverse() + .forEach((data) => { + if (groups[data.eventId!]) { + groups[data.eventId!].push(data); + } else { + groups[data.eventId!] = [data]; + } + }); + + return groups; + }, [trainData]); + + const eventIdsQuery = useMemo(() => Object.keys(groups).join(","), [groups]); + + const { data: events } = useSWR([ + "event_ids", + { ids: eventIdsQuery }, + ]); + const threshold = useMemo(() => { return { recognition: model.threshold, @@ -851,46 +892,50 @@ function ObjectTrainGrid({ return (
- {trainData?.map((data) => ( - onClickImages([data.filename], meta)} - > - { + const event = events?.find((ev) => ev.id == key); + return ( + {}} + onSelectEvent={() => {}} > - - - - - { - e.stopPropagation(); - onDelete([data.filename]); - }} - /> - - - {t("button.deleteClassificationAttempts")} - - - - ))} + {(data) => ( + <> + + + + + + { + e.stopPropagation(); + onDelete([data.filename]); + }} + /> + + + {t("button.deleteClassificationAttempts")} + + + + )} + + ); + })}
); }