From 2b2ba6aee39976b90378df2838ab4e5c9b61b60f Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Wed, 17 Dec 2025 15:09:28 -0700 Subject: [PATCH] Handle attribute based models --- .../components/card/ClassificationCard.tsx | 85 ++++++++++--------- web/src/pages/FaceLibrary.tsx | 19 ++++- web/src/types/classification.ts | 6 ++ web/src/types/event.ts | 7 ++ .../classification/ModelTrainingView.tsx | 44 +++++++++- 5 files changed, 117 insertions(+), 44 deletions(-) diff --git a/web/src/components/card/ClassificationCard.tsx b/web/src/components/card/ClassificationCard.tsx index be348db05..360bb11bf 100644 --- a/web/src/components/card/ClassificationCard.tsx +++ b/web/src/components/card/ClassificationCard.tsx @@ -4,8 +4,8 @@ import { cn } from "@/lib/utils"; import { ClassificationItemData, ClassificationThreshold, + ClassifiedEvent, } from "@/types/classification"; -import { Event } from "@/types/event"; import { forwardRef, useMemo, useRef, useState } from "react"; import { isDesktop, isIOS, isMobile, isMobileOnly } from "react-device-detect"; import { useTranslation } from "react-i18next"; @@ -190,7 +190,7 @@ export const ClassificationCard = forwardRef< type GroupedClassificationCardProps = { group: ClassificationItemData[]; - event?: Event; + classifiedEvent?: ClassifiedEvent; threshold?: ClassificationThreshold; selectedItems: string[]; i18nLibrary: string; @@ -201,7 +201,7 @@ type GroupedClassificationCardProps = { }; export function GroupedClassificationCard({ group, - event, + classifiedEvent, threshold, selectedItems, i18nLibrary, @@ -236,14 +236,15 @@ export function GroupedClassificationCard({ const bestTyped: ClassificationItemData = best; return { ...bestTyped, - name: event - ? event.sub_label && event.sub_label !== "none" - ? event.sub_label - : t(noClassificationLabel) - : bestTyped.name, - score: event?.data?.sub_label_score, + name: + classifiedEvent?.label && classifiedEvent.label !== "none" + ? classifiedEvent.label + : classifiedEvent + ? t(noClassificationLabel) + : bestTyped.name, + score: classifiedEvent?.score, }; - }, [group, event, noClassificationLabel, t]); + }, [group, classifiedEvent, noClassificationLabel, t]); const bestScoreStatus = useMemo(() => { if (!bestItem?.score || !threshold) { @@ -329,36 +330,38 @@ export function GroupedClassificationCard({ )} > - {event?.sub_label && event.sub_label !== "none" - ? event.sub_label + {classifiedEvent?.label && classifiedEvent.label !== "none" + ? classifiedEvent.label : t(noClassificationLabel)} - {event?.sub_label && event.sub_label !== "none" && ( -
-
{`${Math.round((event.data.sub_label_score || 0) * 100)}%`}
- - - - - - {t("details.scoreInfo", { ns: i18nLibrary })} - - -
- )} + {classifiedEvent?.label && + classifiedEvent.label !== "none" && + classifiedEvent.score !== undefined && ( +
+
{`${Math.round((classifiedEvent.score || 0) * 100)}%`}
+ + + + + + {t("details.scoreInfo", { ns: i18nLibrary })} + + +
+ )}
{time && ( @@ -372,14 +375,14 @@ export function GroupedClassificationCard({ {isDesktop && (
- {event && ( + {classifiedEvent && (
{ - navigate(`/explore?event_id=${event.id}`); + navigate(`/explore?event_id=${classifiedEvent.id}`); }} > diff --git a/web/src/pages/FaceLibrary.tsx b/web/src/pages/FaceLibrary.tsx index 0f281895e..0a2789c00 100644 --- a/web/src/pages/FaceLibrary.tsx +++ b/web/src/pages/FaceLibrary.tsx @@ -68,7 +68,10 @@ import { ClassificationCard, GroupedClassificationCard, } from "@/components/card/ClassificationCard"; -import { ClassificationItemData } from "@/types/classification"; +import { + ClassificationItemData, + ClassifiedEvent, +} from "@/types/classification"; export default function FaceLibrary() { const { t } = useTranslation(["views/faceLibrary"]); @@ -922,10 +925,22 @@ function FaceAttemptGroup({ [onRefresh, t], ); + // Create ClassifiedEvent from Event (face recognition uses sub_label) + const classifiedEvent: ClassifiedEvent | undefined = useMemo(() => { + if (!event || !event.sub_label || event.sub_label === "none") { + return undefined; + } + return { + id: event.id, + label: event.sub_label, + score: event.data?.sub_label_score, + }; + }, [event]); + return ( { + if (!event || !model.object_config) { + return undefined; + } + + const classificationType = model.object_config.classification_type; + + if (classificationType === "attribute") { + // For attribute type, look at event.data[model.name] + const attributeValue = event.data[model.name] as string | undefined; + const attributeScore = event.data[`${model.name}_score`] as + | number + | undefined; + + if (attributeValue && attributeValue !== "none") { + return { + id: event.id, + label: attributeValue, + score: attributeScore, + }; + } + } else { + // For sub_label type, use event.sub_label + if (event.sub_label && event.sub_label !== "none") { + return { + id: event.id, + label: event.sub_label, + score: event.data?.sub_label_score, + }; + } + } + + return undefined; + }, + [model], + ); + // selection const [selectedEvent, setSelectedEvent] = useState(); @@ -1095,11 +1135,13 @@ function ObjectTrainGrid({ > {Object.entries(groups).map(([key, group]) => { const event = events?.find((ev) => ev.id == key); + const classifiedEvent = createClassifiedEvent(event); + return (