From b75173a5020a2139d6664e007e4bdd5791e71ddf Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Wed, 5 Nov 2025 04:59:29 -0700 Subject: [PATCH] Add dialog to edit classification models --- .../locales/en/views/classificationModel.json | 15 +- .../ClassificationModelEditDialog.tsx | 477 ++++++++++++++++++ web/src/types/frigateConfig.ts | 1 + .../classification/ModelSelectionView.tsx | 24 +- 4 files changed, 512 insertions(+), 5 deletions(-) create mode 100644 web/src/components/classification/ClassificationModelEditDialog.tsx diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index 291b2bdf3..4a873533a 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -10,7 +10,8 @@ "deleteImages": "Delete Images", "trainModel": "Train Model", "addClassification": "Add Classification", - "deleteModels": "Delete Models" + "deleteModels": "Delete Models", + "editModel": "Edit Model" }, "toast": { "success": { @@ -19,14 +20,16 @@ "deletedModel": "Successfully deleted {{count}} model(s)", "categorizedImage": "Successfully Classified Image", "trainedModel": "Successfully trained model.", - "trainingModel": "Successfully started model training." + "trainingModel": "Successfully started model training.", + "updatedModel": "Successfully updated model configuration" }, "error": { "deleteImageFailed": "Failed to delete: {{errorMessage}}", "deleteCategoryFailed": "Failed to delete class: {{errorMessage}}", "deleteModelFailed": "Failed to delete model: {{errorMessage}}", "categorizeFailed": "Failed to categorize image: {{errorMessage}}", - "trainingFailed": "Failed to start model training: {{errorMessage}}" + "trainingFailed": "Failed to start model training: {{errorMessage}}", + "updateModelFailed": "Failed to update model: {{errorMessage}}" } }, "deleteCategory": { @@ -38,6 +41,12 @@ "single": "Are you sure you want to delete {{name}}? This will permanently delete all associated data including images and training data. This action cannot be undone.", "desc": "Are you sure you want to delete {{count}} model(s)? This will permanently delete all associated data including images and training data. This action cannot be undone." }, + "edit": { + "title": "Edit Classification Model", + "descriptionState": "Edit the classes for this state classification model. Changes will require retraining the model.", + "descriptionObject": "Edit the object type and classification type for this object classification model.", + "stateClassesInfo": "Note: Changing state classes requires retraining the model with the updated classes." + }, "deleteDatasetImages": { "title": "Delete Dataset Images", "desc": "Are you sure you want to delete {{count}} images from {{dataset}}? This action cannot be undone and will require re-training the model." diff --git a/web/src/components/classification/ClassificationModelEditDialog.tsx b/web/src/components/classification/ClassificationModelEditDialog.tsx new file mode 100644 index 000000000..56a96b654 --- /dev/null +++ b/web/src/components/classification/ClassificationModelEditDialog.tsx @@ -0,0 +1,477 @@ +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { + CustomClassificationModelConfig, + FrigateConfig, +} from "@/types/frigateConfig"; +import { getTranslatedLabel } from "@/utils/i18n"; +import { zodResolver } from "@hookform/resolvers/zod"; +import axios from "axios"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { useForm } from "react-hook-form"; +import { useTranslation } from "react-i18next"; +import { LuPlus, LuX } from "react-icons/lu"; +import { toast } from "sonner"; +import useSWR from "swr"; +import { z } from "zod"; + +type ClassificationModelEditDialogProps = { + open: boolean; + model: CustomClassificationModelConfig; + onClose: () => void; + onSuccess: () => void; +}; + +type ObjectClassificationType = "sub_label" | "attribute"; + +type ObjectFormData = { + objectLabel: string; + objectType: ObjectClassificationType; +}; + +type StateFormData = { + classes: string[]; +}; + +export default function ClassificationModelEditDialog({ + open, + model, + onClose, + onSuccess, +}: ClassificationModelEditDialogProps) { + const { t } = useTranslation(["views/classificationModel"]); + const { data: config } = useSWR("config"); + const [isSaving, setIsSaving] = useState(false); + + const isStateModel = model.state_config !== undefined; + const isObjectModel = model.object_config !== undefined; + + const objectLabels = useMemo(() => { + if (!config) return []; + + const labels = new Set(); + + Object.values(config.cameras).forEach((cameraConfig) => { + if (!cameraConfig.enabled || !cameraConfig.enabled_in_config) { + return; + } + + cameraConfig.objects.track.forEach((label) => { + if (!config.model.all_attributes.includes(label)) { + labels.add(label); + } + }); + }); + + return [...labels].sort(); + }, [config]); + + // Define form schema based on model type + const formSchema = useMemo(() => { + if (isObjectModel) { + return z.object({ + objectLabel: z + .string() + .min(1, t("wizard.step1.errors.objectLabelRequired")), + objectType: z.enum(["sub_label", "attribute"]), + }); + } else { + // State model + return z.object({ + classes: z + .array(z.string()) + .min(1, t("wizard.step1.errors.classRequired")) + .refine( + (classes) => { + const nonEmpty = classes.filter((c) => c.trim().length > 0); + return nonEmpty.length >= 2; + }, + { message: t("wizard.step1.errors.stateRequiresTwoClasses") }, + ) + .refine( + (classes) => { + const nonEmpty = classes.filter((c) => c.trim().length > 0); + const unique = new Set(nonEmpty.map((c) => c.toLowerCase())); + return unique.size === nonEmpty.length; + }, + { message: t("wizard.step1.errors.classesUnique") }, + ), + }); + } + }, [isObjectModel, t]); + + const form = useForm({ + resolver: zodResolver(formSchema), + defaultValues: isObjectModel + ? ({ + objectLabel: model.object_config?.objects?.[0] || "", + objectType: + (model.object_config + ?.classification_type as ObjectClassificationType) || "sub_label", + } as ObjectFormData) + : ({ + classes: [""], // Will be populated from dataset + } as StateFormData), + mode: "onChange", + }); + + // Fetch dataset to get current classes for state models + const { data: dataset } = useSWR<{ + [id: string]: string[]; + }>(isStateModel ? `classification/${model.name}/dataset` : null, { + revalidateOnFocus: false, + }); + + // Update form with classes from dataset when loaded + useEffect(() => { + if (isStateModel && dataset) { + const classes = Object.keys(dataset).filter((key) => key !== "none"); + if (classes.length > 0) { + (form as ReturnType>).setValue( + "classes", + classes, + ); + } + } + }, [dataset, isStateModel, form]); + + const watchedClasses = isStateModel + ? (form as ReturnType>).watch("classes") + : undefined; + const watchedObjectType = isObjectModel + ? (form as ReturnType>).watch("objectType") + : undefined; + + const handleAddClass = useCallback(() => { + const currentClasses = ( + form as ReturnType> + ).getValues("classes"); + (form as ReturnType>).setValue( + "classes", + [...currentClasses, ""], + { + shouldValidate: true, + }, + ); + }, [form]); + + const handleRemoveClass = useCallback( + (index: number) => { + const currentClasses = ( + form as ReturnType> + ).getValues("classes"); + const newClasses = currentClasses.filter((_, i) => i !== index); + + // Ensure at least one field remains (even if empty) + if (newClasses.length === 0) { + (form as ReturnType>).setValue( + "classes", + [""], + { shouldValidate: true }, + ); + } else { + (form as ReturnType>).setValue( + "classes", + newClasses, + { shouldValidate: true }, + ); + } + }, + [form], + ); + + const onSubmit = useCallback( + async (data: ObjectFormData | StateFormData) => { + setIsSaving(true); + try { + if (isObjectModel) { + const objectData = data as ObjectFormData; + + // Update the config + await axios.put("/config/set", { + requires_restart: 0, + update_topic: `config/classification/custom/${model.name}`, + config_data: { + classification: { + custom: { + [model.name]: { + enabled: model.enabled, + name: model.name, + threshold: model.threshold, + object_config: { + objects: [objectData.objectLabel], + classification_type: objectData.objectType, + }, + }, + }, + }, + }, + }); + + toast.success(t("toast.success.updatedModel"), { + position: "top-center", + }); + } else { + // State model - update classes + // Note: For state models, updating classes requires renaming categories + // which is handled through the dataset API, not the config API + // We'll need to implement this by calling the rename endpoint for each class + // For now, we just show a message that this requires retraining + + toast.info(t("edit.stateClassesInfo"), { + position: "top-center", + }); + } + + onSuccess(); + onClose(); + } catch (err) { + const error = err as { + response?: { data?: { message?: string; detail?: string } }; + }; + const errorMessage = + error.response?.data?.message || + error.response?.data?.detail || + "Unknown error"; + toast.error(t("toast.error.updateModelFailed", { errorMessage }), { + position: "top-center", + }); + } finally { + setIsSaving(false); + } + }, + [isObjectModel, model, t, onSuccess, onClose], + ); + + const handleCancel = useCallback(() => { + form.reset(); + onClose(); + }, [form, onClose]); + + return ( + !open && handleCancel()}> + + + {t("edit.title")} + + {isStateModel + ? t("edit.descriptionState") + : t("edit.descriptionObject")} + + + +
+
+ + {isObjectModel && ( + <> + ( + + + {t("wizard.step1.objectLabel")} + + + + + )} + /> + + ( + + + {t("wizard.step1.classificationType")} + + + +
+ + +
+
+ + +
+
+
+ +
+ )} + /> + + )} + + {isStateModel && ( +
+
+ + {t("wizard.step1.states")} + + +
+
+ {watchedClasses?.map((_: string, index: number) => ( + >) + .control + } + name={`classes.${index}` as const} + render={({ field }) => ( + + +
+ + {watchedClasses && + watchedClasses.length > 1 && ( + + )} +
+
+
+ )} + /> + ))} +
+ {isStateModel && + "classes" in form.formState.errors && + form.formState.errors.classes && ( +

+ {form.formState.errors.classes.message} +

+ )} +
+ )} + +
+ + +
+ + +
+
+
+ ); +} diff --git a/web/src/types/frigateConfig.ts b/web/src/types/frigateConfig.ts index ffe4cc14d..f10563379 100644 --- a/web/src/types/frigateConfig.ts +++ b/web/src/types/frigateConfig.ts @@ -306,6 +306,7 @@ export type CustomClassificationModelConfig = { threshold: number; object_config?: { objects: string[]; + classification_type: string; }; state_config?: { cameras: { diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx index b1b462497..c5e65e0e5 100644 --- a/web/src/views/classification/ModelSelectionView.tsx +++ b/web/src/views/classification/ModelSelectionView.tsx @@ -1,5 +1,6 @@ import { baseUrl } from "@/api/baseUrl"; import ClassificationModelWizardDialog from "@/components/classification/ClassificationModelWizardDialog"; +import ClassificationModelEditDialog from "@/components/classification/ClassificationModelEditDialog"; import ActivityIndicator from "@/components/indicators/activity-indicator"; import { ImageShadowOverlay } from "@/components/overlay/ImageShadowOverlay"; import { Button, buttonVariants } from "@/components/ui/button"; @@ -14,7 +15,7 @@ import { useCallback, useEffect, useMemo, useState } from "react"; import { useTranslation } from "react-i18next"; import { FaFolderPlus } from "react-icons/fa"; import { MdModelTraining } from "react-icons/md"; -import { LuTrash2 } from "react-icons/lu"; +import { LuPencil, LuTrash2 } from "react-icons/lu"; import { FiMoreVertical } from "react-icons/fi"; import useSWR from "swr"; import Heading from "@/components/ui/heading"; @@ -163,6 +164,7 @@ export default function ModelSelectionView({ key={config.name} config={config} onClick={() => onClick(config)} + onUpdate={() => refreshConfig()} onDelete={() => refreshConfig()} /> ))} @@ -201,9 +203,10 @@ function NoModelsView({ type ModelCardProps = { config: CustomClassificationModelConfig; onClick: () => void; + onUpdate: () => void; onDelete: () => void; }; -function ModelCard({ config, onClick, onDelete }: ModelCardProps) { +function ModelCard({ config, onClick, onUpdate, onDelete }: ModelCardProps) { const { t } = useTranslation(["views/classificationModel"]); const { data: dataset } = useSWR<{ @@ -211,6 +214,7 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) { }>(`classification/${config.name}/dataset`, { revalidateOnFocus: false }); const [deleteDialogOpen, setDeleteDialogOpen] = useState(false); + const [editDialogOpen, setEditDialogOpen] = useState(false); const handleDelete = useCallback(async () => { try { @@ -250,6 +254,11 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) { setDeleteDialogOpen(true); }, []); + const handleEditClick = useCallback((e: React.MouseEvent) => { + e.stopPropagation(); + setEditDialogOpen(true); + }, []); + const coverImage = useMemo(() => { if (!dataset) { return undefined; @@ -270,6 +279,13 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) { return ( <> + setEditDialogOpen(false)} + onSuccess={() => onUpdate()} + /> + setDeleteDialogOpen(!deleteDialogOpen)} @@ -320,6 +336,10 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) { align="end" onClick={(e) => e.stopPropagation()} > + + + {t("button.edit", { ns: "common" })} + {t("button.delete", { ns: "common" })}