From 9b93667b877cb04a0a8ec758d0167c252a6e5a62 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Wed, 22 Oct 2025 13:38:07 -0600 Subject: [PATCH] Improve object cropping implementation --- frigate/api/classification.py | 76 ++++++ .../api/defs/request/classification_body.py | 14 +- frigate/util/classification.py | 96 ++++--- .../locales/en/views/classificationModel.json | 17 ++ .../ClassificationModelWizardDialog.tsx | 30 ++- .../wizard/Step3ChooseExamples.tsx | 240 ++++++++++++++++++ 6 files changed, 432 insertions(+), 41 deletions(-) create mode 100644 web/src/components/classification/wizard/Step3ChooseExamples.tsx diff --git a/frigate/api/classification.py b/frigate/api/classification.py index 623ceba32..b15f7bf06 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -17,6 +17,8 @@ from frigate.api.auth import require_role from frigate.api.defs.request.classification_body import ( AudioTranscriptionBody, DeleteFaceImagesBody, + GenerateObjectExamplesBody, + GenerateStateExamplesBody, RenameFaceBody, ) from frigate.api.defs.response.classification_response import ( @@ -30,6 +32,10 @@ from frigate.config.camera import DetectConfig from frigate.const import CLIPS_DIR, FACE_DIR from frigate.embeddings import EmbeddingsContext from frigate.models import Event +from frigate.util.classification import ( + collect_object_classification_examples, + collect_state_classification_examples, +) from frigate.util.path import get_event_snapshot logger = logging.getLogger(__name__) @@ -756,3 +762,73 @@ def delete_classification_train_images(request: Request, name: str, body: dict = content=({"success": True, "message": "Successfully deleted faces."}), status_code=200, ) + + +@router.post( + "/classification/generate_examples/state", + response_model=GenericResponse, + dependencies=[Depends(require_role(["admin"]))], + summary="Generate state classification examples", +) +async def generate_state_examples(request: Request, body: GenerateStateExamplesBody): + """Generate examples for state classification.""" + try: + cameras_with_pixels = {} + config: FrigateConfig = request.app.frigate_config + + for camera_name, crop in body.cameras.items(): + if camera_name not in config.cameras: + continue + + camera_config = config.cameras[camera_name] + width = camera_config.detect.width + height = camera_config.detect.height + + x1 = int(crop[0] * width) + y1 = int(crop[1] * height) + x2 = int((crop[0] + crop[2]) * width) + y2 = int((crop[1] + crop[3]) * height) + + cameras_with_pixels[camera_name] = (x1, y1, x2, y2) + + collect_state_classification_examples(body.model_name, cameras_with_pixels) + + return JSONResponse( + content={"success": True, "message": "Example generation completed"}, + status_code=200, + ) + except Exception as e: + logger.error(f"Failed to generate state examples: {e}") + return JSONResponse( + content={ + "success": False, + "message": f"Failed to generate examples: {str(e)}", + }, + status_code=500, + ) + + +@router.post( + "/classification/generate_examples/object", + response_model=GenericResponse, + dependencies=[Depends(require_role(["admin"]))], + summary="Generate object classification examples", +) +async def generate_object_examples(request: Request, body: GenerateObjectExamplesBody): + """Generate examples for object classification.""" + try: + collect_object_classification_examples(body.model_name, body.label) + + return JSONResponse( + content={"success": True, "message": "Example generation completed"}, + status_code=200, + ) + except Exception as e: + logger.error(f"Failed to generate object examples: {e}") + return JSONResponse( + content={ + "success": False, + "message": f"Failed to generate examples: {str(e)}", + }, + status_code=500, + ) diff --git a/frigate/api/defs/request/classification_body.py b/frigate/api/defs/request/classification_body.py index dabff0912..d38ba4b0f 100644 --- a/frigate/api/defs/request/classification_body.py +++ b/frigate/api/defs/request/classification_body.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List, Tuple from pydantic import BaseModel, Field @@ -15,3 +15,15 @@ class DeleteFaceImagesBody(BaseModel): ids: List[str] = Field( description="List of image filenames to delete from the face folder" ) + + +class GenerateStateExamplesBody(BaseModel): + model_name: str + cameras: Dict[str, Tuple[float, float, float, float]] = Field( + description="Dictionary mapping camera names to crop coordinates (x, y, width, height) normalized 0-1" + ) + + +class GenerateObjectExamplesBody(BaseModel): + model_name: str + label: str diff --git a/frigate/util/classification.py b/frigate/util/classification.py index c03244fdb..aac4d6eb6 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -145,7 +145,6 @@ class ClassificationTrainingProcess(FrigateProcess): f.write(tflite_model) -@staticmethod def kickoff_model_training( embeddingRequestor: EmbeddingsRequestor, model_name: str ) -> None: @@ -222,12 +221,8 @@ def collect_state_classification_examples( "/usr/lib/ffmpeg/7.0/bin/ffmpeg", timestamps, temp_dir, cameras ) - if len(keyframes) < 20: - logger.warning(f"Only extracted {len(keyframes)} keyframes, need at least 20") - return - - # Step 4: Select 20 most visually distinct images (they're already cropped) - distinct_images = _select_distinct_images(keyframes, target_count=20) + # Step 4: Select 24 most visually distinct images (they're already cropped) + distinct_images = _select_distinct_images(keyframes, target_count=24) # Step 5: Save to dataset directory (in "unknown" subfolder for unlabeled data) unknown_dir = os.path.join(dataset_dir, "unknown") @@ -502,66 +497,52 @@ def _select_distinct_images( @staticmethod def collect_object_classification_examples( - model_name: str, label: str, cameras: list[str] + model_name: str, + label: str, ) -> None: """ Collect representative object classification examples from event thumbnails. This function: - 1. Queries events for the specified label and cameras + 1. Queries events for the specified label 2. Selects 100 balanced events across different cameras and times - 3. Retrieves thumbnails for selected events - 4. Selects 20 most visually distinct thumbnails - 5. Saves them to the dataset directory + 3. Retrieves thumbnails for selected events (with 33% center crop applied) + 4. Selects 24 most visually distinct thumbnails + 5. Saves to dataset directory Args: model_name: Name of the classification model label: Object label to collect (e.g., "person", "car") cameras: List of camera names to collect examples from """ - logger.info( - f"Collecting examples for {model_name} with label '{label}' from {len(cameras)} cameras" - ) - dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") temp_dir = os.path.join(dataset_dir, "temp") os.makedirs(temp_dir, exist_ok=True) # Step 1: Query events for the specified label and cameras events = list( - Event.select() - .where( - (Event.label == label) - & (Event.camera.in_(cameras)) - & (Event.false_positive == False) - ) - .order_by(Event.start_time.asc()) + Event.select().where((Event.label == label)).order_by(Event.start_time.asc()) ) if not events: - logger.warning(f"No events found for label '{label}' on cameras: {cameras}") + logger.warning(f"No events found for label '{label}'") return - logger.info(f"Found {len(events)} events") + logger.debug(f"Found {len(events)} events") # Step 2: Select balanced events (100 samples) selected_events = _select_balanced_events(events, target_count=100) - logger.info(f"Selected {len(selected_events)} events") + logger.debug(f"Selected {len(selected_events)} events") # Step 3: Extract thumbnails from events thumbnails = _extract_event_thumbnails(selected_events, temp_dir) + logger.debug(f"Successfully extracted {len(thumbnails)} thumbnails") - if len(thumbnails) < 20: - logger.warning(f"Only extracted {len(thumbnails)} thumbnails, need at least 20") - return + # Step 4: Select 24 most visually distinct thumbnails + distinct_images = _select_distinct_images(thumbnails, target_count=24) + logger.debug(f"Selected {len(distinct_images)} distinct images") - logger.info(f"Successfully extracted {len(thumbnails)} thumbnails") - - # Step 4: Select 20 most visually distinct thumbnails - distinct_images = _select_distinct_images(thumbnails, target_count=20) - logger.info(f"Selected {len(distinct_images)} distinct images") - - # Step 5: Save to dataset directory (in "unknown" subfolder for unlabeled data) + # Step 5: Save to dataset directory unknown_dir = os.path.join(dataset_dir, "unknown") os.makedirs(unknown_dir, exist_ok=True) @@ -584,7 +565,7 @@ def collect_object_classification_examples( except Exception as e: logger.warning(f"Failed to clean up temp directory: {e}") - logger.info( + logger.debug( f"Successfully collected {saved_count} classification examples in {unknown_dir}" ) @@ -663,7 +644,46 @@ def _extract_event_thumbnails(events: list[Event], output_dir: str) -> list[str] img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if img is not None: - resized = cv2.resize(img, (224, 224)) + height, width = img.shape[:2] + + # Calculate crop based on object size relative to the thumbnail region + crop_size = 1.0 # Default to no crop + if event.data and "box" in event.data and "region" in event.data: + # Box is [x, y, w, h] format + box = event.data["box"] + region = event.data["region"] + + if len(box) == 4 and len(region) == 4: + box_w, box_h = box[2], box[3] + region_w, region_h = region[2], region[3] + + # Calculate what percentage of the region the box occupies + box_area = (box_w * box_h) / (region_w * region_h) + + # Crop inversely proportional to object size in thumbnail + # Small objects need more crop (zoom in), large objects need less + if box_area < 0.05: # Very small (< 5%) + crop_size = 0.4 + elif box_area < 0.10: # Small (5-10%) + crop_size = 0.5 + elif box_area < 0.20: # Medium-small (10-20%) + crop_size = 0.65 + elif box_area < 0.35: # Medium (20-35%) + crop_size = 0.80 + else: # Large (>35%) + crop_size = 0.95 + + crop_width = int(width * crop_size) + crop_height = int(height * crop_size) + + # Calculate center crop coordinates + x1 = (width - crop_width) // 2 + y1 = (height - crop_height) // 2 + x2 = x1 + crop_width + y2 = y1 + crop_height + + cropped = img[y1:y2, x1:x2] + resized = cv2.resize(cropped, (224, 224)) output_path = os.path.join(output_dir, f"thumbnail_{idx:04d}.jpg") cv2.imwrite(output_path, resized) thumbnail_paths.append(output_path) diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index 833240d54..8f6614f37 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -89,6 +89,23 @@ "selectCamera": "Select Camera", "noCameras": "Click + to add cameras", "selectCameraPrompt": "Select a camera from the list to define its monitoring area" + }, + "step3": { + "description": "Classify the example images below. These samples will be used to train your model.", + "generating": { + "title": "Generating Sample Images", + "description": "We're pulling representative images from your recordings. This may take a moment..." + }, + "retryGenerate": "Retry Generation", + "selectClass": "Select class...", + "noImages": "No sample images generated", + "errors": { + "noCameras": "No cameras configured", + "noObjectLabel": "No object label selected", + "generateFailed": "Failed to generate examples: {{error}}", + "generationFailed": "Generation failed. Please try again." + }, + "generateSuccess": "Successfully generated sample images" } } } diff --git a/web/src/components/classification/ClassificationModelWizardDialog.tsx b/web/src/components/classification/ClassificationModelWizardDialog.tsx index 5da96e9a4..7efa57b66 100644 --- a/web/src/components/classification/ClassificationModelWizardDialog.tsx +++ b/web/src/components/classification/ClassificationModelWizardDialog.tsx @@ -10,6 +10,9 @@ import { import { useReducer, useMemo } from "react"; import Step1NameAndDefine, { Step1FormData } from "./wizard/Step1NameAndDefine"; import Step2StateArea, { Step2FormData } from "./wizard/Step2StateArea"; +import Step3ChooseExamples, { + Step3FormData, +} from "./wizard/Step3ChooseExamples"; import { cn } from "@/lib/utils"; import { isDesktop } from "react-device-detect"; @@ -35,8 +38,7 @@ type WizardState = { currentStep: number; step1Data?: Step1FormData; step2Data?: Step2FormData; - // Future steps can be added here - // step3Data?: Step3FormData; + step3Data?: Step3FormData; }; type WizardAction = @@ -44,6 +46,7 @@ type WizardAction = | { type: "PREVIOUS_STEP" } | { type: "SET_STEP_1"; payload: Step1FormData } | { type: "SET_STEP_2"; payload: Step2FormData } + | { type: "SET_STEP_3"; payload: Step3FormData } | { type: "RESET" }; const initialState: WizardState = { @@ -64,6 +67,12 @@ function wizardReducer(state: WizardState, action: WizardAction): WizardState { step2Data: action.payload, currentStep: 2, }; + case "SET_STEP_3": + return { + ...state, + step3Data: action.payload, + currentStep: 3, + }; case "NEXT_STEP": return { ...state, @@ -107,6 +116,10 @@ export default function ClassificationModelWizardDialog({ dispatch({ type: "SET_STEP_2", payload: data }); }; + const handleStep3Next = (data: Step3FormData) => { + dispatch({ type: "SET_STEP_3", payload: data }); + }; + const handleBack = () => { dispatch({ type: "PREVIOUS_STEP" }); }; @@ -163,6 +176,19 @@ export default function ClassificationModelWizardDialog({ onBack={handleBack} /> )} + {((wizardState.currentStep === 2 && + wizardState.step1Data?.modelType === "state") || + (wizardState.currentStep === 1 && + wizardState.step1Data?.modelType === "object")) && + wizardState.step1Data && ( + + )} diff --git a/web/src/components/classification/wizard/Step3ChooseExamples.tsx b/web/src/components/classification/wizard/Step3ChooseExamples.tsx new file mode 100644 index 000000000..5172cae8a --- /dev/null +++ b/web/src/components/classification/wizard/Step3ChooseExamples.tsx @@ -0,0 +1,240 @@ +import { Button } from "@/components/ui/button"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { useTranslation } from "react-i18next"; +import { useState, useEffect, useCallback, useMemo } from "react"; +import ActivityIndicator from "@/components/indicators/activity-indicator"; +import axios from "axios"; +import { toast } from "sonner"; +import { Step1FormData } from "./Step1NameAndDefine"; +import { Step2FormData } from "./Step2StateArea"; +import useSWR from "swr"; +import { baseUrl } from "@/api/baseUrl"; + +export type Step3FormData = { + examplesGenerated: boolean; + imageClassifications?: { [imageName: string]: string }; +}; + +type Step3ChooseExamplesProps = { + step1Data: Step1FormData; + step2Data?: Step2FormData; + initialData?: Partial; + onNext: (data: Step3FormData) => void; + onBack: () => void; +}; + +export default function Step3ChooseExamples({ + step1Data, + step2Data, + initialData, + onNext, + onBack, +}: Step3ChooseExamplesProps) { + const { t } = useTranslation(["views/classificationModel"]); + const [isGenerating, setIsGenerating] = useState(false); + const [hasGenerated, setHasGenerated] = useState( + initialData?.examplesGenerated || false, + ); + const [imageClassifications, setImageClassifications] = useState<{ + [imageName: string]: string; + }>(initialData?.imageClassifications || {}); + + const { data: dataset, mutate: refreshDataset } = useSWR<{ + [id: string]: string[]; + }>(hasGenerated ? `classification/${step1Data.modelName}/dataset` : null); + + const unknownImages = useMemo(() => { + if (!dataset || !dataset.unknown) return []; + return dataset.unknown; + }, [dataset]); + + const handleClassificationChange = useCallback( + (imageName: string, className: string) => { + setImageClassifications((prev) => ({ + ...prev, + [imageName]: className, + })); + }, + [], + ); + + const generateExamples = useCallback(async () => { + setIsGenerating(true); + + try { + if (step1Data.modelType === "state") { + // For state models, use cameras and crop areas + if (!step2Data?.cameraAreas || step2Data.cameraAreas.length === 0) { + toast.error(t("wizard.step3.errors.noCameras")); + setIsGenerating(false); + return; + } + + const cameras: { [key: string]: [number, number, number, number] } = {}; + step2Data.cameraAreas.forEach((area) => { + cameras[area.camera] = area.crop; + }); + + await axios.post("/classification/generate_examples/state", { + model_name: step1Data.modelName, + cameras, + }); + } else { + // For object models, use label + if (!step1Data.objectLabel) { + toast.error(t("wizard.step3.errors.noObjectLabel")); + setIsGenerating(false); + return; + } + + // For now, use all enabled cameras + // TODO: In the future, we might want to let users select specific cameras + await axios.post("/classification/generate_examples/object", { + model_name: step1Data.modelName, + label: step1Data.objectLabel, + }); + } + + setHasGenerated(true); + toast.success(t("wizard.step3.generateSuccess")); + + await refreshDataset(); + } catch (error) { + const axiosError = error as { + response?: { data?: { message?: string; detail?: string } }; + message?: string; + }; + const errorMessage = + axiosError.response?.data?.message || + axiosError.response?.data?.detail || + axiosError.message || + "Failed to generate examples"; + + toast.error( + t("wizard.step3.errors.generateFailed", { error: errorMessage }), + ); + } finally { + setIsGenerating(false); + } + }, [step1Data, step2Data, t, refreshDataset]); + + useEffect(() => { + if (!hasGenerated && !isGenerating) { + generateExamples(); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + const handleContinue = useCallback(() => { + onNext({ examplesGenerated: true, imageClassifications }); + }, [onNext, imageClassifications]); + + const allImagesClassified = useMemo(() => { + if (!unknownImages || unknownImages.length === 0) return false; + const imagesToClassify = unknownImages.slice(0, 24); + return imagesToClassify.every((img) => imageClassifications[img]); + }, [unknownImages, imageClassifications]); + + return ( +
+ {isGenerating ? ( +
+ +
+

+ {t("wizard.step3.generating.title")} +

+

+ {t("wizard.step3.generating.description")} +

+
+
+ ) : hasGenerated ? ( +
+
+ {t("wizard.step3.description")} +
+
+ {!unknownImages || unknownImages.length === 0 ? ( +
+

+ {t("wizard.step3.noImages")} +

+
+ ) : ( +
+ {unknownImages.slice(0, 24).map((imageName, index) => ( +
+ {`Example +
+ +
+
+ ))} +
+ )} +
+
+ ) : ( +
+

+ {t("wizard.step3.errors.generationFailed")} +

+ +
+ )} + +
+ + +
+
+ ); +}