diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 65a62b568..e4048b6ec 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -466,6 +466,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): now, self.labelmap[best_id], score, + max_files=200, ) if score < self.model_config.threshold: @@ -529,6 +530,7 @@ def write_classification_attempt( timestamp: float, label: str, score: float, + max_files: int = 100, ) -> None: if "-" in label: label = label.replace("-", "_") @@ -544,5 +546,5 @@ def write_classification_attempt( ) # delete oldest face image if maximum is reached - if len(files) > 100: + if len(files) > max_files: os.unlink(os.path.join(folder, files[-1])) diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index 59c8e53d4..ebc819e4e 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -10,7 +10,8 @@ "deleteImages": "Delete Images", "trainModel": "Train Model", "addClassification": "Add Classification", - "deleteModels": "Delete Models" + "deleteModels": "Delete Models", + "editModel": "Edit Model" }, "toast": { "success": { @@ -20,14 +21,16 @@ "deletedModel_other": "Successfully deleted {{count}} models", "categorizedImage": "Successfully Classified Image", "trainedModel": "Successfully trained model.", - "trainingModel": "Successfully started model training." + "trainingModel": "Successfully started model training.", + "updatedModel": "Successfully updated model configuration" }, "error": { "deleteImageFailed": "Failed to delete: {{errorMessage}}", "deleteCategoryFailed": "Failed to delete class: {{errorMessage}}", "deleteModelFailed": "Failed to delete model: {{errorMessage}}", "categorizeFailed": "Failed to categorize image: {{errorMessage}}", - "trainingFailed": "Failed to start model training: {{errorMessage}}" + "trainingFailed": "Failed to start model training: {{errorMessage}}", + "updateModelFailed": "Failed to update model: {{errorMessage}}" } }, "deleteCategory": { @@ -39,6 +42,12 @@ "single": "Are you sure you want to delete {{name}}? This will permanently delete all associated data including images and training data. This action cannot be undone.", "desc": "Are you sure you want to delete {{count}} model(s)? This will permanently delete all associated data including images and training data. This action cannot be undone." }, + "edit": { + "title": "Edit Classification Model", + "descriptionState": "Edit the classes for this state classification model. Changes will require retraining the model.", + "descriptionObject": "Edit the object type and classification type for this object classification model.", + "stateClassesInfo": "Note: Changing state classes requires retraining the model with the updated classes." + }, "deleteDatasetImages": { "title": "Delete Dataset Images", "desc": "Are you sure you want to delete {{count}} images from {{dataset}}? This action cannot be undone and will require re-training the model." diff --git a/web/src/components/card/ClassificationCard.tsx b/web/src/components/card/ClassificationCard.tsx index 73be455bc..bde452770 100644 --- a/web/src/components/card/ClassificationCard.tsx +++ b/web/src/components/card/ClassificationCard.tsx @@ -7,7 +7,7 @@ import { } from "@/types/classification"; import { Event } from "@/types/event"; import { forwardRef, useMemo, useRef, useState } from "react"; -import { isDesktop, isMobile } from "react-device-detect"; +import { isDesktop, isMobile, isMobileOnly } from "react-device-detect"; import { useTranslation } from "react-i18next"; import TimeAgo from "../dynamic/TimeAgo"; import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip"; @@ -264,8 +264,8 @@ export function GroupedClassificationCard({ const Overlay = isDesktop ? Dialog : MobilePage; const Trigger = isDesktop ? DialogTrigger : MobilePageTrigger; - const Header = isDesktop ? DialogHeader : MobilePageHeader; const Content = isDesktop ? DialogContent : MobilePageContent; + const Header = isDesktop ? DialogHeader : MobilePageHeader; const ContentTitle = isDesktop ? DialogTitle : MobilePageTitle; const ContentDescription = isDesktop ? DialogDescription @@ -298,9 +298,9 @@ export function GroupedClassificationCard({ e.preventDefault()} > @@ -308,16 +308,16 @@ export function GroupedClassificationCard({ - - + + {event?.sub_label && event.sub_label !== "none" ? event.sub_label : t(noClassificationLabel)} @@ -390,7 +390,7 @@ export function GroupedClassificationCard({ className={cn( "grid w-full auto-rows-min grid-cols-2 gap-2 sm:grid-cols-3 md:grid-cols-4 lg:grid-cols-6 xl:grid-cols-6 2xl:grid-cols-8", isDesktop && "p-2", - isMobile && "scrollbar-container flex-1 overflow-y-auto", + isMobile && "px-4 pb-4", )} > {group.map((data: ClassificationItemData) => ( diff --git a/web/src/components/classification/ClassificationModelEditDialog.tsx b/web/src/components/classification/ClassificationModelEditDialog.tsx new file mode 100644 index 000000000..ff80a1a29 --- /dev/null +++ b/web/src/components/classification/ClassificationModelEditDialog.tsx @@ -0,0 +1,477 @@ +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { + CustomClassificationModelConfig, + FrigateConfig, +} from "@/types/frigateConfig"; +import { getTranslatedLabel } from "@/utils/i18n"; +import { zodResolver } from "@hookform/resolvers/zod"; +import axios from "axios"; +import { useCallback, useEffect, useMemo, useState } from "react"; +import { useForm } from "react-hook-form"; +import { useTranslation } from "react-i18next"; +import { LuPlus, LuX } from "react-icons/lu"; +import { toast } from "sonner"; +import useSWR from "swr"; +import { z } from "zod"; + +type ClassificationModelEditDialogProps = { + open: boolean; + model: CustomClassificationModelConfig; + onClose: () => void; + onSuccess: () => void; +}; + +type ObjectClassificationType = "sub_label" | "attribute"; + +type ObjectFormData = { + objectLabel: string; + objectType: ObjectClassificationType; +}; + +type StateFormData = { + classes: string[]; +}; + +export default function ClassificationModelEditDialog({ + open, + model, + onClose, + onSuccess, +}: ClassificationModelEditDialogProps) { + const { t } = useTranslation(["views/classificationModel"]); + const { data: config } = useSWR("config"); + const [isSaving, setIsSaving] = useState(false); + + const isStateModel = model.state_config !== undefined; + const isObjectModel = model.object_config !== undefined; + + const objectLabels = useMemo(() => { + if (!config) return []; + + const labels = new Set(); + + Object.values(config.cameras).forEach((cameraConfig) => { + if (!cameraConfig.enabled || !cameraConfig.enabled_in_config) { + return; + } + + cameraConfig.objects.track.forEach((label) => { + if (!config.model.all_attributes.includes(label)) { + labels.add(label); + } + }); + }); + + return [...labels].sort(); + }, [config]); + + // Define form schema based on model type + const formSchema = useMemo(() => { + if (isObjectModel) { + return z.object({ + objectLabel: z + .string() + .min(1, t("wizard.step1.errors.objectLabelRequired")), + objectType: z.enum(["sub_label", "attribute"]), + }); + } else { + // State model + return z.object({ + classes: z + .array(z.string()) + .min(1, t("wizard.step1.errors.classRequired")) + .refine( + (classes) => { + const nonEmpty = classes.filter((c) => c.trim().length > 0); + return nonEmpty.length >= 2; + }, + { message: t("wizard.step1.errors.stateRequiresTwoClasses") }, + ) + .refine( + (classes) => { + const nonEmpty = classes.filter((c) => c.trim().length > 0); + const unique = new Set(nonEmpty.map((c) => c.toLowerCase())); + return unique.size === nonEmpty.length; + }, + { message: t("wizard.step1.errors.classesUnique") }, + ), + }); + } + }, [isObjectModel, t]); + + const form = useForm({ + resolver: zodResolver(formSchema), + defaultValues: isObjectModel + ? ({ + objectLabel: model.object_config?.objects?.[0] || "", + objectType: + (model.object_config + ?.classification_type as ObjectClassificationType) || "sub_label", + } as ObjectFormData) + : ({ + classes: [""], // Will be populated from dataset + } as StateFormData), + mode: "onChange", + }); + + // Fetch dataset to get current classes for state models + const { data: dataset } = useSWR<{ + [id: string]: string[]; + }>(isStateModel ? `classification/${model.name}/dataset` : null, { + revalidateOnFocus: false, + }); + + // Update form with classes from dataset when loaded + useEffect(() => { + if (isStateModel && dataset) { + const classes = Object.keys(dataset).filter((key) => key !== "none"); + if (classes.length > 0) { + (form as ReturnType>).setValue( + "classes", + classes, + ); + } + } + }, [dataset, isStateModel, form]); + + const watchedClasses = isStateModel + ? (form as ReturnType>).watch("classes") + : undefined; + const watchedObjectType = isObjectModel + ? (form as ReturnType>).watch("objectType") + : undefined; + + const handleAddClass = useCallback(() => { + const currentClasses = ( + form as ReturnType> + ).getValues("classes"); + (form as ReturnType>).setValue( + "classes", + [...currentClasses, ""], + { + shouldValidate: true, + }, + ); + }, [form]); + + const handleRemoveClass = useCallback( + (index: number) => { + const currentClasses = ( + form as ReturnType> + ).getValues("classes"); + const newClasses = currentClasses.filter((_, i) => i !== index); + + // Ensure at least one field remains (even if empty) + if (newClasses.length === 0) { + (form as ReturnType>).setValue( + "classes", + [""], + { shouldValidate: true }, + ); + } else { + (form as ReturnType>).setValue( + "classes", + newClasses, + { shouldValidate: true }, + ); + } + }, + [form], + ); + + const onSubmit = useCallback( + async (data: ObjectFormData | StateFormData) => { + setIsSaving(true); + try { + if (isObjectModel) { + const objectData = data as ObjectFormData; + + // Update the config + await axios.put("/config/set", { + requires_restart: 0, + update_topic: `config/classification/custom/${model.name}`, + config_data: { + classification: { + custom: { + [model.name]: { + enabled: model.enabled, + name: model.name, + threshold: model.threshold, + object_config: { + objects: [objectData.objectLabel], + classification_type: objectData.objectType, + }, + }, + }, + }, + }, + }); + + toast.success(t("toast.success.updatedModel"), { + position: "top-center", + }); + } else { + // State model - update classes + // Note: For state models, updating classes requires renaming categories + // which is handled through the dataset API, not the config API + // We'll need to implement this by calling the rename endpoint for each class + // For now, we just show a message that this requires retraining + + toast.info(t("edit.stateClassesInfo"), { + position: "top-center", + }); + } + + onSuccess(); + onClose(); + } catch (err) { + const error = err as { + response?: { data?: { message?: string; detail?: string } }; + }; + const errorMessage = + error.response?.data?.message || + error.response?.data?.detail || + "Unknown error"; + toast.error(t("toast.error.updateModelFailed", { errorMessage }), { + position: "top-center", + }); + } finally { + setIsSaving(false); + } + }, + [isObjectModel, model, t, onSuccess, onClose], + ); + + const handleCancel = useCallback(() => { + form.reset(); + onClose(); + }, [form, onClose]); + + return ( + !open && handleCancel()}> + + + {t("edit.title")} + + {isStateModel + ? t("edit.descriptionState") + : t("edit.descriptionObject")} + + + + + + + {isObjectModel && ( + <> + ( + + + {t("wizard.step1.objectLabel")} + + + + + + + + + {objectLabels.map((label) => ( + + {getTranslatedLabel(label)} + + ))} + + + + + )} + /> + + ( + + + {t("wizard.step1.classificationType")} + + + + + + + {t("wizard.step1.classificationSubLabel")} + + + + + + {t("wizard.step1.classificationAttribute")} + + + + + + + )} + /> + > + )} + + {isStateModel && ( + + + + {t("wizard.step1.states")} + + + + + + + {watchedClasses?.map((_: string, index: number) => ( + >) + .control + } + name={`classes.${index}` as const} + render={({ field }) => ( + + + + + {watchedClasses && + watchedClasses.length > 1 && ( + handleRemoveClass(index)} + > + + + )} + + + + )} + /> + ))} + + {isStateModel && + "classes" in form.formState.errors && + form.formState.errors.classes && ( + + {form.formState.errors.classes.message} + + )} + + )} + + + + {t("button.cancel", { ns: "common" })} + + + {isSaving + ? t("button.saving", { ns: "common" }) + : t("button.save", { ns: "common" })} + + + + + + + + ); +} diff --git a/web/src/types/frigateConfig.ts b/web/src/types/frigateConfig.ts index ffe4cc14d..f10563379 100644 --- a/web/src/types/frigateConfig.ts +++ b/web/src/types/frigateConfig.ts @@ -306,6 +306,7 @@ export type CustomClassificationModelConfig = { threshold: number; object_config?: { objects: string[]; + classification_type: string; }; state_config?: { cameras: { diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx index b1b462497..c5e65e0e5 100644 --- a/web/src/views/classification/ModelSelectionView.tsx +++ b/web/src/views/classification/ModelSelectionView.tsx @@ -1,5 +1,6 @@ import { baseUrl } from "@/api/baseUrl"; import ClassificationModelWizardDialog from "@/components/classification/ClassificationModelWizardDialog"; +import ClassificationModelEditDialog from "@/components/classification/ClassificationModelEditDialog"; import ActivityIndicator from "@/components/indicators/activity-indicator"; import { ImageShadowOverlay } from "@/components/overlay/ImageShadowOverlay"; import { Button, buttonVariants } from "@/components/ui/button"; @@ -14,7 +15,7 @@ import { useCallback, useEffect, useMemo, useState } from "react"; import { useTranslation } from "react-i18next"; import { FaFolderPlus } from "react-icons/fa"; import { MdModelTraining } from "react-icons/md"; -import { LuTrash2 } from "react-icons/lu"; +import { LuPencil, LuTrash2 } from "react-icons/lu"; import { FiMoreVertical } from "react-icons/fi"; import useSWR from "swr"; import Heading from "@/components/ui/heading"; @@ -163,6 +164,7 @@ export default function ModelSelectionView({ key={config.name} config={config} onClick={() => onClick(config)} + onUpdate={() => refreshConfig()} onDelete={() => refreshConfig()} /> ))} @@ -201,9 +203,10 @@ function NoModelsView({ type ModelCardProps = { config: CustomClassificationModelConfig; onClick: () => void; + onUpdate: () => void; onDelete: () => void; }; -function ModelCard({ config, onClick, onDelete }: ModelCardProps) { +function ModelCard({ config, onClick, onUpdate, onDelete }: ModelCardProps) { const { t } = useTranslation(["views/classificationModel"]); const { data: dataset } = useSWR<{ @@ -211,6 +214,7 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) { }>(`classification/${config.name}/dataset`, { revalidateOnFocus: false }); const [deleteDialogOpen, setDeleteDialogOpen] = useState(false); + const [editDialogOpen, setEditDialogOpen] = useState(false); const handleDelete = useCallback(async () => { try { @@ -250,6 +254,11 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) { setDeleteDialogOpen(true); }, []); + const handleEditClick = useCallback((e: React.MouseEvent) => { + e.stopPropagation(); + setEditDialogOpen(true); + }, []); + const coverImage = useMemo(() => { if (!dataset) { return undefined; @@ -270,6 +279,13 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) { return ( <> + setEditDialogOpen(false)} + onSuccess={() => onUpdate()} + /> + setDeleteDialogOpen(!deleteDialogOpen)} @@ -320,6 +336,10 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) { align="end" onClick={(e) => e.stopPropagation()} > + + + {t("button.edit", { ns: "common" })} + {t("button.delete", { ns: "common" })} diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 92a4adcdf..a27a06a9e 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -327,31 +327,39 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { - - navigate(-1)} - > - - {isDesktop && ( - - {t("button.back", { ns: "common" })} - - )} - - {}} - /> - + {(isDesktop || !selectedImages?.length) && ( + + navigate(-1)} + > + + {isDesktop && ( + + {t("button.back", { ns: "common" })} + + )} + + + {}} + /> + + )} {selectedImages?.length > 0 ? ( - - + + {`${selectedImages.length} selected`} {"|"}
+ {form.formState.errors.classes.message} +