Handle attribute based models

This commit is contained in:
Nicolas Mowen 2025-12-17 15:09:28 -07:00
parent b6ea89b820
commit 2b2ba6aee3
5 changed files with 117 additions and 44 deletions

View File

@ -4,8 +4,8 @@ import { cn } from "@/lib/utils";
import { import {
ClassificationItemData, ClassificationItemData,
ClassificationThreshold, ClassificationThreshold,
ClassifiedEvent,
} from "@/types/classification"; } from "@/types/classification";
import { Event } from "@/types/event";
import { forwardRef, useMemo, useRef, useState } from "react"; import { forwardRef, useMemo, useRef, useState } from "react";
import { isDesktop, isIOS, isMobile, isMobileOnly } from "react-device-detect"; import { isDesktop, isIOS, isMobile, isMobileOnly } from "react-device-detect";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
@ -190,7 +190,7 @@ export const ClassificationCard = forwardRef<
type GroupedClassificationCardProps = { type GroupedClassificationCardProps = {
group: ClassificationItemData[]; group: ClassificationItemData[];
event?: Event; classifiedEvent?: ClassifiedEvent;
threshold?: ClassificationThreshold; threshold?: ClassificationThreshold;
selectedItems: string[]; selectedItems: string[];
i18nLibrary: string; i18nLibrary: string;
@ -201,7 +201,7 @@ type GroupedClassificationCardProps = {
}; };
export function GroupedClassificationCard({ export function GroupedClassificationCard({
group, group,
event, classifiedEvent,
threshold, threshold,
selectedItems, selectedItems,
i18nLibrary, i18nLibrary,
@ -236,14 +236,15 @@ export function GroupedClassificationCard({
const bestTyped: ClassificationItemData = best; const bestTyped: ClassificationItemData = best;
return { return {
...bestTyped, ...bestTyped,
name: event name:
? event.sub_label && event.sub_label !== "none" classifiedEvent?.label && classifiedEvent.label !== "none"
? event.sub_label ? classifiedEvent.label
: t(noClassificationLabel) : classifiedEvent
: bestTyped.name, ? t(noClassificationLabel)
score: event?.data?.sub_label_score, : bestTyped.name,
score: classifiedEvent?.score,
}; };
}, [group, event, noClassificationLabel, t]); }, [group, classifiedEvent, noClassificationLabel, t]);
const bestScoreStatus = useMemo(() => { const bestScoreStatus = useMemo(() => {
if (!bestItem?.score || !threshold) { if (!bestItem?.score || !threshold) {
@ -329,36 +330,38 @@ export function GroupedClassificationCard({
)} )}
> >
<ContentTitle className="flex items-center gap-2 font-normal capitalize"> <ContentTitle className="flex items-center gap-2 font-normal capitalize">
{event?.sub_label && event.sub_label !== "none" {classifiedEvent?.label && classifiedEvent.label !== "none"
? event.sub_label ? classifiedEvent.label
: t(noClassificationLabel)} : t(noClassificationLabel)}
{event?.sub_label && event.sub_label !== "none" && ( {classifiedEvent?.label &&
<div className="flex items-center gap-1"> classifiedEvent.label !== "none" &&
<div classifiedEvent.score !== undefined && (
className={cn( <div className="flex items-center gap-1">
"", <div
bestScoreStatus == "match" && "text-success", className={cn(
bestScoreStatus == "potential" && "text-orange-400", "",
bestScoreStatus == "unknown" && "text-danger", bestScoreStatus == "match" && "text-success",
)} bestScoreStatus == "potential" && "text-orange-400",
>{`${Math.round((event.data.sub_label_score || 0) * 100)}%`}</div> bestScoreStatus == "unknown" && "text-danger",
<Popover> )}
<PopoverTrigger asChild> >{`${Math.round((classifiedEvent.score || 0) * 100)}%`}</div>
<button <Popover>
className="focus:outline-none" <PopoverTrigger asChild>
aria-label={t("details.scoreInfo", { <button
ns: i18nLibrary, className="focus:outline-none"
})} aria-label={t("details.scoreInfo", {
> ns: i18nLibrary,
<LuInfo className="size-3" /> })}
</button> >
</PopoverTrigger> <LuInfo className="size-3" />
<PopoverContent className="w-80 text-sm"> </button>
{t("details.scoreInfo", { ns: i18nLibrary })} </PopoverTrigger>
</PopoverContent> <PopoverContent className="w-80 text-sm">
</Popover> {t("details.scoreInfo", { ns: i18nLibrary })}
</div> </PopoverContent>
)} </Popover>
</div>
)}
</ContentTitle> </ContentTitle>
<ContentDescription className={cn("", isMobile && "px-2")}> <ContentDescription className={cn("", isMobile && "px-2")}>
{time && ( {time && (
@ -372,14 +375,14 @@ export function GroupedClassificationCard({
</div> </div>
{isDesktop && ( {isDesktop && (
<div className="flex flex-row justify-between"> <div className="flex flex-row justify-between">
{event && ( {classifiedEvent && (
<Tooltip> <Tooltip>
<TooltipTrigger asChild> <TooltipTrigger asChild>
<div <div
className="cursor-pointer" className="cursor-pointer"
tabIndex={-1} tabIndex={-1}
onClick={() => { onClick={() => {
navigate(`/explore?event_id=${event.id}`); navigate(`/explore?event_id=${classifiedEvent.id}`);
}} }}
> >
<LuSearch className="size-4 text-secondary-foreground" /> <LuSearch className="size-4 text-secondary-foreground" />

View File

@ -68,7 +68,10 @@ import {
ClassificationCard, ClassificationCard,
GroupedClassificationCard, GroupedClassificationCard,
} from "@/components/card/ClassificationCard"; } from "@/components/card/ClassificationCard";
import { ClassificationItemData } from "@/types/classification"; import {
ClassificationItemData,
ClassifiedEvent,
} from "@/types/classification";
export default function FaceLibrary() { export default function FaceLibrary() {
const { t } = useTranslation(["views/faceLibrary"]); const { t } = useTranslation(["views/faceLibrary"]);
@ -922,10 +925,22 @@ function FaceAttemptGroup({
[onRefresh, t], [onRefresh, t],
); );
// Create ClassifiedEvent from Event (face recognition uses sub_label)
const classifiedEvent: ClassifiedEvent | undefined = useMemo(() => {
if (!event || !event.sub_label || event.sub_label === "none") {
return undefined;
}
return {
id: event.id,
label: event.sub_label,
score: event.data?.sub_label_score,
};
}, [event]);
return ( return (
<GroupedClassificationCard <GroupedClassificationCard
group={group} group={group}
event={event} classifiedEvent={classifiedEvent}
threshold={threshold} threshold={threshold}
selectedItems={selectedFaces} selectedItems={selectedFaces}
i18nLibrary="views/faceLibrary" i18nLibrary="views/faceLibrary"

View File

@ -21,6 +21,12 @@ export type ClassificationThreshold = {
unknown: number; unknown: number;
}; };
export type ClassifiedEvent = {
id: string;
label?: string;
score?: number;
};
export type ClassificationDatasetResponse = { export type ClassificationDatasetResponse = {
categories: { categories: {
[id: string]: string[]; [id: string]: string[];

View File

@ -24,5 +24,12 @@ export interface Event {
type: "object" | "audio" | "manual"; type: "object" | "audio" | "manual";
recognized_license_plate?: string; recognized_license_plate?: string;
path_data: [number[], number][]; path_data: [number[], number][];
// Allow arbitrary keys for attributes (e.g., model_name, model_name_score)
[key: string]:
| number
| number[]
| string
| [number[], number][]
| undefined;
}; };
} }

View File

@ -62,6 +62,7 @@ import useApiFilter from "@/hooks/use-api-filter";
import { import {
ClassificationDatasetResponse, ClassificationDatasetResponse,
ClassificationItemData, ClassificationItemData,
ClassifiedEvent,
TrainFilter, TrainFilter,
} from "@/types/classification"; } from "@/types/classification";
import { import {
@ -1033,6 +1034,45 @@ function ObjectTrainGrid({
}; };
}, [model]); }, [model]);
// Helper function to create ClassifiedEvent from Event
const createClassifiedEvent = useCallback(
(event: Event | undefined): ClassifiedEvent | undefined => {
if (!event || !model.object_config) {
return undefined;
}
const classificationType = model.object_config.classification_type;
if (classificationType === "attribute") {
// For attribute type, look at event.data[model.name]
const attributeValue = event.data[model.name] as string | undefined;
const attributeScore = event.data[`${model.name}_score`] as
| number
| undefined;
if (attributeValue && attributeValue !== "none") {
return {
id: event.id,
label: attributeValue,
score: attributeScore,
};
}
} else {
// For sub_label type, use event.sub_label
if (event.sub_label && event.sub_label !== "none") {
return {
id: event.id,
label: event.sub_label,
score: event.data?.sub_label_score,
};
}
}
return undefined;
},
[model],
);
// selection // selection
const [selectedEvent, setSelectedEvent] = useState<Event>(); const [selectedEvent, setSelectedEvent] = useState<Event>();
@ -1095,11 +1135,13 @@ function ObjectTrainGrid({
> >
{Object.entries(groups).map(([key, group]) => { {Object.entries(groups).map(([key, group]) => {
const event = events?.find((ev) => ev.id == key); const event = events?.find((ev) => ev.id == key);
const classifiedEvent = createClassifiedEvent(event);
return ( return (
<div key={key} className="aspect-square w-full"> <div key={key} className="aspect-square w-full">
<GroupedClassificationCard <GroupedClassificationCard
group={group} group={group}
event={event} classifiedEvent={classifiedEvent}
threshold={threshold} threshold={threshold}
selectedItems={selectedImages} selectedItems={selectedImages}
i18nLibrary="views/classificationModel" i18nLibrary="views/classificationModel"