Cleanup model running

This commit is contained in:
Nicolas Mowen 2025-01-02 15:07:50 -07:00
parent c7a787c858
commit c9445ac3f5
4 changed files with 128 additions and 17 deletions

View File

@ -25,7 +25,13 @@ def get_faces():
for name in os.listdir(FACE_DIR): for name in os.listdir(FACE_DIR):
face_dict[name] = [] face_dict[name] = []
for file in os.listdir(os.path.join(FACE_DIR, name)):
face_dir = os.path.join(FACE_DIR, name)
if not os.path.isdir(face_dir):
continue
for file in os.listdir(face_dir):
face_dict[name].append(file) face_dict[name].append(file)
return JSONResponse(status_code=200, content=face_dict) return JSONResponse(status_code=200, content=face_dict)
@ -41,18 +47,17 @@ async def register_face(request: Request, name: str, file: UploadFile):
) )
@router.post("/faces/{name}/train") @router.post("/faces/train/{name}/classify")
def train_face(name: str, body: dict = None): def train_face(name: str, body: dict = None):
json: dict[str, any] = body or {} json: dict[str, any] = body or {}
file_name = sanitize_filename(json.get("training_file", "")) training_file = os.path.join(FACE_DIR, f"train/{sanitize_filename(json.get("training_file", ""))}")
training_file = os.path.join(FACE_DIR, f"train/{file_name}")
if not file_name or not os.path.isfile(training_file): if not training_file or not os.path.isfile(training_file):
return JSONResponse( return JSONResponse(
content=( content=(
{ {
"success": False, "success": False,
"message": f"Invalid filename or no file exists: {file_name}", "message": f"Invalid filename or no file exists: {training_file}",
} }
), ),
status_code=404, status_code=404,
@ -66,7 +71,7 @@ def train_face(name: str, body: dict = None):
content=( content=(
{ {
"success": True, "success": True,
"message": f"Successfully saved {file_name} as {new_name}.", "message": f"Successfully saved {training_file} as {new_name}.",
} }
), ),
status_code=200, status_code=200,

View File

@ -163,7 +163,10 @@ class FaceClassificationModel:
self.config = config self.config = config
self.db = db self.db = db
self.landmark_detector = cv2.face.createFacemarkLBF() self.landmark_detector = cv2.face.createFacemarkLBF()
self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml")
if os.path.isfile("/config/model_cache/facedet/landmarkdet.yaml"):
self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml")
self.recognizer: cv2.face.LBPHFaceRecognizer = ( self.recognizer: cv2.face.LBPHFaceRecognizer = (
cv2.face.LBPHFaceRecognizer_create( cv2.face.LBPHFaceRecognizer_create(
radius=2, threshold=(1 - config.min_score) * 1000 radius=2, threshold=(1 - config.min_score) * 1000
@ -178,13 +181,21 @@ class FaceClassificationModel:
dir = "/media/frigate/clips/faces" dir = "/media/frigate/clips/faces"
for idx, name in enumerate(os.listdir(dir)): for idx, name in enumerate(os.listdir(dir)):
if name == "debug": if name == "train":
continue
face_folder = os.path.join(dir, name)
if not os.path.isdir(face_folder):
continue continue
self.label_map[idx] = name self.label_map[idx] = name
face_folder = os.path.join(dir, name)
for image in os.listdir(face_folder): for image in os.listdir(face_folder):
img = cv2.imread(os.path.join(face_folder, image)) img = cv2.imread(os.path.join(face_folder, image))
if img is None:
continue
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = self.__align_face(img, img.shape[1], img.shape[0]) img = self.__align_face(img, img.shape[1], img.shape[0])
faces.append(img) faces.append(img)

View File

@ -0,0 +1,25 @@
import { forwardRef } from "react";
import { LuPlus, LuScanFace } from "react-icons/lu";
import { cn } from "@/lib/utils";
type AddFaceIconProps = {
className?: string;
onClick?: () => void;
};
const AddFaceIcon = forwardRef<HTMLDivElement, AddFaceIconProps>(
({ className, onClick }, ref) => {
return (
<div
ref={ref}
className={cn("relative flex items-center", className)}
onClick={onClick}
>
<LuScanFace className="size-full" />
<LuPlus className="absolute size-4 translate-x-3 translate-y-3" />
</div>
);
},
);
export default AddFaceIcon;

View File

@ -1,7 +1,14 @@
import { baseUrl } from "@/api/baseUrl"; import { baseUrl } from "@/api/baseUrl";
import Chip from "@/components/indicators/Chip"; import AddFaceIcon from "@/components/icons/AddFaceIcon";
import UploadImageDialog from "@/components/overlay/dialog/UploadImageDialog"; import UploadImageDialog from "@/components/overlay/dialog/UploadImageDialog";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area"; import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area";
import { Toaster } from "@/components/ui/sonner"; import { Toaster } from "@/components/ui/sonner";
import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group"; import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group";
@ -13,8 +20,7 @@ import {
import useOptimisticState from "@/hooks/use-optimistic-state"; import useOptimisticState from "@/hooks/use-optimistic-state";
import axios from "axios"; import axios from "axios";
import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { isDesktop } from "react-device-detect"; import { LuImagePlus, LuTrash2 } from "react-icons/lu";
import { LuImagePlus, LuTrash, LuTrash2 } from "react-icons/lu";
import { toast } from "sonner"; import { toast } from "sonner";
import useSWR from "swr"; import useSWR from "swr";
@ -154,7 +160,11 @@ export default function FaceLibrary() {
</div> </div>
{pageToggle && {pageToggle &&
(pageToggle == "train" ? ( (pageToggle == "train" ? (
<TrainingGrid attemptImages={trainImages} onRefresh={refreshFaces} /> <TrainingGrid
attemptImages={trainImages}
faceNames={faces}
onRefresh={refreshFaces}
/>
) : ( ) : (
<FaceGrid <FaceGrid
faceImages={faceImages} faceImages={faceImages}
@ -169,13 +179,23 @@ export default function FaceLibrary() {
type TrainingGridProps = { type TrainingGridProps = {
attemptImages: string[]; attemptImages: string[];
faceNames: string[];
onRefresh: () => void; onRefresh: () => void;
}; };
function TrainingGrid({ attemptImages, onRefresh }: TrainingGridProps) { function TrainingGrid({
attemptImages,
faceNames,
onRefresh,
}: TrainingGridProps) {
return ( return (
<div className="scrollbar-container flex flex-wrap gap-2 overflow-y-scroll"> <div className="scrollbar-container flex flex-wrap gap-2 overflow-y-scroll">
{attemptImages.map((image: string) => ( {attemptImages.map((image: string) => (
<FaceAttempt key={image} image={image} onRefresh={onRefresh} /> <FaceAttempt
key={image}
image={image}
faceNames={faceNames}
onRefresh={onRefresh}
/>
))} ))}
</div> </div>
); );
@ -183,9 +203,10 @@ function TrainingGrid({ attemptImages, onRefresh }: TrainingGridProps) {
type FaceAttemptProps = { type FaceAttemptProps = {
image: string; image: string;
faceNames: string[];
onRefresh: () => void; onRefresh: () => void;
}; };
function FaceAttempt({ image, onRefresh }: FaceAttemptProps) { function FaceAttempt({ image, faceNames, onRefresh }: FaceAttemptProps) {
const data = useMemo(() => { const data = useMemo(() => {
const parts = image.split("-"); const parts = image.split("-");
@ -196,6 +217,33 @@ function FaceAttempt({ image, onRefresh }: FaceAttemptProps) {
}; };
}, [image]); }, [image]);
const onTrainAttempt = useCallback(
(trainName: string) => {
axios
.post(`/faces/train/${trainName}/classify`, { training_file: image })
.then((resp) => {
if (resp.status == 200) {
toast.success(`Successfully trained face.`, {
position: "top-center",
});
onRefresh();
}
})
.catch((error) => {
if (error.response?.data?.message) {
toast.error(`Failed to train: ${error.response.data.message}`, {
position: "top-center",
});
} else {
toast.error(`Failed to train: ${error.message}`, {
position: "top-center",
});
}
});
},
[image, onRefresh],
);
const onDelete = useCallback(() => { const onDelete = useCallback(() => {
axios axios
.post(`/faces/train/delete`, { ids: [image] }) .post(`/faces/train/delete`, { ids: [image] })
@ -232,6 +280,28 @@ function FaceAttempt({ image, onRefresh }: FaceAttemptProps) {
<div>{Number.parseFloat(data.score) * 100}%</div> <div>{Number.parseFloat(data.score) * 100}%</div>
</div> </div>
<div className="flex flex-row items-start justify-end gap-5 md:gap-4"> <div className="flex flex-row items-start justify-end gap-5 md:gap-4">
<Tooltip>
<DropdownMenu>
<DropdownMenuTrigger>
<TooltipTrigger>
<AddFaceIcon className="size-5 cursor-pointer text-primary-variant hover:text-primary" />
</TooltipTrigger>
</DropdownMenuTrigger>
<DropdownMenuContent>
<DropdownMenuLabel>Train Face as:</DropdownMenuLabel>
{faceNames.map((faceName) => (
<DropdownMenuItem
key={faceName}
className="cursor-pointer capitalize"
onClick={() => onTrainAttempt(faceName)}
>
{faceName}
</DropdownMenuItem>
))}
</DropdownMenuContent>
</DropdownMenu>
<TooltipContent>Train Face as Person</TooltipContent>
</Tooltip>
<Tooltip> <Tooltip>
<TooltipTrigger> <TooltipTrigger>
<LuTrash2 <LuTrash2