Implement classification model training

This commit is contained in:
Nicolas Mowen 2025-06-04 07:06:31 -06:00
parent b3cee44f06
commit f6c9413944
4 changed files with 237 additions and 6 deletions

View File

@ -21,7 +21,7 @@ from frigate.api.defs.request.classification_body import (
from frigate.api.defs.tags import Tags
from frigate.config import FrigateConfig
from frigate.config.camera import DetectConfig
from frigate.const import CLIPS_DIR, FACE_DIR, MODEL_CACHE_DIR
from frigate.const import CLIPS_DIR, FACE_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.models import Event
from frigate.util.classification import train_classification_model
@ -537,6 +537,59 @@ def delete_classification_dataset_images(
)
@router.post(
"/classification/{name}/dataset/categorize",
dependencies=[Depends(require_role(["admin"]))],
)
def categorize_classification_image(
request: Request, name: str, body: dict = None
):
config: FrigateConfig = request.app.frigate_config
if name not in config.classification.custom:
return JSONResponse(
content=(
{
"success": False,
"message": f"{name} is not a known classification model.",
}
),
status_code=404,
)
json: dict[str, Any] = body or {}
category = sanitize_filename(json.get("category", ""))
training_file_name = sanitize_filename(json.get("training_file", ""))
training_file = os.path.join(CLIPS_DIR, name, "train", training_file_name)
if training_file_name and not os.path.isfile(training_file):
return JSONResponse(
content=(
{
"success": False,
"message": f"Invalid filename or no file exists: {training_file_name}",
}
),
status_code=404,
)
new_name = f"{category}-{datetime.datetime.now().timestamp()}.png"
new_file_folder = os.path.join(CLIPS_DIR, name, "dataset", category)
if not os.path.exists(new_file_folder):
os.mkdir(new_file_folder)
# use opencv because webp images can not be used to train
img = cv2.imread(training_file)
cv2.imwrite(os.path.join(new_file_folder, new_name), img)
os.unlink(training_file)
return JSONResponse(
content=({"success": True, "message": "Successfully deleted faces."}),
status_code=200,
)
@router.post(
"/classification/{name}/train/delete",
dependencies=[Depends(require_role(["admin"]))],

View File

@ -8,11 +8,13 @@
"toast": {
"success": {
"deletedCategory": "Deleted Category",
"deletedImage": "Deleted Images"
"deletedImage": "Deleted Images",
"categorizedImage": "Successfully Categorized Image"
},
"error": {
"deleteImageFailed": "Failed to delete: {{errorMessage}}",
"deleteCategoryFailed": "Failed to delete category: {{errorMessage}}"
"deleteCategoryFailed": "Failed to delete category: {{errorMessage}}",
"categorizeFailed": "Failed to categorize image: {{errorMessage}}"
}
},
"deleteCategory": {
@ -38,5 +40,10 @@
"title": "Train",
"aria": "Select Train"
},
"categories": "Categories"
"categories": "Categories",
"createCategory": {
"new": "Create New Category"
},
"categorizeImageAs": "Categorize Image As:",
"categorizeImage": "Categorize Image"
}

View File

@ -0,0 +1,155 @@
import {
Drawer,
DrawerClose,
DrawerContent,
DrawerDescription,
DrawerHeader,
DrawerTitle,
DrawerTrigger,
} from "@/components/ui/drawer";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { isDesktop, isMobile } from "react-device-detect";
import { LuPlus } from "react-icons/lu";
import { useTranslation } from "react-i18next";
import { cn } from "@/lib/utils";
import React, { ReactNode, useCallback, useMemo, useState } from "react";
import TextEntryDialog from "./dialog/TextEntryDialog";
import { Button } from "../ui/button";
import { MdCategory } from "react-icons/md";
import axios from "axios";
import { toast } from "sonner";
type ClassificationSelectionDialogProps = {
className?: string;
categories: string[];
modelName: string;
image: string;
onRefresh: () => void;
children: ReactNode;
};
export default function ClassificationSelectionDialog({
className,
categories,
modelName,
image,
onRefresh,
children,
}: ClassificationSelectionDialogProps) {
const { t } = useTranslation(["views/classificationModel"]);
const onCategorizeImage = useCallback(
(category: string) => {
axios
.post(`/classification/${modelName}/dataset/categorize`, {
category,
training_file: image,
})
.then((resp) => {
if (resp.status == 200) {
toast.success(t("toast.success.categorizedImage"), {
position: "top-center",
});
onRefresh();
}
})
.catch((error) => {
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
toast.error(t("toast.error.categorizeFailed", { errorMessage }), {
position: "top-center",
});
});
},
[modelName, image, onRefresh, t],
);
const isChildButton = useMemo(
() => React.isValidElement(children) && children.type === Button,
[children],
);
// control
const [newFace, setNewFace] = useState(false);
// components
const Selector = isDesktop ? DropdownMenu : Drawer;
const SelectorTrigger = isDesktop ? DropdownMenuTrigger : DrawerTrigger;
const SelectorContent = isDesktop ? DropdownMenuContent : DrawerContent;
const SelectorItem = isDesktop
? DropdownMenuItem
: (props: React.HTMLAttributes<HTMLDivElement>) => (
<DrawerClose asChild>
<div {...props} className={cn(props.className, "my-2")} />
</DrawerClose>
);
return (
<div className={className ?? ""}>
{newFace && (
<TextEntryDialog
open={true}
setOpen={setNewFace}
title={t("createCategory.new")}
onSave={(newCat) => onCategorizeImage(newCat)}
/>
)}
<Tooltip>
<Selector>
<SelectorTrigger asChild>
<TooltipTrigger asChild={isChildButton}>{children}</TooltipTrigger>
</SelectorTrigger>
<SelectorContent
className={cn("", isMobile && "mx-1 gap-2 rounded-t-2xl px-4")}
>
{isMobile && (
<DrawerHeader className="sr-only">
<DrawerTitle>Details</DrawerTitle>
<DrawerDescription>Details</DrawerDescription>
</DrawerHeader>
)}
<DropdownMenuLabel>{t("categorizeImageAs")}</DropdownMenuLabel>
<div
className={cn(
"flex max-h-[40dvh] flex-col overflow-y-auto",
isMobile && "gap-2 pb-4",
)}
>
<SelectorItem
className="flex cursor-pointer gap-2 smart-capitalize"
onClick={() => setNewFace(true)}
>
<LuPlus />
{t("createCategory.new")}
</SelectorItem>
{categories.sort().map((category) => (
<SelectorItem
key={category}
className="flex cursor-pointer gap-2 smart-capitalize"
onClick={() => onCategorizeImage(category)}
>
<MdCategory />
{category}
</SelectorItem>
))}
</div>
</SelectorContent>
</Selector>
<TooltipContent>{t("categorizeImage")}</TooltipContent>
</Tooltip>
</div>
);
}

View File

@ -43,6 +43,8 @@ import { Trans, useTranslation } from "react-i18next";
import { LuPencil, LuTrash2 } from "react-icons/lu";
import { toast } from "sonner";
import useSWR from "swr";
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
import { TbCategoryPlus } from "react-icons/tb";
type ModelTrainingViewProps = {
model: CustomClassificationModelConfig;
@ -242,7 +244,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</AlertDialogContent>
</AlertDialog>
<div className="flex flex-row justify-between gap-2 px-2 pt-2 align-middle">
<div className="flex flex-row justify-between gap-2 p-2 align-middle">
<LibrarySelector
pageToggle={pageToggle}
dataset={dataset || {}}
@ -278,8 +280,10 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
{pageToggle == "train" ? (
<TrainGrid
model={model}
categories={Object.keys(dataset || {})}
trainImages={trainImages || []}
selectedImages={selectedImages}
onRefresh={refreshTrain}
onClickImages={onClickImages}
onDelete={onDelete}
/>
@ -557,16 +561,20 @@ function DatasetGrid({
type TrainGridProps = {
model: CustomClassificationModelConfig;
categories: string[];
trainImages: string[];
selectedImages: string[];
onClickImages: (images: string[], ctrl: boolean) => void;
onRefresh: () => void;
onDelete: (ids: string[]) => void;
};
function TrainGrid({
model,
categories,
trainImages,
selectedImages,
onClickImages,
onRefresh,
onDelete,
}: TrainGridProps) {
const { t } = useTranslation(["views/classificationModel"]);
@ -586,7 +594,7 @@ function TrainGrid({
);
return (
<div className="grid grid-cols-10 gap-2 overflow-y-auto p-2">
<div className="grid grid-cols-10 gap-2 overflow-y-auto px-2">
{trainData?.map((data) => (
<div
key={data.timestamp}
@ -619,6 +627,14 @@ function TrainGrid({
<div>{data.score}%</div>
</div>
<div className="flex flex-row items-start justify-end gap-5 md:gap-4">
<ClassificationSelectionDialog
categories={categories}
modelName={model.name}
image={data.raw}
onRefresh={onRefresh}
>
<TbCategoryPlus className="size-5 cursor-pointer text-primary-variant hover:text-primary" />
</ClassificationSelectionDialog>
<Tooltip>
<TooltipTrigger>
<LuTrash2