Implement classification model training

This commit is contained in:
Nicolas Mowen 2025-06-04 07:06:31 -06:00
parent b3cee44f06
commit f6c9413944
4 changed files with 237 additions and 6 deletions

View File

@ -21,7 +21,7 @@ from frigate.api.defs.request.classification_body import (
from frigate.api.defs.tags import Tags from frigate.api.defs.tags import Tags
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.config.camera import DetectConfig from frigate.config.camera import DetectConfig
from frigate.const import CLIPS_DIR, FACE_DIR, MODEL_CACHE_DIR from frigate.const import CLIPS_DIR, FACE_DIR
from frigate.embeddings import EmbeddingsContext from frigate.embeddings import EmbeddingsContext
from frigate.models import Event from frigate.models import Event
from frigate.util.classification import train_classification_model from frigate.util.classification import train_classification_model
@ -537,6 +537,59 @@ def delete_classification_dataset_images(
) )
@router.post(
"/classification/{name}/dataset/categorize",
dependencies=[Depends(require_role(["admin"]))],
)
def categorize_classification_image(
request: Request, name: str, body: dict = None
):
config: FrigateConfig = request.app.frigate_config
if name not in config.classification.custom:
return JSONResponse(
content=(
{
"success": False,
"message": f"{name} is not a known classification model.",
}
),
status_code=404,
)
json: dict[str, Any] = body or {}
category = sanitize_filename(json.get("category", ""))
training_file_name = sanitize_filename(json.get("training_file", ""))
training_file = os.path.join(CLIPS_DIR, name, "train", training_file_name)
if training_file_name and not os.path.isfile(training_file):
return JSONResponse(
content=(
{
"success": False,
"message": f"Invalid filename or no file exists: {training_file_name}",
}
),
status_code=404,
)
new_name = f"{category}-{datetime.datetime.now().timestamp()}.png"
new_file_folder = os.path.join(CLIPS_DIR, name, "dataset", category)
if not os.path.exists(new_file_folder):
os.mkdir(new_file_folder)
# use opencv because webp images can not be used to train
img = cv2.imread(training_file)
cv2.imwrite(os.path.join(new_file_folder, new_name), img)
os.unlink(training_file)
return JSONResponse(
content=({"success": True, "message": "Successfully deleted faces."}),
status_code=200,
)
@router.post( @router.post(
"/classification/{name}/train/delete", "/classification/{name}/train/delete",
dependencies=[Depends(require_role(["admin"]))], dependencies=[Depends(require_role(["admin"]))],

View File

@ -8,11 +8,13 @@
"toast": { "toast": {
"success": { "success": {
"deletedCategory": "Deleted Category", "deletedCategory": "Deleted Category",
"deletedImage": "Deleted Images" "deletedImage": "Deleted Images",
"categorizedImage": "Successfully Categorized Image"
}, },
"error": { "error": {
"deleteImageFailed": "Failed to delete: {{errorMessage}}", "deleteImageFailed": "Failed to delete: {{errorMessage}}",
"deleteCategoryFailed": "Failed to delete category: {{errorMessage}}" "deleteCategoryFailed": "Failed to delete category: {{errorMessage}}",
"categorizeFailed": "Failed to categorize image: {{errorMessage}}"
} }
}, },
"deleteCategory": { "deleteCategory": {
@ -38,5 +40,10 @@
"title": "Train", "title": "Train",
"aria": "Select Train" "aria": "Select Train"
}, },
"categories": "Categories" "categories": "Categories",
"createCategory": {
"new": "Create New Category"
},
"categorizeImageAs": "Categorize Image As:",
"categorizeImage": "Categorize Image"
} }

View File

@ -0,0 +1,155 @@
import {
Drawer,
DrawerClose,
DrawerContent,
DrawerDescription,
DrawerHeader,
DrawerTitle,
DrawerTrigger,
} from "@/components/ui/drawer";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { isDesktop, isMobile } from "react-device-detect";
import { LuPlus } from "react-icons/lu";
import { useTranslation } from "react-i18next";
import { cn } from "@/lib/utils";
import React, { ReactNode, useCallback, useMemo, useState } from "react";
import TextEntryDialog from "./dialog/TextEntryDialog";
import { Button } from "../ui/button";
import { MdCategory } from "react-icons/md";
import axios from "axios";
import { toast } from "sonner";
type ClassificationSelectionDialogProps = {
className?: string;
categories: string[];
modelName: string;
image: string;
onRefresh: () => void;
children: ReactNode;
};
export default function ClassificationSelectionDialog({
className,
categories,
modelName,
image,
onRefresh,
children,
}: ClassificationSelectionDialogProps) {
const { t } = useTranslation(["views/classificationModel"]);
const onCategorizeImage = useCallback(
(category: string) => {
axios
.post(`/classification/${modelName}/dataset/categorize`, {
category,
training_file: image,
})
.then((resp) => {
if (resp.status == 200) {
toast.success(t("toast.success.categorizedImage"), {
position: "top-center",
});
onRefresh();
}
})
.catch((error) => {
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
toast.error(t("toast.error.categorizeFailed", { errorMessage }), {
position: "top-center",
});
});
},
[modelName, image, onRefresh, t],
);
const isChildButton = useMemo(
() => React.isValidElement(children) && children.type === Button,
[children],
);
// control
const [newFace, setNewFace] = useState(false);
// components
const Selector = isDesktop ? DropdownMenu : Drawer;
const SelectorTrigger = isDesktop ? DropdownMenuTrigger : DrawerTrigger;
const SelectorContent = isDesktop ? DropdownMenuContent : DrawerContent;
const SelectorItem = isDesktop
? DropdownMenuItem
: (props: React.HTMLAttributes<HTMLDivElement>) => (
<DrawerClose asChild>
<div {...props} className={cn(props.className, "my-2")} />
</DrawerClose>
);
return (
<div className={className ?? ""}>
{newFace && (
<TextEntryDialog
open={true}
setOpen={setNewFace}
title={t("createCategory.new")}
onSave={(newCat) => onCategorizeImage(newCat)}
/>
)}
<Tooltip>
<Selector>
<SelectorTrigger asChild>
<TooltipTrigger asChild={isChildButton}>{children}</TooltipTrigger>
</SelectorTrigger>
<SelectorContent
className={cn("", isMobile && "mx-1 gap-2 rounded-t-2xl px-4")}
>
{isMobile && (
<DrawerHeader className="sr-only">
<DrawerTitle>Details</DrawerTitle>
<DrawerDescription>Details</DrawerDescription>
</DrawerHeader>
)}
<DropdownMenuLabel>{t("categorizeImageAs")}</DropdownMenuLabel>
<div
className={cn(
"flex max-h-[40dvh] flex-col overflow-y-auto",
isMobile && "gap-2 pb-4",
)}
>
<SelectorItem
className="flex cursor-pointer gap-2 smart-capitalize"
onClick={() => setNewFace(true)}
>
<LuPlus />
{t("createCategory.new")}
</SelectorItem>
{categories.sort().map((category) => (
<SelectorItem
key={category}
className="flex cursor-pointer gap-2 smart-capitalize"
onClick={() => onCategorizeImage(category)}
>
<MdCategory />
{category}
</SelectorItem>
))}
</div>
</SelectorContent>
</Selector>
<TooltipContent>{t("categorizeImage")}</TooltipContent>
</Tooltip>
</div>
);
}

View File

@ -43,6 +43,8 @@ import { Trans, useTranslation } from "react-i18next";
import { LuPencil, LuTrash2 } from "react-icons/lu"; import { LuPencil, LuTrash2 } from "react-icons/lu";
import { toast } from "sonner"; import { toast } from "sonner";
import useSWR from "swr"; import useSWR from "swr";
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
import { TbCategoryPlus } from "react-icons/tb";
type ModelTrainingViewProps = { type ModelTrainingViewProps = {
model: CustomClassificationModelConfig; model: CustomClassificationModelConfig;
@ -242,7 +244,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</AlertDialogContent> </AlertDialogContent>
</AlertDialog> </AlertDialog>
<div className="flex flex-row justify-between gap-2 px-2 pt-2 align-middle"> <div className="flex flex-row justify-between gap-2 p-2 align-middle">
<LibrarySelector <LibrarySelector
pageToggle={pageToggle} pageToggle={pageToggle}
dataset={dataset || {}} dataset={dataset || {}}
@ -278,8 +280,10 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
{pageToggle == "train" ? ( {pageToggle == "train" ? (
<TrainGrid <TrainGrid
model={model} model={model}
categories={Object.keys(dataset || {})}
trainImages={trainImages || []} trainImages={trainImages || []}
selectedImages={selectedImages} selectedImages={selectedImages}
onRefresh={refreshTrain}
onClickImages={onClickImages} onClickImages={onClickImages}
onDelete={onDelete} onDelete={onDelete}
/> />
@ -557,16 +561,20 @@ function DatasetGrid({
type TrainGridProps = { type TrainGridProps = {
model: CustomClassificationModelConfig; model: CustomClassificationModelConfig;
categories: string[];
trainImages: string[]; trainImages: string[];
selectedImages: string[]; selectedImages: string[];
onClickImages: (images: string[], ctrl: boolean) => void; onClickImages: (images: string[], ctrl: boolean) => void;
onRefresh: () => void;
onDelete: (ids: string[]) => void; onDelete: (ids: string[]) => void;
}; };
function TrainGrid({ function TrainGrid({
model, model,
categories,
trainImages, trainImages,
selectedImages, selectedImages,
onClickImages, onClickImages,
onRefresh,
onDelete, onDelete,
}: TrainGridProps) { }: TrainGridProps) {
const { t } = useTranslation(["views/classificationModel"]); const { t } = useTranslation(["views/classificationModel"]);
@ -586,7 +594,7 @@ function TrainGrid({
); );
return ( return (
<div className="grid grid-cols-10 gap-2 overflow-y-auto p-2"> <div className="grid grid-cols-10 gap-2 overflow-y-auto px-2">
{trainData?.map((data) => ( {trainData?.map((data) => (
<div <div
key={data.timestamp} key={data.timestamp}
@ -619,6 +627,14 @@ function TrainGrid({
<div>{data.score}%</div> <div>{data.score}%</div>
</div> </div>
<div className="flex flex-row items-start justify-end gap-5 md:gap-4"> <div className="flex flex-row items-start justify-end gap-5 md:gap-4">
<ClassificationSelectionDialog
categories={categories}
modelName={model.name}
image={data.raw}
onRefresh={onRefresh}
>
<TbCategoryPlus className="size-5 cursor-pointer text-primary-variant hover:text-primary" />
</ClassificationSelectionDialog>
<Tooltip> <Tooltip>
<TooltipTrigger> <TooltipTrigger>
<LuTrash2 <LuTrash2