mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-05-01 19:17:41 +03:00
Implement classification model training
This commit is contained in:
parent
b3cee44f06
commit
f6c9413944
@ -21,7 +21,7 @@ from frigate.api.defs.request.classification_body import (
|
||||
from frigate.api.defs.tags import Tags
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.config.camera import DetectConfig
|
||||
from frigate.const import CLIPS_DIR, FACE_DIR, MODEL_CACHE_DIR
|
||||
from frigate.const import CLIPS_DIR, FACE_DIR
|
||||
from frigate.embeddings import EmbeddingsContext
|
||||
from frigate.models import Event
|
||||
from frigate.util.classification import train_classification_model
|
||||
@ -537,6 +537,59 @@ def delete_classification_dataset_images(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/classification/{name}/dataset/categorize",
|
||||
dependencies=[Depends(require_role(["admin"]))],
|
||||
)
|
||||
def categorize_classification_image(
|
||||
request: Request, name: str, body: dict = None
|
||||
):
|
||||
config: FrigateConfig = request.app.frigate_config
|
||||
|
||||
if name not in config.classification.custom:
|
||||
return JSONResponse(
|
||||
content=(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"{name} is not a known classification model.",
|
||||
}
|
||||
),
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
json: dict[str, Any] = body or {}
|
||||
category = sanitize_filename(json.get("category", ""))
|
||||
training_file_name = sanitize_filename(json.get("training_file", ""))
|
||||
training_file = os.path.join(CLIPS_DIR, name, "train", training_file_name)
|
||||
|
||||
if training_file_name and not os.path.isfile(training_file):
|
||||
return JSONResponse(
|
||||
content=(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid filename or no file exists: {training_file_name}",
|
||||
}
|
||||
),
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
new_name = f"{category}-{datetime.datetime.now().timestamp()}.png"
|
||||
new_file_folder = os.path.join(CLIPS_DIR, name, "dataset", category)
|
||||
|
||||
if not os.path.exists(new_file_folder):
|
||||
os.mkdir(new_file_folder)
|
||||
|
||||
# use opencv because webp images can not be used to train
|
||||
img = cv2.imread(training_file)
|
||||
cv2.imwrite(os.path.join(new_file_folder, new_name), img)
|
||||
os.unlink(training_file)
|
||||
|
||||
return JSONResponse(
|
||||
content=({"success": True, "message": "Successfully deleted faces."}),
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/classification/{name}/train/delete",
|
||||
dependencies=[Depends(require_role(["admin"]))],
|
||||
|
||||
@ -8,11 +8,13 @@
|
||||
"toast": {
|
||||
"success": {
|
||||
"deletedCategory": "Deleted Category",
|
||||
"deletedImage": "Deleted Images"
|
||||
"deletedImage": "Deleted Images",
|
||||
"categorizedImage": "Successfully Categorized Image"
|
||||
},
|
||||
"error": {
|
||||
"deleteImageFailed": "Failed to delete: {{errorMessage}}",
|
||||
"deleteCategoryFailed": "Failed to delete category: {{errorMessage}}"
|
||||
"deleteCategoryFailed": "Failed to delete category: {{errorMessage}}",
|
||||
"categorizeFailed": "Failed to categorize image: {{errorMessage}}"
|
||||
}
|
||||
},
|
||||
"deleteCategory": {
|
||||
@ -38,5 +40,10 @@
|
||||
"title": "Train",
|
||||
"aria": "Select Train"
|
||||
},
|
||||
"categories": "Categories"
|
||||
"categories": "Categories",
|
||||
"createCategory": {
|
||||
"new": "Create New Category"
|
||||
},
|
||||
"categorizeImageAs": "Categorize Image As:",
|
||||
"categorizeImage": "Categorize Image"
|
||||
}
|
||||
|
||||
155
web/src/components/overlay/ClassificationSelectionDialog.tsx
Normal file
155
web/src/components/overlay/ClassificationSelectionDialog.tsx
Normal file
@ -0,0 +1,155 @@
|
||||
import {
|
||||
Drawer,
|
||||
DrawerClose,
|
||||
DrawerContent,
|
||||
DrawerDescription,
|
||||
DrawerHeader,
|
||||
DrawerTitle,
|
||||
DrawerTrigger,
|
||||
} from "@/components/ui/drawer";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { isDesktop, isMobile } from "react-device-detect";
|
||||
import { LuPlus } from "react-icons/lu";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { cn } from "@/lib/utils";
|
||||
import React, { ReactNode, useCallback, useMemo, useState } from "react";
|
||||
import TextEntryDialog from "./dialog/TextEntryDialog";
|
||||
import { Button } from "../ui/button";
|
||||
import { MdCategory } from "react-icons/md";
|
||||
import axios from "axios";
|
||||
import { toast } from "sonner";
|
||||
|
||||
type ClassificationSelectionDialogProps = {
|
||||
className?: string;
|
||||
categories: string[];
|
||||
modelName: string;
|
||||
image: string;
|
||||
onRefresh: () => void;
|
||||
children: ReactNode;
|
||||
};
|
||||
export default function ClassificationSelectionDialog({
|
||||
className,
|
||||
categories,
|
||||
modelName,
|
||||
image,
|
||||
onRefresh,
|
||||
children,
|
||||
}: ClassificationSelectionDialogProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
|
||||
const onCategorizeImage = useCallback(
|
||||
(category: string) => {
|
||||
axios
|
||||
.post(`/classification/${modelName}/dataset/categorize`, {
|
||||
category,
|
||||
training_file: image,
|
||||
})
|
||||
.then((resp) => {
|
||||
if (resp.status == 200) {
|
||||
toast.success(t("toast.success.categorizedImage"), {
|
||||
position: "top-center",
|
||||
});
|
||||
onRefresh();
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
const errorMessage =
|
||||
error.response?.data?.message ||
|
||||
error.response?.data?.detail ||
|
||||
"Unknown error";
|
||||
toast.error(t("toast.error.categorizeFailed", { errorMessage }), {
|
||||
position: "top-center",
|
||||
});
|
||||
});
|
||||
},
|
||||
[modelName, image, onRefresh, t],
|
||||
);
|
||||
|
||||
const isChildButton = useMemo(
|
||||
() => React.isValidElement(children) && children.type === Button,
|
||||
[children],
|
||||
);
|
||||
|
||||
// control
|
||||
const [newFace, setNewFace] = useState(false);
|
||||
|
||||
// components
|
||||
const Selector = isDesktop ? DropdownMenu : Drawer;
|
||||
const SelectorTrigger = isDesktop ? DropdownMenuTrigger : DrawerTrigger;
|
||||
const SelectorContent = isDesktop ? DropdownMenuContent : DrawerContent;
|
||||
const SelectorItem = isDesktop
|
||||
? DropdownMenuItem
|
||||
: (props: React.HTMLAttributes<HTMLDivElement>) => (
|
||||
<DrawerClose asChild>
|
||||
<div {...props} className={cn(props.className, "my-2")} />
|
||||
</DrawerClose>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className={className ?? ""}>
|
||||
{newFace && (
|
||||
<TextEntryDialog
|
||||
open={true}
|
||||
setOpen={setNewFace}
|
||||
title={t("createCategory.new")}
|
||||
onSave={(newCat) => onCategorizeImage(newCat)}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Tooltip>
|
||||
<Selector>
|
||||
<SelectorTrigger asChild>
|
||||
<TooltipTrigger asChild={isChildButton}>{children}</TooltipTrigger>
|
||||
</SelectorTrigger>
|
||||
<SelectorContent
|
||||
className={cn("", isMobile && "mx-1 gap-2 rounded-t-2xl px-4")}
|
||||
>
|
||||
{isMobile && (
|
||||
<DrawerHeader className="sr-only">
|
||||
<DrawerTitle>Details</DrawerTitle>
|
||||
<DrawerDescription>Details</DrawerDescription>
|
||||
</DrawerHeader>
|
||||
)}
|
||||
<DropdownMenuLabel>{t("categorizeImageAs")}</DropdownMenuLabel>
|
||||
<div
|
||||
className={cn(
|
||||
"flex max-h-[40dvh] flex-col overflow-y-auto",
|
||||
isMobile && "gap-2 pb-4",
|
||||
)}
|
||||
>
|
||||
<SelectorItem
|
||||
className="flex cursor-pointer gap-2 smart-capitalize"
|
||||
onClick={() => setNewFace(true)}
|
||||
>
|
||||
<LuPlus />
|
||||
{t("createCategory.new")}
|
||||
</SelectorItem>
|
||||
{categories.sort().map((category) => (
|
||||
<SelectorItem
|
||||
key={category}
|
||||
className="flex cursor-pointer gap-2 smart-capitalize"
|
||||
onClick={() => onCategorizeImage(category)}
|
||||
>
|
||||
<MdCategory />
|
||||
{category}
|
||||
</SelectorItem>
|
||||
))}
|
||||
</div>
|
||||
</SelectorContent>
|
||||
</Selector>
|
||||
<TooltipContent>{t("categorizeImage")}</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -43,6 +43,8 @@ import { Trans, useTranslation } from "react-i18next";
|
||||
import { LuPencil, LuTrash2 } from "react-icons/lu";
|
||||
import { toast } from "sonner";
|
||||
import useSWR from "swr";
|
||||
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
|
||||
import { TbCategoryPlus } from "react-icons/tb";
|
||||
|
||||
type ModelTrainingViewProps = {
|
||||
model: CustomClassificationModelConfig;
|
||||
@ -242,7 +244,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
|
||||
<div className="flex flex-row justify-between gap-2 px-2 pt-2 align-middle">
|
||||
<div className="flex flex-row justify-between gap-2 p-2 align-middle">
|
||||
<LibrarySelector
|
||||
pageToggle={pageToggle}
|
||||
dataset={dataset || {}}
|
||||
@ -278,8 +280,10 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
{pageToggle == "train" ? (
|
||||
<TrainGrid
|
||||
model={model}
|
||||
categories={Object.keys(dataset || {})}
|
||||
trainImages={trainImages || []}
|
||||
selectedImages={selectedImages}
|
||||
onRefresh={refreshTrain}
|
||||
onClickImages={onClickImages}
|
||||
onDelete={onDelete}
|
||||
/>
|
||||
@ -557,16 +561,20 @@ function DatasetGrid({
|
||||
|
||||
type TrainGridProps = {
|
||||
model: CustomClassificationModelConfig;
|
||||
categories: string[];
|
||||
trainImages: string[];
|
||||
selectedImages: string[];
|
||||
onClickImages: (images: string[], ctrl: boolean) => void;
|
||||
onRefresh: () => void;
|
||||
onDelete: (ids: string[]) => void;
|
||||
};
|
||||
function TrainGrid({
|
||||
model,
|
||||
categories,
|
||||
trainImages,
|
||||
selectedImages,
|
||||
onClickImages,
|
||||
onRefresh,
|
||||
onDelete,
|
||||
}: TrainGridProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
@ -586,7 +594,7 @@ function TrainGrid({
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="grid grid-cols-10 gap-2 overflow-y-auto p-2">
|
||||
<div className="grid grid-cols-10 gap-2 overflow-y-auto px-2">
|
||||
{trainData?.map((data) => (
|
||||
<div
|
||||
key={data.timestamp}
|
||||
@ -619,6 +627,14 @@ function TrainGrid({
|
||||
<div>{data.score}%</div>
|
||||
</div>
|
||||
<div className="flex flex-row items-start justify-end gap-5 md:gap-4">
|
||||
<ClassificationSelectionDialog
|
||||
categories={categories}
|
||||
modelName={model.name}
|
||||
image={data.raw}
|
||||
onRefresh={onRefresh}
|
||||
>
|
||||
<TbCategoryPlus className="size-5 cursor-pointer text-primary-variant hover:text-primary" />
|
||||
</ClassificationSelectionDialog>
|
||||
<Tooltip>
|
||||
<TooltipTrigger>
|
||||
<LuTrash2
|
||||
|
||||
Loading…
Reference in New Issue
Block a user