Combine classification objects by event

This commit is contained in:
Nicolas Mowen 2025-10-07 08:28:17 -06:00
parent e6b3b8a693
commit 699a626cf0
2 changed files with 94 additions and 44 deletions

View File

@ -151,6 +151,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
write_classification_attempt(
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
"none-none",
now,
"unknown",
0.0,
@ -171,6 +172,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
write_classification_attempt(
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
"none-none",
now,
self.labelmap[best_id],
score,
@ -293,6 +295,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
write_classification_attempt(
self.train_dir,
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
obj_data["id"],
now,
"unknown",
0.0,
@ -314,6 +317,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
write_classification_attempt(
self.train_dir,
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
obj_data["id"],
now,
self.labelmap[best_id],
score,
@ -372,6 +376,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
def write_classification_attempt(
folder: str,
frame: np.ndarray,
event_id: str,
timestamp: float,
label: str,
score: float,
@ -379,7 +384,7 @@ def write_classification_attempt(
if "-" in label:
label = label.replace("-", "_")
file = os.path.join(folder, f"{timestamp}-{label}-{score}.webp")
file = os.path.join(folder, f"{event_id}-{timestamp}-{label}-{score}.webp")
os.makedirs(folder, exist_ok=True)
cv2.imwrite(file, frame)

View File

@ -1,4 +1,3 @@
import { baseUrl } from "@/api/baseUrl";
import TextEntryDialog from "@/components/overlay/dialog/TextEntryDialog";
import { Button, buttonVariants } from "@/components/ui/button";
import {
@ -61,7 +60,11 @@ import { MdAutoFixHigh } from "react-icons/md";
import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog";
import useApiFilter from "@/hooks/use-api-filter";
import { ClassificationItemData, TrainFilter } from "@/types/classification";
import { ClassificationCard } from "@/components/card/ClassificationCard";
import {
ClassificationCard,
GroupedClassificationCard,
} from "@/components/card/ClassificationCard";
import { Event } from "@/types/event";
type ModelTrainingViewProps = {
model: CustomClassificationModelConfig;
@ -687,11 +690,12 @@ function TrainGrid({
trainImages
.map((raw) => {
const parts = raw.replaceAll(".webp", "").split("-");
const rawScore = Number.parseFloat(parts[2]);
const rawScore = Number.parseFloat(parts[4]);
return {
filename: raw,
filepath: `clips/${model.name}/train/${raw}`,
timestamp: Number.parseFloat(parts[0]),
timestamp: Number.parseFloat(parts[2]),
eventId: `${parts[0]}-${parts[1]}`,
name: parts[1],
score: rawScore,
};
@ -740,7 +744,18 @@ function TrainGrid({
);
}
return <div />;
return (
<ObjectTrainGrid
model={model}
contentRef={contentRef}
classes={classes}
trainData={trainData}
selectedImages={selectedImages}
onClickImages={onClickImages}
onRefresh={onRefresh}
onDelete={onDelete}
/>
);
}
type StateTrainGridProps = {
@ -841,6 +856,32 @@ function ObjectTrainGrid({
}: ObjectTrainGridProps) {
const { t } = useTranslation(["views/classificationModel"]);
// item data
const groups = useMemo(() => {
const groups: { [eventId: string]: ClassificationItemData[] } = {};
trainData
?.sort((a, b) => a.eventId!.localeCompare(b.eventId!))
.reverse()
.forEach((data) => {
if (groups[data.eventId!]) {
groups[data.eventId!].push(data);
} else {
groups[data.eventId!] = [data];
}
});
return groups;
}, [trainData]);
const eventIdsQuery = useMemo(() => Object.keys(groups).join(","), [groups]);
const { data: events } = useSWR<Event[]>([
"event_ids",
{ ids: eventIdsQuery },
]);
const threshold = useMemo(() => {
return {
recognition: model.threshold,
@ -851,46 +892,50 @@ function ObjectTrainGrid({
return (
<div
ref={contentRef}
className={cn(
"scrollbar-container flex flex-wrap gap-2 overflow-y-auto p-2",
isMobile && "justify-center",
)}
className="scrollbar-container flex flex-wrap gap-2 overflow-y-scroll p-1"
>
{trainData?.map((data) => (
<ClassificationCard
className="w-60 gap-2 rounded-lg bg-card p-2"
imgClassName="size-auto"
data={data}
threshold={threshold}
selected={selectedImages.includes(data.filename)}
i18nLibrary="views/classificationModel"
showArea={false}
onClick={(data, meta) => onClickImages([data.filename], meta)}
>
<ClassificationSelectionDialog
classes={classes}
modelName={model.name}
image={data.filename}
onRefresh={onRefresh}
{Object.entries(groups).map(([key, group]) => {
const event = events?.find((ev) => ev.id == key);
return (
<GroupedClassificationCard
key={key}
group={group}
event={event}
threshold={threshold}
selectedItems={selectedImages}
i18nLibrary="views/classificationModel"
onClick={() => {}}
onSelectEvent={() => {}}
>
<TbCategoryPlus className="size-5 cursor-pointer text-primary-variant hover:text-primary" />
</ClassificationSelectionDialog>
<Tooltip>
<TooltipTrigger>
<LuTrash2
className="size-5 cursor-pointer text-primary-variant hover:text-primary"
onClick={(e) => {
e.stopPropagation();
onDelete([data.filename]);
}}
/>
</TooltipTrigger>
<TooltipContent>
{t("button.deleteClassificationAttempts")}
</TooltipContent>
</Tooltip>
</ClassificationCard>
))}
{(data) => (
<>
<ClassificationSelectionDialog
classes={classes}
modelName={model.name}
image={data.filename}
onRefresh={onRefresh}
>
<TbCategoryPlus className="size-5 cursor-pointer text-primary-variant hover:text-primary" />
</ClassificationSelectionDialog>
<Tooltip>
<TooltipTrigger>
<LuTrash2
className="size-5 cursor-pointer text-primary-variant hover:text-primary"
onClick={(e) => {
e.stopPropagation();
onDelete([data.filename]);
}}
/>
</TooltipTrigger>
<TooltipContent>
{t("button.deleteClassificationAttempts")}
</TooltipContent>
</Tooltip>
</>
)}
</GroupedClassificationCard>
);
})}
</div>
);
}