diff --git a/.cursor/rules/frontend-always-use-translation-files.mdc b/.cursor/rules/frontend-always-use-translation-files.mdc new file mode 100644 index 000000000..35034069b --- /dev/null +++ b/.cursor/rules/frontend-always-use-translation-files.mdc @@ -0,0 +1,6 @@ +--- +globs: ["**/*.ts", "**/*.tsx"] +alwaysApply: false +--- + +Never write strings in the frontend directly, always write to and reference the relevant translations file. \ No newline at end of file diff --git a/docs/docs/configuration/custom_classification/object_classification.md b/docs/docs/configuration/custom_classification/object_classification.md index 9465716b7..983fce852 100644 --- a/docs/docs/configuration/custom_classification/object_classification.md +++ b/docs/docs/configuration/custom_classification/object_classification.md @@ -12,7 +12,18 @@ Object classification models are lightweight and run very fast on CPU. Inference Training the model does briefly use a high amount of system resources for about 1–3 minutes per training run. On lower-power devices, training may take longer. When running the `-tensorrt` image, Nvidia GPUs will automatically be used to accelerate training. -### Sub label vs Attribute +## Classes + +Classes are the categories your model will learn to distinguish between. Each class represents a distinct visual category that the model will predict. + +For object classification: + +- Define classes that represent different types or attributes of the detected object +- Examples: For `person` objects, classes might be `delivery_person`, `resident`, `stranger` +- Include a `none` class for objects that don't fit any specific category +- Keep classes visually distinct to improve accuracy + +### Classification Type - **Sub label**: diff --git a/docs/docs/configuration/custom_classification/state_classification.md b/docs/docs/configuration/custom_classification/state_classification.md index afc79eff8..c22661f26 100644 --- a/docs/docs/configuration/custom_classification/state_classification.md +++ b/docs/docs/configuration/custom_classification/state_classification.md @@ -12,6 +12,17 @@ State classification models are lightweight and run very fast on CPU. Inference Training the model does briefly use a high amount of system resources for about 1–3 minutes per training run. On lower-power devices, training may take longer. When running the `-tensorrt` image, Nvidia GPUs will automatically be used to accelerate training. +## Classes + +Classes are the different states an area on your camera can be in. Each class represents a distinct visual state that the model will learn to recognize. + +For state classification: + +- Define classes that represent mutually exclusive states +- Examples: `open` and `closed` for a garage door, `on` and `off` for lights +- Use at least 2 classes (typically binary states work best) +- Keep class names clear and descriptive + ## Example use cases - **Door state**: Detect if a garage or front door is open vs closed. diff --git a/frigate/api/app.py b/frigate/api/app.py index f84190407..5d09ecf00 100644 --- a/frigate/api/app.py +++ b/frigate/api/app.py @@ -387,20 +387,28 @@ def config_set(request: Request, body: AppConfigSetBody): old_config: FrigateConfig = request.app.frigate_config request.app.frigate_config = config - if body.update_topic and body.update_topic.startswith("config/cameras/"): - _, _, camera, field = body.update_topic.split("/") + if body.update_topic: + if body.update_topic.startswith("config/cameras/"): + _, _, camera, field = body.update_topic.split("/") - if field == "add": - settings = config.cameras[camera] - elif field == "remove": - settings = old_config.cameras[camera] + if field == "add": + settings = config.cameras[camera] + elif field == "remove": + settings = old_config.cameras[camera] + else: + settings = config.get_nested_object(body.update_topic) + + request.app.config_publisher.publish_update( + CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera), + settings, + ) else: + # Handle nested config updates (e.g., config/classification/custom/{name}) settings = config.get_nested_object(body.update_topic) - - request.app.config_publisher.publish_update( - CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera), - settings, - ) + if settings: + request.app.config_publisher.publisher.publish( + body.update_topic, settings + ) return JSONResponse( content=( diff --git a/frigate/api/classification.py b/frigate/api/classification.py index 623ceba32..e9052097a 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -3,7 +3,9 @@ import datetime import logging import os +import random import shutil +import string from typing import Any import cv2 @@ -17,6 +19,8 @@ from frigate.api.auth import require_role from frigate.api.defs.request.classification_body import ( AudioTranscriptionBody, DeleteFaceImagesBody, + GenerateObjectExamplesBody, + GenerateStateExamplesBody, RenameFaceBody, ) from frigate.api.defs.response.classification_response import ( @@ -30,6 +34,10 @@ from frigate.config.camera import DetectConfig from frigate.const import CLIPS_DIR, FACE_DIR from frigate.embeddings import EmbeddingsContext from frigate.models import Event +from frigate.util.classification import ( + collect_object_classification_examples, + collect_state_classification_examples, +) from frigate.util.path import get_event_snapshot logger = logging.getLogger(__name__) @@ -159,8 +167,7 @@ def train_face(request: Request, name: str, body: dict = None): new_name = f"{sanitized_name}-{datetime.datetime.now().timestamp()}.webp" new_file_folder = os.path.join(FACE_DIR, f"{sanitized_name}") - if not os.path.exists(new_file_folder): - os.mkdir(new_file_folder) + os.makedirs(new_file_folder, exist_ok=True) if training_file_name: shutil.move(training_file, os.path.join(new_file_folder, new_name)) @@ -701,13 +708,14 @@ def categorize_classification_image(request: Request, name: str, body: dict = No status_code=404, ) - new_name = f"{category}-{datetime.datetime.now().timestamp()}.png" + random_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6)) + timestamp = datetime.datetime.now().timestamp() + new_name = f"{category}-{timestamp}-{random_id}.png" new_file_folder = os.path.join( CLIPS_DIR, sanitize_filename(name), "dataset", category ) - if not os.path.exists(new_file_folder): - os.mkdir(new_file_folder) + os.makedirs(new_file_folder, exist_ok=True) # use opencv because webp images can not be used to train img = cv2.imread(training_file) @@ -756,3 +764,43 @@ def delete_classification_train_images(request: Request, name: str, body: dict = content=({"success": True, "message": "Successfully deleted faces."}), status_code=200, ) + + +@router.post( + "/classification/generate_examples/state", + response_model=GenericResponse, + dependencies=[Depends(require_role(["admin"]))], + summary="Generate state classification examples", +) +async def generate_state_examples(request: Request, body: GenerateStateExamplesBody): + """Generate examples for state classification.""" + model_name = sanitize_filename(body.model_name) + cameras_normalized = { + camera_name: tuple(crop) + for camera_name, crop in body.cameras.items() + if camera_name in request.app.frigate_config.cameras + } + + collect_state_classification_examples(model_name, cameras_normalized) + + return JSONResponse( + content={"success": True, "message": "Example generation completed"}, + status_code=200, + ) + + +@router.post( + "/classification/generate_examples/object", + response_model=GenericResponse, + dependencies=[Depends(require_role(["admin"]))], + summary="Generate object classification examples", +) +async def generate_object_examples(request: Request, body: GenerateObjectExamplesBody): + """Generate examples for object classification.""" + model_name = sanitize_filename(body.model_name) + collect_object_classification_examples(model_name, body.label) + + return JSONResponse( + content={"success": True, "message": "Example generation completed"}, + status_code=200, + ) diff --git a/frigate/api/defs/request/classification_body.py b/frigate/api/defs/request/classification_body.py index dabff0912..fb6a7dd0f 100644 --- a/frigate/api/defs/request/classification_body.py +++ b/frigate/api/defs/request/classification_body.py @@ -1,17 +1,31 @@ -from typing import List +from typing import Dict, List, Tuple from pydantic import BaseModel, Field class RenameFaceBody(BaseModel): - new_name: str + new_name: str = Field(description="New name for the face") class AudioTranscriptionBody(BaseModel): - event_id: str + event_id: str = Field(description="ID of the event to transcribe audio for") class DeleteFaceImagesBody(BaseModel): ids: List[str] = Field( description="List of image filenames to delete from the face folder" ) + + +class GenerateStateExamplesBody(BaseModel): + model_name: str = Field(description="Name of the classification model") + cameras: Dict[str, Tuple[float, float, float, float]] = Field( + description="Dictionary mapping camera names to normalized crop coordinates in [x1, y1, x2, y2] format (values 0-1)" + ) + + +class GenerateObjectExamplesBody(BaseModel): + model_name: str = Field(description="Name of the classification model") + label: str = Field( + description="Object label to collect examples for (e.g., 'person', 'car')" + ) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 1fb9dfc97..ac6387785 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -53,9 +53,17 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): self.tensor_output_details: dict[str, Any] | None = None self.labelmap: dict[int, str] = {} self.classifications_per_second = EventsPerSecond() - self.inference_speed = InferenceSpeed( - self.metrics.classification_speeds[self.model_config.name] - ) + + if ( + self.metrics + and self.model_config.name in self.metrics.classification_speeds + ): + self.inference_speed = InferenceSpeed( + self.metrics.classification_speeds[self.model_config.name] + ) + else: + self.inference_speed = None + self.last_run = datetime.datetime.now().timestamp() self.__build_detector() @@ -83,12 +91,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): def __update_metrics(self, duration: float) -> None: self.classifications_per_second.update() - self.inference_speed.update(duration) + if self.inference_speed: + self.inference_speed.update(duration) def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): - self.metrics.classification_cps[ - self.model_config.name - ].value = self.classifications_per_second.eps() + if self.metrics and self.model_config.name in self.metrics.classification_cps: + self.metrics.classification_cps[ + self.model_config.name + ].value = self.classifications_per_second.eps() camera = frame_data.get("camera") if camera not in self.model_config.state_config.cameras: @@ -223,9 +233,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): self.detected_objects: dict[str, float] = {} self.labelmap: dict[int, str] = {} self.classifications_per_second = EventsPerSecond() - self.inference_speed = InferenceSpeed( - self.metrics.classification_speeds[self.model_config.name] - ) + + if ( + self.metrics + and self.model_config.name in self.metrics.classification_speeds + ): + self.inference_speed = InferenceSpeed( + self.metrics.classification_speeds[self.model_config.name] + ) + else: + self.inference_speed = None + self.__build_detector() @redirect_output_to_logger(logger, logging.DEBUG) @@ -251,12 +269,14 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): def __update_metrics(self, duration: float) -> None: self.classifications_per_second.update() - self.inference_speed.update(duration) + if self.inference_speed: + self.inference_speed.update(duration) def process_frame(self, obj_data, frame): - self.metrics.classification_cps[ - self.model_config.name - ].value = self.classifications_per_second.eps() + if self.metrics and self.model_config.name in self.metrics.classification_cps: + self.metrics.classification_cps[ + self.model_config.name + ].value = self.classifications_per_second.eps() if obj_data["false_positive"]: return diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 55e3d57ba..fe04d8b17 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -9,6 +9,7 @@ from typing import Any from peewee import DoesNotExist +from frigate.comms.config_updater import ConfigSubscriber from frigate.comms.detections_updater import DetectionSubscriber, DetectionTypeEnum from frigate.comms.embeddings_updater import ( EmbeddingsRequestEnum, @@ -95,6 +96,9 @@ class EmbeddingMaintainer(threading.Thread): CameraConfigUpdateEnum.semantic_search, ], ) + self.classification_config_subscriber = ConfigSubscriber( + "config/classification/custom/" + ) # Configure Frigate DB db = SqliteVecQueueDatabase( @@ -255,6 +259,7 @@ class EmbeddingMaintainer(threading.Thread): """Maintain a SQLite-vec database for semantic search.""" while not self.stop_event.is_set(): self.config_updater.check_for_updates() + self._check_classification_config_updates() self._process_requests() self._process_updates() self._process_recordings_updates() @@ -265,6 +270,7 @@ class EmbeddingMaintainer(threading.Thread): self._process_event_metadata() self.config_updater.stop() + self.classification_config_subscriber.stop() self.event_subscriber.stop() self.event_end_subscriber.stop() self.recordings_subscriber.stop() @@ -275,6 +281,46 @@ class EmbeddingMaintainer(threading.Thread): self.requestor.stop() logger.info("Exiting embeddings maintenance...") + def _check_classification_config_updates(self) -> None: + """Check for classification config updates and add new processors.""" + topic, model_config = self.classification_config_subscriber.check_for_update() + + if topic and model_config: + model_name = topic.split("/")[-1] + self.config.classification.custom[model_name] = model_config + + # Check if processor already exists + for processor in self.realtime_processors: + if isinstance( + processor, + ( + CustomStateClassificationProcessor, + CustomObjectClassificationProcessor, + ), + ): + if processor.model_config.name == model_name: + logger.debug( + f"Classification processor for model {model_name} already exists, skipping" + ) + return + + if model_config.state_config is not None: + processor = CustomStateClassificationProcessor( + self.config, model_config, self.requestor, self.metrics + ) + else: + processor = CustomObjectClassificationProcessor( + self.config, + model_config, + self.event_metadata_publisher, + self.metrics, + ) + + self.realtime_processors.append(processor) + logger.info( + f"Added classification processor for model: {model_name} (type: {type(processor).__name__})" + ) + def _process_requests(self) -> None: """Process embeddings requests""" diff --git a/frigate/util/classification.py b/frigate/util/classification.py index e4133ded4..ab17a9444 100644 --- a/frigate/util/classification.py +++ b/frigate/util/classification.py @@ -2,12 +2,15 @@ import logging import os +import random +from collections import defaultdict import cv2 import numpy as np from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor from frigate.comms.inter_process import InterProcessRequestor +from frigate.config import FfmpegConfig from frigate.const import ( CLIPS_DIR, MODEL_CACHE_DIR, @@ -15,7 +18,10 @@ from frigate.const import ( UPDATE_MODEL_STATE, ) from frigate.log import redirect_output_to_logger +from frigate.models import Event, Recordings, ReviewSegment from frigate.types import ModelStatusTypesEnum +from frigate.util.image import get_image_from_recording +from frigate.util.path import get_event_thumbnail_bytes from frigate.util.process import FrigateProcess BATCH_SIZE = 16 @@ -69,6 +75,7 @@ class ClassificationTrainingProcess(FrigateProcess): logger.info(f"Kicking off classification training for {self.model_name}.") dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset") model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name) + os.makedirs(model_dir, exist_ok=True) num_classes = len( [ d @@ -139,7 +146,6 @@ class ClassificationTrainingProcess(FrigateProcess): f.write(tflite_model) -@staticmethod def kickoff_model_training( embeddingRequestor: EmbeddingsRequestor, model_name: str ) -> None: @@ -172,3 +178,520 @@ def kickoff_model_training( }, ) requestor.stop() + + +@staticmethod +def collect_state_classification_examples( + model_name: str, cameras: dict[str, tuple[float, float, float, float]] +) -> None: + """ + Collect representative state classification examples from review items. + + This function: + 1. Queries review items from specified cameras + 2. Selects 100 balanced timestamps across the data + 3. Extracts keyframes from recordings (cropped to specified regions) + 4. Selects 20 most visually distinct images + 5. Saves them to the dataset directory + + Args: + model_name: Name of the classification model + cameras: Dict mapping camera names to normalized crop coordinates [x1, y1, x2, y2] (0-1) + """ + dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") + temp_dir = os.path.join(dataset_dir, "temp") + os.makedirs(temp_dir, exist_ok=True) + + # Step 1: Get review items for the cameras + camera_names = list(cameras.keys()) + review_items = list( + ReviewSegment.select() + .where(ReviewSegment.camera.in_(camera_names)) + .where(ReviewSegment.end_time.is_null(False)) + .order_by(ReviewSegment.start_time.asc()) + ) + + if not review_items: + logger.warning(f"No review items found for cameras: {camera_names}") + return + + # Step 2: Create balanced timestamp selection (100 samples) + timestamps = _select_balanced_timestamps(review_items, target_count=100) + + # Step 3: Extract keyframes from recordings with crops applied + keyframes = _extract_keyframes( + "/usr/lib/ffmpeg/7.0/bin/ffmpeg", timestamps, temp_dir, cameras + ) + + # Step 4: Select 24 most visually distinct images (they're already cropped) + distinct_images = _select_distinct_images(keyframes, target_count=24) + + # Step 5: Save to train directory for later classification + train_dir = os.path.join(CLIPS_DIR, model_name, "train") + os.makedirs(train_dir, exist_ok=True) + + saved_count = 0 + for idx, image_path in enumerate(distinct_images): + dest_path = os.path.join(train_dir, f"example_{idx:03d}.jpg") + try: + img = cv2.imread(image_path) + + if img is not None: + cv2.imwrite(dest_path, img) + saved_count += 1 + except Exception as e: + logger.error(f"Failed to save image {image_path}: {e}") + + import shutil + + try: + shutil.rmtree(temp_dir) + except Exception as e: + logger.warning(f"Failed to clean up temp directory: {e}") + + +def _select_balanced_timestamps( + review_items: list[ReviewSegment], target_count: int = 100 +) -> list[dict]: + """ + Select balanced timestamps from review items. + + Strategy: + - Group review items by camera and time of day + - Sample evenly across groups to ensure diversity + - For each selected review item, pick a random timestamp within its duration + + Returns: + List of dicts with keys: camera, timestamp, review_item + """ + # Group by camera and hour of day for temporal diversity + grouped = defaultdict(list) + + for item in review_items: + camera = item.camera + # Group by 6-hour blocks for temporal diversity + hour_block = int(item.start_time // (6 * 3600)) + key = f"{camera}_{hour_block}" + grouped[key].append(item) + + # Calculate how many samples per group + num_groups = len(grouped) + if num_groups == 0: + return [] + + samples_per_group = max(1, target_count // num_groups) + timestamps = [] + + # Sample from each group + for group_items in grouped.values(): + # Take samples_per_group items from this group + sample_size = min(samples_per_group, len(group_items)) + sampled_items = random.sample(group_items, sample_size) + + for item in sampled_items: + # Pick a random timestamp within the review item's duration + duration = item.end_time - item.start_time + if duration <= 0: + continue + + # Sample from middle 80% to avoid edge artifacts + offset = random.uniform(duration * 0.1, duration * 0.9) + timestamp = item.start_time + offset + + timestamps.append( + { + "camera": item.camera, + "timestamp": timestamp, + "review_item": item, + } + ) + + # If we don't have enough, sample more from larger groups + while len(timestamps) < target_count and len(timestamps) < len(review_items): + for group_items in grouped.values(): + if len(timestamps) >= target_count: + break + + # Pick a random item not already sampled + item = random.choice(group_items) + duration = item.end_time - item.start_time + if duration <= 0: + continue + + offset = random.uniform(duration * 0.1, duration * 0.9) + timestamp = item.start_time + offset + + # Check if we already have a timestamp near this one + if not any(abs(t["timestamp"] - timestamp) < 1.0 for t in timestamps): + timestamps.append( + { + "camera": item.camera, + "timestamp": timestamp, + "review_item": item, + } + ) + + return timestamps[:target_count] + + +def _extract_keyframes( + ffmpeg_path: str, + timestamps: list[dict], + output_dir: str, + camera_crops: dict[str, tuple[float, float, float, float]], +) -> list[str]: + """ + Extract keyframes from recordings at specified timestamps and crop to specified regions. + + Args: + ffmpeg_path: Path to ffmpeg binary + timestamps: List of timestamp dicts from _select_balanced_timestamps + output_dir: Directory to save extracted frames + camera_crops: Dict mapping camera names to normalized crop coordinates [x1, y1, x2, y2] (0-1) + + Returns: + List of paths to successfully extracted and cropped keyframe images + """ + keyframe_paths = [] + + for idx, ts_info in enumerate(timestamps): + camera = ts_info["camera"] + timestamp = ts_info["timestamp"] + + if camera not in camera_crops: + logger.warning(f"No crop coordinates for camera {camera}") + continue + + norm_x1, norm_y1, norm_x2, norm_y2 = camera_crops[camera] + + try: + recording = ( + Recordings.select() + .where( + (timestamp >= Recordings.start_time) + & (timestamp <= Recordings.end_time) + & (Recordings.camera == camera) + ) + .order_by(Recordings.start_time.desc()) + .limit(1) + .get() + ) + except Exception: + continue + + relative_time = timestamp - recording.start_time + + try: + config = FfmpegConfig(path="/usr/lib/ffmpeg/7.0") + image_data = get_image_from_recording( + config, + recording.path, + relative_time, + codec="mjpeg", + height=None, + ) + + if image_data: + nparr = np.frombuffer(image_data, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if img is not None: + height, width = img.shape[:2] + + x1 = int(norm_x1 * width) + y1 = int(norm_y1 * height) + x2 = int(norm_x2 * width) + y2 = int(norm_y2 * height) + + x1_clipped = max(0, min(x1, width)) + y1_clipped = max(0, min(y1, height)) + x2_clipped = max(0, min(x2, width)) + y2_clipped = max(0, min(y2, height)) + + if x2_clipped > x1_clipped and y2_clipped > y1_clipped: + cropped = img[y1_clipped:y2_clipped, x1_clipped:x2_clipped] + resized = cv2.resize(cropped, (224, 224)) + + output_path = os.path.join(output_dir, f"frame_{idx:04d}.jpg") + cv2.imwrite(output_path, resized) + keyframe_paths.append(output_path) + + except Exception as e: + logger.debug( + f"Failed to extract frame from {recording.path} at {relative_time}s: {e}" + ) + continue + + return keyframe_paths + + +def _select_distinct_images( + image_paths: list[str], target_count: int = 20 +) -> list[str]: + """ + Select the most visually distinct images from a set of keyframes. + + Uses a greedy algorithm based on image histograms: + 1. Start with a random image + 2. Iteratively add the image that is most different from already selected images + 3. Difference is measured using histogram comparison + + Args: + image_paths: List of paths to candidate images + target_count: Number of distinct images to select + + Returns: + List of paths to selected images + """ + if len(image_paths) <= target_count: + return image_paths + + histograms = {} + valid_paths = [] + + for path in image_paths: + try: + img = cv2.imread(path) + + if img is None: + continue + + hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + hist = cv2.calcHist( + [hsv], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256] + ) + hist = cv2.normalize(hist, hist).flatten() + histograms[path] = hist + valid_paths.append(path) + except Exception as e: + logger.debug(f"Failed to process image {path}: {e}") + continue + + if len(valid_paths) <= target_count: + return valid_paths + + selected = [] + first_image = random.choice(valid_paths) + selected.append(first_image) + remaining = [p for p in valid_paths if p != first_image] + + while len(selected) < target_count and remaining: + max_min_distance = -1 + best_candidate = None + + for candidate in remaining: + min_distance = float("inf") + + for selected_img in selected: + distance = cv2.compareHist( + histograms[candidate], + histograms[selected_img], + cv2.HISTCMP_BHATTACHARYYA, + ) + min_distance = min(min_distance, distance) + + if min_distance > max_min_distance: + max_min_distance = min_distance + best_candidate = candidate + + if best_candidate: + selected.append(best_candidate) + remaining.remove(best_candidate) + else: + break + + return selected + + +@staticmethod +def collect_object_classification_examples( + model_name: str, + label: str, +) -> None: + """ + Collect representative object classification examples from event thumbnails. + + This function: + 1. Queries events for the specified label + 2. Selects 100 balanced events across different cameras and times + 3. Retrieves thumbnails for selected events (with 33% center crop applied) + 4. Selects 24 most visually distinct thumbnails + 5. Saves to dataset directory + + Args: + model_name: Name of the classification model + label: Object label to collect (e.g., "person", "car") + cameras: List of camera names to collect examples from + """ + dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset") + temp_dir = os.path.join(dataset_dir, "temp") + os.makedirs(temp_dir, exist_ok=True) + + # Step 1: Query events for the specified label and cameras + events = list( + Event.select().where((Event.label == label)).order_by(Event.start_time.asc()) + ) + + if not events: + logger.warning(f"No events found for label '{label}'") + return + + logger.debug(f"Found {len(events)} events") + + # Step 2: Select balanced events (100 samples) + selected_events = _select_balanced_events(events, target_count=100) + logger.debug(f"Selected {len(selected_events)} events") + + # Step 3: Extract thumbnails from events + thumbnails = _extract_event_thumbnails(selected_events, temp_dir) + logger.debug(f"Successfully extracted {len(thumbnails)} thumbnails") + + # Step 4: Select 24 most visually distinct thumbnails + distinct_images = _select_distinct_images(thumbnails, target_count=24) + logger.debug(f"Selected {len(distinct_images)} distinct images") + + # Step 5: Save to train directory for later classification + train_dir = os.path.join(CLIPS_DIR, model_name, "train") + os.makedirs(train_dir, exist_ok=True) + + saved_count = 0 + for idx, image_path in enumerate(distinct_images): + dest_path = os.path.join(train_dir, f"example_{idx:03d}.jpg") + try: + img = cv2.imread(image_path) + + if img is not None: + cv2.imwrite(dest_path, img) + saved_count += 1 + except Exception as e: + logger.error(f"Failed to save image {image_path}: {e}") + + import shutil + + try: + shutil.rmtree(temp_dir) + except Exception as e: + logger.warning(f"Failed to clean up temp directory: {e}") + + logger.debug( + f"Successfully collected {saved_count} classification examples in {train_dir}" + ) + + +def _select_balanced_events( + events: list[Event], target_count: int = 100 +) -> list[Event]: + """ + Select balanced events from the event list. + + Strategy: + - Group events by camera and time of day + - Sample evenly across groups to ensure diversity + - Prioritize events with higher scores + + Returns: + List of selected events + """ + grouped = defaultdict(list) + + for event in events: + camera = event.camera + hour_block = int(event.start_time // (6 * 3600)) + key = f"{camera}_{hour_block}" + grouped[key].append(event) + + num_groups = len(grouped) + if num_groups == 0: + return [] + + samples_per_group = max(1, target_count // num_groups) + selected = [] + + for group_events in grouped.values(): + sorted_events = sorted( + group_events, + key=lambda e: e.data.get("score", 0) if e.data else 0, + reverse=True, + ) + + sample_size = min(samples_per_group, len(sorted_events)) + selected.extend(sorted_events[:sample_size]) + + if len(selected) < target_count: + remaining = [e for e in events if e not in selected] + remaining_sorted = sorted( + remaining, + key=lambda e: e.data.get("score", 0) if e.data else 0, + reverse=True, + ) + needed = target_count - len(selected) + selected.extend(remaining_sorted[:needed]) + + return selected[:target_count] + + +def _extract_event_thumbnails(events: list[Event], output_dir: str) -> list[str]: + """ + Extract thumbnails from events and save to disk. + + Args: + events: List of Event objects + output_dir: Directory to save thumbnails + + Returns: + List of paths to successfully extracted thumbnail images + """ + thumbnail_paths = [] + + for idx, event in enumerate(events): + try: + thumbnail_bytes = get_event_thumbnail_bytes(event) + + if thumbnail_bytes: + nparr = np.frombuffer(thumbnail_bytes, np.uint8) + img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + if img is not None: + height, width = img.shape[:2] + + crop_size = 1.0 + if event.data and "box" in event.data and "region" in event.data: + box = event.data["box"] + region = event.data["region"] + + if len(box) == 4 and len(region) == 4: + box_w, box_h = box[2], box[3] + region_w, region_h = region[2], region[3] + + box_area = (box_w * box_h) / (region_w * region_h) + + if box_area < 0.05: + crop_size = 0.4 + elif box_area < 0.10: + crop_size = 0.5 + elif box_area < 0.20: + crop_size = 0.65 + elif box_area < 0.35: + crop_size = 0.80 + else: + crop_size = 0.95 + + crop_width = int(width * crop_size) + crop_height = int(height * crop_size) + + x1 = (width - crop_width) // 2 + y1 = (height - crop_height) // 2 + x2 = x1 + crop_width + y2 = y1 + crop_height + + cropped = img[y1:y2, x1:x2] + resized = cv2.resize(cropped, (224, 224)) + output_path = os.path.join(output_dir, f"thumbnail_{idx:04d}.jpg") + cv2.imwrite(output_path, resized) + thumbnail_paths.append(output_path) + + except Exception as e: + logger.debug(f"Failed to extract thumbnail for event {event.id}: {e}") + continue + + return thumbnail_paths diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index dcfc5a1b2..a13870221 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -1,4 +1,5 @@ { + "documentTitle": "Classification Models", "button": { "deleteClassificationAttempts": "Delete Classification Images", "renameCategory": "Rename Class", @@ -50,8 +51,85 @@ }, "categorizeImageAs": "Classify Image As:", "categorizeImage": "Classify Image", + "noModels": { + "object": { + "title": "No Object Classification Models", + "description": "Create a custom model to classify detected objects.", + "buttonText": "Create Object Model" + }, + "state": { + "title": "No State Classification Models", + "description": "Create a custom model to monitor and classify state changes in specific camera areas.", + "buttonText": "Create State Model" + } + }, "wizard": { "title": "Create New Classification", - "description": "Create a new state or object classification model." + "steps": { + "nameAndDefine": "Name & Define", + "stateArea": "State Area", + "chooseExamples": "Choose Examples" + }, + "step1": { + "description": "State models monitor fixed camera areas for changes (e.g., door open/closed). Object models add classifications to detected objects (e.g., known animals, delivery persons, etc.).", + "name": "Name", + "namePlaceholder": "Enter model name...", + "type": "Type", + "typeState": "State", + "typeObject": "Object", + "objectLabel": "Object Label", + "objectLabelPlaceholder": "Select object type...", + "classificationType": "Classification Type", + "classificationTypeTip": "Learn about classification types", + "classificationTypeDesc": "Sub Labels add additional text to the object label (e.g., 'Person: UPS'). Attributes are searchable metadata stored separately in the object metadata.", + "classificationSubLabel": "Sub Label", + "classificationAttribute": "Attribute", + "classes": "Classes", + "classesTip": "Learn about classes", + "classesStateDesc": "Define the different states your camera area can be in. For example: 'open' and 'closed' for a garage door.", + "classesObjectDesc": "Define the different categories to classify detected objects into. For example: 'delivery_person', 'resident', 'stranger' for person classification.", + "classPlaceholder": "Enter class name...", + "errors": { + "nameRequired": "Model name is required", + "nameLength": "Model name must be 64 characters or less", + "nameOnlyNumbers": "Model name cannot contain only numbers", + "classRequired": "At least 1 class is required", + "classesUnique": "Class names must be unique", + "stateRequiresTwoClasses": "State models require at least 2 classes", + "objectLabelRequired": "Please select an object label", + "objectTypeRequired": "Please select a classification type" + } + }, + "step2": { + "description": "Select cameras and define the area to monitor for each camera. The model will classify the state of these areas.", + "cameras": "Cameras", + "selectCamera": "Select Camera", + "noCameras": "Click + to add cameras", + "selectCameraPrompt": "Select a camera from the list to define its monitoring area" + }, + "step3": { + "selectImagesPrompt": "Select all images with: {{className}}", + "selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.", + "generating": { + "title": "Generating Sample Images", + "description": "Frigate is pulling representative images from your recordings. This may take a moment..." + }, + "training": { + "title": "Training Model", + "description": "Your model is being trained in the background. Close this dialog, and your model will start running as soon as training is complete." + }, + "retryGenerate": "Retry Generation", + "noImages": "No sample images generated", + "classifying": "Classifying & Training...", + "trainingStarted": "Training started successfully", + "errors": { + "noCameras": "No cameras configured", + "noObjectLabel": "No object label selected", + "generateFailed": "Failed to generate examples: {{error}}", + "generationFailed": "Generation failed. Please try again.", + "classifyFailed": "Failed to classify images: {{error}}" + }, + "generateSuccess": "Successfully generated sample images" + } } } diff --git a/web/public/locales/en/views/faceLibrary.json b/web/public/locales/en/views/faceLibrary.json index 6febf85f0..08050977e 100644 --- a/web/public/locales/en/views/faceLibrary.json +++ b/web/public/locales/en/views/faceLibrary.json @@ -5,10 +5,6 @@ "invalidName": "Invalid name. Names can only include letters, numbers, spaces, apostrophes, underscores, and hyphens." }, "details": { - "subLabelScore": "Sub Label Score", - "scoreInfo": "The sub label score is the weighted score for all of the recognized face confidences, so this may differ from the score shown on the snapshot.", - "face": "Face Details", - "faceDesc": "Details of the tracked object that generated this face", "timestamp": "Timestamp", "unknown": "Unknown" }, @@ -19,8 +15,6 @@ }, "collections": "Collections", "createFaceLibrary": { - "title": "Create Collection", - "desc": "Create a new collection", "new": "Create New Face", "nextSteps": "To build a strong foundation:
  • Use the Recent Recognitions tab to select and train on images for each detected person.
  • Focus on straight-on images for best results; avoid training images that capture faces at an angle.
  • " }, @@ -37,8 +31,6 @@ "aria": "Select recent recognitions", "empty": "There are no recent face recognition attempts" }, - "selectItem": "Select {{item}}", - "selectFace": "Select Face", "deleteFaceLibrary": { "title": "Delete Name", "desc": "Are you sure you want to delete the collection {{name}}? This will permanently delete all associated faces." @@ -69,7 +61,6 @@ "maxSize": "Max size: {{size}}MB" }, "nofaces": "No faces available", - "pixels": "{{area}}px", "trainFaceAs": "Train Face as:", "trainFace": "Train Face", "toast": { diff --git a/web/src/components/card/ClassificationCard.tsx b/web/src/components/card/ClassificationCard.tsx index 2fbe36804..de99c60fe 100644 --- a/web/src/components/card/ClassificationCard.tsx +++ b/web/src/components/card/ClassificationCard.tsx @@ -126,6 +126,7 @@ export const ClassificationCard = forwardRef< imgClassName, isMobile && "w-full", )} + loading="lazy" onLoad={() => setImageLoaded(true)} src={`${baseUrl}${data.filepath}`} /> diff --git a/web/src/components/classification/ClassificationModelWizardDialog.tsx b/web/src/components/classification/ClassificationModelWizardDialog.tsx index 621c9ea90..e67a95f89 100644 --- a/web/src/components/classification/ClassificationModelWizardDialog.tsx +++ b/web/src/components/classification/ClassificationModelWizardDialog.tsx @@ -7,58 +7,198 @@ import { DialogHeader, DialogTitle, } from "../ui/dialog"; -import { useState } from "react"; +import { useReducer, useMemo } from "react"; +import Step1NameAndDefine, { Step1FormData } from "./wizard/Step1NameAndDefine"; +import Step2StateArea, { Step2FormData } from "./wizard/Step2StateArea"; +import Step3ChooseExamples, { + Step3FormData, +} from "./wizard/Step3ChooseExamples"; +import { cn } from "@/lib/utils"; +import { isDesktop } from "react-device-detect"; -const STEPS = [ - "classificationWizard.steps.nameAndDefine", - "classificationWizard.steps.stateArea", - "classificationWizard.steps.chooseExamples", - "classificationWizard.steps.train", +const OBJECT_STEPS = [ + "wizard.steps.nameAndDefine", + "wizard.steps.chooseExamples", +]; + +const STATE_STEPS = [ + "wizard.steps.nameAndDefine", + "wizard.steps.stateArea", + "wizard.steps.chooseExamples", ]; type ClassificationModelWizardDialogProps = { open: boolean; onClose: () => void; + defaultModelType?: "state" | "object"; }; + +type WizardState = { + currentStep: number; + step1Data?: Step1FormData; + step2Data?: Step2FormData; + step3Data?: Step3FormData; +}; + +type WizardAction = + | { type: "NEXT_STEP"; payload?: Partial } + | { type: "PREVIOUS_STEP" } + | { type: "SET_STEP_1"; payload: Step1FormData } + | { type: "SET_STEP_2"; payload: Step2FormData } + | { type: "SET_STEP_3"; payload: Step3FormData } + | { type: "RESET" }; + +const initialState: WizardState = { + currentStep: 0, +}; + +function wizardReducer(state: WizardState, action: WizardAction): WizardState { + switch (action.type) { + case "SET_STEP_1": + return { + ...state, + step1Data: action.payload, + currentStep: 1, + }; + case "SET_STEP_2": + return { + ...state, + step2Data: action.payload, + currentStep: 2, + }; + case "SET_STEP_3": + return { + ...state, + step3Data: action.payload, + currentStep: 3, + }; + case "NEXT_STEP": + return { + ...state, + ...action.payload, + currentStep: state.currentStep + 1, + }; + case "PREVIOUS_STEP": + return { + ...state, + currentStep: Math.max(0, state.currentStep - 1), + }; + case "RESET": + return initialState; + default: + return state; + } +} + export default function ClassificationModelWizardDialog({ open, onClose, + defaultModelType, }: ClassificationModelWizardDialogProps) { const { t } = useTranslation(["views/classificationModel"]); - // step management - const [currentStep, _] = useState(0); + const [wizardState, dispatch] = useReducer(wizardReducer, initialState); + + const steps = useMemo(() => { + if (!wizardState.step1Data) { + return OBJECT_STEPS; + } + return wizardState.step1Data.modelType === "state" + ? STATE_STEPS + : OBJECT_STEPS; + }, [wizardState.step1Data]); + + const handleStep1Next = (data: Step1FormData) => { + dispatch({ type: "SET_STEP_1", payload: data }); + }; + + const handleStep2Next = (data: Step2FormData) => { + dispatch({ type: "SET_STEP_2", payload: data }); + }; + + const handleBack = () => { + dispatch({ type: "PREVIOUS_STEP" }); + }; + + const handleCancel = () => { + dispatch({ type: "RESET" }); + onClose(); + }; return ( { if (!open) { - onClose; + handleCancel(); } }} > 0 && + "max-h-[90%] max-w-[70%] overflow-y-auto xl:max-h-[80%]", + )} onInteractOutside={(e) => { e.preventDefault(); }} > {t("wizard.title")} - {currentStep === 0 && ( - {t("wizard.description")} + {wizardState.currentStep === 0 && ( + + {t("wizard.step1.description")} + )} + {wizardState.currentStep === 1 && + wizardState.step1Data?.modelType === "state" && ( + + {t("wizard.step2.description")} + + )}
    -
    + {wizardState.currentStep === 0 && ( + + )} + {wizardState.currentStep === 1 && + wizardState.step1Data?.modelType === "state" && ( + + )} + {((wizardState.currentStep === 2 && + wizardState.step1Data?.modelType === "state") || + (wizardState.currentStep === 1 && + wizardState.step1Data?.modelType === "object")) && + wizardState.step1Data && ( + + )}
    diff --git a/web/src/components/classification/wizard/Step1NameAndDefine.tsx b/web/src/components/classification/wizard/Step1NameAndDefine.tsx new file mode 100644 index 000000000..94e1741da --- /dev/null +++ b/web/src/components/classification/wizard/Step1NameAndDefine.tsx @@ -0,0 +1,498 @@ +import { Button } from "@/components/ui/button"; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { useForm } from "react-hook-form"; +import { zodResolver } from "@hookform/resolvers/zod"; +import { z } from "zod"; +import { useTranslation } from "react-i18next"; +import { useMemo } from "react"; +import { LuX, LuPlus, LuInfo, LuExternalLink } from "react-icons/lu"; +import useSWR from "swr"; +import { FrigateConfig } from "@/types/frigateConfig"; +import { getTranslatedLabel } from "@/utils/i18n"; +import { useDocDomain } from "@/hooks/use-doc-domain"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; + +export type ModelType = "state" | "object"; +export type ObjectClassificationType = "sub_label" | "attribute"; + +export type Step1FormData = { + modelName: string; + modelType: ModelType; + objectLabel?: string; + objectType?: ObjectClassificationType; + classes: string[]; +}; + +type Step1NameAndDefineProps = { + initialData?: Partial; + defaultModelType?: "state" | "object"; + onNext: (data: Step1FormData) => void; + onCancel: () => void; +}; + +export default function Step1NameAndDefine({ + initialData, + defaultModelType, + onNext, + onCancel, +}: Step1NameAndDefineProps) { + const { t } = useTranslation(["views/classificationModel"]); + const { data: config } = useSWR("config"); + const { getLocaleDocUrl } = useDocDomain(); + + const objectLabels = useMemo(() => { + if (!config) return []; + + const labels = new Set(); + + Object.values(config.cameras).forEach((cameraConfig) => { + if (!cameraConfig.enabled || !cameraConfig.enabled_in_config) { + return; + } + + cameraConfig.objects.track.forEach((label) => { + if (!config.model.all_attributes.includes(label)) { + labels.add(label); + } + }); + }); + + return [...labels].sort(); + }, [config]); + + const step1FormData = z + .object({ + modelName: z + .string() + .min(1, t("wizard.step1.errors.nameRequired")) + .max(64, t("wizard.step1.errors.nameLength")) + .refine((value) => !/^\d+$/.test(value), { + message: t("wizard.step1.errors.nameOnlyNumbers"), + }), + modelType: z.enum(["state", "object"]), + objectLabel: z.string().optional(), + objectType: z.enum(["sub_label", "attribute"]).optional(), + classes: z + .array(z.string()) + .min(1, t("wizard.step1.errors.classRequired")) + .refine( + (classes) => { + const nonEmpty = classes.filter((c) => c.trim().length > 0); + return nonEmpty.length >= 1; + }, + { message: t("wizard.step1.errors.classRequired") }, + ) + .refine( + (classes) => { + const nonEmpty = classes.filter((c) => c.trim().length > 0); + const unique = new Set(nonEmpty.map((c) => c.toLowerCase())); + return unique.size === nonEmpty.length; + }, + { message: t("wizard.step1.errors.classesUnique") }, + ), + }) + .refine( + (data) => { + // State models require at least 2 classes + if (data.modelType === "state") { + const nonEmpty = data.classes.filter((c) => c.trim().length > 0); + return nonEmpty.length >= 2; + } + return true; + }, + { + message: t("wizard.step1.errors.stateRequiresTwoClasses"), + path: ["classes"], + }, + ) + .refine( + (data) => { + if (data.modelType === "object") { + return data.objectLabel !== undefined && data.objectLabel !== ""; + } + return true; + }, + { + message: t("wizard.step1.errors.objectLabelRequired"), + path: ["objectLabel"], + }, + ) + .refine( + (data) => { + if (data.modelType === "object") { + return data.objectType !== undefined; + } + return true; + }, + { + message: t("wizard.step1.errors.objectTypeRequired"), + path: ["objectType"], + }, + ); + + const form = useForm>({ + resolver: zodResolver(step1FormData), + defaultValues: { + modelName: initialData?.modelName || "", + modelType: initialData?.modelType || defaultModelType || "state", + objectLabel: initialData?.objectLabel, + objectType: initialData?.objectType || "sub_label", + classes: initialData?.classes?.length ? initialData.classes : [""], + }, + mode: "onChange", + }); + + const watchedClasses = form.watch("classes"); + const watchedModelType = form.watch("modelType"); + const watchedObjectType = form.watch("objectType"); + + const handleAddClass = () => { + const currentClasses = form.getValues("classes"); + form.setValue("classes", [...currentClasses, ""], { shouldValidate: true }); + }; + + const handleRemoveClass = (index: number) => { + const currentClasses = form.getValues("classes"); + const newClasses = currentClasses.filter((_, i) => i !== index); + + // Ensure at least one field remains (even if empty) + if (newClasses.length === 0) { + form.setValue("classes", [""], { shouldValidate: true }); + } else { + form.setValue("classes", newClasses, { shouldValidate: true }); + } + }; + + const onSubmit = (data: z.infer) => { + // Filter out empty classes + const filteredClasses = data.classes.filter((c) => c.trim().length > 0); + onNext({ + ...data, + classes: filteredClasses, + }); + }; + + return ( +
    +
    + + ( + + + {t("wizard.step1.name")} + + + + + + + )} + /> + + ( + + + {t("wizard.step1.type")} + + + +
    + + +
    +
    + + +
    +
    +
    + +
    + )} + /> + + {watchedModelType === "object" && ( + <> + ( + + + {t("wizard.step1.objectLabel")} + + + + + )} + /> + + ( + +
    + + {t("wizard.step1.classificationType")} + + + + + + +
    +
    + {t("wizard.step1.classificationTypeDesc")} +
    + +
    +
    +
    +
    + + +
    + + +
    +
    + + +
    +
    +
    + +
    + )} + /> + + )} + +
    +
    +
    + + {t("wizard.step1.classes")} + + + + + + +
    +
    + {watchedModelType === "state" + ? t("wizard.step1.classesStateDesc") + : t("wizard.step1.classesObjectDesc")} +
    + +
    +
    +
    +
    + +
    +
    + {watchedClasses.map((_, index) => ( + ( + + +
    + + {watchedClasses.length > 1 && ( + + )} +
    +
    +
    + )} + /> + ))} +
    + {form.formState.errors.classes && ( +

    + {form.formState.errors.classes.message} +

    + )} +
    + + + +
    + + +
    +
    + ); +} diff --git a/web/src/components/classification/wizard/Step2StateArea.tsx b/web/src/components/classification/wizard/Step2StateArea.tsx new file mode 100644 index 000000000..38c2fcad7 --- /dev/null +++ b/web/src/components/classification/wizard/Step2StateArea.tsx @@ -0,0 +1,479 @@ +import { Button } from "@/components/ui/button"; +import { useTranslation } from "react-i18next"; +import { useState, useMemo, useRef, useCallback, useEffect } from "react"; +import useSWR from "swr"; +import { FrigateConfig } from "@/types/frigateConfig"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; +import { LuX, LuPlus } from "react-icons/lu"; +import { Stage, Layer, Rect, Transformer } from "react-konva"; +import Konva from "konva"; +import { useResizeObserver } from "@/hooks/resize-observer"; +import { useApiHost } from "@/api"; +import { resolveCameraName } from "@/hooks/use-camera-friendly-name"; +import Heading from "@/components/ui/heading"; +import { isMobile } from "react-device-detect"; +import { cn } from "@/lib/utils"; + +export type CameraAreaConfig = { + camera: string; + crop: [number, number, number, number]; +}; + +export type Step2FormData = { + cameraAreas: CameraAreaConfig[]; +}; + +type Step2StateAreaProps = { + initialData?: Partial; + onNext: (data: Step2FormData) => void; + onBack: () => void; +}; + +export default function Step2StateArea({ + initialData, + onNext, + onBack, +}: Step2StateAreaProps) { + const { t } = useTranslation(["views/classificationModel"]); + const { data: config } = useSWR("config"); + const apiHost = useApiHost(); + + const [cameraAreas, setCameraAreas] = useState( + initialData?.cameraAreas || [], + ); + const [selectedCameraIndex, setSelectedCameraIndex] = useState(0); + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const [imageLoaded, setImageLoaded] = useState(false); + + const containerRef = useRef(null); + const imageRef = useRef(null); + const stageRef = useRef(null); + const rectRef = useRef(null); + const transformerRef = useRef(null); + + const [{ width: containerWidth }] = useResizeObserver(containerRef); + + const availableCameras = useMemo(() => { + if (!config) return []; + + const selectedCameraNames = cameraAreas.map((ca) => ca.camera); + return Object.entries(config.cameras) + .sort() + .filter( + ([name, cam]) => + cam.enabled && + cam.enabled_in_config && + !selectedCameraNames.includes(name), + ) + .map(([name]) => ({ + name, + displayName: resolveCameraName(config, name), + })); + }, [config, cameraAreas]); + + const selectedCamera = useMemo(() => { + if (cameraAreas.length === 0) return null; + return cameraAreas[selectedCameraIndex]; + }, [cameraAreas, selectedCameraIndex]); + + const selectedCameraConfig = useMemo(() => { + if (!config || !selectedCamera) return null; + return config.cameras[selectedCamera.camera]; + }, [config, selectedCamera]); + + const imageSize = useMemo(() => { + if (!containerWidth || !selectedCameraConfig) { + return { width: 0, height: 0 }; + } + + const containerAspectRatio = 16 / 9; + const containerHeight = containerWidth / containerAspectRatio; + + const cameraAspectRatio = + selectedCameraConfig.detect.width / selectedCameraConfig.detect.height; + + // Fit camera within 16:9 container + let imageWidth, imageHeight; + if (cameraAspectRatio > containerAspectRatio) { + imageWidth = containerWidth; + imageHeight = imageWidth / cameraAspectRatio; + } else { + imageHeight = containerHeight; + imageWidth = imageHeight * cameraAspectRatio; + } + + return { width: imageWidth, height: imageHeight }; + }, [containerWidth, selectedCameraConfig]); + + const handleAddCamera = useCallback( + (cameraName: string) => { + // Calculate a square crop in pixel space + const camera = config?.cameras[cameraName]; + if (!camera) return; + + const cameraAspect = camera.detect.width / camera.detect.height; + const cropSize = 0.3; + let x1, y1, x2, y2; + + if (cameraAspect >= 1) { + const pixelSize = cropSize * camera.detect.height; + const normalizedWidth = pixelSize / camera.detect.width; + x1 = (1 - normalizedWidth) / 2; + y1 = (1 - cropSize) / 2; + x2 = x1 + normalizedWidth; + y2 = y1 + cropSize; + } else { + const pixelSize = cropSize * camera.detect.width; + const normalizedHeight = pixelSize / camera.detect.height; + x1 = (1 - cropSize) / 2; + y1 = (1 - normalizedHeight) / 2; + x2 = x1 + cropSize; + y2 = y1 + normalizedHeight; + } + + const newArea: CameraAreaConfig = { + camera: cameraName, + crop: [x1, y1, x2, y2], + }; + setCameraAreas([...cameraAreas, newArea]); + setSelectedCameraIndex(cameraAreas.length); + setIsPopoverOpen(false); + }, + [cameraAreas, config], + ); + + const handleRemoveCamera = useCallback( + (index: number) => { + const newAreas = cameraAreas.filter((_, i) => i !== index); + setCameraAreas(newAreas); + if (selectedCameraIndex >= newAreas.length) { + setSelectedCameraIndex(Math.max(0, newAreas.length - 1)); + } + }, + [cameraAreas, selectedCameraIndex], + ); + + const handleCropChange = useCallback( + (crop: [number, number, number, number]) => { + const newAreas = [...cameraAreas]; + newAreas[selectedCameraIndex] = { + ...newAreas[selectedCameraIndex], + crop, + }; + setCameraAreas(newAreas); + }, + [cameraAreas, selectedCameraIndex], + ); + + useEffect(() => { + setImageLoaded(false); + }, [selectedCamera]); + + useEffect(() => { + const rect = rectRef.current; + const transformer = transformerRef.current; + + if ( + rect && + transformer && + selectedCamera && + imageSize.width > 0 && + imageLoaded + ) { + rect.scaleX(1); + rect.scaleY(1); + transformer.nodes([rect]); + transformer.getLayer()?.batchDraw(); + } + }, [selectedCamera, imageSize, imageLoaded]); + + const handleRectChange = useCallback(() => { + const rect = rectRef.current; + + if (rect && imageSize.width > 0) { + const actualWidth = rect.width() * rect.scaleX(); + const actualHeight = rect.height() * rect.scaleY(); + + // Average dimensions to maintain perfect square + const size = (actualWidth + actualHeight) / 2; + + rect.width(size); + rect.height(size); + rect.scaleX(1); + rect.scaleY(1); + + const x1 = rect.x() / imageSize.width; + const y1 = rect.y() / imageSize.height; + const x2 = (rect.x() + size) / imageSize.width; + const y2 = (rect.y() + size) / imageSize.height; + + handleCropChange([x1, y1, x2, y2]); + } + }, [imageSize, handleCropChange]); + + const handleContinue = useCallback(() => { + onNext({ cameraAreas }); + }, [cameraAreas, onNext]); + + const canContinue = cameraAreas.length > 0; + + return ( +
    +
    +
    +
    +

    {t("wizard.step2.cameras")}

    + {availableCameras.length > 0 ? ( + + + + + e.preventDefault()} + > +
    + + {t("wizard.step2.selectCamera")} + +
    + {availableCameras.map((cam) => ( + + ))} +
    +
    +
    +
    + ) : ( + + )} +
    + +
    + {cameraAreas.map((area, index) => { + const isSelected = index === selectedCameraIndex; + const displayName = resolveCameraName(config, area.camera); + + return ( +
    setSelectedCameraIndex(index)} + > + {displayName} + +
    + ); + })} +
    + + {cameraAreas.length === 0 && ( +
    + {t("wizard.step2.noCameras")} +
    + )} +
    + +
    +
    + {selectedCamera && selectedCameraConfig && imageSize.width > 0 ? ( +
    + {resolveCameraName(config, setImageLoaded(true)} + /> + + + { + const rect = rectRef.current; + if (!rect) return pos; + + const size = rect.width(); + const x = Math.max( + 0, + Math.min(pos.x, imageSize.width - size), + ); + const y = Math.max( + 0, + Math.min(pos.y, imageSize.height - size), + ); + + return { x, y }; + }} + onDragEnd={handleRectChange} + onTransformEnd={handleRectChange} + /> + { + const minSize = 50; + const maxSize = Math.min( + imageSize.width, + imageSize.height, + ); + + // Clamp dimensions to stage bounds first + const clampedWidth = Math.max( + minSize, + Math.min(newBox.width, maxSize), + ); + const clampedHeight = Math.max( + minSize, + Math.min(newBox.height, maxSize), + ); + + // Enforce square using average + const size = (clampedWidth + clampedHeight) / 2; + + // Clamp position to keep square within bounds + const x = Math.max( + 0, + Math.min(newBox.x, imageSize.width - size), + ); + const y = Math.max( + 0, + Math.min(newBox.y, imageSize.height - size), + ); + + return { + ...newBox, + x, + y, + width: size, + height: size, + }; + }} + /> + + +
    + ) : ( +
    + {t("wizard.step2.selectCameraPrompt")} +
    + )} +
    +
    +
    + +
    + + +
    +
    + ); +} diff --git a/web/src/components/classification/wizard/Step3ChooseExamples.tsx b/web/src/components/classification/wizard/Step3ChooseExamples.tsx new file mode 100644 index 000000000..06bb2bbad --- /dev/null +++ b/web/src/components/classification/wizard/Step3ChooseExamples.tsx @@ -0,0 +1,444 @@ +import { Button } from "@/components/ui/button"; +import { useTranslation } from "react-i18next"; +import { useState, useEffect, useCallback, useMemo } from "react"; +import ActivityIndicator from "@/components/indicators/activity-indicator"; +import axios from "axios"; +import { toast } from "sonner"; +import { Step1FormData } from "./Step1NameAndDefine"; +import { Step2FormData } from "./Step2StateArea"; +import useSWR from "swr"; +import { baseUrl } from "@/api/baseUrl"; +import { isMobile } from "react-device-detect"; +import { cn } from "@/lib/utils"; + +export type Step3FormData = { + examplesGenerated: boolean; + imageClassifications?: { [imageName: string]: string }; +}; + +type Step3ChooseExamplesProps = { + step1Data: Step1FormData; + step2Data?: Step2FormData; + initialData?: Partial; + onClose: () => void; + onBack: () => void; +}; + +export default function Step3ChooseExamples({ + step1Data, + step2Data, + initialData, + onClose, + onBack, +}: Step3ChooseExamplesProps) { + const { t } = useTranslation(["views/classificationModel"]); + const [isGenerating, setIsGenerating] = useState(false); + const [hasGenerated, setHasGenerated] = useState( + initialData?.examplesGenerated || false, + ); + const [imageClassifications, setImageClassifications] = useState<{ + [imageName: string]: string; + }>(initialData?.imageClassifications || {}); + const [isTraining, setIsTraining] = useState(false); + const [isProcessing, setIsProcessing] = useState(false); + const [currentClassIndex, setCurrentClassIndex] = useState(0); + const [selectedImages, setSelectedImages] = useState>(new Set()); + + const { data: trainImages, mutate: refreshTrainImages } = useSWR( + hasGenerated ? `classification/${step1Data.modelName}/train` : null, + ); + + const unknownImages = useMemo(() => { + if (!trainImages) return []; + return trainImages; + }, [trainImages]); + + const toggleImageSelection = useCallback((imageName: string) => { + setSelectedImages((prev) => { + const newSet = new Set(prev); + if (newSet.has(imageName)) { + newSet.delete(imageName); + } else { + newSet.add(imageName); + } + return newSet; + }); + }, []); + + // Get all classes (excluding "none" - it will be auto-assigned) + const allClasses = useMemo(() => { + return [...step1Data.classes]; + }, [step1Data.classes]); + + const currentClass = allClasses[currentClassIndex]; + + const processClassificationsAndTrain = useCallback( + async (classifications: { [imageName: string]: string }) => { + // Step 1: Create config for the new model + const modelConfig: { + enabled: boolean; + name: string; + threshold: number; + state_config?: { + cameras: Record; + motion: boolean; + }; + object_config?: { objects: string[]; classification_type: string }; + } = { + enabled: true, + name: step1Data.modelName, + threshold: 0.8, + }; + + if (step1Data.modelType === "state") { + // State model config + const cameras: Record = {}; + step2Data?.cameraAreas.forEach((area) => { + cameras[area.camera] = { + crop: area.crop, + }; + }); + + modelConfig.state_config = { + cameras, + motion: true, + }; + } else { + // Object model config + modelConfig.object_config = { + objects: step1Data.objectLabel ? [step1Data.objectLabel] : [], + classification_type: step1Data.objectType || "sub_label", + } as { objects: string[]; classification_type: string }; + } + + // Update config via config API + await axios.put("/config/set", { + requires_restart: 0, + update_topic: `config/classification/custom/${step1Data.modelName}`, + config_data: { + classification: { + custom: { + [step1Data.modelName]: modelConfig, + }, + }, + }, + }); + + // Step 2: Classify each image by moving it to the correct category folder + const categorizePromises = Object.entries(classifications).map( + ([imageName, className]) => { + if (!className) return Promise.resolve(); + return axios.post( + `/classification/${step1Data.modelName}/dataset/categorize`, + { + training_file: imageName, + category: className === "none" ? "none" : className, + }, + ); + }, + ); + await Promise.all(categorizePromises); + + // Step 3: Kick off training + await axios.post(`/classification/${step1Data.modelName}/train`); + + toast.success(t("wizard.step3.trainingStarted")); + setIsTraining(true); + }, + [step1Data, step2Data, t], + ); + + const handleContinueClassification = useCallback(async () => { + // Mark selected images with current class + const newClassifications = { ...imageClassifications }; + selectedImages.forEach((imageName) => { + newClassifications[imageName] = currentClass; + }); + + // Check if we're on the last class to select + const isLastClass = currentClassIndex === allClasses.length - 1; + + if (isLastClass) { + // Assign remaining unclassified images + unknownImages.slice(0, 24).forEach((imageName) => { + if (!newClassifications[imageName]) { + // For state models with 2 classes, assign to the last class + // For object models, assign to "none" + if (step1Data.modelType === "state" && allClasses.length === 2) { + newClassifications[imageName] = allClasses[allClasses.length - 1]; + } else { + newClassifications[imageName] = "none"; + } + } + }); + + // All done, trigger training immediately + setImageClassifications(newClassifications); + setIsProcessing(true); + + try { + await processClassificationsAndTrain(newClassifications); + } catch (error) { + const axiosError = error as { + response?: { data?: { message?: string; detail?: string } }; + message?: string; + }; + const errorMessage = + axiosError.response?.data?.message || + axiosError.response?.data?.detail || + axiosError.message || + "Failed to classify images"; + + toast.error( + t("wizard.step3.errors.classifyFailed", { error: errorMessage }), + ); + setIsProcessing(false); + } + } else { + // Move to next class + setImageClassifications(newClassifications); + setCurrentClassIndex((prev) => prev + 1); + setSelectedImages(new Set()); + } + }, [ + selectedImages, + currentClass, + currentClassIndex, + allClasses, + imageClassifications, + unknownImages, + step1Data, + processClassificationsAndTrain, + t, + ]); + + const generateExamples = useCallback(async () => { + setIsGenerating(true); + + try { + if (step1Data.modelType === "state") { + // For state models, use cameras and crop areas + if (!step2Data?.cameraAreas || step2Data.cameraAreas.length === 0) { + toast.error(t("wizard.step3.errors.noCameras")); + setIsGenerating(false); + return; + } + + const cameras: { [key: string]: [number, number, number, number] } = {}; + step2Data.cameraAreas.forEach((area) => { + cameras[area.camera] = area.crop; + }); + + await axios.post("/classification/generate_examples/state", { + model_name: step1Data.modelName, + cameras, + }); + } else { + // For object models, use label + if (!step1Data.objectLabel) { + toast.error(t("wizard.step3.errors.noObjectLabel")); + setIsGenerating(false); + return; + } + + // For now, use all enabled cameras + // TODO: In the future, we might want to let users select specific cameras + await axios.post("/classification/generate_examples/object", { + model_name: step1Data.modelName, + label: step1Data.objectLabel, + }); + } + + setHasGenerated(true); + toast.success(t("wizard.step3.generateSuccess")); + + await refreshTrainImages(); + } catch (error) { + const axiosError = error as { + response?: { data?: { message?: string; detail?: string } }; + message?: string; + }; + const errorMessage = + axiosError.response?.data?.message || + axiosError.response?.data?.detail || + axiosError.message || + "Failed to generate examples"; + + toast.error( + t("wizard.step3.errors.generateFailed", { error: errorMessage }), + ); + } finally { + setIsGenerating(false); + } + }, [step1Data, step2Data, t, refreshTrainImages]); + + useEffect(() => { + if (!hasGenerated && !isGenerating) { + generateExamples(); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + const handleContinue = useCallback(async () => { + setIsProcessing(true); + try { + await processClassificationsAndTrain(imageClassifications); + } catch (error) { + const axiosError = error as { + response?: { data?: { message?: string; detail?: string } }; + message?: string; + }; + const errorMessage = + axiosError.response?.data?.message || + axiosError.response?.data?.detail || + axiosError.message || + "Failed to classify images"; + + toast.error( + t("wizard.step3.errors.classifyFailed", { error: errorMessage }), + ); + setIsProcessing(false); + } + }, [imageClassifications, processClassificationsAndTrain, t]); + + const unclassifiedImages = useMemo(() => { + if (!unknownImages) return []; + const images = unknownImages.slice(0, 24); + + // Only filter if we have any classifications + if (Object.keys(imageClassifications).length === 0) { + return images; + } + + return images.filter((img) => !imageClassifications[img]); + }, [unknownImages, imageClassifications]); + + const allImagesClassified = useMemo(() => { + return unclassifiedImages.length === 0; + }, [unclassifiedImages]); + + return ( +
    + {isTraining ? ( +
    + +
    +

    + {t("wizard.step3.training.title")} +

    +

    + {t("wizard.step3.training.description")} +

    +
    + +
    + ) : isGenerating ? ( +
    + +
    +

    + {t("wizard.step3.generating.title")} +

    +

    + {t("wizard.step3.generating.description")} +

    +
    +
    + ) : hasGenerated ? ( +
    + {!allImagesClassified && ( +
    +

    + {t("wizard.step3.selectImagesPrompt", { + className: currentClass, + })} +

    +

    + {t("wizard.step3.selectImagesDescription")} +

    +
    + )} +
    + {!unknownImages || unknownImages.length === 0 ? ( +
    +

    + {t("wizard.step3.noImages")} +

    + +
    + ) : allImagesClassified && isProcessing ? ( +
    + +

    + {t("wizard.step3.classifying")} +

    +
    + ) : ( +
    + {unclassifiedImages.map((imageName, index) => { + const isSelected = selectedImages.has(imageName); + return ( +
    toggleImageSelection(imageName)} + > + {`Example +
    + ); + })} +
    + )} +
    +
    + ) : ( +
    +

    + {t("wizard.step3.errors.generationFailed")} +

    + +
    + )} + + {!isTraining && ( +
    + + +
    + )} +
    + ); +} diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx index 6d6287b4d..db9c66194 100644 --- a/web/src/views/classification/ModelSelectionView.tsx +++ b/web/src/views/classification/ModelSelectionView.tsx @@ -10,11 +10,14 @@ import { CustomClassificationModelConfig, FrigateConfig, } from "@/types/frigateConfig"; -import { useMemo, useState } from "react"; +import { useEffect, useMemo, useState } from "react"; import { isMobile } from "react-device-detect"; import { useTranslation } from "react-i18next"; import { FaFolderPlus } from "react-icons/fa"; +import { MdModelTraining } from "react-icons/md"; import useSWR from "swr"; +import Heading from "@/components/ui/heading"; +import { useOverlayState } from "@/hooks/use-overlay-state"; const allModelTypes = ["objects", "states"] as const; type ModelType = (typeof allModelTypes)[number]; @@ -26,11 +29,24 @@ export default function ModelSelectionView({ onClick, }: ModelSelectionViewProps) { const { t } = useTranslation(["views/classificationModel"]); - const [page, setPage] = useState("objects"); - const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); - const { data: config } = useSWR("config", { - revalidateOnFocus: false, - }); + const [page, setPage] = useOverlayState("objects", "objects"); + const [pageToggle, setPageToggle] = useOptimisticState( + page || "objects", + setPage, + 100, + ); + const { data: config, mutate: refreshConfig } = useSWR( + "config", + { + revalidateOnFocus: false, + }, + ); + + // title + + useEffect(() => { + document.title = t("documentTitle"); + }, [t]); // data @@ -64,15 +80,15 @@ export default function ModelSelectionView({ return ; } - if (classificationConfigs.length == 0) { - return
    You need to setup a custom model configuration.
    ; - } - return (
    setNewModel(false)} + defaultModelType={pageToggle === "objects" ? "object" : "state"} + onClose={() => { + setNewModel(false); + refreshConfig(); + }} />
    @@ -84,7 +100,6 @@ export default function ModelSelectionView({ value={pageToggle} onValueChange={(value: ModelType) => { if (value) { - // Restrict viewer navigation setPageToggle(value); } }} @@ -117,13 +132,46 @@ export default function ModelSelectionView({
    - {selectedClassificationConfigs.map((config) => ( - onClick(config)} + {selectedClassificationConfigs.length === 0 ? ( + setNewModel(true)} + modelType={pageToggle} /> - ))} + ) : ( + selectedClassificationConfigs.map((config) => ( + onClick(config)} + /> + )) + )} +
    + + ); +} + +function NoModelsView({ + onCreateModel, + modelType, +}: { + onCreateModel: () => void; + modelType: ModelType; +}) { + const { t } = useTranslation(["views/classificationModel"]); + const typeKey = modelType === "objects" ? "object" : "state"; + + return ( +
    +
    + + {t(`noModels.${typeKey}.title`)} +
    + {t(`noModels.${typeKey}.description`)} +
    +
    ); @@ -139,13 +187,17 @@ function ModelCard({ config, onClick }: ModelCardProps) { }>(`classification/${config.name}/dataset`, { revalidateOnFocus: false }); const coverImage = useMemo(() => { - if (!dataset?.length) { + if (!dataset) { return undefined; } const keys = Object.keys(dataset).filter((key) => key != "none"); const selectedKey = keys[0]; + if (!dataset[selectedKey]) { + return undefined; + } + return { name: selectedKey, img: dataset[selectedKey][0], diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 6f9e479c2..f313d6829 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -642,6 +642,7 @@ function DatasetGrid({ filepath: `clips/${modelName}/dataset/${categoryName}/${image}`, name: "", }} + showArea={false} selected={selectedImages.includes(image)} i18nLibrary="views/classificationModel" onClick={(data, _) => onClickImages([data.filename], true)}