diff --git a/web/src/components/classification/ClassificationModelEditDialog.tsx b/web/src/components/classification/ClassificationModelEditDialog.tsx index ff80a1a29..c47765d76 100644 --- a/web/src/components/classification/ClassificationModelEditDialog.tsx +++ b/web/src/components/classification/ClassificationModelEditDialog.tsx @@ -28,6 +28,7 @@ import { CustomClassificationModelConfig, FrigateConfig, } from "@/types/frigateConfig"; +import { ClassificationDatasetResponse } from "@/types/classification"; import { getTranslatedLabel } from "@/utils/i18n"; import { zodResolver } from "@hookform/resolvers/zod"; import axios from "axios"; @@ -140,16 +141,19 @@ export default function ClassificationModelEditDialog({ }); // Fetch dataset to get current classes for state models - const { data: dataset } = useSWR<{ - [id: string]: string[]; - }>(isStateModel ? `classification/${model.name}/dataset` : null, { - revalidateOnFocus: false, - }); + const { data: dataset } = useSWR( + isStateModel ? `classification/${model.name}/dataset` : null, + { + revalidateOnFocus: false, + }, + ); // Update form with classes from dataset when loaded useEffect(() => { - if (isStateModel && dataset) { - const classes = Object.keys(dataset).filter((key) => key !== "none"); + if (isStateModel && dataset?.categories) { + const classes = Object.keys(dataset.categories).filter( + (key) => key !== "none", + ); if (classes.length > 0) { (form as ReturnType>).setValue( "classes", diff --git a/web/src/types/classification.ts b/web/src/types/classification.ts index 092021342..10c130459 100644 --- a/web/src/types/classification.ts +++ b/web/src/types/classification.ts @@ -20,3 +20,17 @@ export type ClassificationThreshold = { recognition: number; unknown: number; }; + +export type ClassificationDatasetResponse = { + categories: { + [id: string]: string[]; + }; + training_metadata: { + has_trained: boolean; + last_training_date: string | null; + last_training_image_count: number; + current_image_count: number; + new_images_count: number; + dataset_changed: boolean; + } | null; +}; diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx index c5e65e0e5..e72d2b6c1 100644 --- a/web/src/views/classification/ModelSelectionView.tsx +++ b/web/src/views/classification/ModelSelectionView.tsx @@ -11,6 +11,7 @@ import { CustomClassificationModelConfig, FrigateConfig, } from "@/types/frigateConfig"; +import { ClassificationDatasetResponse } from "@/types/classification"; import { useCallback, useEffect, useMemo, useState } from "react"; import { useTranslation } from "react-i18next"; import { FaFolderPlus } from "react-icons/fa"; @@ -209,9 +210,10 @@ type ModelCardProps = { function ModelCard({ config, onClick, onUpdate, onDelete }: ModelCardProps) { const { t } = useTranslation(["views/classificationModel"]); - const { data: dataset } = useSWR<{ - [id: string]: string[]; - }>(`classification/${config.name}/dataset`, { revalidateOnFocus: false }); + const { data: dataset } = useSWR( + `classification/${config.name}/dataset`, + { revalidateOnFocus: false }, + ); const [deleteDialogOpen, setDeleteDialogOpen] = useState(false); const [editDialogOpen, setEditDialogOpen] = useState(false); @@ -260,20 +262,25 @@ function ModelCard({ config, onClick, onUpdate, onDelete }: ModelCardProps) { }, []); const coverImage = useMemo(() => { - if (!dataset) { + if (!dataset || !dataset.categories) { return undefined; } - const keys = Object.keys(dataset).filter((key) => key != "none"); - const selectedKey = keys[0]; + const keys = Object.keys(dataset.categories).filter((key) => key != "none"); + if (keys.length === 0) { + return undefined; + } - if (!dataset[selectedKey]) { + const selectedKey = keys[0]; + const images = dataset.categories[selectedKey]; + + if (!images || images.length === 0) { return undefined; } return { name: selectedKey, - img: dataset[selectedKey][0], + img: images[0], }; }, [dataset]); @@ -317,11 +324,19 @@ function ModelCard({ config, onClick, onUpdate, onDelete }: ModelCardProps) { )} onClick={onClick} > - - + {coverImage ? ( + <> + + + + ) : ( +
+ +
+ )}
{config.name}
diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index b0664534c..53328e0e2 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -59,7 +59,11 @@ import { useNavigate } from "react-router-dom"; import { IoMdArrowRoundBack } from "react-icons/io"; import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog"; import useApiFilter from "@/hooks/use-api-filter"; -import { ClassificationItemData, TrainFilter } from "@/types/classification"; +import { + ClassificationDatasetResponse, + ClassificationItemData, + TrainFilter, +} from "@/types/classification"; import { ClassificationCard, GroupedClassificationCard, @@ -118,17 +122,10 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { const { data: trainImages, mutate: refreshTrain } = useSWR( `classification/${model.name}/train`, ); - const { data: datasetResponse, mutate: refreshDataset } = useSWR<{ - categories: { [id: string]: string[] }; - training_metadata: { - has_trained: boolean; - last_training_date: string | null; - last_training_image_count: number; - current_image_count: number; - new_images_count: number; - dataset_changed: boolean; - } | null; - }>(`classification/${model.name}/dataset`); + const { data: datasetResponse, mutate: refreshDataset } = + useSWR( + `classification/${model.name}/dataset`, + ); const dataset = datasetResponse?.categories || {}; const trainingMetadata = datasetResponse?.training_metadata;