import { Button } from "@/components/ui/button"; import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, } from "@/components/ui/dialog"; import { Form, FormControl, FormField, FormItem, FormLabel, FormMessage, } from "@/components/ui/form"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue, } from "@/components/ui/select"; import { CustomClassificationModelConfig, FrigateConfig, } from "@/types/frigateConfig"; import { ClassificationDatasetResponse } from "@/types/classification"; import { getTranslatedLabel } from "@/utils/i18n"; import { zodResolver } from "@hookform/resolvers/zod"; import axios from "axios"; import { useCallback, useEffect, useMemo, useState } from "react"; import { useForm } from "react-hook-form"; import { useTranslation } from "react-i18next"; import { LuPlus, LuX } from "react-icons/lu"; import { toast } from "sonner"; import useSWR, { mutate } from "swr"; import { z } from "zod"; type ClassificationModelEditDialogProps = { open: boolean; model: CustomClassificationModelConfig; onClose: () => void; onSuccess: () => void; }; type ObjectClassificationType = "sub_label" | "attribute"; type ObjectFormData = { objectLabel: string; objectType: ObjectClassificationType; }; type StateFormData = { classes: string[]; }; export default function ClassificationModelEditDialog({ open, model, onClose, onSuccess, }: ClassificationModelEditDialogProps) { const { t } = useTranslation(["views/classificationModel"]); const { data: config } = useSWR("config"); const [isSaving, setIsSaving] = useState(false); const isStateModel = model.state_config !== undefined; const isObjectModel = model.object_config !== undefined; const objectLabels = useMemo(() => { if (!config) return []; const labels = new Set(); Object.values(config.cameras).forEach((cameraConfig) => { if (!cameraConfig.enabled || !cameraConfig.enabled_in_config) { return; } cameraConfig.objects.track.forEach((label) => { if (!config.model.all_attributes.includes(label)) { labels.add(label); } }); }); return [...labels].sort(); }, [config]); // Define form schema based on model type const formSchema = useMemo(() => { if (isObjectModel) { return z.object({ objectLabel: z .string() .min(1, t("wizard.step1.errors.objectLabelRequired")), objectType: z.enum(["sub_label", "attribute"]), }); } else { // State model return z.object({ classes: z .array(z.string()) .min(1, t("wizard.step1.errors.classRequired")) .refine( (classes) => { const nonEmpty = classes.filter((c) => c.trim().length > 0); return nonEmpty.length >= 2; }, { message: t("wizard.step1.errors.stateRequiresTwoClasses") }, ) .refine( (classes) => { const nonEmpty = classes.filter((c) => c.trim().length > 0); const unique = new Set(nonEmpty.map((c) => c.toLowerCase())); return unique.size === nonEmpty.length; }, { message: t("wizard.step1.errors.classesUnique") }, ), }); } }, [isObjectModel, t]); const form = useForm({ resolver: zodResolver(formSchema), defaultValues: isObjectModel ? ({ objectLabel: model.object_config?.objects?.[0] || "", objectType: (model.object_config ?.classification_type as ObjectClassificationType) || "sub_label", } as ObjectFormData) : ({ classes: [""], // Will be populated from dataset } as StateFormData), mode: "onChange", }); // Fetch dataset to get current classes for state models const { data: dataset } = useSWR( isStateModel ? `classification/${model.name}/dataset` : null, { revalidateOnFocus: false, }, ); // Update form with classes from dataset when loaded useEffect(() => { if (isStateModel && dataset?.categories) { const classes = Object.keys(dataset.categories).filter( (key) => key !== "none", ); if (classes.length > 0) { (form as ReturnType>).setValue( "classes", classes, ); } } }, [dataset, isStateModel, form]); const watchedClasses = isStateModel ? (form as ReturnType>).watch("classes") : undefined; const watchedObjectType = isObjectModel ? (form as ReturnType>).watch("objectType") : undefined; const handleAddClass = useCallback(() => { const currentClasses = ( form as ReturnType> ).getValues("classes"); (form as ReturnType>).setValue( "classes", [...currentClasses, ""], { shouldValidate: true, }, ); }, [form]); const handleRemoveClass = useCallback( (index: number) => { const currentClasses = ( form as ReturnType> ).getValues("classes"); const newClasses = currentClasses.filter((_, i) => i !== index); // Ensure at least one field remains (even if empty) if (newClasses.length === 0) { (form as ReturnType>).setValue( "classes", [""], { shouldValidate: true }, ); } else { (form as ReturnType>).setValue( "classes", newClasses, { shouldValidate: true }, ); } }, [form], ); const onSubmit = useCallback( async (data: ObjectFormData | StateFormData) => { setIsSaving(true); try { if (isObjectModel) { const objectData = data as ObjectFormData; // Update the config await axios.put("/config/set", { requires_restart: 0, update_topic: `config/classification/custom/${model.name}`, config_data: { classification: { custom: { [model.name]: { enabled: model.enabled, name: model.name, threshold: model.threshold, object_config: { objects: [objectData.objectLabel], classification_type: objectData.objectType, }, }, }, }, }, }); toast.success(t("toast.success.updatedModel"), { position: "top-center", }); } else { const stateData = data as StateFormData; const newClasses = stateData.classes.filter( (c) => c.trim().length > 0, ); const oldClasses = dataset?.categories ? Object.keys(dataset.categories).filter((key) => key !== "none") : []; const renameMap = new Map(); const maxLength = Math.max(oldClasses.length, newClasses.length); for (let i = 0; i < maxLength; i++) { const oldClass = oldClasses[i]; const newClass = newClasses[i]; if (oldClass && newClass && oldClass !== newClass) { renameMap.set(oldClass, newClass); } } const renamePromises = Array.from(renameMap.entries()).map( async ([oldName, newName]) => { try { await axios.put( `/classification/${model.name}/dataset/${oldName}/rename`, { new_category: newName, }, ); } catch (err) { const error = err as { response?: { data?: { message?: string; detail?: string } }; }; const errorMessage = error.response?.data?.message || error.response?.data?.detail || "Unknown error"; throw new Error( `Failed to rename ${oldName} to ${newName}: ${errorMessage}`, ); } }, ); if (renamePromises.length > 0) { await Promise.all(renamePromises); await mutate(`classification/${model.name}/dataset`); toast.success(t("toast.success.updatedModel"), { position: "top-center", }); } else { toast.info(t("edit.stateClassesInfo"), { position: "top-center", }); } } onSuccess(); onClose(); } catch (err) { const error = err as { response?: { data?: { message?: string; detail?: string } }; message?: string; }; const errorMessage = error.message || error.response?.data?.message || error.response?.data?.detail || "Unknown error"; toast.error(t("toast.error.updateModelFailed", { errorMessage }), { position: "top-center", }); } finally { setIsSaving(false); } }, [isObjectModel, model, dataset, t, onSuccess, onClose], ); const handleCancel = useCallback(() => { form.reset(); onClose(); }, [form, onClose]); return ( !open && handleCancel()}> {t("edit.title")} {isStateModel ? t("edit.descriptionState") : t("edit.descriptionObject")}
{isObjectModel && ( <> ( {t("wizard.step1.objectLabel")} )} /> ( {t("wizard.step1.classificationType")}
)} /> )} {isStateModel && (
{t("wizard.step1.states")}
{watchedClasses?.map((_: string, index: number) => ( >) .control } name={`classes.${index}` as const} render={({ field }) => (
{watchedClasses && watchedClasses.length > 1 && ( )}
)} /> ))}
{isStateModel && "classes" in form.formState.errors && form.formState.errors.classes && (

{form.formState.errors.classes.message}

)}
)}
); }