diff --git a/web/src/components/classification/ClassificationModelEditDialog.tsx b/web/src/components/classification/ClassificationModelEditDialog.tsx index c47765d76..a3ff2df8a 100644 --- a/web/src/components/classification/ClassificationModelEditDialog.tsx +++ b/web/src/components/classification/ClassificationModelEditDialog.tsx @@ -37,7 +37,7 @@ import { useForm } from "react-hook-form"; import { useTranslation } from "react-i18next"; import { LuPlus, LuX } from "react-icons/lu"; import { toast } from "sonner"; -import useSWR from "swr"; +import useSWR, { mutate } from "swr"; import { z } from "zod"; type ClassificationModelEditDialogProps = { @@ -240,15 +240,61 @@ export default function ClassificationModelEditDialog({ position: "top-center", }); } else { - // State model - update classes - // Note: For state models, updating classes requires renaming categories - // which is handled through the dataset API, not the config API - // We'll need to implement this by calling the rename endpoint for each class - // For now, we just show a message that this requires retraining + const stateData = data as StateFormData; + const newClasses = stateData.classes.filter( + (c) => c.trim().length > 0, + ); + const oldClasses = dataset?.categories + ? Object.keys(dataset.categories).filter((key) => key !== "none") + : []; - toast.info(t("edit.stateClassesInfo"), { - position: "top-center", - }); + const renameMap = new Map(); + const maxLength = Math.max(oldClasses.length, newClasses.length); + + for (let i = 0; i < maxLength; i++) { + const oldClass = oldClasses[i]; + const newClass = newClasses[i]; + + if (oldClass && newClass && oldClass !== newClass) { + renameMap.set(oldClass, newClass); + } + } + + const renamePromises = Array.from(renameMap.entries()).map( + async ([oldName, newName]) => { + try { + await axios.put( + `/classification/${model.name}/dataset/${oldName}/rename`, + { + new_category: newName, + }, + ); + } catch (err) { + const error = err as { + response?: { data?: { message?: string; detail?: string } }; + }; + const errorMessage = + error.response?.data?.message || + error.response?.data?.detail || + "Unknown error"; + throw new Error( + `Failed to rename ${oldName} to ${newName}: ${errorMessage}`, + ); + } + }, + ); + + if (renamePromises.length > 0) { + await Promise.all(renamePromises); + await mutate(`classification/${model.name}/dataset`); + toast.success(t("toast.success.updatedModel"), { + position: "top-center", + }); + } else { + toast.info(t("edit.stateClassesInfo"), { + position: "top-center", + }); + } } onSuccess(); @@ -256,8 +302,10 @@ export default function ClassificationModelEditDialog({ } catch (err) { const error = err as { response?: { data?: { message?: string; detail?: string } }; + message?: string; }; const errorMessage = + error.message || error.response?.data?.message || error.response?.data?.detail || "Unknown error"; @@ -268,7 +316,7 @@ export default function ClassificationModelEditDialog({ setIsSaving(false); } }, - [isObjectModel, model, t, onSuccess, onClose], + [isObjectModel, model, dataset, t, onSuccess, onClose], ); const handleCancel = useCallback(() => {