Implement Wizard for Creating Classification Models (#20622)

* Implement extraction of images for classification state models

* Add object classification dataset preparation

* Add first step wizard

* Update i18n

* Add state classification image selection step

* Improve box handling

* Add object selector

* Improve object cropping implementation

* Fix state classification selection

* Finalize training and image selection step

* Cleanup

* Design optimizations

* Cleanup mobile styling

* Update no models screen

* Cleanups and fixes

* Fix bugs

* Improve model training and creation process

* Cleanup

* Dynamically add metrics for new model

* Add loading when hitting continue

* Improve image selection mechanism

* Remove unused translation keys

* Adjust wording

* Add retry button for image generation

* Make no models view more specific

* Adjust plus icon

* Adjust form label

* Start with correct type selected

* Cleanup sizing and more font colors

* Small tweaks

* Add tips and more info

* Cleanup dialog sizing

* Add cursor rule for frontend

* Cleanup

* remove underline

* Lazy loading
This commit is contained in:
Nicolas Mowen 2025-10-23 13:27:28 -06:00 committed by GitHub
parent 4df7793587
commit f5a57edcc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2450 additions and 79 deletions

View File

@ -0,0 +1,6 @@
---
globs: ["**/*.ts", "**/*.tsx"]
alwaysApply: false
---
Never write strings in the frontend directly, always write to and reference the relevant translations file.

View File

@ -12,7 +12,18 @@ Object classification models are lightweight and run very fast on CPU. Inference
Training the model does briefly use a high amount of system resources for about 13 minutes per training run. On lower-power devices, training may take longer. Training the model does briefly use a high amount of system resources for about 13 minutes per training run. On lower-power devices, training may take longer.
When running the `-tensorrt` image, Nvidia GPUs will automatically be used to accelerate training. When running the `-tensorrt` image, Nvidia GPUs will automatically be used to accelerate training.
### Sub label vs Attribute ## Classes
Classes are the categories your model will learn to distinguish between. Each class represents a distinct visual category that the model will predict.
For object classification:
- Define classes that represent different types or attributes of the detected object
- Examples: For `person` objects, classes might be `delivery_person`, `resident`, `stranger`
- Include a `none` class for objects that don't fit any specific category
- Keep classes visually distinct to improve accuracy
### Classification Type
- **Sub label**: - **Sub label**:

View File

@ -12,6 +12,17 @@ State classification models are lightweight and run very fast on CPU. Inference
Training the model does briefly use a high amount of system resources for about 13 minutes per training run. On lower-power devices, training may take longer. Training the model does briefly use a high amount of system resources for about 13 minutes per training run. On lower-power devices, training may take longer.
When running the `-tensorrt` image, Nvidia GPUs will automatically be used to accelerate training. When running the `-tensorrt` image, Nvidia GPUs will automatically be used to accelerate training.
## Classes
Classes are the different states an area on your camera can be in. Each class represents a distinct visual state that the model will learn to recognize.
For state classification:
- Define classes that represent mutually exclusive states
- Examples: `open` and `closed` for a garage door, `on` and `off` for lights
- Use at least 2 classes (typically binary states work best)
- Keep class names clear and descriptive
## Example use cases ## Example use cases
- **Door state**: Detect if a garage or front door is open vs closed. - **Door state**: Detect if a garage or front door is open vs closed.

View File

@ -387,20 +387,28 @@ def config_set(request: Request, body: AppConfigSetBody):
old_config: FrigateConfig = request.app.frigate_config old_config: FrigateConfig = request.app.frigate_config
request.app.frigate_config = config request.app.frigate_config = config
if body.update_topic and body.update_topic.startswith("config/cameras/"): if body.update_topic:
_, _, camera, field = body.update_topic.split("/") if body.update_topic.startswith("config/cameras/"):
_, _, camera, field = body.update_topic.split("/")
if field == "add": if field == "add":
settings = config.cameras[camera] settings = config.cameras[camera]
elif field == "remove": elif field == "remove":
settings = old_config.cameras[camera] settings = old_config.cameras[camera]
else:
settings = config.get_nested_object(body.update_topic)
request.app.config_publisher.publish_update(
CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera),
settings,
)
else: else:
# Handle nested config updates (e.g., config/classification/custom/{name})
settings = config.get_nested_object(body.update_topic) settings = config.get_nested_object(body.update_topic)
if settings:
request.app.config_publisher.publish_update( request.app.config_publisher.publisher.publish(
CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera), body.update_topic, settings
settings, )
)
return JSONResponse( return JSONResponse(
content=( content=(

View File

@ -3,7 +3,9 @@
import datetime import datetime
import logging import logging
import os import os
import random
import shutil import shutil
import string
from typing import Any from typing import Any
import cv2 import cv2
@ -17,6 +19,8 @@ from frigate.api.auth import require_role
from frigate.api.defs.request.classification_body import ( from frigate.api.defs.request.classification_body import (
AudioTranscriptionBody, AudioTranscriptionBody,
DeleteFaceImagesBody, DeleteFaceImagesBody,
GenerateObjectExamplesBody,
GenerateStateExamplesBody,
RenameFaceBody, RenameFaceBody,
) )
from frigate.api.defs.response.classification_response import ( from frigate.api.defs.response.classification_response import (
@ -30,6 +34,10 @@ from frigate.config.camera import DetectConfig
from frigate.const import CLIPS_DIR, FACE_DIR from frigate.const import CLIPS_DIR, FACE_DIR
from frigate.embeddings import EmbeddingsContext from frigate.embeddings import EmbeddingsContext
from frigate.models import Event from frigate.models import Event
from frigate.util.classification import (
collect_object_classification_examples,
collect_state_classification_examples,
)
from frigate.util.path import get_event_snapshot from frigate.util.path import get_event_snapshot
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -159,8 +167,7 @@ def train_face(request: Request, name: str, body: dict = None):
new_name = f"{sanitized_name}-{datetime.datetime.now().timestamp()}.webp" new_name = f"{sanitized_name}-{datetime.datetime.now().timestamp()}.webp"
new_file_folder = os.path.join(FACE_DIR, f"{sanitized_name}") new_file_folder = os.path.join(FACE_DIR, f"{sanitized_name}")
if not os.path.exists(new_file_folder): os.makedirs(new_file_folder, exist_ok=True)
os.mkdir(new_file_folder)
if training_file_name: if training_file_name:
shutil.move(training_file, os.path.join(new_file_folder, new_name)) shutil.move(training_file, os.path.join(new_file_folder, new_name))
@ -701,13 +708,14 @@ def categorize_classification_image(request: Request, name: str, body: dict = No
status_code=404, status_code=404,
) )
new_name = f"{category}-{datetime.datetime.now().timestamp()}.png" random_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
timestamp = datetime.datetime.now().timestamp()
new_name = f"{category}-{timestamp}-{random_id}.png"
new_file_folder = os.path.join( new_file_folder = os.path.join(
CLIPS_DIR, sanitize_filename(name), "dataset", category CLIPS_DIR, sanitize_filename(name), "dataset", category
) )
if not os.path.exists(new_file_folder): os.makedirs(new_file_folder, exist_ok=True)
os.mkdir(new_file_folder)
# use opencv because webp images can not be used to train # use opencv because webp images can not be used to train
img = cv2.imread(training_file) img = cv2.imread(training_file)
@ -756,3 +764,43 @@ def delete_classification_train_images(request: Request, name: str, body: dict =
content=({"success": True, "message": "Successfully deleted faces."}), content=({"success": True, "message": "Successfully deleted faces."}),
status_code=200, status_code=200,
) )
@router.post(
"/classification/generate_examples/state",
response_model=GenericResponse,
dependencies=[Depends(require_role(["admin"]))],
summary="Generate state classification examples",
)
async def generate_state_examples(request: Request, body: GenerateStateExamplesBody):
"""Generate examples for state classification."""
model_name = sanitize_filename(body.model_name)
cameras_normalized = {
camera_name: tuple(crop)
for camera_name, crop in body.cameras.items()
if camera_name in request.app.frigate_config.cameras
}
collect_state_classification_examples(model_name, cameras_normalized)
return JSONResponse(
content={"success": True, "message": "Example generation completed"},
status_code=200,
)
@router.post(
"/classification/generate_examples/object",
response_model=GenericResponse,
dependencies=[Depends(require_role(["admin"]))],
summary="Generate object classification examples",
)
async def generate_object_examples(request: Request, body: GenerateObjectExamplesBody):
"""Generate examples for object classification."""
model_name = sanitize_filename(body.model_name)
collect_object_classification_examples(model_name, body.label)
return JSONResponse(
content={"success": True, "message": "Example generation completed"},
status_code=200,
)

View File

@ -1,17 +1,31 @@
from typing import List from typing import Dict, List, Tuple
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class RenameFaceBody(BaseModel): class RenameFaceBody(BaseModel):
new_name: str new_name: str = Field(description="New name for the face")
class AudioTranscriptionBody(BaseModel): class AudioTranscriptionBody(BaseModel):
event_id: str event_id: str = Field(description="ID of the event to transcribe audio for")
class DeleteFaceImagesBody(BaseModel): class DeleteFaceImagesBody(BaseModel):
ids: List[str] = Field( ids: List[str] = Field(
description="List of image filenames to delete from the face folder" description="List of image filenames to delete from the face folder"
) )
class GenerateStateExamplesBody(BaseModel):
model_name: str = Field(description="Name of the classification model")
cameras: Dict[str, Tuple[float, float, float, float]] = Field(
description="Dictionary mapping camera names to normalized crop coordinates in [x1, y1, x2, y2] format (values 0-1)"
)
class GenerateObjectExamplesBody(BaseModel):
model_name: str = Field(description="Name of the classification model")
label: str = Field(
description="Object label to collect examples for (e.g., 'person', 'car')"
)

View File

@ -53,9 +53,17 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
self.tensor_output_details: dict[str, Any] | None = None self.tensor_output_details: dict[str, Any] | None = None
self.labelmap: dict[int, str] = {} self.labelmap: dict[int, str] = {}
self.classifications_per_second = EventsPerSecond() self.classifications_per_second = EventsPerSecond()
self.inference_speed = InferenceSpeed(
self.metrics.classification_speeds[self.model_config.name] if (
) self.metrics
and self.model_config.name in self.metrics.classification_speeds
):
self.inference_speed = InferenceSpeed(
self.metrics.classification_speeds[self.model_config.name]
)
else:
self.inference_speed = None
self.last_run = datetime.datetime.now().timestamp() self.last_run = datetime.datetime.now().timestamp()
self.__build_detector() self.__build_detector()
@ -83,12 +91,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
def __update_metrics(self, duration: float) -> None: def __update_metrics(self, duration: float) -> None:
self.classifications_per_second.update() self.classifications_per_second.update()
self.inference_speed.update(duration) if self.inference_speed:
self.inference_speed.update(duration)
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
self.metrics.classification_cps[ if self.metrics and self.model_config.name in self.metrics.classification_cps:
self.model_config.name self.metrics.classification_cps[
].value = self.classifications_per_second.eps() self.model_config.name
].value = self.classifications_per_second.eps()
camera = frame_data.get("camera") camera = frame_data.get("camera")
if camera not in self.model_config.state_config.cameras: if camera not in self.model_config.state_config.cameras:
@ -223,9 +233,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.detected_objects: dict[str, float] = {} self.detected_objects: dict[str, float] = {}
self.labelmap: dict[int, str] = {} self.labelmap: dict[int, str] = {}
self.classifications_per_second = EventsPerSecond() self.classifications_per_second = EventsPerSecond()
self.inference_speed = InferenceSpeed(
self.metrics.classification_speeds[self.model_config.name] if (
) self.metrics
and self.model_config.name in self.metrics.classification_speeds
):
self.inference_speed = InferenceSpeed(
self.metrics.classification_speeds[self.model_config.name]
)
else:
self.inference_speed = None
self.__build_detector() self.__build_detector()
@redirect_output_to_logger(logger, logging.DEBUG) @redirect_output_to_logger(logger, logging.DEBUG)
@ -251,12 +269,14 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
def __update_metrics(self, duration: float) -> None: def __update_metrics(self, duration: float) -> None:
self.classifications_per_second.update() self.classifications_per_second.update()
self.inference_speed.update(duration) if self.inference_speed:
self.inference_speed.update(duration)
def process_frame(self, obj_data, frame): def process_frame(self, obj_data, frame):
self.metrics.classification_cps[ if self.metrics and self.model_config.name in self.metrics.classification_cps:
self.model_config.name self.metrics.classification_cps[
].value = self.classifications_per_second.eps() self.model_config.name
].value = self.classifications_per_second.eps()
if obj_data["false_positive"]: if obj_data["false_positive"]:
return return

View File

@ -9,6 +9,7 @@ from typing import Any
from peewee import DoesNotExist from peewee import DoesNotExist
from frigate.comms.config_updater import ConfigSubscriber
from frigate.comms.detections_updater import DetectionSubscriber, DetectionTypeEnum from frigate.comms.detections_updater import DetectionSubscriber, DetectionTypeEnum
from frigate.comms.embeddings_updater import ( from frigate.comms.embeddings_updater import (
EmbeddingsRequestEnum, EmbeddingsRequestEnum,
@ -95,6 +96,9 @@ class EmbeddingMaintainer(threading.Thread):
CameraConfigUpdateEnum.semantic_search, CameraConfigUpdateEnum.semantic_search,
], ],
) )
self.classification_config_subscriber = ConfigSubscriber(
"config/classification/custom/"
)
# Configure Frigate DB # Configure Frigate DB
db = SqliteVecQueueDatabase( db = SqliteVecQueueDatabase(
@ -255,6 +259,7 @@ class EmbeddingMaintainer(threading.Thread):
"""Maintain a SQLite-vec database for semantic search.""" """Maintain a SQLite-vec database for semantic search."""
while not self.stop_event.is_set(): while not self.stop_event.is_set():
self.config_updater.check_for_updates() self.config_updater.check_for_updates()
self._check_classification_config_updates()
self._process_requests() self._process_requests()
self._process_updates() self._process_updates()
self._process_recordings_updates() self._process_recordings_updates()
@ -265,6 +270,7 @@ class EmbeddingMaintainer(threading.Thread):
self._process_event_metadata() self._process_event_metadata()
self.config_updater.stop() self.config_updater.stop()
self.classification_config_subscriber.stop()
self.event_subscriber.stop() self.event_subscriber.stop()
self.event_end_subscriber.stop() self.event_end_subscriber.stop()
self.recordings_subscriber.stop() self.recordings_subscriber.stop()
@ -275,6 +281,46 @@ class EmbeddingMaintainer(threading.Thread):
self.requestor.stop() self.requestor.stop()
logger.info("Exiting embeddings maintenance...") logger.info("Exiting embeddings maintenance...")
def _check_classification_config_updates(self) -> None:
"""Check for classification config updates and add new processors."""
topic, model_config = self.classification_config_subscriber.check_for_update()
if topic and model_config:
model_name = topic.split("/")[-1]
self.config.classification.custom[model_name] = model_config
# Check if processor already exists
for processor in self.realtime_processors:
if isinstance(
processor,
(
CustomStateClassificationProcessor,
CustomObjectClassificationProcessor,
),
):
if processor.model_config.name == model_name:
logger.debug(
f"Classification processor for model {model_name} already exists, skipping"
)
return
if model_config.state_config is not None:
processor = CustomStateClassificationProcessor(
self.config, model_config, self.requestor, self.metrics
)
else:
processor = CustomObjectClassificationProcessor(
self.config,
model_config,
self.event_metadata_publisher,
self.metrics,
)
self.realtime_processors.append(processor)
logger.info(
f"Added classification processor for model: {model_name} (type: {type(processor).__name__})"
)
def _process_requests(self) -> None: def _process_requests(self) -> None:
"""Process embeddings requests""" """Process embeddings requests"""

View File

@ -2,12 +2,15 @@
import logging import logging
import os import os
import random
from collections import defaultdict
import cv2 import cv2
import numpy as np import numpy as np
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
from frigate.comms.inter_process import InterProcessRequestor from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FfmpegConfig
from frigate.const import ( from frigate.const import (
CLIPS_DIR, CLIPS_DIR,
MODEL_CACHE_DIR, MODEL_CACHE_DIR,
@ -15,7 +18,10 @@ from frigate.const import (
UPDATE_MODEL_STATE, UPDATE_MODEL_STATE,
) )
from frigate.log import redirect_output_to_logger from frigate.log import redirect_output_to_logger
from frigate.models import Event, Recordings, ReviewSegment
from frigate.types import ModelStatusTypesEnum from frigate.types import ModelStatusTypesEnum
from frigate.util.image import get_image_from_recording
from frigate.util.path import get_event_thumbnail_bytes
from frigate.util.process import FrigateProcess from frigate.util.process import FrigateProcess
BATCH_SIZE = 16 BATCH_SIZE = 16
@ -69,6 +75,7 @@ class ClassificationTrainingProcess(FrigateProcess):
logger.info(f"Kicking off classification training for {self.model_name}.") logger.info(f"Kicking off classification training for {self.model_name}.")
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset") dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name) model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
os.makedirs(model_dir, exist_ok=True)
num_classes = len( num_classes = len(
[ [
d d
@ -139,7 +146,6 @@ class ClassificationTrainingProcess(FrigateProcess):
f.write(tflite_model) f.write(tflite_model)
@staticmethod
def kickoff_model_training( def kickoff_model_training(
embeddingRequestor: EmbeddingsRequestor, model_name: str embeddingRequestor: EmbeddingsRequestor, model_name: str
) -> None: ) -> None:
@ -172,3 +178,520 @@ def kickoff_model_training(
}, },
) )
requestor.stop() requestor.stop()
@staticmethod
def collect_state_classification_examples(
model_name: str, cameras: dict[str, tuple[float, float, float, float]]
) -> None:
"""
Collect representative state classification examples from review items.
This function:
1. Queries review items from specified cameras
2. Selects 100 balanced timestamps across the data
3. Extracts keyframes from recordings (cropped to specified regions)
4. Selects 20 most visually distinct images
5. Saves them to the dataset directory
Args:
model_name: Name of the classification model
cameras: Dict mapping camera names to normalized crop coordinates [x1, y1, x2, y2] (0-1)
"""
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
temp_dir = os.path.join(dataset_dir, "temp")
os.makedirs(temp_dir, exist_ok=True)
# Step 1: Get review items for the cameras
camera_names = list(cameras.keys())
review_items = list(
ReviewSegment.select()
.where(ReviewSegment.camera.in_(camera_names))
.where(ReviewSegment.end_time.is_null(False))
.order_by(ReviewSegment.start_time.asc())
)
if not review_items:
logger.warning(f"No review items found for cameras: {camera_names}")
return
# Step 2: Create balanced timestamp selection (100 samples)
timestamps = _select_balanced_timestamps(review_items, target_count=100)
# Step 3: Extract keyframes from recordings with crops applied
keyframes = _extract_keyframes(
"/usr/lib/ffmpeg/7.0/bin/ffmpeg", timestamps, temp_dir, cameras
)
# Step 4: Select 24 most visually distinct images (they're already cropped)
distinct_images = _select_distinct_images(keyframes, target_count=24)
# Step 5: Save to train directory for later classification
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
os.makedirs(train_dir, exist_ok=True)
saved_count = 0
for idx, image_path in enumerate(distinct_images):
dest_path = os.path.join(train_dir, f"example_{idx:03d}.jpg")
try:
img = cv2.imread(image_path)
if img is not None:
cv2.imwrite(dest_path, img)
saved_count += 1
except Exception as e:
logger.error(f"Failed to save image {image_path}: {e}")
import shutil
try:
shutil.rmtree(temp_dir)
except Exception as e:
logger.warning(f"Failed to clean up temp directory: {e}")
def _select_balanced_timestamps(
review_items: list[ReviewSegment], target_count: int = 100
) -> list[dict]:
"""
Select balanced timestamps from review items.
Strategy:
- Group review items by camera and time of day
- Sample evenly across groups to ensure diversity
- For each selected review item, pick a random timestamp within its duration
Returns:
List of dicts with keys: camera, timestamp, review_item
"""
# Group by camera and hour of day for temporal diversity
grouped = defaultdict(list)
for item in review_items:
camera = item.camera
# Group by 6-hour blocks for temporal diversity
hour_block = int(item.start_time // (6 * 3600))
key = f"{camera}_{hour_block}"
grouped[key].append(item)
# Calculate how many samples per group
num_groups = len(grouped)
if num_groups == 0:
return []
samples_per_group = max(1, target_count // num_groups)
timestamps = []
# Sample from each group
for group_items in grouped.values():
# Take samples_per_group items from this group
sample_size = min(samples_per_group, len(group_items))
sampled_items = random.sample(group_items, sample_size)
for item in sampled_items:
# Pick a random timestamp within the review item's duration
duration = item.end_time - item.start_time
if duration <= 0:
continue
# Sample from middle 80% to avoid edge artifacts
offset = random.uniform(duration * 0.1, duration * 0.9)
timestamp = item.start_time + offset
timestamps.append(
{
"camera": item.camera,
"timestamp": timestamp,
"review_item": item,
}
)
# If we don't have enough, sample more from larger groups
while len(timestamps) < target_count and len(timestamps) < len(review_items):
for group_items in grouped.values():
if len(timestamps) >= target_count:
break
# Pick a random item not already sampled
item = random.choice(group_items)
duration = item.end_time - item.start_time
if duration <= 0:
continue
offset = random.uniform(duration * 0.1, duration * 0.9)
timestamp = item.start_time + offset
# Check if we already have a timestamp near this one
if not any(abs(t["timestamp"] - timestamp) < 1.0 for t in timestamps):
timestamps.append(
{
"camera": item.camera,
"timestamp": timestamp,
"review_item": item,
}
)
return timestamps[:target_count]
def _extract_keyframes(
ffmpeg_path: str,
timestamps: list[dict],
output_dir: str,
camera_crops: dict[str, tuple[float, float, float, float]],
) -> list[str]:
"""
Extract keyframes from recordings at specified timestamps and crop to specified regions.
Args:
ffmpeg_path: Path to ffmpeg binary
timestamps: List of timestamp dicts from _select_balanced_timestamps
output_dir: Directory to save extracted frames
camera_crops: Dict mapping camera names to normalized crop coordinates [x1, y1, x2, y2] (0-1)
Returns:
List of paths to successfully extracted and cropped keyframe images
"""
keyframe_paths = []
for idx, ts_info in enumerate(timestamps):
camera = ts_info["camera"]
timestamp = ts_info["timestamp"]
if camera not in camera_crops:
logger.warning(f"No crop coordinates for camera {camera}")
continue
norm_x1, norm_y1, norm_x2, norm_y2 = camera_crops[camera]
try:
recording = (
Recordings.select()
.where(
(timestamp >= Recordings.start_time)
& (timestamp <= Recordings.end_time)
& (Recordings.camera == camera)
)
.order_by(Recordings.start_time.desc())
.limit(1)
.get()
)
except Exception:
continue
relative_time = timestamp - recording.start_time
try:
config = FfmpegConfig(path="/usr/lib/ffmpeg/7.0")
image_data = get_image_from_recording(
config,
recording.path,
relative_time,
codec="mjpeg",
height=None,
)
if image_data:
nparr = np.frombuffer(image_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is not None:
height, width = img.shape[:2]
x1 = int(norm_x1 * width)
y1 = int(norm_y1 * height)
x2 = int(norm_x2 * width)
y2 = int(norm_y2 * height)
x1_clipped = max(0, min(x1, width))
y1_clipped = max(0, min(y1, height))
x2_clipped = max(0, min(x2, width))
y2_clipped = max(0, min(y2, height))
if x2_clipped > x1_clipped and y2_clipped > y1_clipped:
cropped = img[y1_clipped:y2_clipped, x1_clipped:x2_clipped]
resized = cv2.resize(cropped, (224, 224))
output_path = os.path.join(output_dir, f"frame_{idx:04d}.jpg")
cv2.imwrite(output_path, resized)
keyframe_paths.append(output_path)
except Exception as e:
logger.debug(
f"Failed to extract frame from {recording.path} at {relative_time}s: {e}"
)
continue
return keyframe_paths
def _select_distinct_images(
image_paths: list[str], target_count: int = 20
) -> list[str]:
"""
Select the most visually distinct images from a set of keyframes.
Uses a greedy algorithm based on image histograms:
1. Start with a random image
2. Iteratively add the image that is most different from already selected images
3. Difference is measured using histogram comparison
Args:
image_paths: List of paths to candidate images
target_count: Number of distinct images to select
Returns:
List of paths to selected images
"""
if len(image_paths) <= target_count:
return image_paths
histograms = {}
valid_paths = []
for path in image_paths:
try:
img = cv2.imread(path)
if img is None:
continue
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
hist = cv2.calcHist(
[hsv], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256]
)
hist = cv2.normalize(hist, hist).flatten()
histograms[path] = hist
valid_paths.append(path)
except Exception as e:
logger.debug(f"Failed to process image {path}: {e}")
continue
if len(valid_paths) <= target_count:
return valid_paths
selected = []
first_image = random.choice(valid_paths)
selected.append(first_image)
remaining = [p for p in valid_paths if p != first_image]
while len(selected) < target_count and remaining:
max_min_distance = -1
best_candidate = None
for candidate in remaining:
min_distance = float("inf")
for selected_img in selected:
distance = cv2.compareHist(
histograms[candidate],
histograms[selected_img],
cv2.HISTCMP_BHATTACHARYYA,
)
min_distance = min(min_distance, distance)
if min_distance > max_min_distance:
max_min_distance = min_distance
best_candidate = candidate
if best_candidate:
selected.append(best_candidate)
remaining.remove(best_candidate)
else:
break
return selected
@staticmethod
def collect_object_classification_examples(
model_name: str,
label: str,
) -> None:
"""
Collect representative object classification examples from event thumbnails.
This function:
1. Queries events for the specified label
2. Selects 100 balanced events across different cameras and times
3. Retrieves thumbnails for selected events (with 33% center crop applied)
4. Selects 24 most visually distinct thumbnails
5. Saves to dataset directory
Args:
model_name: Name of the classification model
label: Object label to collect (e.g., "person", "car")
cameras: List of camera names to collect examples from
"""
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
temp_dir = os.path.join(dataset_dir, "temp")
os.makedirs(temp_dir, exist_ok=True)
# Step 1: Query events for the specified label and cameras
events = list(
Event.select().where((Event.label == label)).order_by(Event.start_time.asc())
)
if not events:
logger.warning(f"No events found for label '{label}'")
return
logger.debug(f"Found {len(events)} events")
# Step 2: Select balanced events (100 samples)
selected_events = _select_balanced_events(events, target_count=100)
logger.debug(f"Selected {len(selected_events)} events")
# Step 3: Extract thumbnails from events
thumbnails = _extract_event_thumbnails(selected_events, temp_dir)
logger.debug(f"Successfully extracted {len(thumbnails)} thumbnails")
# Step 4: Select 24 most visually distinct thumbnails
distinct_images = _select_distinct_images(thumbnails, target_count=24)
logger.debug(f"Selected {len(distinct_images)} distinct images")
# Step 5: Save to train directory for later classification
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
os.makedirs(train_dir, exist_ok=True)
saved_count = 0
for idx, image_path in enumerate(distinct_images):
dest_path = os.path.join(train_dir, f"example_{idx:03d}.jpg")
try:
img = cv2.imread(image_path)
if img is not None:
cv2.imwrite(dest_path, img)
saved_count += 1
except Exception as e:
logger.error(f"Failed to save image {image_path}: {e}")
import shutil
try:
shutil.rmtree(temp_dir)
except Exception as e:
logger.warning(f"Failed to clean up temp directory: {e}")
logger.debug(
f"Successfully collected {saved_count} classification examples in {train_dir}"
)
def _select_balanced_events(
events: list[Event], target_count: int = 100
) -> list[Event]:
"""
Select balanced events from the event list.
Strategy:
- Group events by camera and time of day
- Sample evenly across groups to ensure diversity
- Prioritize events with higher scores
Returns:
List of selected events
"""
grouped = defaultdict(list)
for event in events:
camera = event.camera
hour_block = int(event.start_time // (6 * 3600))
key = f"{camera}_{hour_block}"
grouped[key].append(event)
num_groups = len(grouped)
if num_groups == 0:
return []
samples_per_group = max(1, target_count // num_groups)
selected = []
for group_events in grouped.values():
sorted_events = sorted(
group_events,
key=lambda e: e.data.get("score", 0) if e.data else 0,
reverse=True,
)
sample_size = min(samples_per_group, len(sorted_events))
selected.extend(sorted_events[:sample_size])
if len(selected) < target_count:
remaining = [e for e in events if e not in selected]
remaining_sorted = sorted(
remaining,
key=lambda e: e.data.get("score", 0) if e.data else 0,
reverse=True,
)
needed = target_count - len(selected)
selected.extend(remaining_sorted[:needed])
return selected[:target_count]
def _extract_event_thumbnails(events: list[Event], output_dir: str) -> list[str]:
"""
Extract thumbnails from events and save to disk.
Args:
events: List of Event objects
output_dir: Directory to save thumbnails
Returns:
List of paths to successfully extracted thumbnail images
"""
thumbnail_paths = []
for idx, event in enumerate(events):
try:
thumbnail_bytes = get_event_thumbnail_bytes(event)
if thumbnail_bytes:
nparr = np.frombuffer(thumbnail_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img is not None:
height, width = img.shape[:2]
crop_size = 1.0
if event.data and "box" in event.data and "region" in event.data:
box = event.data["box"]
region = event.data["region"]
if len(box) == 4 and len(region) == 4:
box_w, box_h = box[2], box[3]
region_w, region_h = region[2], region[3]
box_area = (box_w * box_h) / (region_w * region_h)
if box_area < 0.05:
crop_size = 0.4
elif box_area < 0.10:
crop_size = 0.5
elif box_area < 0.20:
crop_size = 0.65
elif box_area < 0.35:
crop_size = 0.80
else:
crop_size = 0.95
crop_width = int(width * crop_size)
crop_height = int(height * crop_size)
x1 = (width - crop_width) // 2
y1 = (height - crop_height) // 2
x2 = x1 + crop_width
y2 = y1 + crop_height
cropped = img[y1:y2, x1:x2]
resized = cv2.resize(cropped, (224, 224))
output_path = os.path.join(output_dir, f"thumbnail_{idx:04d}.jpg")
cv2.imwrite(output_path, resized)
thumbnail_paths.append(output_path)
except Exception as e:
logger.debug(f"Failed to extract thumbnail for event {event.id}: {e}")
continue
return thumbnail_paths

View File

@ -1,4 +1,5 @@
{ {
"documentTitle": "Classification Models",
"button": { "button": {
"deleteClassificationAttempts": "Delete Classification Images", "deleteClassificationAttempts": "Delete Classification Images",
"renameCategory": "Rename Class", "renameCategory": "Rename Class",
@ -50,8 +51,85 @@
}, },
"categorizeImageAs": "Classify Image As:", "categorizeImageAs": "Classify Image As:",
"categorizeImage": "Classify Image", "categorizeImage": "Classify Image",
"noModels": {
"object": {
"title": "No Object Classification Models",
"description": "Create a custom model to classify detected objects.",
"buttonText": "Create Object Model"
},
"state": {
"title": "No State Classification Models",
"description": "Create a custom model to monitor and classify state changes in specific camera areas.",
"buttonText": "Create State Model"
}
},
"wizard": { "wizard": {
"title": "Create New Classification", "title": "Create New Classification",
"description": "Create a new state or object classification model." "steps": {
"nameAndDefine": "Name & Define",
"stateArea": "State Area",
"chooseExamples": "Choose Examples"
},
"step1": {
"description": "State models monitor fixed camera areas for changes (e.g., door open/closed). Object models add classifications to detected objects (e.g., known animals, delivery persons, etc.).",
"name": "Name",
"namePlaceholder": "Enter model name...",
"type": "Type",
"typeState": "State",
"typeObject": "Object",
"objectLabel": "Object Label",
"objectLabelPlaceholder": "Select object type...",
"classificationType": "Classification Type",
"classificationTypeTip": "Learn about classification types",
"classificationTypeDesc": "Sub Labels add additional text to the object label (e.g., 'Person: UPS'). Attributes are searchable metadata stored separately in the object metadata.",
"classificationSubLabel": "Sub Label",
"classificationAttribute": "Attribute",
"classes": "Classes",
"classesTip": "Learn about classes",
"classesStateDesc": "Define the different states your camera area can be in. For example: 'open' and 'closed' for a garage door.",
"classesObjectDesc": "Define the different categories to classify detected objects into. For example: 'delivery_person', 'resident', 'stranger' for person classification.",
"classPlaceholder": "Enter class name...",
"errors": {
"nameRequired": "Model name is required",
"nameLength": "Model name must be 64 characters or less",
"nameOnlyNumbers": "Model name cannot contain only numbers",
"classRequired": "At least 1 class is required",
"classesUnique": "Class names must be unique",
"stateRequiresTwoClasses": "State models require at least 2 classes",
"objectLabelRequired": "Please select an object label",
"objectTypeRequired": "Please select a classification type"
}
},
"step2": {
"description": "Select cameras and define the area to monitor for each camera. The model will classify the state of these areas.",
"cameras": "Cameras",
"selectCamera": "Select Camera",
"noCameras": "Click + to add cameras",
"selectCameraPrompt": "Select a camera from the list to define its monitoring area"
},
"step3": {
"selectImagesPrompt": "Select all images with: {{className}}",
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
"generating": {
"title": "Generating Sample Images",
"description": "Frigate is pulling representative images from your recordings. This may take a moment..."
},
"training": {
"title": "Training Model",
"description": "Your model is being trained in the background. Close this dialog, and your model will start running as soon as training is complete."
},
"retryGenerate": "Retry Generation",
"noImages": "No sample images generated",
"classifying": "Classifying & Training...",
"trainingStarted": "Training started successfully",
"errors": {
"noCameras": "No cameras configured",
"noObjectLabel": "No object label selected",
"generateFailed": "Failed to generate examples: {{error}}",
"generationFailed": "Generation failed. Please try again.",
"classifyFailed": "Failed to classify images: {{error}}"
},
"generateSuccess": "Successfully generated sample images"
}
} }
} }

View File

@ -5,10 +5,6 @@
"invalidName": "Invalid name. Names can only include letters, numbers, spaces, apostrophes, underscores, and hyphens." "invalidName": "Invalid name. Names can only include letters, numbers, spaces, apostrophes, underscores, and hyphens."
}, },
"details": { "details": {
"subLabelScore": "Sub Label Score",
"scoreInfo": "The sub label score is the weighted score for all of the recognized face confidences, so this may differ from the score shown on the snapshot.",
"face": "Face Details",
"faceDesc": "Details of the tracked object that generated this face",
"timestamp": "Timestamp", "timestamp": "Timestamp",
"unknown": "Unknown" "unknown": "Unknown"
}, },
@ -19,8 +15,6 @@
}, },
"collections": "Collections", "collections": "Collections",
"createFaceLibrary": { "createFaceLibrary": {
"title": "Create Collection",
"desc": "Create a new collection",
"new": "Create New Face", "new": "Create New Face",
"nextSteps": "To build a strong foundation:<li>Use the Recent Recognitions tab to select and train on images for each detected person.</li><li>Focus on straight-on images for best results; avoid training images that capture faces at an angle.</li></ul>" "nextSteps": "To build a strong foundation:<li>Use the Recent Recognitions tab to select and train on images for each detected person.</li><li>Focus on straight-on images for best results; avoid training images that capture faces at an angle.</li></ul>"
}, },
@ -37,8 +31,6 @@
"aria": "Select recent recognitions", "aria": "Select recent recognitions",
"empty": "There are no recent face recognition attempts" "empty": "There are no recent face recognition attempts"
}, },
"selectItem": "Select {{item}}",
"selectFace": "Select Face",
"deleteFaceLibrary": { "deleteFaceLibrary": {
"title": "Delete Name", "title": "Delete Name",
"desc": "Are you sure you want to delete the collection {{name}}? This will permanently delete all associated faces." "desc": "Are you sure you want to delete the collection {{name}}? This will permanently delete all associated faces."
@ -69,7 +61,6 @@
"maxSize": "Max size: {{size}}MB" "maxSize": "Max size: {{size}}MB"
}, },
"nofaces": "No faces available", "nofaces": "No faces available",
"pixels": "{{area}}px",
"trainFaceAs": "Train Face as:", "trainFaceAs": "Train Face as:",
"trainFace": "Train Face", "trainFace": "Train Face",
"toast": { "toast": {

View File

@ -126,6 +126,7 @@ export const ClassificationCard = forwardRef<
imgClassName, imgClassName,
isMobile && "w-full", isMobile && "w-full",
)} )}
loading="lazy"
onLoad={() => setImageLoaded(true)} onLoad={() => setImageLoaded(true)}
src={`${baseUrl}${data.filepath}`} src={`${baseUrl}${data.filepath}`}
/> />

View File

@ -7,58 +7,198 @@ import {
DialogHeader, DialogHeader,
DialogTitle, DialogTitle,
} from "../ui/dialog"; } from "../ui/dialog";
import { useState } from "react"; import { useReducer, useMemo } from "react";
import Step1NameAndDefine, { Step1FormData } from "./wizard/Step1NameAndDefine";
import Step2StateArea, { Step2FormData } from "./wizard/Step2StateArea";
import Step3ChooseExamples, {
Step3FormData,
} from "./wizard/Step3ChooseExamples";
import { cn } from "@/lib/utils";
import { isDesktop } from "react-device-detect";
const STEPS = [ const OBJECT_STEPS = [
"classificationWizard.steps.nameAndDefine", "wizard.steps.nameAndDefine",
"classificationWizard.steps.stateArea", "wizard.steps.chooseExamples",
"classificationWizard.steps.chooseExamples", ];
"classificationWizard.steps.train",
const STATE_STEPS = [
"wizard.steps.nameAndDefine",
"wizard.steps.stateArea",
"wizard.steps.chooseExamples",
]; ];
type ClassificationModelWizardDialogProps = { type ClassificationModelWizardDialogProps = {
open: boolean; open: boolean;
onClose: () => void; onClose: () => void;
defaultModelType?: "state" | "object";
}; };
type WizardState = {
currentStep: number;
step1Data?: Step1FormData;
step2Data?: Step2FormData;
step3Data?: Step3FormData;
};
type WizardAction =
| { type: "NEXT_STEP"; payload?: Partial<WizardState> }
| { type: "PREVIOUS_STEP" }
| { type: "SET_STEP_1"; payload: Step1FormData }
| { type: "SET_STEP_2"; payload: Step2FormData }
| { type: "SET_STEP_3"; payload: Step3FormData }
| { type: "RESET" };
const initialState: WizardState = {
currentStep: 0,
};
function wizardReducer(state: WizardState, action: WizardAction): WizardState {
switch (action.type) {
case "SET_STEP_1":
return {
...state,
step1Data: action.payload,
currentStep: 1,
};
case "SET_STEP_2":
return {
...state,
step2Data: action.payload,
currentStep: 2,
};
case "SET_STEP_3":
return {
...state,
step3Data: action.payload,
currentStep: 3,
};
case "NEXT_STEP":
return {
...state,
...action.payload,
currentStep: state.currentStep + 1,
};
case "PREVIOUS_STEP":
return {
...state,
currentStep: Math.max(0, state.currentStep - 1),
};
case "RESET":
return initialState;
default:
return state;
}
}
export default function ClassificationModelWizardDialog({ export default function ClassificationModelWizardDialog({
open, open,
onClose, onClose,
defaultModelType,
}: ClassificationModelWizardDialogProps) { }: ClassificationModelWizardDialogProps) {
const { t } = useTranslation(["views/classificationModel"]); const { t } = useTranslation(["views/classificationModel"]);
// step management const [wizardState, dispatch] = useReducer(wizardReducer, initialState);
const [currentStep, _] = useState(0);
const steps = useMemo(() => {
if (!wizardState.step1Data) {
return OBJECT_STEPS;
}
return wizardState.step1Data.modelType === "state"
? STATE_STEPS
: OBJECT_STEPS;
}, [wizardState.step1Data]);
const handleStep1Next = (data: Step1FormData) => {
dispatch({ type: "SET_STEP_1", payload: data });
};
const handleStep2Next = (data: Step2FormData) => {
dispatch({ type: "SET_STEP_2", payload: data });
};
const handleBack = () => {
dispatch({ type: "PREVIOUS_STEP" });
};
const handleCancel = () => {
dispatch({ type: "RESET" });
onClose();
};
return ( return (
<Dialog <Dialog
open={open} open={open}
onOpenChange={(open) => { onOpenChange={(open) => {
if (!open) { if (!open) {
onClose; handleCancel();
} }
}} }}
> >
<DialogContent <DialogContent
className="max-h-[90dvh] max-w-4xl overflow-y-auto" className={cn(
"",
isDesktop &&
wizardState.currentStep == 0 &&
"max-h-[90%] overflow-y-auto xl:max-h-[80%]",
isDesktop &&
wizardState.currentStep > 0 &&
"max-h-[90%] max-w-[70%] overflow-y-auto xl:max-h-[80%]",
)}
onInteractOutside={(e) => { onInteractOutside={(e) => {
e.preventDefault(); e.preventDefault();
}} }}
> >
<StepIndicator <StepIndicator
steps={STEPS} steps={steps}
currentStep={currentStep} currentStep={wizardState.currentStep}
variant="dots" variant="dots"
className="mb-4 justify-start" className="mb-4 justify-start"
/> />
<DialogHeader> <DialogHeader>
<DialogTitle>{t("wizard.title")}</DialogTitle> <DialogTitle>{t("wizard.title")}</DialogTitle>
{currentStep === 0 && ( {wizardState.currentStep === 0 && (
<DialogDescription>{t("wizard.description")}</DialogDescription> <DialogDescription>
{t("wizard.step1.description")}
</DialogDescription>
)} )}
{wizardState.currentStep === 1 &&
wizardState.step1Data?.modelType === "state" && (
<DialogDescription>
{t("wizard.step2.description")}
</DialogDescription>
)}
</DialogHeader> </DialogHeader>
<div className="pb-4"> <div className="pb-4">
<div className="size-full"></div> {wizardState.currentStep === 0 && (
<Step1NameAndDefine
initialData={wizardState.step1Data}
defaultModelType={defaultModelType}
onNext={handleStep1Next}
onCancel={handleCancel}
/>
)}
{wizardState.currentStep === 1 &&
wizardState.step1Data?.modelType === "state" && (
<Step2StateArea
initialData={wizardState.step2Data}
onNext={handleStep2Next}
onBack={handleBack}
/>
)}
{((wizardState.currentStep === 2 &&
wizardState.step1Data?.modelType === "state") ||
(wizardState.currentStep === 1 &&
wizardState.step1Data?.modelType === "object")) &&
wizardState.step1Data && (
<Step3ChooseExamples
step1Data={wizardState.step1Data}
step2Data={wizardState.step2Data}
initialData={wizardState.step3Data}
onClose={onClose}
onBack={handleBack}
/>
)}
</div> </div>
</DialogContent> </DialogContent>
</Dialog> </Dialog>

View File

@ -0,0 +1,498 @@
import { Button } from "@/components/ui/button";
import {
Form,
FormControl,
FormField,
FormItem,
FormLabel,
FormMessage,
} from "@/components/ui/form";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { useForm } from "react-hook-form";
import { zodResolver } from "@hookform/resolvers/zod";
import { z } from "zod";
import { useTranslation } from "react-i18next";
import { useMemo } from "react";
import { LuX, LuPlus, LuInfo, LuExternalLink } from "react-icons/lu";
import useSWR from "swr";
import { FrigateConfig } from "@/types/frigateConfig";
import { getTranslatedLabel } from "@/utils/i18n";
import { useDocDomain } from "@/hooks/use-doc-domain";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/ui/popover";
export type ModelType = "state" | "object";
export type ObjectClassificationType = "sub_label" | "attribute";
export type Step1FormData = {
modelName: string;
modelType: ModelType;
objectLabel?: string;
objectType?: ObjectClassificationType;
classes: string[];
};
type Step1NameAndDefineProps = {
initialData?: Partial<Step1FormData>;
defaultModelType?: "state" | "object";
onNext: (data: Step1FormData) => void;
onCancel: () => void;
};
export default function Step1NameAndDefine({
initialData,
defaultModelType,
onNext,
onCancel,
}: Step1NameAndDefineProps) {
const { t } = useTranslation(["views/classificationModel"]);
const { data: config } = useSWR<FrigateConfig>("config");
const { getLocaleDocUrl } = useDocDomain();
const objectLabels = useMemo(() => {
if (!config) return [];
const labels = new Set<string>();
Object.values(config.cameras).forEach((cameraConfig) => {
if (!cameraConfig.enabled || !cameraConfig.enabled_in_config) {
return;
}
cameraConfig.objects.track.forEach((label) => {
if (!config.model.all_attributes.includes(label)) {
labels.add(label);
}
});
});
return [...labels].sort();
}, [config]);
const step1FormData = z
.object({
modelName: z
.string()
.min(1, t("wizard.step1.errors.nameRequired"))
.max(64, t("wizard.step1.errors.nameLength"))
.refine((value) => !/^\d+$/.test(value), {
message: t("wizard.step1.errors.nameOnlyNumbers"),
}),
modelType: z.enum(["state", "object"]),
objectLabel: z.string().optional(),
objectType: z.enum(["sub_label", "attribute"]).optional(),
classes: z
.array(z.string())
.min(1, t("wizard.step1.errors.classRequired"))
.refine(
(classes) => {
const nonEmpty = classes.filter((c) => c.trim().length > 0);
return nonEmpty.length >= 1;
},
{ message: t("wizard.step1.errors.classRequired") },
)
.refine(
(classes) => {
const nonEmpty = classes.filter((c) => c.trim().length > 0);
const unique = new Set(nonEmpty.map((c) => c.toLowerCase()));
return unique.size === nonEmpty.length;
},
{ message: t("wizard.step1.errors.classesUnique") },
),
})
.refine(
(data) => {
// State models require at least 2 classes
if (data.modelType === "state") {
const nonEmpty = data.classes.filter((c) => c.trim().length > 0);
return nonEmpty.length >= 2;
}
return true;
},
{
message: t("wizard.step1.errors.stateRequiresTwoClasses"),
path: ["classes"],
},
)
.refine(
(data) => {
if (data.modelType === "object") {
return data.objectLabel !== undefined && data.objectLabel !== "";
}
return true;
},
{
message: t("wizard.step1.errors.objectLabelRequired"),
path: ["objectLabel"],
},
)
.refine(
(data) => {
if (data.modelType === "object") {
return data.objectType !== undefined;
}
return true;
},
{
message: t("wizard.step1.errors.objectTypeRequired"),
path: ["objectType"],
},
);
const form = useForm<z.infer<typeof step1FormData>>({
resolver: zodResolver(step1FormData),
defaultValues: {
modelName: initialData?.modelName || "",
modelType: initialData?.modelType || defaultModelType || "state",
objectLabel: initialData?.objectLabel,
objectType: initialData?.objectType || "sub_label",
classes: initialData?.classes?.length ? initialData.classes : [""],
},
mode: "onChange",
});
const watchedClasses = form.watch("classes");
const watchedModelType = form.watch("modelType");
const watchedObjectType = form.watch("objectType");
const handleAddClass = () => {
const currentClasses = form.getValues("classes");
form.setValue("classes", [...currentClasses, ""], { shouldValidate: true });
};
const handleRemoveClass = (index: number) => {
const currentClasses = form.getValues("classes");
const newClasses = currentClasses.filter((_, i) => i !== index);
// Ensure at least one field remains (even if empty)
if (newClasses.length === 0) {
form.setValue("classes", [""], { shouldValidate: true });
} else {
form.setValue("classes", newClasses, { shouldValidate: true });
}
};
const onSubmit = (data: z.infer<typeof step1FormData>) => {
// Filter out empty classes
const filteredClasses = data.classes.filter((c) => c.trim().length > 0);
onNext({
...data,
classes: filteredClasses,
});
};
return (
<div className="space-y-6">
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
<FormField
control={form.control}
name="modelName"
render={({ field }) => (
<FormItem>
<FormLabel className="text-primary-variant">
{t("wizard.step1.name")}
</FormLabel>
<FormControl>
<Input
className="h-8"
placeholder={t("wizard.step1.namePlaceholder")}
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="modelType"
render={({ field }) => (
<FormItem>
<FormLabel className="text-primary-variant">
{t("wizard.step1.type")}
</FormLabel>
<FormControl>
<RadioGroup
onValueChange={field.onChange}
defaultValue={field.value}
className="flex flex-col gap-4 pt-2"
>
<div className="flex items-center gap-2">
<RadioGroupItem
className={
watchedModelType === "state"
? "bg-selected from-selected/50 to-selected/90 text-selected"
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
}
id="state"
value="state"
/>
<Label className="cursor-pointer" htmlFor="state">
{t("wizard.step1.typeState")}
</Label>
</div>
<div className="flex items-center gap-2">
<RadioGroupItem
className={
watchedModelType === "object"
? "bg-selected from-selected/50 to-selected/90 text-selected"
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
}
id="object"
value="object"
/>
<Label className="cursor-pointer" htmlFor="object">
{t("wizard.step1.typeObject")}
</Label>
</div>
</RadioGroup>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
{watchedModelType === "object" && (
<>
<FormField
control={form.control}
name="objectLabel"
render={({ field }) => (
<FormItem>
<FormLabel className="text-primary-variant">
{t("wizard.step1.objectLabel")}
</FormLabel>
<Select
onValueChange={field.onChange}
defaultValue={field.value}
>
<FormControl>
<SelectTrigger className="h-8">
<SelectValue
placeholder={t(
"wizard.step1.objectLabelPlaceholder",
)}
/>
</SelectTrigger>
</FormControl>
<SelectContent>
{objectLabels.map((label) => (
<SelectItem
key={label}
value={label}
className="cursor-pointer hover:bg-secondary-highlight"
>
{getTranslatedLabel(label)}
</SelectItem>
))}
</SelectContent>
</Select>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="objectType"
render={({ field }) => (
<FormItem>
<div className="flex items-center gap-1">
<FormLabel className="text-primary-variant">
{t("wizard.step1.classificationType")}
</FormLabel>
<Popover>
<PopoverTrigger asChild>
<Button
variant="ghost"
size="sm"
className="h-4 w-4 p-0"
>
<LuInfo className="size-3" />
</Button>
</PopoverTrigger>
<PopoverContent className="pointer-events-auto w-80 text-xs">
<div className="flex flex-col gap-2">
<div className="text-sm">
{t("wizard.step1.classificationTypeDesc")}
</div>
<div className="mt-3 flex items-center text-primary">
<a
href={getLocaleDocUrl(
"configuration/custom_classification/object_classification#classification-type",
)}
target="_blank"
rel="noopener noreferrer"
className="inline cursor-pointer"
>
{t("readTheDocumentation", { ns: "common" })}
<LuExternalLink className="ml-2 inline-flex size-3" />
</a>
</div>
</div>
</PopoverContent>
</Popover>
</div>
<FormControl>
<RadioGroup
onValueChange={field.onChange}
defaultValue={field.value}
className="flex flex-col gap-4 pt-2"
>
<div className="flex items-center gap-2">
<RadioGroupItem
className={
watchedObjectType === "sub_label"
? "bg-selected from-selected/50 to-selected/90 text-selected"
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
}
id="sub_label"
value="sub_label"
/>
<Label className="cursor-pointer" htmlFor="sub_label">
{t("wizard.step1.classificationSubLabel")}
</Label>
</div>
<div className="flex items-center gap-2">
<RadioGroupItem
className={
watchedObjectType === "attribute"
? "bg-selected from-selected/50 to-selected/90 text-selected"
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
}
id="attribute"
value="attribute"
/>
<Label className="cursor-pointer" htmlFor="attribute">
{t("wizard.step1.classificationAttribute")}
</Label>
</div>
</RadioGroup>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
</>
)}
<div className="space-y-2">
<div className="flex items-center justify-between">
<div className="flex items-center gap-1">
<FormLabel className="text-primary-variant">
{t("wizard.step1.classes")}
</FormLabel>
<Popover>
<PopoverTrigger asChild>
<Button variant="ghost" size="sm" className="h-4 w-4 p-0">
<LuInfo className="size-3" />
</Button>
</PopoverTrigger>
<PopoverContent className="pointer-events-auto w-80 text-xs">
<div className="flex flex-col gap-2">
<div className="text-sm">
{watchedModelType === "state"
? t("wizard.step1.classesStateDesc")
: t("wizard.step1.classesObjectDesc")}
</div>
<div className="mt-3 flex items-center text-primary">
<a
href={getLocaleDocUrl(
watchedModelType === "state"
? "configuration/custom_classification/state_classification"
: "configuration/custom_classification/object_classification",
)}
target="_blank"
rel="noopener noreferrer"
className="inline cursor-pointer"
>
{t("readTheDocumentation", { ns: "common" })}
<LuExternalLink className="ml-2 inline-flex size-3" />
</a>
</div>
</div>
</PopoverContent>
</Popover>
</div>
<Button
type="button"
variant="secondary"
className="size-6 rounded-md bg-secondary-foreground p-1 text-background"
onClick={handleAddClass}
>
<LuPlus />
</Button>
</div>
<div className="space-y-2">
{watchedClasses.map((_, index) => (
<FormField
key={index}
control={form.control}
name={`classes.${index}`}
render={({ field }) => (
<FormItem>
<FormControl>
<div className="flex items-center gap-2">
<Input
className="h-8"
placeholder={t("wizard.step1.classPlaceholder")}
{...field}
/>
{watchedClasses.length > 1 && (
<Button
type="button"
variant="ghost"
size="sm"
className="h-8 w-8 p-0"
onClick={() => handleRemoveClass(index)}
>
<LuX className="size-4" />
</Button>
)}
</div>
</FormControl>
</FormItem>
)}
/>
))}
</div>
{form.formState.errors.classes && (
<p className="text-sm font-medium text-destructive">
{form.formState.errors.classes.message}
</p>
)}
</div>
</form>
</Form>
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
<Button type="button" onClick={onCancel} className="sm:flex-1">
{t("button.cancel", { ns: "common" })}
</Button>
<Button
type="button"
onClick={form.handleSubmit(onSubmit)}
variant="select"
className="flex items-center justify-center gap-2 sm:flex-1"
disabled={!form.formState.isValid}
>
{t("button.continue", { ns: "common" })}
</Button>
</div>
</div>
);
}

View File

@ -0,0 +1,479 @@
import { Button } from "@/components/ui/button";
import { useTranslation } from "react-i18next";
import { useState, useMemo, useRef, useCallback, useEffect } from "react";
import useSWR from "swr";
import { FrigateConfig } from "@/types/frigateConfig";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/ui/popover";
import { LuX, LuPlus } from "react-icons/lu";
import { Stage, Layer, Rect, Transformer } from "react-konva";
import Konva from "konva";
import { useResizeObserver } from "@/hooks/resize-observer";
import { useApiHost } from "@/api";
import { resolveCameraName } from "@/hooks/use-camera-friendly-name";
import Heading from "@/components/ui/heading";
import { isMobile } from "react-device-detect";
import { cn } from "@/lib/utils";
export type CameraAreaConfig = {
camera: string;
crop: [number, number, number, number];
};
export type Step2FormData = {
cameraAreas: CameraAreaConfig[];
};
type Step2StateAreaProps = {
initialData?: Partial<Step2FormData>;
onNext: (data: Step2FormData) => void;
onBack: () => void;
};
export default function Step2StateArea({
initialData,
onNext,
onBack,
}: Step2StateAreaProps) {
const { t } = useTranslation(["views/classificationModel"]);
const { data: config } = useSWR<FrigateConfig>("config");
const apiHost = useApiHost();
const [cameraAreas, setCameraAreas] = useState<CameraAreaConfig[]>(
initialData?.cameraAreas || [],
);
const [selectedCameraIndex, setSelectedCameraIndex] = useState<number>(0);
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
const [imageLoaded, setImageLoaded] = useState(false);
const containerRef = useRef<HTMLDivElement>(null);
const imageRef = useRef<HTMLImageElement>(null);
const stageRef = useRef<Konva.Stage>(null);
const rectRef = useRef<Konva.Rect>(null);
const transformerRef = useRef<Konva.Transformer>(null);
const [{ width: containerWidth }] = useResizeObserver(containerRef);
const availableCameras = useMemo(() => {
if (!config) return [];
const selectedCameraNames = cameraAreas.map((ca) => ca.camera);
return Object.entries(config.cameras)
.sort()
.filter(
([name, cam]) =>
cam.enabled &&
cam.enabled_in_config &&
!selectedCameraNames.includes(name),
)
.map(([name]) => ({
name,
displayName: resolveCameraName(config, name),
}));
}, [config, cameraAreas]);
const selectedCamera = useMemo(() => {
if (cameraAreas.length === 0) return null;
return cameraAreas[selectedCameraIndex];
}, [cameraAreas, selectedCameraIndex]);
const selectedCameraConfig = useMemo(() => {
if (!config || !selectedCamera) return null;
return config.cameras[selectedCamera.camera];
}, [config, selectedCamera]);
const imageSize = useMemo(() => {
if (!containerWidth || !selectedCameraConfig) {
return { width: 0, height: 0 };
}
const containerAspectRatio = 16 / 9;
const containerHeight = containerWidth / containerAspectRatio;
const cameraAspectRatio =
selectedCameraConfig.detect.width / selectedCameraConfig.detect.height;
// Fit camera within 16:9 container
let imageWidth, imageHeight;
if (cameraAspectRatio > containerAspectRatio) {
imageWidth = containerWidth;
imageHeight = imageWidth / cameraAspectRatio;
} else {
imageHeight = containerHeight;
imageWidth = imageHeight * cameraAspectRatio;
}
return { width: imageWidth, height: imageHeight };
}, [containerWidth, selectedCameraConfig]);
const handleAddCamera = useCallback(
(cameraName: string) => {
// Calculate a square crop in pixel space
const camera = config?.cameras[cameraName];
if (!camera) return;
const cameraAspect = camera.detect.width / camera.detect.height;
const cropSize = 0.3;
let x1, y1, x2, y2;
if (cameraAspect >= 1) {
const pixelSize = cropSize * camera.detect.height;
const normalizedWidth = pixelSize / camera.detect.width;
x1 = (1 - normalizedWidth) / 2;
y1 = (1 - cropSize) / 2;
x2 = x1 + normalizedWidth;
y2 = y1 + cropSize;
} else {
const pixelSize = cropSize * camera.detect.width;
const normalizedHeight = pixelSize / camera.detect.height;
x1 = (1 - cropSize) / 2;
y1 = (1 - normalizedHeight) / 2;
x2 = x1 + cropSize;
y2 = y1 + normalizedHeight;
}
const newArea: CameraAreaConfig = {
camera: cameraName,
crop: [x1, y1, x2, y2],
};
setCameraAreas([...cameraAreas, newArea]);
setSelectedCameraIndex(cameraAreas.length);
setIsPopoverOpen(false);
},
[cameraAreas, config],
);
const handleRemoveCamera = useCallback(
(index: number) => {
const newAreas = cameraAreas.filter((_, i) => i !== index);
setCameraAreas(newAreas);
if (selectedCameraIndex >= newAreas.length) {
setSelectedCameraIndex(Math.max(0, newAreas.length - 1));
}
},
[cameraAreas, selectedCameraIndex],
);
const handleCropChange = useCallback(
(crop: [number, number, number, number]) => {
const newAreas = [...cameraAreas];
newAreas[selectedCameraIndex] = {
...newAreas[selectedCameraIndex],
crop,
};
setCameraAreas(newAreas);
},
[cameraAreas, selectedCameraIndex],
);
useEffect(() => {
setImageLoaded(false);
}, [selectedCamera]);
useEffect(() => {
const rect = rectRef.current;
const transformer = transformerRef.current;
if (
rect &&
transformer &&
selectedCamera &&
imageSize.width > 0 &&
imageLoaded
) {
rect.scaleX(1);
rect.scaleY(1);
transformer.nodes([rect]);
transformer.getLayer()?.batchDraw();
}
}, [selectedCamera, imageSize, imageLoaded]);
const handleRectChange = useCallback(() => {
const rect = rectRef.current;
if (rect && imageSize.width > 0) {
const actualWidth = rect.width() * rect.scaleX();
const actualHeight = rect.height() * rect.scaleY();
// Average dimensions to maintain perfect square
const size = (actualWidth + actualHeight) / 2;
rect.width(size);
rect.height(size);
rect.scaleX(1);
rect.scaleY(1);
const x1 = rect.x() / imageSize.width;
const y1 = rect.y() / imageSize.height;
const x2 = (rect.x() + size) / imageSize.width;
const y2 = (rect.y() + size) / imageSize.height;
handleCropChange([x1, y1, x2, y2]);
}
}, [imageSize, handleCropChange]);
const handleContinue = useCallback(() => {
onNext({ cameraAreas });
}, [cameraAreas, onNext]);
const canContinue = cameraAreas.length > 0;
return (
<div className="flex flex-col gap-4">
<div
className={cn(
"flex gap-4 overflow-hidden",
isMobile ? "flex-col" : "flex-row",
)}
>
<div
className={cn(
"flex flex-shrink-0 flex-col gap-2 overflow-y-auto rounded-lg bg-secondary p-4",
isMobile ? "w-full" : "w-64",
)}
>
<div className="flex items-center justify-between">
<h3 className="text-sm font-medium">{t("wizard.step2.cameras")}</h3>
{availableCameras.length > 0 ? (
<Popover
open={isPopoverOpen}
onOpenChange={setIsPopoverOpen}
modal={true}
>
<PopoverTrigger asChild>
<Button
type="button"
variant="secondary"
className="size-6 rounded-md bg-secondary-foreground p-1 text-background"
aria-label="Add camera"
>
<LuPlus />
</Button>
</PopoverTrigger>
<PopoverContent
className="scrollbar-container w-64 border bg-background p-3 shadow-lg"
align="start"
sideOffset={5}
onOpenAutoFocus={(e) => e.preventDefault()}
>
<div className="flex flex-col gap-2">
<Heading as="h4" className="text-sm text-primary-variant">
{t("wizard.step2.selectCamera")}
</Heading>
<div className="scrollbar-container flex max-h-[30vh] flex-col gap-1 overflow-y-auto">
{availableCameras.map((cam) => (
<Button
key={cam.name}
type="button"
variant="ghost"
size="sm"
className="h-auto justify-start p-2 capitalize text-primary"
onClick={() => {
handleAddCamera(cam.name);
}}
>
{cam.displayName}
</Button>
))}
</div>
</div>
</PopoverContent>
</Popover>
) : (
<Button
variant="secondary"
className="size-6 cursor-not-allowed rounded-md bg-muted p-1 text-muted-foreground"
disabled
>
<LuPlus />
</Button>
)}
</div>
<div className="flex flex-col gap-1">
{cameraAreas.map((area, index) => {
const isSelected = index === selectedCameraIndex;
const displayName = resolveCameraName(config, area.camera);
return (
<div
key={area.camera}
className={`flex items-center justify-between rounded-md p-2 ${
isSelected
? "bg-selected/20 ring-1 ring-selected"
: "hover:bg-secondary/50"
} cursor-pointer`}
onClick={() => setSelectedCameraIndex(index)}
>
<span className="text-sm capitalize">{displayName}</span>
<Button
type="button"
variant="ghost"
size="sm"
className="h-6 w-6 p-0"
onClick={(e) => {
e.stopPropagation();
handleRemoveCamera(index);
}}
>
<LuX className="size-4" />
</Button>
</div>
);
})}
</div>
{cameraAreas.length === 0 && (
<div className="flex flex-1 items-center justify-center text-center text-sm text-muted-foreground">
{t("wizard.step2.noCameras")}
</div>
)}
</div>
<div className="flex flex-1 items-center justify-center overflow-hidden rounded-lg p-4">
<div
ref={containerRef}
className="flex items-center justify-center"
style={{
width: "100%",
aspectRatio: "16 / 9",
maxHeight: "100%",
}}
>
{selectedCamera && selectedCameraConfig && imageSize.width > 0 ? (
<div
style={{
width: imageSize.width,
height: imageSize.height,
position: "relative",
}}
>
<img
ref={imageRef}
src={`${apiHost}api/${selectedCamera.camera}/latest.jpg?h=500`}
alt={resolveCameraName(config, selectedCamera.camera)}
className="h-full w-full object-contain"
onLoad={() => setImageLoaded(true)}
/>
<Stage
ref={stageRef}
width={imageSize.width}
height={imageSize.height}
className="absolute inset-0"
>
<Layer>
<Rect
ref={rectRef}
x={selectedCamera.crop[0] * imageSize.width}
y={selectedCamera.crop[1] * imageSize.height}
width={
(selectedCamera.crop[2] - selectedCamera.crop[0]) *
imageSize.width
}
height={
(selectedCamera.crop[3] - selectedCamera.crop[1]) *
imageSize.height
}
stroke="#3b82f6"
strokeWidth={2}
fill="rgba(59, 130, 246, 0.1)"
draggable
dragBoundFunc={(pos) => {
const rect = rectRef.current;
if (!rect) return pos;
const size = rect.width();
const x = Math.max(
0,
Math.min(pos.x, imageSize.width - size),
);
const y = Math.max(
0,
Math.min(pos.y, imageSize.height - size),
);
return { x, y };
}}
onDragEnd={handleRectChange}
onTransformEnd={handleRectChange}
/>
<Transformer
ref={transformerRef}
rotateEnabled={false}
enabledAnchors={[
"top-left",
"top-right",
"bottom-left",
"bottom-right",
]}
boundBoxFunc={(_oldBox, newBox) => {
const minSize = 50;
const maxSize = Math.min(
imageSize.width,
imageSize.height,
);
// Clamp dimensions to stage bounds first
const clampedWidth = Math.max(
minSize,
Math.min(newBox.width, maxSize),
);
const clampedHeight = Math.max(
minSize,
Math.min(newBox.height, maxSize),
);
// Enforce square using average
const size = (clampedWidth + clampedHeight) / 2;
// Clamp position to keep square within bounds
const x = Math.max(
0,
Math.min(newBox.x, imageSize.width - size),
);
const y = Math.max(
0,
Math.min(newBox.y, imageSize.height - size),
);
return {
...newBox,
x,
y,
width: size,
height: size,
};
}}
/>
</Layer>
</Stage>
</div>
) : (
<div className="flex items-center justify-center text-muted-foreground">
{t("wizard.step2.selectCameraPrompt")}
</div>
)}
</div>
</div>
</div>
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
<Button type="button" onClick={onBack} className="sm:flex-1">
{t("button.back", { ns: "common" })}
</Button>
<Button
type="button"
onClick={handleContinue}
variant="select"
className="flex items-center justify-center gap-2 sm:flex-1"
disabled={!canContinue}
>
{t("button.continue", { ns: "common" })}
</Button>
</div>
</div>
);
}

View File

@ -0,0 +1,444 @@
import { Button } from "@/components/ui/button";
import { useTranslation } from "react-i18next";
import { useState, useEffect, useCallback, useMemo } from "react";
import ActivityIndicator from "@/components/indicators/activity-indicator";
import axios from "axios";
import { toast } from "sonner";
import { Step1FormData } from "./Step1NameAndDefine";
import { Step2FormData } from "./Step2StateArea";
import useSWR from "swr";
import { baseUrl } from "@/api/baseUrl";
import { isMobile } from "react-device-detect";
import { cn } from "@/lib/utils";
export type Step3FormData = {
examplesGenerated: boolean;
imageClassifications?: { [imageName: string]: string };
};
type Step3ChooseExamplesProps = {
step1Data: Step1FormData;
step2Data?: Step2FormData;
initialData?: Partial<Step3FormData>;
onClose: () => void;
onBack: () => void;
};
export default function Step3ChooseExamples({
step1Data,
step2Data,
initialData,
onClose,
onBack,
}: Step3ChooseExamplesProps) {
const { t } = useTranslation(["views/classificationModel"]);
const [isGenerating, setIsGenerating] = useState(false);
const [hasGenerated, setHasGenerated] = useState(
initialData?.examplesGenerated || false,
);
const [imageClassifications, setImageClassifications] = useState<{
[imageName: string]: string;
}>(initialData?.imageClassifications || {});
const [isTraining, setIsTraining] = useState(false);
const [isProcessing, setIsProcessing] = useState(false);
const [currentClassIndex, setCurrentClassIndex] = useState(0);
const [selectedImages, setSelectedImages] = useState<Set<string>>(new Set());
const { data: trainImages, mutate: refreshTrainImages } = useSWR<string[]>(
hasGenerated ? `classification/${step1Data.modelName}/train` : null,
);
const unknownImages = useMemo(() => {
if (!trainImages) return [];
return trainImages;
}, [trainImages]);
const toggleImageSelection = useCallback((imageName: string) => {
setSelectedImages((prev) => {
const newSet = new Set(prev);
if (newSet.has(imageName)) {
newSet.delete(imageName);
} else {
newSet.add(imageName);
}
return newSet;
});
}, []);
// Get all classes (excluding "none" - it will be auto-assigned)
const allClasses = useMemo(() => {
return [...step1Data.classes];
}, [step1Data.classes]);
const currentClass = allClasses[currentClassIndex];
const processClassificationsAndTrain = useCallback(
async (classifications: { [imageName: string]: string }) => {
// Step 1: Create config for the new model
const modelConfig: {
enabled: boolean;
name: string;
threshold: number;
state_config?: {
cameras: Record<string, { crop: number[] }>;
motion: boolean;
};
object_config?: { objects: string[]; classification_type: string };
} = {
enabled: true,
name: step1Data.modelName,
threshold: 0.8,
};
if (step1Data.modelType === "state") {
// State model config
const cameras: Record<string, { crop: number[] }> = {};
step2Data?.cameraAreas.forEach((area) => {
cameras[area.camera] = {
crop: area.crop,
};
});
modelConfig.state_config = {
cameras,
motion: true,
};
} else {
// Object model config
modelConfig.object_config = {
objects: step1Data.objectLabel ? [step1Data.objectLabel] : [],
classification_type: step1Data.objectType || "sub_label",
} as { objects: string[]; classification_type: string };
}
// Update config via config API
await axios.put("/config/set", {
requires_restart: 0,
update_topic: `config/classification/custom/${step1Data.modelName}`,
config_data: {
classification: {
custom: {
[step1Data.modelName]: modelConfig,
},
},
},
});
// Step 2: Classify each image by moving it to the correct category folder
const categorizePromises = Object.entries(classifications).map(
([imageName, className]) => {
if (!className) return Promise.resolve();
return axios.post(
`/classification/${step1Data.modelName}/dataset/categorize`,
{
training_file: imageName,
category: className === "none" ? "none" : className,
},
);
},
);
await Promise.all(categorizePromises);
// Step 3: Kick off training
await axios.post(`/classification/${step1Data.modelName}/train`);
toast.success(t("wizard.step3.trainingStarted"));
setIsTraining(true);
},
[step1Data, step2Data, t],
);
const handleContinueClassification = useCallback(async () => {
// Mark selected images with current class
const newClassifications = { ...imageClassifications };
selectedImages.forEach((imageName) => {
newClassifications[imageName] = currentClass;
});
// Check if we're on the last class to select
const isLastClass = currentClassIndex === allClasses.length - 1;
if (isLastClass) {
// Assign remaining unclassified images
unknownImages.slice(0, 24).forEach((imageName) => {
if (!newClassifications[imageName]) {
// For state models with 2 classes, assign to the last class
// For object models, assign to "none"
if (step1Data.modelType === "state" && allClasses.length === 2) {
newClassifications[imageName] = allClasses[allClasses.length - 1];
} else {
newClassifications[imageName] = "none";
}
}
});
// All done, trigger training immediately
setImageClassifications(newClassifications);
setIsProcessing(true);
try {
await processClassificationsAndTrain(newClassifications);
} catch (error) {
const axiosError = error as {
response?: { data?: { message?: string; detail?: string } };
message?: string;
};
const errorMessage =
axiosError.response?.data?.message ||
axiosError.response?.data?.detail ||
axiosError.message ||
"Failed to classify images";
toast.error(
t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
);
setIsProcessing(false);
}
} else {
// Move to next class
setImageClassifications(newClassifications);
setCurrentClassIndex((prev) => prev + 1);
setSelectedImages(new Set());
}
}, [
selectedImages,
currentClass,
currentClassIndex,
allClasses,
imageClassifications,
unknownImages,
step1Data,
processClassificationsAndTrain,
t,
]);
const generateExamples = useCallback(async () => {
setIsGenerating(true);
try {
if (step1Data.modelType === "state") {
// For state models, use cameras and crop areas
if (!step2Data?.cameraAreas || step2Data.cameraAreas.length === 0) {
toast.error(t("wizard.step3.errors.noCameras"));
setIsGenerating(false);
return;
}
const cameras: { [key: string]: [number, number, number, number] } = {};
step2Data.cameraAreas.forEach((area) => {
cameras[area.camera] = area.crop;
});
await axios.post("/classification/generate_examples/state", {
model_name: step1Data.modelName,
cameras,
});
} else {
// For object models, use label
if (!step1Data.objectLabel) {
toast.error(t("wizard.step3.errors.noObjectLabel"));
setIsGenerating(false);
return;
}
// For now, use all enabled cameras
// TODO: In the future, we might want to let users select specific cameras
await axios.post("/classification/generate_examples/object", {
model_name: step1Data.modelName,
label: step1Data.objectLabel,
});
}
setHasGenerated(true);
toast.success(t("wizard.step3.generateSuccess"));
await refreshTrainImages();
} catch (error) {
const axiosError = error as {
response?: { data?: { message?: string; detail?: string } };
message?: string;
};
const errorMessage =
axiosError.response?.data?.message ||
axiosError.response?.data?.detail ||
axiosError.message ||
"Failed to generate examples";
toast.error(
t("wizard.step3.errors.generateFailed", { error: errorMessage }),
);
} finally {
setIsGenerating(false);
}
}, [step1Data, step2Data, t, refreshTrainImages]);
useEffect(() => {
if (!hasGenerated && !isGenerating) {
generateExamples();
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const handleContinue = useCallback(async () => {
setIsProcessing(true);
try {
await processClassificationsAndTrain(imageClassifications);
} catch (error) {
const axiosError = error as {
response?: { data?: { message?: string; detail?: string } };
message?: string;
};
const errorMessage =
axiosError.response?.data?.message ||
axiosError.response?.data?.detail ||
axiosError.message ||
"Failed to classify images";
toast.error(
t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
);
setIsProcessing(false);
}
}, [imageClassifications, processClassificationsAndTrain, t]);
const unclassifiedImages = useMemo(() => {
if (!unknownImages) return [];
const images = unknownImages.slice(0, 24);
// Only filter if we have any classifications
if (Object.keys(imageClassifications).length === 0) {
return images;
}
return images.filter((img) => !imageClassifications[img]);
}, [unknownImages, imageClassifications]);
const allImagesClassified = useMemo(() => {
return unclassifiedImages.length === 0;
}, [unclassifiedImages]);
return (
<div className="flex flex-col gap-6">
{isTraining ? (
<div className="flex flex-col items-center gap-6 py-12">
<ActivityIndicator className="size-12" />
<div className="text-center">
<h3 className="mb-2 text-lg font-medium">
{t("wizard.step3.training.title")}
</h3>
<p className="text-sm text-muted-foreground">
{t("wizard.step3.training.description")}
</p>
</div>
<Button onClick={onClose} variant="select" className="mt-4">
{t("button.close", { ns: "common" })}
</Button>
</div>
) : isGenerating ? (
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
<ActivityIndicator className="size-12" />
<div className="text-center">
<h3 className="mb-2 text-lg font-medium">
{t("wizard.step3.generating.title")}
</h3>
<p className="text-sm text-muted-foreground">
{t("wizard.step3.generating.description")}
</p>
</div>
</div>
) : hasGenerated ? (
<div className="flex flex-col gap-4">
{!allImagesClassified && (
<div className="text-center">
<h3 className="text-lg font-medium">
{t("wizard.step3.selectImagesPrompt", {
className: currentClass,
})}
</h3>
<p className="text-sm text-muted-foreground">
{t("wizard.step3.selectImagesDescription")}
</p>
</div>
)}
<div
className={cn(
"rounded-lg bg-secondary/30 p-4",
isMobile && "max-h-[60vh] overflow-y-auto",
)}
>
{!unknownImages || unknownImages.length === 0 ? (
<div className="flex h-[40vh] flex-col items-center justify-center gap-4">
<p className="text-muted-foreground">
{t("wizard.step3.noImages")}
</p>
<Button onClick={generateExamples} variant="select">
{t("wizard.step3.retryGenerate")}
</Button>
</div>
) : allImagesClassified && isProcessing ? (
<div className="flex h-[40vh] flex-col items-center justify-center gap-4">
<ActivityIndicator className="size-12" />
<p className="text-lg font-medium">
{t("wizard.step3.classifying")}
</p>
</div>
) : (
<div className="grid grid-cols-2 gap-4 sm:grid-cols-6">
{unclassifiedImages.map((imageName, index) => {
const isSelected = selectedImages.has(imageName);
return (
<div
key={imageName}
className={cn(
"aspect-square cursor-pointer overflow-hidden rounded-lg border-2 bg-background transition-all",
isSelected && "border-selected ring-2 ring-selected",
)}
onClick={() => toggleImageSelection(imageName)}
>
<img
src={`${baseUrl}clips/${step1Data.modelName}/train/${imageName}`}
alt={`Example ${index + 1}`}
className="h-full w-full object-cover"
/>
</div>
);
})}
</div>
)}
</div>
</div>
) : (
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
<p className="text-sm text-destructive">
{t("wizard.step3.errors.generationFailed")}
</p>
<Button onClick={generateExamples} variant="select">
{t("wizard.step3.retryGenerate")}
</Button>
</div>
)}
{!isTraining && (
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
<Button type="button" onClick={onBack} className="sm:flex-1">
{t("button.back", { ns: "common" })}
</Button>
<Button
type="button"
onClick={
allImagesClassified
? handleContinue
: handleContinueClassification
}
variant="select"
className="flex items-center justify-center gap-2 sm:flex-1"
disabled={!hasGenerated || isGenerating || isProcessing}
>
{isProcessing && <ActivityIndicator className="size-4" />}
{t("button.continue", { ns: "common" })}
</Button>
</div>
)}
</div>
);
}

View File

@ -10,11 +10,14 @@ import {
CustomClassificationModelConfig, CustomClassificationModelConfig,
FrigateConfig, FrigateConfig,
} from "@/types/frigateConfig"; } from "@/types/frigateConfig";
import { useMemo, useState } from "react"; import { useEffect, useMemo, useState } from "react";
import { isMobile } from "react-device-detect"; import { isMobile } from "react-device-detect";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { FaFolderPlus } from "react-icons/fa"; import { FaFolderPlus } from "react-icons/fa";
import { MdModelTraining } from "react-icons/md";
import useSWR from "swr"; import useSWR from "swr";
import Heading from "@/components/ui/heading";
import { useOverlayState } from "@/hooks/use-overlay-state";
const allModelTypes = ["objects", "states"] as const; const allModelTypes = ["objects", "states"] as const;
type ModelType = (typeof allModelTypes)[number]; type ModelType = (typeof allModelTypes)[number];
@ -26,11 +29,24 @@ export default function ModelSelectionView({
onClick, onClick,
}: ModelSelectionViewProps) { }: ModelSelectionViewProps) {
const { t } = useTranslation(["views/classificationModel"]); const { t } = useTranslation(["views/classificationModel"]);
const [page, setPage] = useState<ModelType>("objects"); const [page, setPage] = useOverlayState<ModelType>("objects", "objects");
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); const [pageToggle, setPageToggle] = useOptimisticState(
const { data: config } = useSWR<FrigateConfig>("config", { page || "objects",
revalidateOnFocus: false, setPage,
}); 100,
);
const { data: config, mutate: refreshConfig } = useSWR<FrigateConfig>(
"config",
{
revalidateOnFocus: false,
},
);
// title
useEffect(() => {
document.title = t("documentTitle");
}, [t]);
// data // data
@ -64,15 +80,15 @@ export default function ModelSelectionView({
return <ActivityIndicator />; return <ActivityIndicator />;
} }
if (classificationConfigs.length == 0) {
return <div>You need to setup a custom model configuration.</div>;
}
return ( return (
<div className="flex size-full flex-col p-2"> <div className="flex size-full flex-col p-2">
<ClassificationModelWizardDialog <ClassificationModelWizardDialog
open={newModel} open={newModel}
onClose={() => setNewModel(false)} defaultModelType={pageToggle === "objects" ? "object" : "state"}
onClose={() => {
setNewModel(false);
refreshConfig();
}}
/> />
<div className="flex h-12 w-full items-center justify-between"> <div className="flex h-12 w-full items-center justify-between">
@ -84,7 +100,6 @@ export default function ModelSelectionView({
value={pageToggle} value={pageToggle}
onValueChange={(value: ModelType) => { onValueChange={(value: ModelType) => {
if (value) { if (value) {
// Restrict viewer navigation
setPageToggle(value); setPageToggle(value);
} }
}} }}
@ -117,13 +132,46 @@ export default function ModelSelectionView({
</div> </div>
</div> </div>
<div className="flex size-full gap-2 p-2"> <div className="flex size-full gap-2 p-2">
{selectedClassificationConfigs.map((config) => ( {selectedClassificationConfigs.length === 0 ? (
<ModelCard <NoModelsView
key={config.name} onCreateModel={() => setNewModel(true)}
config={config} modelType={pageToggle}
onClick={() => onClick(config)}
/> />
))} ) : (
selectedClassificationConfigs.map((config) => (
<ModelCard
key={config.name}
config={config}
onClick={() => onClick(config)}
/>
))
)}
</div>
</div>
);
}
function NoModelsView({
onCreateModel,
modelType,
}: {
onCreateModel: () => void;
modelType: ModelType;
}) {
const { t } = useTranslation(["views/classificationModel"]);
const typeKey = modelType === "objects" ? "object" : "state";
return (
<div className="flex size-full items-center justify-center">
<div className="flex flex-col items-center gap-2">
<MdModelTraining className="size-8" />
<Heading as="h4">{t(`noModels.${typeKey}.title`)}</Heading>
<div className="mb-3 text-center text-secondary-foreground">
{t(`noModels.${typeKey}.description`)}
</div>
<Button size="sm" variant="select" onClick={onCreateModel}>
{t(`noModels.${typeKey}.buttonText`)}
</Button>
</div> </div>
</div> </div>
); );
@ -139,13 +187,17 @@ function ModelCard({ config, onClick }: ModelCardProps) {
}>(`classification/${config.name}/dataset`, { revalidateOnFocus: false }); }>(`classification/${config.name}/dataset`, { revalidateOnFocus: false });
const coverImage = useMemo(() => { const coverImage = useMemo(() => {
if (!dataset?.length) { if (!dataset) {
return undefined; return undefined;
} }
const keys = Object.keys(dataset).filter((key) => key != "none"); const keys = Object.keys(dataset).filter((key) => key != "none");
const selectedKey = keys[0]; const selectedKey = keys[0];
if (!dataset[selectedKey]) {
return undefined;
}
return { return {
name: selectedKey, name: selectedKey,
img: dataset[selectedKey][0], img: dataset[selectedKey][0],

View File

@ -642,6 +642,7 @@ function DatasetGrid({
filepath: `clips/${modelName}/dataset/${categoryName}/${image}`, filepath: `clips/${modelName}/dataset/${categoryName}/${image}`,
name: "", name: "",
}} }}
showArea={false}
selected={selectedImages.includes(image)} selected={selectedImages.includes(image)}
i18nLibrary="views/classificationModel" i18nLibrary="views/classificationModel"
onClick={(data, _) => onClickImages([data.filename], true)} onClick={(data, _) => onClickImages([data.filename], true)}