Combine classification objects by event

This commit is contained in:
Nicolas Mowen 2025-10-07 08:28:17 -06:00
parent e6b3b8a693
commit 699a626cf0
2 changed files with 94 additions and 44 deletions

View File

@ -151,6 +151,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
"none-none",
now, now,
"unknown", "unknown",
0.0, 0.0,
@ -171,6 +172,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
"none-none",
now, now,
self.labelmap[best_id], self.labelmap[best_id],
score, score,
@ -293,6 +295,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR), cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
obj_data["id"],
now, now,
"unknown", "unknown",
0.0, 0.0,
@ -314,6 +317,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
write_classification_attempt( write_classification_attempt(
self.train_dir, self.train_dir,
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR), cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
obj_data["id"],
now, now,
self.labelmap[best_id], self.labelmap[best_id],
score, score,
@ -372,6 +376,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
def write_classification_attempt( def write_classification_attempt(
folder: str, folder: str,
frame: np.ndarray, frame: np.ndarray,
event_id: str,
timestamp: float, timestamp: float,
label: str, label: str,
score: float, score: float,
@ -379,7 +384,7 @@ def write_classification_attempt(
if "-" in label: if "-" in label:
label = label.replace("-", "_") label = label.replace("-", "_")
file = os.path.join(folder, f"{timestamp}-{label}-{score}.webp") file = os.path.join(folder, f"{event_id}-{timestamp}-{label}-{score}.webp")
os.makedirs(folder, exist_ok=True) os.makedirs(folder, exist_ok=True)
cv2.imwrite(file, frame) cv2.imwrite(file, frame)

View File

@ -1,4 +1,3 @@
import { baseUrl } from "@/api/baseUrl";
import TextEntryDialog from "@/components/overlay/dialog/TextEntryDialog"; import TextEntryDialog from "@/components/overlay/dialog/TextEntryDialog";
import { Button, buttonVariants } from "@/components/ui/button"; import { Button, buttonVariants } from "@/components/ui/button";
import { import {
@ -61,7 +60,11 @@ import { MdAutoFixHigh } from "react-icons/md";
import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog"; import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog";
import useApiFilter from "@/hooks/use-api-filter"; import useApiFilter from "@/hooks/use-api-filter";
import { ClassificationItemData, TrainFilter } from "@/types/classification"; import { ClassificationItemData, TrainFilter } from "@/types/classification";
import { ClassificationCard } from "@/components/card/ClassificationCard"; import {
ClassificationCard,
GroupedClassificationCard,
} from "@/components/card/ClassificationCard";
import { Event } from "@/types/event";
type ModelTrainingViewProps = { type ModelTrainingViewProps = {
model: CustomClassificationModelConfig; model: CustomClassificationModelConfig;
@ -687,11 +690,12 @@ function TrainGrid({
trainImages trainImages
.map((raw) => { .map((raw) => {
const parts = raw.replaceAll(".webp", "").split("-"); const parts = raw.replaceAll(".webp", "").split("-");
const rawScore = Number.parseFloat(parts[2]); const rawScore = Number.parseFloat(parts[4]);
return { return {
filename: raw, filename: raw,
filepath: `clips/${model.name}/train/${raw}`, filepath: `clips/${model.name}/train/${raw}`,
timestamp: Number.parseFloat(parts[0]), timestamp: Number.parseFloat(parts[2]),
eventId: `${parts[0]}-${parts[1]}`,
name: parts[1], name: parts[1],
score: rawScore, score: rawScore,
}; };
@ -740,7 +744,18 @@ function TrainGrid({
); );
} }
return <div />; return (
<ObjectTrainGrid
model={model}
contentRef={contentRef}
classes={classes}
trainData={trainData}
selectedImages={selectedImages}
onClickImages={onClickImages}
onRefresh={onRefresh}
onDelete={onDelete}
/>
);
} }
type StateTrainGridProps = { type StateTrainGridProps = {
@ -841,6 +856,32 @@ function ObjectTrainGrid({
}: ObjectTrainGridProps) { }: ObjectTrainGridProps) {
const { t } = useTranslation(["views/classificationModel"]); const { t } = useTranslation(["views/classificationModel"]);
// item data
const groups = useMemo(() => {
const groups: { [eventId: string]: ClassificationItemData[] } = {};
trainData
?.sort((a, b) => a.eventId!.localeCompare(b.eventId!))
.reverse()
.forEach((data) => {
if (groups[data.eventId!]) {
groups[data.eventId!].push(data);
} else {
groups[data.eventId!] = [data];
}
});
return groups;
}, [trainData]);
const eventIdsQuery = useMemo(() => Object.keys(groups).join(","), [groups]);
const { data: events } = useSWR<Event[]>([
"event_ids",
{ ids: eventIdsQuery },
]);
const threshold = useMemo(() => { const threshold = useMemo(() => {
return { return {
recognition: model.threshold, recognition: model.threshold,
@ -851,22 +892,23 @@ function ObjectTrainGrid({
return ( return (
<div <div
ref={contentRef} ref={contentRef}
className={cn( className="scrollbar-container flex flex-wrap gap-2 overflow-y-scroll p-1"
"scrollbar-container flex flex-wrap gap-2 overflow-y-auto p-2",
isMobile && "justify-center",
)}
> >
{trainData?.map((data) => ( {Object.entries(groups).map(([key, group]) => {
<ClassificationCard const event = events?.find((ev) => ev.id == key);
className="w-60 gap-2 rounded-lg bg-card p-2" return (
imgClassName="size-auto" <GroupedClassificationCard
data={data} key={key}
group={group}
event={event}
threshold={threshold} threshold={threshold}
selected={selectedImages.includes(data.filename)} selectedItems={selectedImages}
i18nLibrary="views/classificationModel" i18nLibrary="views/classificationModel"
showArea={false} onClick={() => {}}
onClick={(data, meta) => onClickImages([data.filename], meta)} onSelectEvent={() => {}}
> >
{(data) => (
<>
<ClassificationSelectionDialog <ClassificationSelectionDialog
classes={classes} classes={classes}
modelName={model.name} modelName={model.name}
@ -889,8 +931,11 @@ function ObjectTrainGrid({
{t("button.deleteClassificationAttempts")} {t("button.deleteClassificationAttempts")}
</TooltipContent> </TooltipContent>
</Tooltip> </Tooltip>
</ClassificationCard> </>
))} )}
</GroupedClassificationCard>
);
})}
</div> </div>
); );
} }