diff --git a/frigate/api/classification.py b/frigate/api/classification.py index 855e0dcc6..9c21384a5 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -25,7 +25,13 @@ def get_faces(): for name in os.listdir(FACE_DIR): face_dict[name] = [] - for file in os.listdir(os.path.join(FACE_DIR, name)): + + face_dir = os.path.join(FACE_DIR, name) + + if not os.path.isdir(face_dir): + continue + + for file in os.listdir(face_dir): face_dict[name].append(file) return JSONResponse(status_code=200, content=face_dict) @@ -41,18 +47,17 @@ async def register_face(request: Request, name: str, file: UploadFile): ) -@router.post("/faces/{name}/train") +@router.post("/faces/train/{name}/classify") def train_face(name: str, body: dict = None): json: dict[str, any] = body or {} - file_name = sanitize_filename(json.get("training_file", "")) - training_file = os.path.join(FACE_DIR, f"train/{file_name}") + training_file = os.path.join(FACE_DIR, f"train/{sanitize_filename(json.get("training_file", ""))}") - if not file_name or not os.path.isfile(training_file): + if not training_file or not os.path.isfile(training_file): return JSONResponse( content=( { "success": False, - "message": f"Invalid filename or no file exists: {file_name}", + "message": f"Invalid filename or no file exists: {training_file}", } ), status_code=404, @@ -66,7 +71,7 @@ def train_face(name: str, body: dict = None): content=( { "success": True, - "message": f"Successfully saved {file_name} as {new_name}.", + "message": f"Successfully saved {training_file} as {new_name}.", } ), status_code=200, diff --git a/frigate/util/model.py b/frigate/util/model.py index c95bc48b9..28bfbdad1 100644 --- a/frigate/util/model.py +++ b/frigate/util/model.py @@ -163,7 +163,10 @@ class FaceClassificationModel: self.config = config self.db = db self.landmark_detector = cv2.face.createFacemarkLBF() - self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml") + + if os.path.isfile("/config/model_cache/facedet/landmarkdet.yaml"): + self.landmark_detector.loadModel("/config/model_cache/facedet/landmarkdet.yaml") + self.recognizer: cv2.face.LBPHFaceRecognizer = ( cv2.face.LBPHFaceRecognizer_create( radius=2, threshold=(1 - config.min_score) * 1000 @@ -178,13 +181,21 @@ class FaceClassificationModel: dir = "/media/frigate/clips/faces" for idx, name in enumerate(os.listdir(dir)): - if name == "debug": + if name == "train": + continue + + face_folder = os.path.join(dir, name) + + if not os.path.isdir(face_folder): continue self.label_map[idx] = name - face_folder = os.path.join(dir, name) for image in os.listdir(face_folder): img = cv2.imread(os.path.join(face_folder, image)) + + if img is None: + continue + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = self.__align_face(img, img.shape[1], img.shape[0]) faces.append(img) diff --git a/web/src/components/icons/AddFaceIcon.tsx b/web/src/components/icons/AddFaceIcon.tsx new file mode 100644 index 000000000..ce06120cc --- /dev/null +++ b/web/src/components/icons/AddFaceIcon.tsx @@ -0,0 +1,25 @@ +import { forwardRef } from "react"; +import { LuPlus, LuScanFace } from "react-icons/lu"; +import { cn } from "@/lib/utils"; + +type AddFaceIconProps = { + className?: string; + onClick?: () => void; +}; + +const AddFaceIcon = forwardRef( + ({ className, onClick }, ref) => { + return ( +
+ + +
+ ); + }, +); + +export default AddFaceIcon; diff --git a/web/src/pages/FaceLibrary.tsx b/web/src/pages/FaceLibrary.tsx index 28bdd1ea0..c721916a5 100644 --- a/web/src/pages/FaceLibrary.tsx +++ b/web/src/pages/FaceLibrary.tsx @@ -1,7 +1,14 @@ import { baseUrl } from "@/api/baseUrl"; -import Chip from "@/components/indicators/Chip"; +import AddFaceIcon from "@/components/icons/AddFaceIcon"; import UploadImageDialog from "@/components/overlay/dialog/UploadImageDialog"; import { Button } from "@/components/ui/button"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area"; import { Toaster } from "@/components/ui/sonner"; import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group"; @@ -13,8 +20,7 @@ import { import useOptimisticState from "@/hooks/use-optimistic-state"; import axios from "axios"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; -import { isDesktop } from "react-device-detect"; -import { LuImagePlus, LuTrash, LuTrash2 } from "react-icons/lu"; +import { LuImagePlus, LuTrash2 } from "react-icons/lu"; import { toast } from "sonner"; import useSWR from "swr"; @@ -154,7 +160,11 @@ export default function FaceLibrary() { {pageToggle && (pageToggle == "train" ? ( - + ) : ( void; }; -function TrainingGrid({ attemptImages, onRefresh }: TrainingGridProps) { +function TrainingGrid({ + attemptImages, + faceNames, + onRefresh, +}: TrainingGridProps) { return (
{attemptImages.map((image: string) => ( - + ))}
); @@ -183,9 +203,10 @@ function TrainingGrid({ attemptImages, onRefresh }: TrainingGridProps) { type FaceAttemptProps = { image: string; + faceNames: string[]; onRefresh: () => void; }; -function FaceAttempt({ image, onRefresh }: FaceAttemptProps) { +function FaceAttempt({ image, faceNames, onRefresh }: FaceAttemptProps) { const data = useMemo(() => { const parts = image.split("-"); @@ -196,6 +217,33 @@ function FaceAttempt({ image, onRefresh }: FaceAttemptProps) { }; }, [image]); + const onTrainAttempt = useCallback( + (trainName: string) => { + axios + .post(`/faces/train/${trainName}/classify`, { training_file: image }) + .then((resp) => { + if (resp.status == 200) { + toast.success(`Successfully trained face.`, { + position: "top-center", + }); + onRefresh(); + } + }) + .catch((error) => { + if (error.response?.data?.message) { + toast.error(`Failed to train: ${error.response.data.message}`, { + position: "top-center", + }); + } else { + toast.error(`Failed to train: ${error.message}`, { + position: "top-center", + }); + } + }); + }, + [image, onRefresh], + ); + const onDelete = useCallback(() => { axios .post(`/faces/train/delete`, { ids: [image] }) @@ -232,6 +280,28 @@ function FaceAttempt({ image, onRefresh }: FaceAttemptProps) {
{Number.parseFloat(data.score) * 100}%
+ + + + + + + + + Train Face as: + {faceNames.map((faceName) => ( + onTrainAttempt(faceName)} + > + {faceName} + + ))} + + + Train Face as Person +