Handle attribute based models

This commit is contained in:
Nicolas Mowen 2025-12-17 15:09:28 -07:00
parent b6ea89b820
commit 2b2ba6aee3
5 changed files with 117 additions and 44 deletions

View File

@ -4,8 +4,8 @@ import { cn } from "@/lib/utils";
import {
ClassificationItemData,
ClassificationThreshold,
ClassifiedEvent,
} from "@/types/classification";
import { Event } from "@/types/event";
import { forwardRef, useMemo, useRef, useState } from "react";
import { isDesktop, isIOS, isMobile, isMobileOnly } from "react-device-detect";
import { useTranslation } from "react-i18next";
@ -190,7 +190,7 @@ export const ClassificationCard = forwardRef<
type GroupedClassificationCardProps = {
group: ClassificationItemData[];
event?: Event;
classifiedEvent?: ClassifiedEvent;
threshold?: ClassificationThreshold;
selectedItems: string[];
i18nLibrary: string;
@ -201,7 +201,7 @@ type GroupedClassificationCardProps = {
};
export function GroupedClassificationCard({
group,
event,
classifiedEvent,
threshold,
selectedItems,
i18nLibrary,
@ -236,14 +236,15 @@ export function GroupedClassificationCard({
const bestTyped: ClassificationItemData = best;
return {
...bestTyped,
name: event
? event.sub_label && event.sub_label !== "none"
? event.sub_label
: t(noClassificationLabel)
: bestTyped.name,
score: event?.data?.sub_label_score,
name:
classifiedEvent?.label && classifiedEvent.label !== "none"
? classifiedEvent.label
: classifiedEvent
? t(noClassificationLabel)
: bestTyped.name,
score: classifiedEvent?.score,
};
}, [group, event, noClassificationLabel, t]);
}, [group, classifiedEvent, noClassificationLabel, t]);
const bestScoreStatus = useMemo(() => {
if (!bestItem?.score || !threshold) {
@ -329,36 +330,38 @@ export function GroupedClassificationCard({
)}
>
<ContentTitle className="flex items-center gap-2 font-normal capitalize">
{event?.sub_label && event.sub_label !== "none"
? event.sub_label
{classifiedEvent?.label && classifiedEvent.label !== "none"
? classifiedEvent.label
: t(noClassificationLabel)}
{event?.sub_label && event.sub_label !== "none" && (
<div className="flex items-center gap-1">
<div
className={cn(
"",
bestScoreStatus == "match" && "text-success",
bestScoreStatus == "potential" && "text-orange-400",
bestScoreStatus == "unknown" && "text-danger",
)}
>{`${Math.round((event.data.sub_label_score || 0) * 100)}%`}</div>
<Popover>
<PopoverTrigger asChild>
<button
className="focus:outline-none"
aria-label={t("details.scoreInfo", {
ns: i18nLibrary,
})}
>
<LuInfo className="size-3" />
</button>
</PopoverTrigger>
<PopoverContent className="w-80 text-sm">
{t("details.scoreInfo", { ns: i18nLibrary })}
</PopoverContent>
</Popover>
</div>
)}
{classifiedEvent?.label &&
classifiedEvent.label !== "none" &&
classifiedEvent.score !== undefined && (
<div className="flex items-center gap-1">
<div
className={cn(
"",
bestScoreStatus == "match" && "text-success",
bestScoreStatus == "potential" && "text-orange-400",
bestScoreStatus == "unknown" && "text-danger",
)}
>{`${Math.round((classifiedEvent.score || 0) * 100)}%`}</div>
<Popover>
<PopoverTrigger asChild>
<button
className="focus:outline-none"
aria-label={t("details.scoreInfo", {
ns: i18nLibrary,
})}
>
<LuInfo className="size-3" />
</button>
</PopoverTrigger>
<PopoverContent className="w-80 text-sm">
{t("details.scoreInfo", { ns: i18nLibrary })}
</PopoverContent>
</Popover>
</div>
)}
</ContentTitle>
<ContentDescription className={cn("", isMobile && "px-2")}>
{time && (
@ -372,14 +375,14 @@ export function GroupedClassificationCard({
</div>
{isDesktop && (
<div className="flex flex-row justify-between">
{event && (
{classifiedEvent && (
<Tooltip>
<TooltipTrigger asChild>
<div
className="cursor-pointer"
tabIndex={-1}
onClick={() => {
navigate(`/explore?event_id=${event.id}`);
navigate(`/explore?event_id=${classifiedEvent.id}`);
}}
>
<LuSearch className="size-4 text-secondary-foreground" />

View File

@ -68,7 +68,10 @@ import {
ClassificationCard,
GroupedClassificationCard,
} from "@/components/card/ClassificationCard";
import { ClassificationItemData } from "@/types/classification";
import {
ClassificationItemData,
ClassifiedEvent,
} from "@/types/classification";
export default function FaceLibrary() {
const { t } = useTranslation(["views/faceLibrary"]);
@ -922,10 +925,22 @@ function FaceAttemptGroup({
[onRefresh, t],
);
// Create ClassifiedEvent from Event (face recognition uses sub_label)
const classifiedEvent: ClassifiedEvent | undefined = useMemo(() => {
if (!event || !event.sub_label || event.sub_label === "none") {
return undefined;
}
return {
id: event.id,
label: event.sub_label,
score: event.data?.sub_label_score,
};
}, [event]);
return (
<GroupedClassificationCard
group={group}
event={event}
classifiedEvent={classifiedEvent}
threshold={threshold}
selectedItems={selectedFaces}
i18nLibrary="views/faceLibrary"

View File

@ -21,6 +21,12 @@ export type ClassificationThreshold = {
unknown: number;
};
export type ClassifiedEvent = {
id: string;
label?: string;
score?: number;
};
export type ClassificationDatasetResponse = {
categories: {
[id: string]: string[];

View File

@ -24,5 +24,12 @@ export interface Event {
type: "object" | "audio" | "manual";
recognized_license_plate?: string;
path_data: [number[], number][];
// Allow arbitrary keys for attributes (e.g., model_name, model_name_score)
[key: string]:
| number
| number[]
| string
| [number[], number][]
| undefined;
};
}

View File

@ -62,6 +62,7 @@ import useApiFilter from "@/hooks/use-api-filter";
import {
ClassificationDatasetResponse,
ClassificationItemData,
ClassifiedEvent,
TrainFilter,
} from "@/types/classification";
import {
@ -1033,6 +1034,45 @@ function ObjectTrainGrid({
};
}, [model]);
// Helper function to create ClassifiedEvent from Event
const createClassifiedEvent = useCallback(
(event: Event | undefined): ClassifiedEvent | undefined => {
if (!event || !model.object_config) {
return undefined;
}
const classificationType = model.object_config.classification_type;
if (classificationType === "attribute") {
// For attribute type, look at event.data[model.name]
const attributeValue = event.data[model.name] as string | undefined;
const attributeScore = event.data[`${model.name}_score`] as
| number
| undefined;
if (attributeValue && attributeValue !== "none") {
return {
id: event.id,
label: attributeValue,
score: attributeScore,
};
}
} else {
// For sub_label type, use event.sub_label
if (event.sub_label && event.sub_label !== "none") {
return {
id: event.id,
label: event.sub_label,
score: event.data?.sub_label_score,
};
}
}
return undefined;
},
[model],
);
// selection
const [selectedEvent, setSelectedEvent] = useState<Event>();
@ -1095,11 +1135,13 @@ function ObjectTrainGrid({
>
{Object.entries(groups).map(([key, group]) => {
const event = events?.find((ev) => ev.id == key);
const classifiedEvent = createClassifiedEvent(event);
return (
<div key={key} className="aspect-square w-full">
<GroupedClassificationCard
group={group}
event={event}
classifiedEvent={classifiedEvent}
threshold={threshold}
selectedItems={selectedImages}
i18nLibrary="views/classificationModel"