Classification Improvements (#20807)

* Don't show model selection or back button when in multi select mode

* Add dialog to edit classification models

* Fix header spacing

* Cleanup desktop

* Incrase max number of object classifications

* fix iOS mobile card

* Cleanup
This commit is contained in:
Nicolas Mowen 2025-11-05 07:11:12 -07:00 committed by GitHub
parent 043bd9e6ee
commit 81faa8899d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 560 additions and 43 deletions

View File

@ -466,6 +466,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
now, now,
self.labelmap[best_id], self.labelmap[best_id],
score, score,
max_files=200,
) )
if score < self.model_config.threshold: if score < self.model_config.threshold:
@ -529,6 +530,7 @@ def write_classification_attempt(
timestamp: float, timestamp: float,
label: str, label: str,
score: float, score: float,
max_files: int = 100,
) -> None: ) -> None:
if "-" in label: if "-" in label:
label = label.replace("-", "_") label = label.replace("-", "_")
@ -544,5 +546,5 @@ def write_classification_attempt(
) )
# delete oldest face image if maximum is reached # delete oldest face image if maximum is reached
if len(files) > 100: if len(files) > max_files:
os.unlink(os.path.join(folder, files[-1])) os.unlink(os.path.join(folder, files[-1]))

View File

@ -10,7 +10,8 @@
"deleteImages": "Delete Images", "deleteImages": "Delete Images",
"trainModel": "Train Model", "trainModel": "Train Model",
"addClassification": "Add Classification", "addClassification": "Add Classification",
"deleteModels": "Delete Models" "deleteModels": "Delete Models",
"editModel": "Edit Model"
}, },
"toast": { "toast": {
"success": { "success": {
@ -20,14 +21,16 @@
"deletedModel_other": "Successfully deleted {{count}} models", "deletedModel_other": "Successfully deleted {{count}} models",
"categorizedImage": "Successfully Classified Image", "categorizedImage": "Successfully Classified Image",
"trainedModel": "Successfully trained model.", "trainedModel": "Successfully trained model.",
"trainingModel": "Successfully started model training." "trainingModel": "Successfully started model training.",
"updatedModel": "Successfully updated model configuration"
}, },
"error": { "error": {
"deleteImageFailed": "Failed to delete: {{errorMessage}}", "deleteImageFailed": "Failed to delete: {{errorMessage}}",
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}", "deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
"deleteModelFailed": "Failed to delete model: {{errorMessage}}", "deleteModelFailed": "Failed to delete model: {{errorMessage}}",
"categorizeFailed": "Failed to categorize image: {{errorMessage}}", "categorizeFailed": "Failed to categorize image: {{errorMessage}}",
"trainingFailed": "Failed to start model training: {{errorMessage}}" "trainingFailed": "Failed to start model training: {{errorMessage}}",
"updateModelFailed": "Failed to update model: {{errorMessage}}"
} }
}, },
"deleteCategory": { "deleteCategory": {
@ -39,6 +42,12 @@
"single": "Are you sure you want to delete {{name}}? This will permanently delete all associated data including images and training data. This action cannot be undone.", "single": "Are you sure you want to delete {{name}}? This will permanently delete all associated data including images and training data. This action cannot be undone.",
"desc": "Are you sure you want to delete {{count}} model(s)? This will permanently delete all associated data including images and training data. This action cannot be undone." "desc": "Are you sure you want to delete {{count}} model(s)? This will permanently delete all associated data including images and training data. This action cannot be undone."
}, },
"edit": {
"title": "Edit Classification Model",
"descriptionState": "Edit the classes for this state classification model. Changes will require retraining the model.",
"descriptionObject": "Edit the object type and classification type for this object classification model.",
"stateClassesInfo": "Note: Changing state classes requires retraining the model with the updated classes."
},
"deleteDatasetImages": { "deleteDatasetImages": {
"title": "Delete Dataset Images", "title": "Delete Dataset Images",
"desc": "Are you sure you want to delete {{count}} images from {{dataset}}? This action cannot be undone and will require re-training the model." "desc": "Are you sure you want to delete {{count}} images from {{dataset}}? This action cannot be undone and will require re-training the model."

View File

@ -7,7 +7,7 @@ import {
} from "@/types/classification"; } from "@/types/classification";
import { Event } from "@/types/event"; import { Event } from "@/types/event";
import { forwardRef, useMemo, useRef, useState } from "react"; import { forwardRef, useMemo, useRef, useState } from "react";
import { isDesktop, isMobile } from "react-device-detect"; import { isDesktop, isMobile, isMobileOnly } from "react-device-detect";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import TimeAgo from "../dynamic/TimeAgo"; import TimeAgo from "../dynamic/TimeAgo";
import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip"; import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip";
@ -264,8 +264,8 @@ export function GroupedClassificationCard({
const Overlay = isDesktop ? Dialog : MobilePage; const Overlay = isDesktop ? Dialog : MobilePage;
const Trigger = isDesktop ? DialogTrigger : MobilePageTrigger; const Trigger = isDesktop ? DialogTrigger : MobilePageTrigger;
const Header = isDesktop ? DialogHeader : MobilePageHeader;
const Content = isDesktop ? DialogContent : MobilePageContent; const Content = isDesktop ? DialogContent : MobilePageContent;
const Header = isDesktop ? DialogHeader : MobilePageHeader;
const ContentTitle = isDesktop ? DialogTitle : MobilePageTitle; const ContentTitle = isDesktop ? DialogTitle : MobilePageTitle;
const ContentDescription = isDesktop const ContentDescription = isDesktop
? DialogDescription ? DialogDescription
@ -298,9 +298,9 @@ export function GroupedClassificationCard({
<Trigger asChild></Trigger> <Trigger asChild></Trigger>
<Content <Content
className={cn( className={cn(
"", "scrollbar-container",
isDesktop && "min-w-[50%] max-w-[65%]", isDesktop && "min-w-[50%] max-w-[65%]",
isMobile && "flex flex-col", isMobile && "overflow-y-auto",
)} )}
onOpenAutoFocus={(e) => e.preventDefault()} onOpenAutoFocus={(e) => e.preventDefault()}
> >
@ -308,16 +308,16 @@ export function GroupedClassificationCard({
<Header <Header
className={cn( className={cn(
"mx-2 flex flex-row items-center gap-4", "mx-2 flex flex-row items-center gap-4",
isMobile && "flex-shrink-0", isMobileOnly && "top-0 mx-4",
)} )}
> >
<div> <div
<ContentTitle
className={cn( className={cn(
"flex items-center gap-2 font-normal capitalize", "",
isMobile && "px-2", isMobile && "flex flex-col items-center justify-center",
)} )}
> >
<ContentTitle className="flex items-center gap-2 font-normal capitalize">
{event?.sub_label && event.sub_label !== "none" {event?.sub_label && event.sub_label !== "none"
? event.sub_label ? event.sub_label
: t(noClassificationLabel)} : t(noClassificationLabel)}
@ -390,7 +390,7 @@ export function GroupedClassificationCard({
className={cn( className={cn(
"grid w-full auto-rows-min grid-cols-2 gap-2 sm:grid-cols-3 md:grid-cols-4 lg:grid-cols-6 xl:grid-cols-6 2xl:grid-cols-8", "grid w-full auto-rows-min grid-cols-2 gap-2 sm:grid-cols-3 md:grid-cols-4 lg:grid-cols-6 xl:grid-cols-6 2xl:grid-cols-8",
isDesktop && "p-2", isDesktop && "p-2",
isMobile && "scrollbar-container flex-1 overflow-y-auto", isMobile && "px-4 pb-4",
)} )}
> >
{group.map((data: ClassificationItemData) => ( {group.map((data: ClassificationItemData) => (

View File

@ -0,0 +1,477 @@
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import {
Form,
FormControl,
FormField,
FormItem,
FormLabel,
FormMessage,
} from "@/components/ui/form";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import {
CustomClassificationModelConfig,
FrigateConfig,
} from "@/types/frigateConfig";
import { getTranslatedLabel } from "@/utils/i18n";
import { zodResolver } from "@hookform/resolvers/zod";
import axios from "axios";
import { useCallback, useEffect, useMemo, useState } from "react";
import { useForm } from "react-hook-form";
import { useTranslation } from "react-i18next";
import { LuPlus, LuX } from "react-icons/lu";
import { toast } from "sonner";
import useSWR from "swr";
import { z } from "zod";
type ClassificationModelEditDialogProps = {
open: boolean;
model: CustomClassificationModelConfig;
onClose: () => void;
onSuccess: () => void;
};
type ObjectClassificationType = "sub_label" | "attribute";
type ObjectFormData = {
objectLabel: string;
objectType: ObjectClassificationType;
};
type StateFormData = {
classes: string[];
};
export default function ClassificationModelEditDialog({
open,
model,
onClose,
onSuccess,
}: ClassificationModelEditDialogProps) {
const { t } = useTranslation(["views/classificationModel"]);
const { data: config } = useSWR<FrigateConfig>("config");
const [isSaving, setIsSaving] = useState(false);
const isStateModel = model.state_config !== undefined;
const isObjectModel = model.object_config !== undefined;
const objectLabels = useMemo(() => {
if (!config) return [];
const labels = new Set<string>();
Object.values(config.cameras).forEach((cameraConfig) => {
if (!cameraConfig.enabled || !cameraConfig.enabled_in_config) {
return;
}
cameraConfig.objects.track.forEach((label) => {
if (!config.model.all_attributes.includes(label)) {
labels.add(label);
}
});
});
return [...labels].sort();
}, [config]);
// Define form schema based on model type
const formSchema = useMemo(() => {
if (isObjectModel) {
return z.object({
objectLabel: z
.string()
.min(1, t("wizard.step1.errors.objectLabelRequired")),
objectType: z.enum(["sub_label", "attribute"]),
});
} else {
// State model
return z.object({
classes: z
.array(z.string())
.min(1, t("wizard.step1.errors.classRequired"))
.refine(
(classes) => {
const nonEmpty = classes.filter((c) => c.trim().length > 0);
return nonEmpty.length >= 2;
},
{ message: t("wizard.step1.errors.stateRequiresTwoClasses") },
)
.refine(
(classes) => {
const nonEmpty = classes.filter((c) => c.trim().length > 0);
const unique = new Set(nonEmpty.map((c) => c.toLowerCase()));
return unique.size === nonEmpty.length;
},
{ message: t("wizard.step1.errors.classesUnique") },
),
});
}
}, [isObjectModel, t]);
const form = useForm<ObjectFormData | StateFormData>({
resolver: zodResolver(formSchema),
defaultValues: isObjectModel
? ({
objectLabel: model.object_config?.objects?.[0] || "",
objectType:
(model.object_config
?.classification_type as ObjectClassificationType) || "sub_label",
} as ObjectFormData)
: ({
classes: [""], // Will be populated from dataset
} as StateFormData),
mode: "onChange",
});
// Fetch dataset to get current classes for state models
const { data: dataset } = useSWR<{
[id: string]: string[];
}>(isStateModel ? `classification/${model.name}/dataset` : null, {
revalidateOnFocus: false,
});
// Update form with classes from dataset when loaded
useEffect(() => {
if (isStateModel && dataset) {
const classes = Object.keys(dataset).filter((key) => key !== "none");
if (classes.length > 0) {
(form as ReturnType<typeof useForm<StateFormData>>).setValue(
"classes",
classes,
);
}
}
}, [dataset, isStateModel, form]);
const watchedClasses = isStateModel
? (form as ReturnType<typeof useForm<StateFormData>>).watch("classes")
: undefined;
const watchedObjectType = isObjectModel
? (form as ReturnType<typeof useForm<ObjectFormData>>).watch("objectType")
: undefined;
const handleAddClass = useCallback(() => {
const currentClasses = (
form as ReturnType<typeof useForm<StateFormData>>
).getValues("classes");
(form as ReturnType<typeof useForm<StateFormData>>).setValue(
"classes",
[...currentClasses, ""],
{
shouldValidate: true,
},
);
}, [form]);
const handleRemoveClass = useCallback(
(index: number) => {
const currentClasses = (
form as ReturnType<typeof useForm<StateFormData>>
).getValues("classes");
const newClasses = currentClasses.filter((_, i) => i !== index);
// Ensure at least one field remains (even if empty)
if (newClasses.length === 0) {
(form as ReturnType<typeof useForm<StateFormData>>).setValue(
"classes",
[""],
{ shouldValidate: true },
);
} else {
(form as ReturnType<typeof useForm<StateFormData>>).setValue(
"classes",
newClasses,
{ shouldValidate: true },
);
}
},
[form],
);
const onSubmit = useCallback(
async (data: ObjectFormData | StateFormData) => {
setIsSaving(true);
try {
if (isObjectModel) {
const objectData = data as ObjectFormData;
// Update the config
await axios.put("/config/set", {
requires_restart: 0,
update_topic: `config/classification/custom/${model.name}`,
config_data: {
classification: {
custom: {
[model.name]: {
enabled: model.enabled,
name: model.name,
threshold: model.threshold,
object_config: {
objects: [objectData.objectLabel],
classification_type: objectData.objectType,
},
},
},
},
},
});
toast.success(t("toast.success.updatedModel"), {
position: "top-center",
});
} else {
// State model - update classes
// Note: For state models, updating classes requires renaming categories
// which is handled through the dataset API, not the config API
// We'll need to implement this by calling the rename endpoint for each class
// For now, we just show a message that this requires retraining
toast.info(t("edit.stateClassesInfo"), {
position: "top-center",
});
}
onSuccess();
onClose();
} catch (err) {
const error = err as {
response?: { data?: { message?: string; detail?: string } };
};
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
toast.error(t("toast.error.updateModelFailed", { errorMessage }), {
position: "top-center",
});
} finally {
setIsSaving(false);
}
},
[isObjectModel, model, t, onSuccess, onClose],
);
const handleCancel = useCallback(() => {
form.reset();
onClose();
}, [form, onClose]);
return (
<Dialog open={open} onOpenChange={(open) => !open && handleCancel()}>
<DialogContent>
<DialogHeader>
<DialogTitle>{t("edit.title")}</DialogTitle>
<DialogDescription>
{isStateModel
? t("edit.descriptionState")
: t("edit.descriptionObject")}
</DialogDescription>
</DialogHeader>
<div className="space-y-6">
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
{isObjectModel && (
<>
<FormField
control={form.control}
name="objectLabel"
render={({ field }) => (
<FormItem>
<FormLabel className="text-primary-variant">
{t("wizard.step1.objectLabel")}
</FormLabel>
<Select
onValueChange={field.onChange}
defaultValue={field.value}
>
<FormControl>
<SelectTrigger className="h-8">
<SelectValue
placeholder={t(
"wizard.step1.objectLabelPlaceholder",
)}
/>
</SelectTrigger>
</FormControl>
<SelectContent>
{objectLabels.map((label) => (
<SelectItem
key={label}
value={label}
className="cursor-pointer hover:bg-secondary-highlight"
>
{getTranslatedLabel(label)}
</SelectItem>
))}
</SelectContent>
</Select>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="objectType"
render={({ field }) => (
<FormItem>
<FormLabel className="text-primary-variant">
{t("wizard.step1.classificationType")}
</FormLabel>
<FormControl>
<RadioGroup
onValueChange={field.onChange}
defaultValue={field.value}
className="flex flex-col gap-4 pt-2"
>
<div className="flex items-center gap-2">
<RadioGroupItem
className={
watchedObjectType === "sub_label"
? "bg-selected from-selected/50 to-selected/90 text-selected"
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
}
id="sub_label"
value="sub_label"
/>
<Label
className="cursor-pointer"
htmlFor="sub_label"
>
{t("wizard.step1.classificationSubLabel")}
</Label>
</div>
<div className="flex items-center gap-2">
<RadioGroupItem
className={
watchedObjectType === "attribute"
? "bg-selected from-selected/50 to-selected/90 text-selected"
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
}
id="attribute"
value="attribute"
/>
<Label
className="cursor-pointer"
htmlFor="attribute"
>
{t("wizard.step1.classificationAttribute")}
</Label>
</div>
</RadioGroup>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
</>
)}
{isStateModel && (
<div className="space-y-2">
<div className="flex items-center justify-between">
<FormLabel className="text-primary-variant">
{t("wizard.step1.states")}
</FormLabel>
<Button
type="button"
variant="secondary"
className="size-6 rounded-md bg-secondary-foreground p-1 text-background"
onClick={handleAddClass}
>
<LuPlus />
</Button>
</div>
<div className="space-y-2">
{watchedClasses?.map((_: string, index: number) => (
<FormField
key={index}
control={
(form as ReturnType<typeof useForm<StateFormData>>)
.control
}
name={`classes.${index}` as const}
render={({ field }) => (
<FormItem>
<FormControl>
<div className="flex items-center gap-2">
<Input
className="text-md h-8"
placeholder={t(
"wizard.step1.classPlaceholder",
)}
{...field}
/>
{watchedClasses &&
watchedClasses.length > 1 && (
<Button
type="button"
variant="ghost"
size="sm"
className="h-8 w-8 p-0"
onClick={() => handleRemoveClass(index)}
>
<LuX className="size-4" />
</Button>
)}
</div>
</FormControl>
</FormItem>
)}
/>
))}
</div>
{isStateModel &&
"classes" in form.formState.errors &&
form.formState.errors.classes && (
<p className="text-sm font-medium text-destructive">
{form.formState.errors.classes.message}
</p>
)}
</div>
)}
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
<Button
type="button"
onClick={handleCancel}
className="sm:flex-1"
disabled={isSaving}
>
{t("button.cancel", { ns: "common" })}
</Button>
<Button
type="submit"
variant="select"
className="flex items-center justify-center gap-2 sm:flex-1"
disabled={!form.formState.isValid || isSaving}
>
{isSaving
? t("button.saving", { ns: "common" })
: t("button.save", { ns: "common" })}
</Button>
</div>
</form>
</Form>
</div>
</DialogContent>
</Dialog>
);
}

View File

@ -306,6 +306,7 @@ export type CustomClassificationModelConfig = {
threshold: number; threshold: number;
object_config?: { object_config?: {
objects: string[]; objects: string[];
classification_type: string;
}; };
state_config?: { state_config?: {
cameras: { cameras: {

View File

@ -1,5 +1,6 @@
import { baseUrl } from "@/api/baseUrl"; import { baseUrl } from "@/api/baseUrl";
import ClassificationModelWizardDialog from "@/components/classification/ClassificationModelWizardDialog"; import ClassificationModelWizardDialog from "@/components/classification/ClassificationModelWizardDialog";
import ClassificationModelEditDialog from "@/components/classification/ClassificationModelEditDialog";
import ActivityIndicator from "@/components/indicators/activity-indicator"; import ActivityIndicator from "@/components/indicators/activity-indicator";
import { ImageShadowOverlay } from "@/components/overlay/ImageShadowOverlay"; import { ImageShadowOverlay } from "@/components/overlay/ImageShadowOverlay";
import { Button, buttonVariants } from "@/components/ui/button"; import { Button, buttonVariants } from "@/components/ui/button";
@ -14,7 +15,7 @@ import { useCallback, useEffect, useMemo, useState } from "react";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { FaFolderPlus } from "react-icons/fa"; import { FaFolderPlus } from "react-icons/fa";
import { MdModelTraining } from "react-icons/md"; import { MdModelTraining } from "react-icons/md";
import { LuTrash2 } from "react-icons/lu"; import { LuPencil, LuTrash2 } from "react-icons/lu";
import { FiMoreVertical } from "react-icons/fi"; import { FiMoreVertical } from "react-icons/fi";
import useSWR from "swr"; import useSWR from "swr";
import Heading from "@/components/ui/heading"; import Heading from "@/components/ui/heading";
@ -163,6 +164,7 @@ export default function ModelSelectionView({
key={config.name} key={config.name}
config={config} config={config}
onClick={() => onClick(config)} onClick={() => onClick(config)}
onUpdate={() => refreshConfig()}
onDelete={() => refreshConfig()} onDelete={() => refreshConfig()}
/> />
))} ))}
@ -201,9 +203,10 @@ function NoModelsView({
type ModelCardProps = { type ModelCardProps = {
config: CustomClassificationModelConfig; config: CustomClassificationModelConfig;
onClick: () => void; onClick: () => void;
onUpdate: () => void;
onDelete: () => void; onDelete: () => void;
}; };
function ModelCard({ config, onClick, onDelete }: ModelCardProps) { function ModelCard({ config, onClick, onUpdate, onDelete }: ModelCardProps) {
const { t } = useTranslation(["views/classificationModel"]); const { t } = useTranslation(["views/classificationModel"]);
const { data: dataset } = useSWR<{ const { data: dataset } = useSWR<{
@ -211,6 +214,7 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) {
}>(`classification/${config.name}/dataset`, { revalidateOnFocus: false }); }>(`classification/${config.name}/dataset`, { revalidateOnFocus: false });
const [deleteDialogOpen, setDeleteDialogOpen] = useState(false); const [deleteDialogOpen, setDeleteDialogOpen] = useState(false);
const [editDialogOpen, setEditDialogOpen] = useState(false);
const handleDelete = useCallback(async () => { const handleDelete = useCallback(async () => {
try { try {
@ -250,6 +254,11 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) {
setDeleteDialogOpen(true); setDeleteDialogOpen(true);
}, []); }, []);
const handleEditClick = useCallback((e: React.MouseEvent) => {
e.stopPropagation();
setEditDialogOpen(true);
}, []);
const coverImage = useMemo(() => { const coverImage = useMemo(() => {
if (!dataset) { if (!dataset) {
return undefined; return undefined;
@ -270,6 +279,13 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) {
return ( return (
<> <>
<ClassificationModelEditDialog
open={editDialogOpen}
model={config}
onClose={() => setEditDialogOpen(false)}
onSuccess={() => onUpdate()}
/>
<AlertDialog <AlertDialog
open={deleteDialogOpen} open={deleteDialogOpen}
onOpenChange={() => setDeleteDialogOpen(!deleteDialogOpen)} onOpenChange={() => setDeleteDialogOpen(!deleteDialogOpen)}
@ -320,6 +336,10 @@ function ModelCard({ config, onClick, onDelete }: ModelCardProps) {
align="end" align="end"
onClick={(e) => e.stopPropagation()} onClick={(e) => e.stopPropagation()}
> >
<DropdownMenuItem onClick={handleEditClick}>
<LuPencil className="mr-2 size-4" />
<span>{t("button.edit", { ns: "common" })}</span>
</DropdownMenuItem>
<DropdownMenuItem onClick={handleDeleteClick}> <DropdownMenuItem onClick={handleDeleteClick}>
<LuTrash2 className="mr-2 size-4" /> <LuTrash2 className="mr-2 size-4" />
<span>{t("button.delete", { ns: "common" })}</span> <span>{t("button.delete", { ns: "common" })}</span>

View File

@ -327,6 +327,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</AlertDialog> </AlertDialog>
<div className="flex flex-row justify-between gap-2 p-2 align-middle"> <div className="flex flex-row justify-between gap-2 p-2 align-middle">
{(isDesktop || !selectedImages?.length) && (
<div className="flex flex-row items-center justify-center gap-2"> <div className="flex flex-row items-center justify-center gap-2">
<Button <Button
className="flex items-center gap-2.5 rounded-lg" className="flex items-center gap-2.5 rounded-lg"
@ -340,6 +341,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</div> </div>
)} )}
</Button> </Button>
<LibrarySelector <LibrarySelector
pageToggle={pageToggle} pageToggle={pageToggle}
dataset={dataset || {}} dataset={dataset || {}}
@ -349,9 +351,15 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
onRename={() => {}} onRename={() => {}}
/> />
</div> </div>
)}
{selectedImages?.length > 0 ? ( {selectedImages?.length > 0 ? (
<div className="flex items-center justify-center gap-2"> <div
<div className="mx-1 flex w-48 items-center justify-center text-sm text-muted-foreground"> className={cn(
"flex w-full items-center justify-end gap-2",
isMobileOnly && "justify-between",
)}
>
<div className="flex w-48 items-center justify-center text-sm text-muted-foreground">
<div className="p-1">{`${selectedImages.length} selected`}</div> <div className="p-1">{`${selectedImages.length} selected`}</div>
<div className="p-1">{"|"}</div> <div className="p-1">{"|"}</div>
<div <div