Improve object cropping implementation

This commit is contained in:
Nicolas Mowen 2025-10-22 13:38:07 -06:00
parent ced177d62c
commit 9b93667b87
6 changed files with 432 additions and 41 deletions

View File

@ -17,6 +17,8 @@ from frigate.api.auth import require_role
from frigate.api.defs.request.classification_body import (
AudioTranscriptionBody,
DeleteFaceImagesBody,
GenerateObjectExamplesBody,
GenerateStateExamplesBody,
RenameFaceBody,
)
from frigate.api.defs.response.classification_response import (
@ -30,6 +32,10 @@ from frigate.config.camera import DetectConfig
from frigate.const import CLIPS_DIR, FACE_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.models import Event
from frigate.util.classification import (
collect_object_classification_examples,
collect_state_classification_examples,
)
from frigate.util.path import get_event_snapshot
logger = logging.getLogger(__name__)
@ -756,3 +762,73 @@ def delete_classification_train_images(request: Request, name: str, body: dict =
content=({"success": True, "message": "Successfully deleted faces."}),
status_code=200,
)
@router.post(
"/classification/generate_examples/state",
response_model=GenericResponse,
dependencies=[Depends(require_role(["admin"]))],
summary="Generate state classification examples",
)
async def generate_state_examples(request: Request, body: GenerateStateExamplesBody):
"""Generate examples for state classification."""
try:
cameras_with_pixels = {}
config: FrigateConfig = request.app.frigate_config
for camera_name, crop in body.cameras.items():
if camera_name not in config.cameras:
continue
camera_config = config.cameras[camera_name]
width = camera_config.detect.width
height = camera_config.detect.height
x1 = int(crop[0] * width)
y1 = int(crop[1] * height)
x2 = int((crop[0] + crop[2]) * width)
y2 = int((crop[1] + crop[3]) * height)
cameras_with_pixels[camera_name] = (x1, y1, x2, y2)
collect_state_classification_examples(body.model_name, cameras_with_pixels)
return JSONResponse(
content={"success": True, "message": "Example generation completed"},
status_code=200,
)
except Exception as e:
logger.error(f"Failed to generate state examples: {e}")
return JSONResponse(
content={
"success": False,
"message": f"Failed to generate examples: {str(e)}",
},
status_code=500,
)
@router.post(
"/classification/generate_examples/object",
response_model=GenericResponse,
dependencies=[Depends(require_role(["admin"]))],
summary="Generate object classification examples",
)
async def generate_object_examples(request: Request, body: GenerateObjectExamplesBody):
"""Generate examples for object classification."""
try:
collect_object_classification_examples(body.model_name, body.label)
return JSONResponse(
content={"success": True, "message": "Example generation completed"},
status_code=200,
)
except Exception as e:
logger.error(f"Failed to generate object examples: {e}")
return JSONResponse(
content={
"success": False,
"message": f"Failed to generate examples: {str(e)}",
},
status_code=500,
)

View File

@ -1,4 +1,4 @@
from typing import List
from typing import Dict, List, Tuple
from pydantic import BaseModel, Field
@ -15,3 +15,15 @@ class DeleteFaceImagesBody(BaseModel):
ids: List[str] = Field(
description="List of image filenames to delete from the face folder"
)
class GenerateStateExamplesBody(BaseModel):
model_name: str
cameras: Dict[str, Tuple[float, float, float, float]] = Field(
description="Dictionary mapping camera names to crop coordinates (x, y, width, height) normalized 0-1"
)
class GenerateObjectExamplesBody(BaseModel):
model_name: str
label: str

View File

@ -145,7 +145,6 @@ class ClassificationTrainingProcess(FrigateProcess):
f.write(tflite_model)
@staticmethod
def kickoff_model_training(
embeddingRequestor: EmbeddingsRequestor, model_name: str
) -> None:
@ -222,12 +221,8 @@ def collect_state_classification_examples(
"/usr/lib/ffmpeg/7.0/bin/ffmpeg", timestamps, temp_dir, cameras
)
if len(keyframes) < 20:
logger.warning(f"Only extracted {len(keyframes)} keyframes, need at least 20")
return
# Step 4: Select 20 most visually distinct images (they're already cropped)
distinct_images = _select_distinct_images(keyframes, target_count=20)
# Step 4: Select 24 most visually distinct images (they're already cropped)
distinct_images = _select_distinct_images(keyframes, target_count=24)
# Step 5: Save to dataset directory (in "unknown" subfolder for unlabeled data)
unknown_dir = os.path.join(dataset_dir, "unknown")
@ -502,66 +497,52 @@ def _select_distinct_images(
@staticmethod
def collect_object_classification_examples(
model_name: str, label: str, cameras: list[str]
model_name: str,
label: str,
) -> None:
"""
Collect representative object classification examples from event thumbnails.
This function:
1. Queries events for the specified label and cameras
1. Queries events for the specified label
2. Selects 100 balanced events across different cameras and times
3. Retrieves thumbnails for selected events
4. Selects 20 most visually distinct thumbnails
5. Saves them to the dataset directory
3. Retrieves thumbnails for selected events (with 33% center crop applied)
4. Selects 24 most visually distinct thumbnails
5. Saves to dataset directory
Args:
model_name: Name of the classification model
label: Object label to collect (e.g., "person", "car")
cameras: List of camera names to collect examples from
"""
logger.info(
f"Collecting examples for {model_name} with label '{label}' from {len(cameras)} cameras"
)
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
temp_dir = os.path.join(dataset_dir, "temp")
os.makedirs(temp_dir, exist_ok=True)
# Step 1: Query events for the specified label and cameras
events = list(
Event.select()
.where(
(Event.label == label)
& (Event.camera.in_(cameras))
& (Event.false_positive == False)
)
.order_by(Event.start_time.asc())
Event.select().where((Event.label == label)).order_by(Event.start_time.asc())
)
if not events:
logger.warning(f"No events found for label '{label}' on cameras: {cameras}")
logger.warning(f"No events found for label '{label}'")
return
logger.info(f"Found {len(events)} events")
logger.debug(f"Found {len(events)} events")
# Step 2: Select balanced events (100 samples)
selected_events = _select_balanced_events(events, target_count=100)
logger.info(f"Selected {len(selected_events)} events")
logger.debug(f"Selected {len(selected_events)} events")
# Step 3: Extract thumbnails from events
thumbnails = _extract_event_thumbnails(selected_events, temp_dir)
logger.debug(f"Successfully extracted {len(thumbnails)} thumbnails")
if len(thumbnails) < 20:
logger.warning(f"Only extracted {len(thumbnails)} thumbnails, need at least 20")
return
# Step 4: Select 24 most visually distinct thumbnails
distinct_images = _select_distinct_images(thumbnails, target_count=24)
logger.debug(f"Selected {len(distinct_images)} distinct images")
logger.info(f"Successfully extracted {len(thumbnails)} thumbnails")
# Step 4: Select 20 most visually distinct thumbnails
distinct_images = _select_distinct_images(thumbnails, target_count=20)
logger.info(f"Selected {len(distinct_images)} distinct images")
# Step 5: Save to dataset directory (in "unknown" subfolder for unlabeled data)
# Step 5: Save to dataset directory
unknown_dir = os.path.join(dataset_dir, "unknown")
os.makedirs(unknown_dir, exist_ok=True)
@ -584,7 +565,7 @@ def collect_object_classification_examples(
except Exception as e:
logger.warning(f"Failed to clean up temp directory: {e}")
logger.info(
logger.debug(
f"Successfully collected {saved_count} classification examples in {unknown_dir}"
)
@ -663,7 +644,46 @@ def _extract_event_thumbnails(events: list[Event], output_dir: str) -> list[str]
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is not None:
resized = cv2.resize(img, (224, 224))
height, width = img.shape[:2]
# Calculate crop based on object size relative to the thumbnail region
crop_size = 1.0 # Default to no crop
if event.data and "box" in event.data and "region" in event.data:
# Box is [x, y, w, h] format
box = event.data["box"]
region = event.data["region"]
if len(box) == 4 and len(region) == 4:
box_w, box_h = box[2], box[3]
region_w, region_h = region[2], region[3]
# Calculate what percentage of the region the box occupies
box_area = (box_w * box_h) / (region_w * region_h)
# Crop inversely proportional to object size in thumbnail
# Small objects need more crop (zoom in), large objects need less
if box_area < 0.05: # Very small (< 5%)
crop_size = 0.4
elif box_area < 0.10: # Small (5-10%)
crop_size = 0.5
elif box_area < 0.20: # Medium-small (10-20%)
crop_size = 0.65
elif box_area < 0.35: # Medium (20-35%)
crop_size = 0.80
else: # Large (>35%)
crop_size = 0.95
crop_width = int(width * crop_size)
crop_height = int(height * crop_size)
# Calculate center crop coordinates
x1 = (width - crop_width) // 2
y1 = (height - crop_height) // 2
x2 = x1 + crop_width
y2 = y1 + crop_height
cropped = img[y1:y2, x1:x2]
resized = cv2.resize(cropped, (224, 224))
output_path = os.path.join(output_dir, f"thumbnail_{idx:04d}.jpg")
cv2.imwrite(output_path, resized)
thumbnail_paths.append(output_path)

View File

@ -89,6 +89,23 @@
"selectCamera": "Select Camera",
"noCameras": "Click + to add cameras",
"selectCameraPrompt": "Select a camera from the list to define its monitoring area"
},
"step3": {
"description": "Classify the example images below. These samples will be used to train your model.",
"generating": {
"title": "Generating Sample Images",
"description": "We're pulling representative images from your recordings. This may take a moment..."
},
"retryGenerate": "Retry Generation",
"selectClass": "Select class...",
"noImages": "No sample images generated",
"errors": {
"noCameras": "No cameras configured",
"noObjectLabel": "No object label selected",
"generateFailed": "Failed to generate examples: {{error}}",
"generationFailed": "Generation failed. Please try again."
},
"generateSuccess": "Successfully generated sample images"
}
}
}

View File

@ -10,6 +10,9 @@ import {
import { useReducer, useMemo } from "react";
import Step1NameAndDefine, { Step1FormData } from "./wizard/Step1NameAndDefine";
import Step2StateArea, { Step2FormData } from "./wizard/Step2StateArea";
import Step3ChooseExamples, {
Step3FormData,
} from "./wizard/Step3ChooseExamples";
import { cn } from "@/lib/utils";
import { isDesktop } from "react-device-detect";
@ -35,8 +38,7 @@ type WizardState = {
currentStep: number;
step1Data?: Step1FormData;
step2Data?: Step2FormData;
// Future steps can be added here
// step3Data?: Step3FormData;
step3Data?: Step3FormData;
};
type WizardAction =
@ -44,6 +46,7 @@ type WizardAction =
| { type: "PREVIOUS_STEP" }
| { type: "SET_STEP_1"; payload: Step1FormData }
| { type: "SET_STEP_2"; payload: Step2FormData }
| { type: "SET_STEP_3"; payload: Step3FormData }
| { type: "RESET" };
const initialState: WizardState = {
@ -64,6 +67,12 @@ function wizardReducer(state: WizardState, action: WizardAction): WizardState {
step2Data: action.payload,
currentStep: 2,
};
case "SET_STEP_3":
return {
...state,
step3Data: action.payload,
currentStep: 3,
};
case "NEXT_STEP":
return {
...state,
@ -107,6 +116,10 @@ export default function ClassificationModelWizardDialog({
dispatch({ type: "SET_STEP_2", payload: data });
};
const handleStep3Next = (data: Step3FormData) => {
dispatch({ type: "SET_STEP_3", payload: data });
};
const handleBack = () => {
dispatch({ type: "PREVIOUS_STEP" });
};
@ -163,6 +176,19 @@ export default function ClassificationModelWizardDialog({
onBack={handleBack}
/>
)}
{((wizardState.currentStep === 2 &&
wizardState.step1Data?.modelType === "state") ||
(wizardState.currentStep === 1 &&
wizardState.step1Data?.modelType === "object")) &&
wizardState.step1Data && (
<Step3ChooseExamples
step1Data={wizardState.step1Data}
step2Data={wizardState.step2Data}
initialData={wizardState.step3Data}
onNext={handleStep3Next}
onBack={handleBack}
/>
)}
</div>
</DialogContent>
</Dialog>

View File

@ -0,0 +1,240 @@
import { Button } from "@/components/ui/button";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { useTranslation } from "react-i18next";
import { useState, useEffect, useCallback, useMemo } from "react";
import ActivityIndicator from "@/components/indicators/activity-indicator";
import axios from "axios";
import { toast } from "sonner";
import { Step1FormData } from "./Step1NameAndDefine";
import { Step2FormData } from "./Step2StateArea";
import useSWR from "swr";
import { baseUrl } from "@/api/baseUrl";
export type Step3FormData = {
examplesGenerated: boolean;
imageClassifications?: { [imageName: string]: string };
};
type Step3ChooseExamplesProps = {
step1Data: Step1FormData;
step2Data?: Step2FormData;
initialData?: Partial<Step3FormData>;
onNext: (data: Step3FormData) => void;
onBack: () => void;
};
export default function Step3ChooseExamples({
step1Data,
step2Data,
initialData,
onNext,
onBack,
}: Step3ChooseExamplesProps) {
const { t } = useTranslation(["views/classificationModel"]);
const [isGenerating, setIsGenerating] = useState(false);
const [hasGenerated, setHasGenerated] = useState(
initialData?.examplesGenerated || false,
);
const [imageClassifications, setImageClassifications] = useState<{
[imageName: string]: string;
}>(initialData?.imageClassifications || {});
const { data: dataset, mutate: refreshDataset } = useSWR<{
[id: string]: string[];
}>(hasGenerated ? `classification/${step1Data.modelName}/dataset` : null);
const unknownImages = useMemo(() => {
if (!dataset || !dataset.unknown) return [];
return dataset.unknown;
}, [dataset]);
const handleClassificationChange = useCallback(
(imageName: string, className: string) => {
setImageClassifications((prev) => ({
...prev,
[imageName]: className,
}));
},
[],
);
const generateExamples = useCallback(async () => {
setIsGenerating(true);
try {
if (step1Data.modelType === "state") {
// For state models, use cameras and crop areas
if (!step2Data?.cameraAreas || step2Data.cameraAreas.length === 0) {
toast.error(t("wizard.step3.errors.noCameras"));
setIsGenerating(false);
return;
}
const cameras: { [key: string]: [number, number, number, number] } = {};
step2Data.cameraAreas.forEach((area) => {
cameras[area.camera] = area.crop;
});
await axios.post("/classification/generate_examples/state", {
model_name: step1Data.modelName,
cameras,
});
} else {
// For object models, use label
if (!step1Data.objectLabel) {
toast.error(t("wizard.step3.errors.noObjectLabel"));
setIsGenerating(false);
return;
}
// For now, use all enabled cameras
// TODO: In the future, we might want to let users select specific cameras
await axios.post("/classification/generate_examples/object", {
model_name: step1Data.modelName,
label: step1Data.objectLabel,
});
}
setHasGenerated(true);
toast.success(t("wizard.step3.generateSuccess"));
await refreshDataset();
} catch (error) {
const axiosError = error as {
response?: { data?: { message?: string; detail?: string } };
message?: string;
};
const errorMessage =
axiosError.response?.data?.message ||
axiosError.response?.data?.detail ||
axiosError.message ||
"Failed to generate examples";
toast.error(
t("wizard.step3.errors.generateFailed", { error: errorMessage }),
);
} finally {
setIsGenerating(false);
}
}, [step1Data, step2Data, t, refreshDataset]);
useEffect(() => {
if (!hasGenerated && !isGenerating) {
generateExamples();
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const handleContinue = useCallback(() => {
onNext({ examplesGenerated: true, imageClassifications });
}, [onNext, imageClassifications]);
const allImagesClassified = useMemo(() => {
if (!unknownImages || unknownImages.length === 0) return false;
const imagesToClassify = unknownImages.slice(0, 24);
return imagesToClassify.every((img) => imageClassifications[img]);
}, [unknownImages, imageClassifications]);
return (
<div className="flex flex-col gap-6">
{isGenerating ? (
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
<ActivityIndicator className="size-12" />
<div className="text-center">
<h3 className="mb-2 text-lg font-medium">
{t("wizard.step3.generating.title")}
</h3>
<p className="text-sm text-muted-foreground">
{t("wizard.step3.generating.description")}
</p>
</div>
</div>
) : hasGenerated ? (
<div className="flex flex-col gap-4">
<div className="text-sm text-muted-foreground">
{t("wizard.step3.description")}
</div>
<div className="rounded-lg bg-secondary/30 p-4">
{!unknownImages || unknownImages.length === 0 ? (
<div className="flex h-[40vh] items-center justify-center">
<p className="text-muted-foreground">
{t("wizard.step3.noImages")}
</p>
</div>
) : (
<div className="grid grid-cols-6 gap-3">
{unknownImages.slice(0, 24).map((imageName, index) => (
<div
key={imageName}
className="group relative aspect-square cursor-pointer overflow-hidden rounded-lg border bg-background transition-all hover:ring-2 hover:ring-primary"
>
<img
src={`${baseUrl}clips/${step1Data.modelName}/dataset/unknown/${imageName}`}
alt={`Example ${index + 1}`}
className="h-full w-full object-cover"
/>
<div className="absolute bottom-0 left-0 right-0 p-2">
<Select
value={imageClassifications[imageName] || ""}
onValueChange={(value) =>
handleClassificationChange(imageName, value)
}
>
<SelectTrigger className="h-7 bg-background/20 text-xs">
<SelectValue
placeholder={t("wizard.step3.selectClass")}
/>
</SelectTrigger>
<SelectContent>
{step1Data.classes.map((className) => (
<SelectItem
key={className}
value={className}
className="cursor-pointer text-xs"
>
{className}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
</div>
))}
</div>
)}
</div>
</div>
) : (
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
<p className="text-sm text-destructive">
{t("wizard.step3.errors.generationFailed")}
</p>
<Button onClick={generateExamples} variant="select">
{t("wizard.step3.retryGenerate")}
</Button>
</div>
)}
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
<Button type="button" onClick={onBack} className="sm:flex-1">
{t("button.back", { ns: "common" })}
</Button>
<Button
type="button"
onClick={handleContinue}
variant="select"
className="flex items-center justify-center gap-2 sm:flex-1"
disabled={!hasGenerated || isGenerating || !allImagesClassified}
>
{t("button.continue", { ns: "common" })}
</Button>
</div>
</div>
);
}