diff --git a/frigate/api/classification.py b/frigate/api/classification.py index 5e1087d17..32a466a03 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -890,7 +890,8 @@ def rename_classification_category( dependencies=[Depends(require_role(["admin"]))], summary="Categorize a classification image", description="""Categorizes a specific classification image for a given classification model and category. - The image must exist in the specified category. Returns a success message or an error if the name or category is invalid.""", + Accepts either a training file from the train directory or an event_id to extract + the object crop from. Returns a success message or an error if the name or category is invalid.""", ) def categorize_classification_image(request: Request, name: str, body: dict = None): config: FrigateConfig = request.app.frigate_config @@ -909,19 +910,17 @@ def categorize_classification_image(request: Request, name: str, body: dict = No json: dict[str, Any] = body or {} category = sanitize_filename(json.get("category", "")) training_file_name = sanitize_filename(json.get("training_file", "")) - training_file = os.path.join( - CLIPS_DIR, sanitize_filename(name), "train", training_file_name - ) + event_id = json.get("event_id") - if training_file_name and not os.path.isfile(training_file): + if not training_file_name and not event_id: return JSONResponse( content=( { "success": False, - "message": f"Invalid filename or no file exists: {training_file_name}", + "message": "A training file or event_id must be passed.", } ), - status_code=404, + status_code=400, ) random_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6)) @@ -933,10 +932,116 @@ def categorize_classification_image(request: Request, name: str, body: dict = No os.makedirs(new_file_folder, exist_ok=True) - # use opencv because webp images can not be used to train - img = cv2.imread(training_file) - cv2.imwrite(os.path.join(new_file_folder, new_name), img) - os.unlink(training_file) + if training_file_name: + # Use existing training file + training_file = os.path.join( + CLIPS_DIR, sanitize_filename(name), "train", training_file_name + ) + + if not os.path.isfile(training_file): + return JSONResponse( + content=( + { + "success": False, + "message": f"Invalid filename or no file exists: {training_file_name}", + } + ), + status_code=404, + ) + + # use opencv because webp images can not be used to train + img = cv2.imread(training_file) + cv2.imwrite(os.path.join(new_file_folder, new_name), img) + os.unlink(training_file) + else: + # Extract from event + try: + event: Event = Event.get(Event.id == event_id) + except DoesNotExist: + return JSONResponse( + content=( + { + "success": False, + "message": f"Invalid event_id or no event exists: {event_id}", + } + ), + status_code=404, + ) + + snapshot = get_event_snapshot(event) + + if snapshot is None: + return JSONResponse( + content=( + { + "success": False, + "message": f"Failed to read snapshot for event {event_id}.", + } + ), + status_code=500, + ) + + # Get object bounding box for the first detection + if not event.data.get("attributes") or len(event.data["attributes"]) == 0: + return JSONResponse( + content=( + { + "success": False, + "message": f"Event {event_id} has no detection attributes.", + } + ), + status_code=400, + ) + + # Use the first attribute's box + box = event.data["attributes"][0]["box"] + + try: + # Extract the crop from the snapshot + frame = snapshot + + height, width = frame.shape[:2] + + # Convert relative coordinates to absolute + x1 = int(box[0] * width) + y1 = int(box[1] * height) + x2 = int(box[2] * width) + y2 = int(box[3] * height) + + # Ensure coordinates are within frame boundaries + x1 = max(0, x1) + y1 = max(0, y1) + x2 = min(width, x2) + y2 = min(height, y2) + + # Extract the crop + crop = frame[y1:y2, x1:x2] + + if crop.size == 0: + return JSONResponse( + content=( + { + "success": False, + "message": f"Failed to extract crop from event {event_id}.", + } + ), + status_code=500, + ) + + # Save the crop + cv2.imwrite(os.path.join(new_file_folder, new_name), crop) + + except Exception as e: + logger.error(f"Failed to extract classification crop: {e}") + return JSONResponse( + content=( + { + "success": False, + "message": f"Failed to process event {event_id}: {str(e)}", + } + ), + status_code=500, + ) return JSONResponse( content=({"success": True, "message": "Successfully categorized image."}), diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index a07114b5c..499b25d35 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -13,7 +13,8 @@ "trainModel": "Train Model", "addClassification": "Add Classification", "deleteModels": "Delete Models", - "editModel": "Edit Model" + "editModel": "Edit Model", + "categorizeImages": "Classify Images" }, "tooltip": { "trainingInProgress": "Model is currently training", @@ -28,6 +29,7 @@ "deletedModel_one": "Successfully deleted {{count}} model", "deletedModel_other": "Successfully deleted {{count}} models", "categorizedImage": "Successfully Classified Image", + "batchCategorized": "Successfully classified {{count}} images", "trainedModel": "Successfully trained model.", "trainingModel": "Successfully started model training.", "updatedModel": "Successfully updated model configuration", @@ -38,10 +40,14 @@ "deleteCategoryFailed": "Failed to delete class: {{errorMessage}}", "deleteModelFailed": "Failed to delete model: {{errorMessage}}", "categorizeFailed": "Failed to categorize image: {{errorMessage}}", + "batchCategorizeFailed": "Failed to classify {{count}} images", "trainingFailed": "Model training failed. Check Frigate logs for details.", "trainingFailedToStart": "Failed to start model training: {{errorMessage}}", "updateModelFailed": "Failed to update model: {{errorMessage}}", "renameCategoryFailed": "Failed to rename class: {{errorMessage}}" + }, + "warning": { + "partialBatchCategorized": "Classified {{success}} of {{total}} images successfully." } }, "deleteCategory": { diff --git a/web/public/locales/en/views/explore.json b/web/public/locales/en/views/explore.json index 53b04e6c4..2deb4611b 100644 --- a/web/public/locales/en/views/explore.json +++ b/web/public/locales/en/views/explore.json @@ -143,6 +143,15 @@ }, "recognizedLicensePlate": "Recognized License Plate", "attributes": "Classification Attributes", + "assignment": { + "title": "Assign To", + "assignToFace": "Assign to Face", + "assignToClassification": "Assign to {{model}}", + "faceSuccess": "Successfully assigned to face: {{name}}", + "faceFailed": "Failed to assign to face: {{errorMessage}}", + "classificationSuccess": "Successfully assigned to {{model}} - {{category}}", + "classificationFailed": "Failed to assign classification: {{errorMessage}}" + }, "estimatedSpeed": "Estimated Speed", "objects": "Objects", "camera": "Camera", diff --git a/web/public/locales/en/views/faceLibrary.json b/web/public/locales/en/views/faceLibrary.json index 354049156..593715261 100644 --- a/web/public/locales/en/views/faceLibrary.json +++ b/web/public/locales/en/views/faceLibrary.json @@ -53,7 +53,8 @@ "renameFace": "Rename Face", "deleteFace": "Delete Face", "uploadImage": "Upload Image", - "reprocessFace": "Reprocess Face" + "reprocessFace": "Reprocess Face", + "trainFaces": "Train Faces" }, "imageEntry": { "validation": { @@ -77,6 +78,7 @@ "deletedName_other": "{{count}} faces have been successfully deleted.", "renamedFace": "Successfully renamed face to {{name}}", "trainedFace": "Successfully trained face.", + "batchTrainedFaces": "Successfully trained {{count}} faces.", "updatedFaceScore": "Successfully updated face score to {{name}} ({{score}})." }, "error": { @@ -86,7 +88,11 @@ "deleteNameFailed": "Failed to delete name: {{errorMessage}}", "renameFaceFailed": "Failed to rename face: {{errorMessage}}", "trainFailed": "Failed to train: {{errorMessage}}", + "batchTrainFailed": "Failed to train {{count}} faces.", "updateFaceScoreFailed": "Failed to update face score: {{errorMessage}}" + }, + "warning": { + "partialBatchTrained": "Trained {{success}} of {{total}} faces successfully." } } } diff --git a/web/src/components/overlay/ClassificationSelectionDialog.tsx b/web/src/components/overlay/ClassificationSelectionDialog.tsx index 6398348a4..8e2037f18 100644 --- a/web/src/components/overlay/ClassificationSelectionDialog.tsx +++ b/web/src/components/overlay/ClassificationSelectionDialog.tsx @@ -35,6 +35,7 @@ type ClassificationSelectionDialogProps = { modelName: string; image: string; onRefresh: () => void; + onCategorize?: (category: string) => void; // Optional custom categorize handler children: ReactNode; }; export default function ClassificationSelectionDialog({ @@ -43,12 +44,20 @@ export default function ClassificationSelectionDialog({ modelName, image, onRefresh, + onCategorize, children, }: ClassificationSelectionDialogProps) { const { t } = useTranslation(["views/classificationModel"]); const onCategorizeImage = useCallback( (category: string) => { + // If custom categorize handler is provided, use it instead + if (onCategorize) { + onCategorize(category); + return; + } + + // Default behavior: categorize single image axios .post(`/classification/${modelName}/dataset/categorize`, { category, @@ -72,7 +81,7 @@ export default function ClassificationSelectionDialog({ }); }); }, - [modelName, image, onRefresh, t], + [modelName, image, onRefresh, onCategorize, t], ); const isChildButton = useMemo( diff --git a/web/src/components/overlay/detail/SearchDetailDialog.tsx b/web/src/components/overlay/detail/SearchDetailDialog.tsx index 01e211eec..9f430fef5 100644 --- a/web/src/components/overlay/detail/SearchDetailDialog.tsx +++ b/web/src/components/overlay/detail/SearchDetailDialog.tsx @@ -94,6 +94,11 @@ import { useDetailStream } from "@/context/detail-stream-context"; import { PiSlidersHorizontalBold } from "react-icons/pi"; import { HiSparkles } from "react-icons/hi"; import { useAudioTranscriptionProcessState } from "@/api/ws"; +import FaceSelectionDialog from "@/components/overlay/FaceSelectionDialog"; +import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog"; +import { FaceLibraryData } from "@/types/face"; +import AddFaceIcon from "@/components/icons/AddFaceIcon"; +import { TbCategoryPlus } from "react-icons/tb"; const SEARCH_TABS = ["snapshot", "tracking_details"] as const; export type SearchTab = (typeof SEARCH_TABS)[number]; @@ -702,6 +707,21 @@ function ObjectDetailsTab({ : null, ); + // Fetch available faces for assignment + const { data: faceData } = useSWR( + config?.face_recognition?.enabled ? "faces" : null, + ); + + const availableFaceNames = useMemo(() => { + if (!faceData) return []; + return Object.keys(faceData).filter((name) => name !== "train").sort(); + }, [faceData]); + + const availableClassificationModels = useMemo(() => { + if (!config?.classification?.custom) return []; + return Object.keys(config.classification.custom).sort(); + }, [config]); + // mutation / revalidation const mutate = useGlobalMutation(); @@ -1216,6 +1236,85 @@ function ObjectDetailsTab({ }); }, [search, t]); + // face and classification assignment + + const onAssignToFace = useCallback( + (faceName: string) => { + if (!search) { + return; + } + + axios + .post(`/faces/train/${faceName}/classify`, { + event_id: search.id, + }) + .then((resp) => { + if (resp.status == 200) { + toast.success(t("details.assignment.faceSuccess", { name: faceName }), { + position: "top-center", + }); + // Refresh the event data + mutate((key) => isEventsKey(key)); + } + }) + .catch((error) => { + const errorMessage = + error.response?.data?.message || + error.response?.data?.detail || + "Unknown error"; + toast.error( + t("details.assignment.faceFailed", { errorMessage }), + { + position: "top-center", + }, + ); + }); + }, + [search, t, mutate, isEventsKey], + ); + + const onAssignToClassification = useCallback( + (modelName: string, category: string) => { + if (!search) { + return; + } + + axios + .post(`/classification/${modelName}/dataset/categorize`, { + event_id: search.id, + category, + }) + .then((resp) => { + if (resp.status == 200) { + toast.success( + t("details.assignment.classificationSuccess", { + model: modelName, + category, + }), + { + position: "top-center", + }, + ); + // Refresh the event data + mutate((key) => isEventsKey(key)); + } + }) + .catch((error) => { + const errorMessage = + error.response?.data?.message || + error.response?.data?.detail || + "Unknown error"; + toast.error( + t("details.assignment.classificationFailed", { errorMessage }), + { + position: "top-center", + }, + ); + }); + }, + [search, t, mutate, isEventsKey], + ); + // audio transcription processing state const { payload: audioTranscriptionProcessState } = @@ -1474,6 +1573,56 @@ function ObjectDetailsTab({ + {isAdmin && (availableFaceNames.length > 0 || availableClassificationModels.length > 0) && ( +
+
+ {t("details.assignment.title")} +
+
+ {config?.face_recognition?.enabled && availableFaceNames.length > 0 && ( + + + + )} + {availableClassificationModels.length > 0 && + availableClassificationModels.map((modelName) => { + const model = config?.classification?.custom?.[modelName]; + if (!model) return null; + + const displayName = model.name || modelName; + const classes = modelAttributes?.[displayName] ?? []; + if (classes.length === 0) return null; + + return ( + {}} + onCategorize={(category) => + onAssignToClassification(modelName, category) + } + > + + + ); + })} +
+
+ )} + {isAdmin && search.data.type === "object" && config?.plus?.enabled && diff --git a/web/src/pages/FaceLibrary.tsx b/web/src/pages/FaceLibrary.tsx index 7595b3cd9..039eb5ddf 100644 --- a/web/src/pages/FaceLibrary.tsx +++ b/web/src/pages/FaceLibrary.tsx @@ -406,6 +406,66 @@ export default function FaceLibrary() { )} + {pageToggle === "train" && ( + { + const requests = selectedFaces.map((filename) => + axios + .post(`/faces/train/${name}/classify`, { + training_file: filename, + }) + .then(() => true) + .catch(() => false), + ); + + Promise.allSettled(requests).then((results) => { + const successCount = results.filter( + (result) => result.status === "fulfilled" && result.value, + ).length; + const totalCount = results.length; + + if (successCount === totalCount) { + toast.success( + t("toast.success.batchTrainedFaces", { + count: successCount, + }), + { + position: "top-center", + }, + ); + } else if (successCount > 0) { + toast.warning( + t("toast.warning.partialBatchTrained", { + success: successCount, + total: totalCount, + }), + { + position: "top-center", + }, + ); + } else { + toast.error( + t("toast.error.batchTrainFailed", { + count: totalCount, + }), + { + position: "top-center", + }, + ); + } + + setSelectedFaces([]); + refreshFaces(); + }); + }} + > + + + )} + + )}