mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-12-06 05:24:11 +03:00
Implement Wizard for Creating Classification Models (#20622)
* Implement extraction of images for classification state models * Add object classification dataset preparation * Add first step wizard * Update i18n * Add state classification image selection step * Improve box handling * Add object selector * Improve object cropping implementation * Fix state classification selection * Finalize training and image selection step * Cleanup * Design optimizations * Cleanup mobile styling * Update no models screen * Cleanups and fixes * Fix bugs * Improve model training and creation process * Cleanup * Dynamically add metrics for new model * Add loading when hitting continue * Improve image selection mechanism * Remove unused translation keys * Adjust wording * Add retry button for image generation * Make no models view more specific * Adjust plus icon * Adjust form label * Start with correct type selected * Cleanup sizing and more font colors * Small tweaks * Add tips and more info * Cleanup dialog sizing * Add cursor rule for frontend * Cleanup * remove underline * Lazy loading
This commit is contained in:
parent
4df7793587
commit
f5a57edcc9
6
.cursor/rules/frontend-always-use-translation-files.mdc
Normal file
6
.cursor/rules/frontend-always-use-translation-files.mdc
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
globs: ["**/*.ts", "**/*.tsx"]
|
||||
alwaysApply: false
|
||||
---
|
||||
|
||||
Never write strings in the frontend directly, always write to and reference the relevant translations file.
|
||||
@ -12,7 +12,18 @@ Object classification models are lightweight and run very fast on CPU. Inference
|
||||
Training the model does briefly use a high amount of system resources for about 1–3 minutes per training run. On lower-power devices, training may take longer.
|
||||
When running the `-tensorrt` image, Nvidia GPUs will automatically be used to accelerate training.
|
||||
|
||||
### Sub label vs Attribute
|
||||
## Classes
|
||||
|
||||
Classes are the categories your model will learn to distinguish between. Each class represents a distinct visual category that the model will predict.
|
||||
|
||||
For object classification:
|
||||
|
||||
- Define classes that represent different types or attributes of the detected object
|
||||
- Examples: For `person` objects, classes might be `delivery_person`, `resident`, `stranger`
|
||||
- Include a `none` class for objects that don't fit any specific category
|
||||
- Keep classes visually distinct to improve accuracy
|
||||
|
||||
### Classification Type
|
||||
|
||||
- **Sub label**:
|
||||
|
||||
|
||||
@ -12,6 +12,17 @@ State classification models are lightweight and run very fast on CPU. Inference
|
||||
Training the model does briefly use a high amount of system resources for about 1–3 minutes per training run. On lower-power devices, training may take longer.
|
||||
When running the `-tensorrt` image, Nvidia GPUs will automatically be used to accelerate training.
|
||||
|
||||
## Classes
|
||||
|
||||
Classes are the different states an area on your camera can be in. Each class represents a distinct visual state that the model will learn to recognize.
|
||||
|
||||
For state classification:
|
||||
|
||||
- Define classes that represent mutually exclusive states
|
||||
- Examples: `open` and `closed` for a garage door, `on` and `off` for lights
|
||||
- Use at least 2 classes (typically binary states work best)
|
||||
- Keep class names clear and descriptive
|
||||
|
||||
## Example use cases
|
||||
|
||||
- **Door state**: Detect if a garage or front door is open vs closed.
|
||||
|
||||
@ -387,20 +387,28 @@ def config_set(request: Request, body: AppConfigSetBody):
|
||||
old_config: FrigateConfig = request.app.frigate_config
|
||||
request.app.frigate_config = config
|
||||
|
||||
if body.update_topic and body.update_topic.startswith("config/cameras/"):
|
||||
_, _, camera, field = body.update_topic.split("/")
|
||||
if body.update_topic:
|
||||
if body.update_topic.startswith("config/cameras/"):
|
||||
_, _, camera, field = body.update_topic.split("/")
|
||||
|
||||
if field == "add":
|
||||
settings = config.cameras[camera]
|
||||
elif field == "remove":
|
||||
settings = old_config.cameras[camera]
|
||||
if field == "add":
|
||||
settings = config.cameras[camera]
|
||||
elif field == "remove":
|
||||
settings = old_config.cameras[camera]
|
||||
else:
|
||||
settings = config.get_nested_object(body.update_topic)
|
||||
|
||||
request.app.config_publisher.publish_update(
|
||||
CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera),
|
||||
settings,
|
||||
)
|
||||
else:
|
||||
# Handle nested config updates (e.g., config/classification/custom/{name})
|
||||
settings = config.get_nested_object(body.update_topic)
|
||||
|
||||
request.app.config_publisher.publish_update(
|
||||
CameraConfigUpdateTopic(CameraConfigUpdateEnum[field], camera),
|
||||
settings,
|
||||
)
|
||||
if settings:
|
||||
request.app.config_publisher.publisher.publish(
|
||||
body.update_topic, settings
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content=(
|
||||
|
||||
@ -3,7 +3,9 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
@ -17,6 +19,8 @@ from frigate.api.auth import require_role
|
||||
from frigate.api.defs.request.classification_body import (
|
||||
AudioTranscriptionBody,
|
||||
DeleteFaceImagesBody,
|
||||
GenerateObjectExamplesBody,
|
||||
GenerateStateExamplesBody,
|
||||
RenameFaceBody,
|
||||
)
|
||||
from frigate.api.defs.response.classification_response import (
|
||||
@ -30,6 +34,10 @@ from frigate.config.camera import DetectConfig
|
||||
from frigate.const import CLIPS_DIR, FACE_DIR
|
||||
from frigate.embeddings import EmbeddingsContext
|
||||
from frigate.models import Event
|
||||
from frigate.util.classification import (
|
||||
collect_object_classification_examples,
|
||||
collect_state_classification_examples,
|
||||
)
|
||||
from frigate.util.path import get_event_snapshot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -159,8 +167,7 @@ def train_face(request: Request, name: str, body: dict = None):
|
||||
new_name = f"{sanitized_name}-{datetime.datetime.now().timestamp()}.webp"
|
||||
new_file_folder = os.path.join(FACE_DIR, f"{sanitized_name}")
|
||||
|
||||
if not os.path.exists(new_file_folder):
|
||||
os.mkdir(new_file_folder)
|
||||
os.makedirs(new_file_folder, exist_ok=True)
|
||||
|
||||
if training_file_name:
|
||||
shutil.move(training_file, os.path.join(new_file_folder, new_name))
|
||||
@ -701,13 +708,14 @@ def categorize_classification_image(request: Request, name: str, body: dict = No
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
new_name = f"{category}-{datetime.datetime.now().timestamp()}.png"
|
||||
random_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6))
|
||||
timestamp = datetime.datetime.now().timestamp()
|
||||
new_name = f"{category}-{timestamp}-{random_id}.png"
|
||||
new_file_folder = os.path.join(
|
||||
CLIPS_DIR, sanitize_filename(name), "dataset", category
|
||||
)
|
||||
|
||||
if not os.path.exists(new_file_folder):
|
||||
os.mkdir(new_file_folder)
|
||||
os.makedirs(new_file_folder, exist_ok=True)
|
||||
|
||||
# use opencv because webp images can not be used to train
|
||||
img = cv2.imread(training_file)
|
||||
@ -756,3 +764,43 @@ def delete_classification_train_images(request: Request, name: str, body: dict =
|
||||
content=({"success": True, "message": "Successfully deleted faces."}),
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/classification/generate_examples/state",
|
||||
response_model=GenericResponse,
|
||||
dependencies=[Depends(require_role(["admin"]))],
|
||||
summary="Generate state classification examples",
|
||||
)
|
||||
async def generate_state_examples(request: Request, body: GenerateStateExamplesBody):
|
||||
"""Generate examples for state classification."""
|
||||
model_name = sanitize_filename(body.model_name)
|
||||
cameras_normalized = {
|
||||
camera_name: tuple(crop)
|
||||
for camera_name, crop in body.cameras.items()
|
||||
if camera_name in request.app.frigate_config.cameras
|
||||
}
|
||||
|
||||
collect_state_classification_examples(model_name, cameras_normalized)
|
||||
|
||||
return JSONResponse(
|
||||
content={"success": True, "message": "Example generation completed"},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/classification/generate_examples/object",
|
||||
response_model=GenericResponse,
|
||||
dependencies=[Depends(require_role(["admin"]))],
|
||||
summary="Generate object classification examples",
|
||||
)
|
||||
async def generate_object_examples(request: Request, body: GenerateObjectExamplesBody):
|
||||
"""Generate examples for object classification."""
|
||||
model_name = sanitize_filename(body.model_name)
|
||||
collect_object_classification_examples(model_name, body.label)
|
||||
|
||||
return JSONResponse(
|
||||
content={"success": True, "message": "Example generation completed"},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
@ -1,17 +1,31 @@
|
||||
from typing import List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RenameFaceBody(BaseModel):
|
||||
new_name: str
|
||||
new_name: str = Field(description="New name for the face")
|
||||
|
||||
|
||||
class AudioTranscriptionBody(BaseModel):
|
||||
event_id: str
|
||||
event_id: str = Field(description="ID of the event to transcribe audio for")
|
||||
|
||||
|
||||
class DeleteFaceImagesBody(BaseModel):
|
||||
ids: List[str] = Field(
|
||||
description="List of image filenames to delete from the face folder"
|
||||
)
|
||||
|
||||
|
||||
class GenerateStateExamplesBody(BaseModel):
|
||||
model_name: str = Field(description="Name of the classification model")
|
||||
cameras: Dict[str, Tuple[float, float, float, float]] = Field(
|
||||
description="Dictionary mapping camera names to normalized crop coordinates in [x1, y1, x2, y2] format (values 0-1)"
|
||||
)
|
||||
|
||||
|
||||
class GenerateObjectExamplesBody(BaseModel):
|
||||
model_name: str = Field(description="Name of the classification model")
|
||||
label: str = Field(
|
||||
description="Object label to collect examples for (e.g., 'person', 'car')"
|
||||
)
|
||||
|
||||
@ -53,9 +53,17 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
self.tensor_output_details: dict[str, Any] | None = None
|
||||
self.labelmap: dict[int, str] = {}
|
||||
self.classifications_per_second = EventsPerSecond()
|
||||
self.inference_speed = InferenceSpeed(
|
||||
self.metrics.classification_speeds[self.model_config.name]
|
||||
)
|
||||
|
||||
if (
|
||||
self.metrics
|
||||
and self.model_config.name in self.metrics.classification_speeds
|
||||
):
|
||||
self.inference_speed = InferenceSpeed(
|
||||
self.metrics.classification_speeds[self.model_config.name]
|
||||
)
|
||||
else:
|
||||
self.inference_speed = None
|
||||
|
||||
self.last_run = datetime.datetime.now().timestamp()
|
||||
self.__build_detector()
|
||||
|
||||
@ -83,12 +91,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
|
||||
def __update_metrics(self, duration: float) -> None:
|
||||
self.classifications_per_second.update()
|
||||
self.inference_speed.update(duration)
|
||||
if self.inference_speed:
|
||||
self.inference_speed.update(duration)
|
||||
|
||||
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
||||
self.metrics.classification_cps[
|
||||
self.model_config.name
|
||||
].value = self.classifications_per_second.eps()
|
||||
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
||||
self.metrics.classification_cps[
|
||||
self.model_config.name
|
||||
].value = self.classifications_per_second.eps()
|
||||
camera = frame_data.get("camera")
|
||||
|
||||
if camera not in self.model_config.state_config.cameras:
|
||||
@ -223,9 +233,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
self.detected_objects: dict[str, float] = {}
|
||||
self.labelmap: dict[int, str] = {}
|
||||
self.classifications_per_second = EventsPerSecond()
|
||||
self.inference_speed = InferenceSpeed(
|
||||
self.metrics.classification_speeds[self.model_config.name]
|
||||
)
|
||||
|
||||
if (
|
||||
self.metrics
|
||||
and self.model_config.name in self.metrics.classification_speeds
|
||||
):
|
||||
self.inference_speed = InferenceSpeed(
|
||||
self.metrics.classification_speeds[self.model_config.name]
|
||||
)
|
||||
else:
|
||||
self.inference_speed = None
|
||||
|
||||
self.__build_detector()
|
||||
|
||||
@redirect_output_to_logger(logger, logging.DEBUG)
|
||||
@ -251,12 +269,14 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
|
||||
def __update_metrics(self, duration: float) -> None:
|
||||
self.classifications_per_second.update()
|
||||
self.inference_speed.update(duration)
|
||||
if self.inference_speed:
|
||||
self.inference_speed.update(duration)
|
||||
|
||||
def process_frame(self, obj_data, frame):
|
||||
self.metrics.classification_cps[
|
||||
self.model_config.name
|
||||
].value = self.classifications_per_second.eps()
|
||||
if self.metrics and self.model_config.name in self.metrics.classification_cps:
|
||||
self.metrics.classification_cps[
|
||||
self.model_config.name
|
||||
].value = self.classifications_per_second.eps()
|
||||
|
||||
if obj_data["false_positive"]:
|
||||
return
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Any
|
||||
|
||||
from peewee import DoesNotExist
|
||||
|
||||
from frigate.comms.config_updater import ConfigSubscriber
|
||||
from frigate.comms.detections_updater import DetectionSubscriber, DetectionTypeEnum
|
||||
from frigate.comms.embeddings_updater import (
|
||||
EmbeddingsRequestEnum,
|
||||
@ -95,6 +96,9 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
CameraConfigUpdateEnum.semantic_search,
|
||||
],
|
||||
)
|
||||
self.classification_config_subscriber = ConfigSubscriber(
|
||||
"config/classification/custom/"
|
||||
)
|
||||
|
||||
# Configure Frigate DB
|
||||
db = SqliteVecQueueDatabase(
|
||||
@ -255,6 +259,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
"""Maintain a SQLite-vec database for semantic search."""
|
||||
while not self.stop_event.is_set():
|
||||
self.config_updater.check_for_updates()
|
||||
self._check_classification_config_updates()
|
||||
self._process_requests()
|
||||
self._process_updates()
|
||||
self._process_recordings_updates()
|
||||
@ -265,6 +270,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
self._process_event_metadata()
|
||||
|
||||
self.config_updater.stop()
|
||||
self.classification_config_subscriber.stop()
|
||||
self.event_subscriber.stop()
|
||||
self.event_end_subscriber.stop()
|
||||
self.recordings_subscriber.stop()
|
||||
@ -275,6 +281,46 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
self.requestor.stop()
|
||||
logger.info("Exiting embeddings maintenance...")
|
||||
|
||||
def _check_classification_config_updates(self) -> None:
|
||||
"""Check for classification config updates and add new processors."""
|
||||
topic, model_config = self.classification_config_subscriber.check_for_update()
|
||||
|
||||
if topic and model_config:
|
||||
model_name = topic.split("/")[-1]
|
||||
self.config.classification.custom[model_name] = model_config
|
||||
|
||||
# Check if processor already exists
|
||||
for processor in self.realtime_processors:
|
||||
if isinstance(
|
||||
processor,
|
||||
(
|
||||
CustomStateClassificationProcessor,
|
||||
CustomObjectClassificationProcessor,
|
||||
),
|
||||
):
|
||||
if processor.model_config.name == model_name:
|
||||
logger.debug(
|
||||
f"Classification processor for model {model_name} already exists, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
if model_config.state_config is not None:
|
||||
processor = CustomStateClassificationProcessor(
|
||||
self.config, model_config, self.requestor, self.metrics
|
||||
)
|
||||
else:
|
||||
processor = CustomObjectClassificationProcessor(
|
||||
self.config,
|
||||
model_config,
|
||||
self.event_metadata_publisher,
|
||||
self.metrics,
|
||||
)
|
||||
|
||||
self.realtime_processors.append(processor)
|
||||
logger.info(
|
||||
f"Added classification processor for model: {model_name} (type: {type(processor).__name__})"
|
||||
)
|
||||
|
||||
def _process_requests(self) -> None:
|
||||
"""Process embeddings requests"""
|
||||
|
||||
|
||||
@ -2,12 +2,15 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.config import FfmpegConfig
|
||||
from frigate.const import (
|
||||
CLIPS_DIR,
|
||||
MODEL_CACHE_DIR,
|
||||
@ -15,7 +18,10 @@ from frigate.const import (
|
||||
UPDATE_MODEL_STATE,
|
||||
)
|
||||
from frigate.log import redirect_output_to_logger
|
||||
from frigate.models import Event, Recordings, ReviewSegment
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.image import get_image_from_recording
|
||||
from frigate.util.path import get_event_thumbnail_bytes
|
||||
from frigate.util.process import FrigateProcess
|
||||
|
||||
BATCH_SIZE = 16
|
||||
@ -69,6 +75,7 @@ class ClassificationTrainingProcess(FrigateProcess):
|
||||
logger.info(f"Kicking off classification training for {self.model_name}.")
|
||||
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
|
||||
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
num_classes = len(
|
||||
[
|
||||
d
|
||||
@ -139,7 +146,6 @@ class ClassificationTrainingProcess(FrigateProcess):
|
||||
f.write(tflite_model)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def kickoff_model_training(
|
||||
embeddingRequestor: EmbeddingsRequestor, model_name: str
|
||||
) -> None:
|
||||
@ -172,3 +178,520 @@ def kickoff_model_training(
|
||||
},
|
||||
)
|
||||
requestor.stop()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def collect_state_classification_examples(
|
||||
model_name: str, cameras: dict[str, tuple[float, float, float, float]]
|
||||
) -> None:
|
||||
"""
|
||||
Collect representative state classification examples from review items.
|
||||
|
||||
This function:
|
||||
1. Queries review items from specified cameras
|
||||
2. Selects 100 balanced timestamps across the data
|
||||
3. Extracts keyframes from recordings (cropped to specified regions)
|
||||
4. Selects 20 most visually distinct images
|
||||
5. Saves them to the dataset directory
|
||||
|
||||
Args:
|
||||
model_name: Name of the classification model
|
||||
cameras: Dict mapping camera names to normalized crop coordinates [x1, y1, x2, y2] (0-1)
|
||||
"""
|
||||
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||
temp_dir = os.path.join(dataset_dir, "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Step 1: Get review items for the cameras
|
||||
camera_names = list(cameras.keys())
|
||||
review_items = list(
|
||||
ReviewSegment.select()
|
||||
.where(ReviewSegment.camera.in_(camera_names))
|
||||
.where(ReviewSegment.end_time.is_null(False))
|
||||
.order_by(ReviewSegment.start_time.asc())
|
||||
)
|
||||
|
||||
if not review_items:
|
||||
logger.warning(f"No review items found for cameras: {camera_names}")
|
||||
return
|
||||
|
||||
# Step 2: Create balanced timestamp selection (100 samples)
|
||||
timestamps = _select_balanced_timestamps(review_items, target_count=100)
|
||||
|
||||
# Step 3: Extract keyframes from recordings with crops applied
|
||||
keyframes = _extract_keyframes(
|
||||
"/usr/lib/ffmpeg/7.0/bin/ffmpeg", timestamps, temp_dir, cameras
|
||||
)
|
||||
|
||||
# Step 4: Select 24 most visually distinct images (they're already cropped)
|
||||
distinct_images = _select_distinct_images(keyframes, target_count=24)
|
||||
|
||||
# Step 5: Save to train directory for later classification
|
||||
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
||||
os.makedirs(train_dir, exist_ok=True)
|
||||
|
||||
saved_count = 0
|
||||
for idx, image_path in enumerate(distinct_images):
|
||||
dest_path = os.path.join(train_dir, f"example_{idx:03d}.jpg")
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
|
||||
if img is not None:
|
||||
cv2.imwrite(dest_path, img)
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save image {image_path}: {e}")
|
||||
|
||||
import shutil
|
||||
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp directory: {e}")
|
||||
|
||||
|
||||
def _select_balanced_timestamps(
|
||||
review_items: list[ReviewSegment], target_count: int = 100
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Select balanced timestamps from review items.
|
||||
|
||||
Strategy:
|
||||
- Group review items by camera and time of day
|
||||
- Sample evenly across groups to ensure diversity
|
||||
- For each selected review item, pick a random timestamp within its duration
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: camera, timestamp, review_item
|
||||
"""
|
||||
# Group by camera and hour of day for temporal diversity
|
||||
grouped = defaultdict(list)
|
||||
|
||||
for item in review_items:
|
||||
camera = item.camera
|
||||
# Group by 6-hour blocks for temporal diversity
|
||||
hour_block = int(item.start_time // (6 * 3600))
|
||||
key = f"{camera}_{hour_block}"
|
||||
grouped[key].append(item)
|
||||
|
||||
# Calculate how many samples per group
|
||||
num_groups = len(grouped)
|
||||
if num_groups == 0:
|
||||
return []
|
||||
|
||||
samples_per_group = max(1, target_count // num_groups)
|
||||
timestamps = []
|
||||
|
||||
# Sample from each group
|
||||
for group_items in grouped.values():
|
||||
# Take samples_per_group items from this group
|
||||
sample_size = min(samples_per_group, len(group_items))
|
||||
sampled_items = random.sample(group_items, sample_size)
|
||||
|
||||
for item in sampled_items:
|
||||
# Pick a random timestamp within the review item's duration
|
||||
duration = item.end_time - item.start_time
|
||||
if duration <= 0:
|
||||
continue
|
||||
|
||||
# Sample from middle 80% to avoid edge artifacts
|
||||
offset = random.uniform(duration * 0.1, duration * 0.9)
|
||||
timestamp = item.start_time + offset
|
||||
|
||||
timestamps.append(
|
||||
{
|
||||
"camera": item.camera,
|
||||
"timestamp": timestamp,
|
||||
"review_item": item,
|
||||
}
|
||||
)
|
||||
|
||||
# If we don't have enough, sample more from larger groups
|
||||
while len(timestamps) < target_count and len(timestamps) < len(review_items):
|
||||
for group_items in grouped.values():
|
||||
if len(timestamps) >= target_count:
|
||||
break
|
||||
|
||||
# Pick a random item not already sampled
|
||||
item = random.choice(group_items)
|
||||
duration = item.end_time - item.start_time
|
||||
if duration <= 0:
|
||||
continue
|
||||
|
||||
offset = random.uniform(duration * 0.1, duration * 0.9)
|
||||
timestamp = item.start_time + offset
|
||||
|
||||
# Check if we already have a timestamp near this one
|
||||
if not any(abs(t["timestamp"] - timestamp) < 1.0 for t in timestamps):
|
||||
timestamps.append(
|
||||
{
|
||||
"camera": item.camera,
|
||||
"timestamp": timestamp,
|
||||
"review_item": item,
|
||||
}
|
||||
)
|
||||
|
||||
return timestamps[:target_count]
|
||||
|
||||
|
||||
def _extract_keyframes(
|
||||
ffmpeg_path: str,
|
||||
timestamps: list[dict],
|
||||
output_dir: str,
|
||||
camera_crops: dict[str, tuple[float, float, float, float]],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Extract keyframes from recordings at specified timestamps and crop to specified regions.
|
||||
|
||||
Args:
|
||||
ffmpeg_path: Path to ffmpeg binary
|
||||
timestamps: List of timestamp dicts from _select_balanced_timestamps
|
||||
output_dir: Directory to save extracted frames
|
||||
camera_crops: Dict mapping camera names to normalized crop coordinates [x1, y1, x2, y2] (0-1)
|
||||
|
||||
Returns:
|
||||
List of paths to successfully extracted and cropped keyframe images
|
||||
"""
|
||||
keyframe_paths = []
|
||||
|
||||
for idx, ts_info in enumerate(timestamps):
|
||||
camera = ts_info["camera"]
|
||||
timestamp = ts_info["timestamp"]
|
||||
|
||||
if camera not in camera_crops:
|
||||
logger.warning(f"No crop coordinates for camera {camera}")
|
||||
continue
|
||||
|
||||
norm_x1, norm_y1, norm_x2, norm_y2 = camera_crops[camera]
|
||||
|
||||
try:
|
||||
recording = (
|
||||
Recordings.select()
|
||||
.where(
|
||||
(timestamp >= Recordings.start_time)
|
||||
& (timestamp <= Recordings.end_time)
|
||||
& (Recordings.camera == camera)
|
||||
)
|
||||
.order_by(Recordings.start_time.desc())
|
||||
.limit(1)
|
||||
.get()
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
relative_time = timestamp - recording.start_time
|
||||
|
||||
try:
|
||||
config = FfmpegConfig(path="/usr/lib/ffmpeg/7.0")
|
||||
image_data = get_image_from_recording(
|
||||
config,
|
||||
recording.path,
|
||||
relative_time,
|
||||
codec="mjpeg",
|
||||
height=None,
|
||||
)
|
||||
|
||||
if image_data:
|
||||
nparr = np.frombuffer(image_data, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if img is not None:
|
||||
height, width = img.shape[:2]
|
||||
|
||||
x1 = int(norm_x1 * width)
|
||||
y1 = int(norm_y1 * height)
|
||||
x2 = int(norm_x2 * width)
|
||||
y2 = int(norm_y2 * height)
|
||||
|
||||
x1_clipped = max(0, min(x1, width))
|
||||
y1_clipped = max(0, min(y1, height))
|
||||
x2_clipped = max(0, min(x2, width))
|
||||
y2_clipped = max(0, min(y2, height))
|
||||
|
||||
if x2_clipped > x1_clipped and y2_clipped > y1_clipped:
|
||||
cropped = img[y1_clipped:y2_clipped, x1_clipped:x2_clipped]
|
||||
resized = cv2.resize(cropped, (224, 224))
|
||||
|
||||
output_path = os.path.join(output_dir, f"frame_{idx:04d}.jpg")
|
||||
cv2.imwrite(output_path, resized)
|
||||
keyframe_paths.append(output_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to extract frame from {recording.path} at {relative_time}s: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return keyframe_paths
|
||||
|
||||
|
||||
def _select_distinct_images(
|
||||
image_paths: list[str], target_count: int = 20
|
||||
) -> list[str]:
|
||||
"""
|
||||
Select the most visually distinct images from a set of keyframes.
|
||||
|
||||
Uses a greedy algorithm based on image histograms:
|
||||
1. Start with a random image
|
||||
2. Iteratively add the image that is most different from already selected images
|
||||
3. Difference is measured using histogram comparison
|
||||
|
||||
Args:
|
||||
image_paths: List of paths to candidate images
|
||||
target_count: Number of distinct images to select
|
||||
|
||||
Returns:
|
||||
List of paths to selected images
|
||||
"""
|
||||
if len(image_paths) <= target_count:
|
||||
return image_paths
|
||||
|
||||
histograms = {}
|
||||
valid_paths = []
|
||||
|
||||
for path in image_paths:
|
||||
try:
|
||||
img = cv2.imread(path)
|
||||
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
hist = cv2.calcHist(
|
||||
[hsv], [0, 1, 2], None, [8, 8, 8], [0, 180, 0, 256, 0, 256]
|
||||
)
|
||||
hist = cv2.normalize(hist, hist).flatten()
|
||||
histograms[path] = hist
|
||||
valid_paths.append(path)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to process image {path}: {e}")
|
||||
continue
|
||||
|
||||
if len(valid_paths) <= target_count:
|
||||
return valid_paths
|
||||
|
||||
selected = []
|
||||
first_image = random.choice(valid_paths)
|
||||
selected.append(first_image)
|
||||
remaining = [p for p in valid_paths if p != first_image]
|
||||
|
||||
while len(selected) < target_count and remaining:
|
||||
max_min_distance = -1
|
||||
best_candidate = None
|
||||
|
||||
for candidate in remaining:
|
||||
min_distance = float("inf")
|
||||
|
||||
for selected_img in selected:
|
||||
distance = cv2.compareHist(
|
||||
histograms[candidate],
|
||||
histograms[selected_img],
|
||||
cv2.HISTCMP_BHATTACHARYYA,
|
||||
)
|
||||
min_distance = min(min_distance, distance)
|
||||
|
||||
if min_distance > max_min_distance:
|
||||
max_min_distance = min_distance
|
||||
best_candidate = candidate
|
||||
|
||||
if best_candidate:
|
||||
selected.append(best_candidate)
|
||||
remaining.remove(best_candidate)
|
||||
else:
|
||||
break
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
@staticmethod
|
||||
def collect_object_classification_examples(
|
||||
model_name: str,
|
||||
label: str,
|
||||
) -> None:
|
||||
"""
|
||||
Collect representative object classification examples from event thumbnails.
|
||||
|
||||
This function:
|
||||
1. Queries events for the specified label
|
||||
2. Selects 100 balanced events across different cameras and times
|
||||
3. Retrieves thumbnails for selected events (with 33% center crop applied)
|
||||
4. Selects 24 most visually distinct thumbnails
|
||||
5. Saves to dataset directory
|
||||
|
||||
Args:
|
||||
model_name: Name of the classification model
|
||||
label: Object label to collect (e.g., "person", "car")
|
||||
cameras: List of camera names to collect examples from
|
||||
"""
|
||||
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||
temp_dir = os.path.join(dataset_dir, "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Step 1: Query events for the specified label and cameras
|
||||
events = list(
|
||||
Event.select().where((Event.label == label)).order_by(Event.start_time.asc())
|
||||
)
|
||||
|
||||
if not events:
|
||||
logger.warning(f"No events found for label '{label}'")
|
||||
return
|
||||
|
||||
logger.debug(f"Found {len(events)} events")
|
||||
|
||||
# Step 2: Select balanced events (100 samples)
|
||||
selected_events = _select_balanced_events(events, target_count=100)
|
||||
logger.debug(f"Selected {len(selected_events)} events")
|
||||
|
||||
# Step 3: Extract thumbnails from events
|
||||
thumbnails = _extract_event_thumbnails(selected_events, temp_dir)
|
||||
logger.debug(f"Successfully extracted {len(thumbnails)} thumbnails")
|
||||
|
||||
# Step 4: Select 24 most visually distinct thumbnails
|
||||
distinct_images = _select_distinct_images(thumbnails, target_count=24)
|
||||
logger.debug(f"Selected {len(distinct_images)} distinct images")
|
||||
|
||||
# Step 5: Save to train directory for later classification
|
||||
train_dir = os.path.join(CLIPS_DIR, model_name, "train")
|
||||
os.makedirs(train_dir, exist_ok=True)
|
||||
|
||||
saved_count = 0
|
||||
for idx, image_path in enumerate(distinct_images):
|
||||
dest_path = os.path.join(train_dir, f"example_{idx:03d}.jpg")
|
||||
try:
|
||||
img = cv2.imread(image_path)
|
||||
|
||||
if img is not None:
|
||||
cv2.imwrite(dest_path, img)
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save image {image_path}: {e}")
|
||||
|
||||
import shutil
|
||||
|
||||
try:
|
||||
shutil.rmtree(temp_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp directory: {e}")
|
||||
|
||||
logger.debug(
|
||||
f"Successfully collected {saved_count} classification examples in {train_dir}"
|
||||
)
|
||||
|
||||
|
||||
def _select_balanced_events(
|
||||
events: list[Event], target_count: int = 100
|
||||
) -> list[Event]:
|
||||
"""
|
||||
Select balanced events from the event list.
|
||||
|
||||
Strategy:
|
||||
- Group events by camera and time of day
|
||||
- Sample evenly across groups to ensure diversity
|
||||
- Prioritize events with higher scores
|
||||
|
||||
Returns:
|
||||
List of selected events
|
||||
"""
|
||||
grouped = defaultdict(list)
|
||||
|
||||
for event in events:
|
||||
camera = event.camera
|
||||
hour_block = int(event.start_time // (6 * 3600))
|
||||
key = f"{camera}_{hour_block}"
|
||||
grouped[key].append(event)
|
||||
|
||||
num_groups = len(grouped)
|
||||
if num_groups == 0:
|
||||
return []
|
||||
|
||||
samples_per_group = max(1, target_count // num_groups)
|
||||
selected = []
|
||||
|
||||
for group_events in grouped.values():
|
||||
sorted_events = sorted(
|
||||
group_events,
|
||||
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
sample_size = min(samples_per_group, len(sorted_events))
|
||||
selected.extend(sorted_events[:sample_size])
|
||||
|
||||
if len(selected) < target_count:
|
||||
remaining = [e for e in events if e not in selected]
|
||||
remaining_sorted = sorted(
|
||||
remaining,
|
||||
key=lambda e: e.data.get("score", 0) if e.data else 0,
|
||||
reverse=True,
|
||||
)
|
||||
needed = target_count - len(selected)
|
||||
selected.extend(remaining_sorted[:needed])
|
||||
|
||||
return selected[:target_count]
|
||||
|
||||
|
||||
def _extract_event_thumbnails(events: list[Event], output_dir: str) -> list[str]:
|
||||
"""
|
||||
Extract thumbnails from events and save to disk.
|
||||
|
||||
Args:
|
||||
events: List of Event objects
|
||||
output_dir: Directory to save thumbnails
|
||||
|
||||
Returns:
|
||||
List of paths to successfully extracted thumbnail images
|
||||
"""
|
||||
thumbnail_paths = []
|
||||
|
||||
for idx, event in enumerate(events):
|
||||
try:
|
||||
thumbnail_bytes = get_event_thumbnail_bytes(event)
|
||||
|
||||
if thumbnail_bytes:
|
||||
nparr = np.frombuffer(thumbnail_bytes, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if img is not None:
|
||||
height, width = img.shape[:2]
|
||||
|
||||
crop_size = 1.0
|
||||
if event.data and "box" in event.data and "region" in event.data:
|
||||
box = event.data["box"]
|
||||
region = event.data["region"]
|
||||
|
||||
if len(box) == 4 and len(region) == 4:
|
||||
box_w, box_h = box[2], box[3]
|
||||
region_w, region_h = region[2], region[3]
|
||||
|
||||
box_area = (box_w * box_h) / (region_w * region_h)
|
||||
|
||||
if box_area < 0.05:
|
||||
crop_size = 0.4
|
||||
elif box_area < 0.10:
|
||||
crop_size = 0.5
|
||||
elif box_area < 0.20:
|
||||
crop_size = 0.65
|
||||
elif box_area < 0.35:
|
||||
crop_size = 0.80
|
||||
else:
|
||||
crop_size = 0.95
|
||||
|
||||
crop_width = int(width * crop_size)
|
||||
crop_height = int(height * crop_size)
|
||||
|
||||
x1 = (width - crop_width) // 2
|
||||
y1 = (height - crop_height) // 2
|
||||
x2 = x1 + crop_width
|
||||
y2 = y1 + crop_height
|
||||
|
||||
cropped = img[y1:y2, x1:x2]
|
||||
resized = cv2.resize(cropped, (224, 224))
|
||||
output_path = os.path.join(output_dir, f"thumbnail_{idx:04d}.jpg")
|
||||
cv2.imwrite(output_path, resized)
|
||||
thumbnail_paths.append(output_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract thumbnail for event {event.id}: {e}")
|
||||
continue
|
||||
|
||||
return thumbnail_paths
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
{
|
||||
"documentTitle": "Classification Models",
|
||||
"button": {
|
||||
"deleteClassificationAttempts": "Delete Classification Images",
|
||||
"renameCategory": "Rename Class",
|
||||
@ -50,8 +51,85 @@
|
||||
},
|
||||
"categorizeImageAs": "Classify Image As:",
|
||||
"categorizeImage": "Classify Image",
|
||||
"noModels": {
|
||||
"object": {
|
||||
"title": "No Object Classification Models",
|
||||
"description": "Create a custom model to classify detected objects.",
|
||||
"buttonText": "Create Object Model"
|
||||
},
|
||||
"state": {
|
||||
"title": "No State Classification Models",
|
||||
"description": "Create a custom model to monitor and classify state changes in specific camera areas.",
|
||||
"buttonText": "Create State Model"
|
||||
}
|
||||
},
|
||||
"wizard": {
|
||||
"title": "Create New Classification",
|
||||
"description": "Create a new state or object classification model."
|
||||
"steps": {
|
||||
"nameAndDefine": "Name & Define",
|
||||
"stateArea": "State Area",
|
||||
"chooseExamples": "Choose Examples"
|
||||
},
|
||||
"step1": {
|
||||
"description": "State models monitor fixed camera areas for changes (e.g., door open/closed). Object models add classifications to detected objects (e.g., known animals, delivery persons, etc.).",
|
||||
"name": "Name",
|
||||
"namePlaceholder": "Enter model name...",
|
||||
"type": "Type",
|
||||
"typeState": "State",
|
||||
"typeObject": "Object",
|
||||
"objectLabel": "Object Label",
|
||||
"objectLabelPlaceholder": "Select object type...",
|
||||
"classificationType": "Classification Type",
|
||||
"classificationTypeTip": "Learn about classification types",
|
||||
"classificationTypeDesc": "Sub Labels add additional text to the object label (e.g., 'Person: UPS'). Attributes are searchable metadata stored separately in the object metadata.",
|
||||
"classificationSubLabel": "Sub Label",
|
||||
"classificationAttribute": "Attribute",
|
||||
"classes": "Classes",
|
||||
"classesTip": "Learn about classes",
|
||||
"classesStateDesc": "Define the different states your camera area can be in. For example: 'open' and 'closed' for a garage door.",
|
||||
"classesObjectDesc": "Define the different categories to classify detected objects into. For example: 'delivery_person', 'resident', 'stranger' for person classification.",
|
||||
"classPlaceholder": "Enter class name...",
|
||||
"errors": {
|
||||
"nameRequired": "Model name is required",
|
||||
"nameLength": "Model name must be 64 characters or less",
|
||||
"nameOnlyNumbers": "Model name cannot contain only numbers",
|
||||
"classRequired": "At least 1 class is required",
|
||||
"classesUnique": "Class names must be unique",
|
||||
"stateRequiresTwoClasses": "State models require at least 2 classes",
|
||||
"objectLabelRequired": "Please select an object label",
|
||||
"objectTypeRequired": "Please select a classification type"
|
||||
}
|
||||
},
|
||||
"step2": {
|
||||
"description": "Select cameras and define the area to monitor for each camera. The model will classify the state of these areas.",
|
||||
"cameras": "Cameras",
|
||||
"selectCamera": "Select Camera",
|
||||
"noCameras": "Click + to add cameras",
|
||||
"selectCameraPrompt": "Select a camera from the list to define its monitoring area"
|
||||
},
|
||||
"step3": {
|
||||
"selectImagesPrompt": "Select all images with: {{className}}",
|
||||
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
|
||||
"generating": {
|
||||
"title": "Generating Sample Images",
|
||||
"description": "Frigate is pulling representative images from your recordings. This may take a moment..."
|
||||
},
|
||||
"training": {
|
||||
"title": "Training Model",
|
||||
"description": "Your model is being trained in the background. Close this dialog, and your model will start running as soon as training is complete."
|
||||
},
|
||||
"retryGenerate": "Retry Generation",
|
||||
"noImages": "No sample images generated",
|
||||
"classifying": "Classifying & Training...",
|
||||
"trainingStarted": "Training started successfully",
|
||||
"errors": {
|
||||
"noCameras": "No cameras configured",
|
||||
"noObjectLabel": "No object label selected",
|
||||
"generateFailed": "Failed to generate examples: {{error}}",
|
||||
"generationFailed": "Generation failed. Please try again.",
|
||||
"classifyFailed": "Failed to classify images: {{error}}"
|
||||
},
|
||||
"generateSuccess": "Successfully generated sample images"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,10 +5,6 @@
|
||||
"invalidName": "Invalid name. Names can only include letters, numbers, spaces, apostrophes, underscores, and hyphens."
|
||||
},
|
||||
"details": {
|
||||
"subLabelScore": "Sub Label Score",
|
||||
"scoreInfo": "The sub label score is the weighted score for all of the recognized face confidences, so this may differ from the score shown on the snapshot.",
|
||||
"face": "Face Details",
|
||||
"faceDesc": "Details of the tracked object that generated this face",
|
||||
"timestamp": "Timestamp",
|
||||
"unknown": "Unknown"
|
||||
},
|
||||
@ -19,8 +15,6 @@
|
||||
},
|
||||
"collections": "Collections",
|
||||
"createFaceLibrary": {
|
||||
"title": "Create Collection",
|
||||
"desc": "Create a new collection",
|
||||
"new": "Create New Face",
|
||||
"nextSteps": "To build a strong foundation:<li>Use the Recent Recognitions tab to select and train on images for each detected person.</li><li>Focus on straight-on images for best results; avoid training images that capture faces at an angle.</li></ul>"
|
||||
},
|
||||
@ -37,8 +31,6 @@
|
||||
"aria": "Select recent recognitions",
|
||||
"empty": "There are no recent face recognition attempts"
|
||||
},
|
||||
"selectItem": "Select {{item}}",
|
||||
"selectFace": "Select Face",
|
||||
"deleteFaceLibrary": {
|
||||
"title": "Delete Name",
|
||||
"desc": "Are you sure you want to delete the collection {{name}}? This will permanently delete all associated faces."
|
||||
@ -69,7 +61,6 @@
|
||||
"maxSize": "Max size: {{size}}MB"
|
||||
},
|
||||
"nofaces": "No faces available",
|
||||
"pixels": "{{area}}px",
|
||||
"trainFaceAs": "Train Face as:",
|
||||
"trainFace": "Train Face",
|
||||
"toast": {
|
||||
|
||||
@ -126,6 +126,7 @@ export const ClassificationCard = forwardRef<
|
||||
imgClassName,
|
||||
isMobile && "w-full",
|
||||
)}
|
||||
loading="lazy"
|
||||
onLoad={() => setImageLoaded(true)}
|
||||
src={`${baseUrl}${data.filepath}`}
|
||||
/>
|
||||
|
||||
@ -7,58 +7,198 @@ import {
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "../ui/dialog";
|
||||
import { useState } from "react";
|
||||
import { useReducer, useMemo } from "react";
|
||||
import Step1NameAndDefine, { Step1FormData } from "./wizard/Step1NameAndDefine";
|
||||
import Step2StateArea, { Step2FormData } from "./wizard/Step2StateArea";
|
||||
import Step3ChooseExamples, {
|
||||
Step3FormData,
|
||||
} from "./wizard/Step3ChooseExamples";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { isDesktop } from "react-device-detect";
|
||||
|
||||
const STEPS = [
|
||||
"classificationWizard.steps.nameAndDefine",
|
||||
"classificationWizard.steps.stateArea",
|
||||
"classificationWizard.steps.chooseExamples",
|
||||
"classificationWizard.steps.train",
|
||||
const OBJECT_STEPS = [
|
||||
"wizard.steps.nameAndDefine",
|
||||
"wizard.steps.chooseExamples",
|
||||
];
|
||||
|
||||
const STATE_STEPS = [
|
||||
"wizard.steps.nameAndDefine",
|
||||
"wizard.steps.stateArea",
|
||||
"wizard.steps.chooseExamples",
|
||||
];
|
||||
|
||||
type ClassificationModelWizardDialogProps = {
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
defaultModelType?: "state" | "object";
|
||||
};
|
||||
|
||||
type WizardState = {
|
||||
currentStep: number;
|
||||
step1Data?: Step1FormData;
|
||||
step2Data?: Step2FormData;
|
||||
step3Data?: Step3FormData;
|
||||
};
|
||||
|
||||
type WizardAction =
|
||||
| { type: "NEXT_STEP"; payload?: Partial<WizardState> }
|
||||
| { type: "PREVIOUS_STEP" }
|
||||
| { type: "SET_STEP_1"; payload: Step1FormData }
|
||||
| { type: "SET_STEP_2"; payload: Step2FormData }
|
||||
| { type: "SET_STEP_3"; payload: Step3FormData }
|
||||
| { type: "RESET" };
|
||||
|
||||
const initialState: WizardState = {
|
||||
currentStep: 0,
|
||||
};
|
||||
|
||||
function wizardReducer(state: WizardState, action: WizardAction): WizardState {
|
||||
switch (action.type) {
|
||||
case "SET_STEP_1":
|
||||
return {
|
||||
...state,
|
||||
step1Data: action.payload,
|
||||
currentStep: 1,
|
||||
};
|
||||
case "SET_STEP_2":
|
||||
return {
|
||||
...state,
|
||||
step2Data: action.payload,
|
||||
currentStep: 2,
|
||||
};
|
||||
case "SET_STEP_3":
|
||||
return {
|
||||
...state,
|
||||
step3Data: action.payload,
|
||||
currentStep: 3,
|
||||
};
|
||||
case "NEXT_STEP":
|
||||
return {
|
||||
...state,
|
||||
...action.payload,
|
||||
currentStep: state.currentStep + 1,
|
||||
};
|
||||
case "PREVIOUS_STEP":
|
||||
return {
|
||||
...state,
|
||||
currentStep: Math.max(0, state.currentStep - 1),
|
||||
};
|
||||
case "RESET":
|
||||
return initialState;
|
||||
default:
|
||||
return state;
|
||||
}
|
||||
}
|
||||
|
||||
export default function ClassificationModelWizardDialog({
|
||||
open,
|
||||
onClose,
|
||||
defaultModelType,
|
||||
}: ClassificationModelWizardDialogProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
|
||||
// step management
|
||||
const [currentStep, _] = useState(0);
|
||||
const [wizardState, dispatch] = useReducer(wizardReducer, initialState);
|
||||
|
||||
const steps = useMemo(() => {
|
||||
if (!wizardState.step1Data) {
|
||||
return OBJECT_STEPS;
|
||||
}
|
||||
return wizardState.step1Data.modelType === "state"
|
||||
? STATE_STEPS
|
||||
: OBJECT_STEPS;
|
||||
}, [wizardState.step1Data]);
|
||||
|
||||
const handleStep1Next = (data: Step1FormData) => {
|
||||
dispatch({ type: "SET_STEP_1", payload: data });
|
||||
};
|
||||
|
||||
const handleStep2Next = (data: Step2FormData) => {
|
||||
dispatch({ type: "SET_STEP_2", payload: data });
|
||||
};
|
||||
|
||||
const handleBack = () => {
|
||||
dispatch({ type: "PREVIOUS_STEP" });
|
||||
};
|
||||
|
||||
const handleCancel = () => {
|
||||
dispatch({ type: "RESET" });
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
onClose;
|
||||
handleCancel();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<DialogContent
|
||||
className="max-h-[90dvh] max-w-4xl overflow-y-auto"
|
||||
className={cn(
|
||||
"",
|
||||
isDesktop &&
|
||||
wizardState.currentStep == 0 &&
|
||||
"max-h-[90%] overflow-y-auto xl:max-h-[80%]",
|
||||
isDesktop &&
|
||||
wizardState.currentStep > 0 &&
|
||||
"max-h-[90%] max-w-[70%] overflow-y-auto xl:max-h-[80%]",
|
||||
)}
|
||||
onInteractOutside={(e) => {
|
||||
e.preventDefault();
|
||||
}}
|
||||
>
|
||||
<StepIndicator
|
||||
steps={STEPS}
|
||||
currentStep={currentStep}
|
||||
steps={steps}
|
||||
currentStep={wizardState.currentStep}
|
||||
variant="dots"
|
||||
className="mb-4 justify-start"
|
||||
/>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t("wizard.title")}</DialogTitle>
|
||||
{currentStep === 0 && (
|
||||
<DialogDescription>{t("wizard.description")}</DialogDescription>
|
||||
{wizardState.currentStep === 0 && (
|
||||
<DialogDescription>
|
||||
{t("wizard.step1.description")}
|
||||
</DialogDescription>
|
||||
)}
|
||||
{wizardState.currentStep === 1 &&
|
||||
wizardState.step1Data?.modelType === "state" && (
|
||||
<DialogDescription>
|
||||
{t("wizard.step2.description")}
|
||||
</DialogDescription>
|
||||
)}
|
||||
</DialogHeader>
|
||||
|
||||
<div className="pb-4">
|
||||
<div className="size-full"></div>
|
||||
{wizardState.currentStep === 0 && (
|
||||
<Step1NameAndDefine
|
||||
initialData={wizardState.step1Data}
|
||||
defaultModelType={defaultModelType}
|
||||
onNext={handleStep1Next}
|
||||
onCancel={handleCancel}
|
||||
/>
|
||||
)}
|
||||
{wizardState.currentStep === 1 &&
|
||||
wizardState.step1Data?.modelType === "state" && (
|
||||
<Step2StateArea
|
||||
initialData={wizardState.step2Data}
|
||||
onNext={handleStep2Next}
|
||||
onBack={handleBack}
|
||||
/>
|
||||
)}
|
||||
{((wizardState.currentStep === 2 &&
|
||||
wizardState.step1Data?.modelType === "state") ||
|
||||
(wizardState.currentStep === 1 &&
|
||||
wizardState.step1Data?.modelType === "object")) &&
|
||||
wizardState.step1Data && (
|
||||
<Step3ChooseExamples
|
||||
step1Data={wizardState.step1Data}
|
||||
step2Data={wizardState.step2Data}
|
||||
initialData={wizardState.step3Data}
|
||||
onClose={onClose}
|
||||
onBack={handleBack}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
498
web/src/components/classification/wizard/Step1NameAndDefine.tsx
Normal file
498
web/src/components/classification/wizard/Step1NameAndDefine.tsx
Normal file
@ -0,0 +1,498 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { z } from "zod";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useMemo } from "react";
|
||||
import { LuX, LuPlus, LuInfo, LuExternalLink } from "react-icons/lu";
|
||||
import useSWR from "swr";
|
||||
import { FrigateConfig } from "@/types/frigateConfig";
|
||||
import { getTranslatedLabel } from "@/utils/i18n";
|
||||
import { useDocDomain } from "@/hooks/use-doc-domain";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/ui/popover";
|
||||
|
||||
export type ModelType = "state" | "object";
|
||||
export type ObjectClassificationType = "sub_label" | "attribute";
|
||||
|
||||
export type Step1FormData = {
|
||||
modelName: string;
|
||||
modelType: ModelType;
|
||||
objectLabel?: string;
|
||||
objectType?: ObjectClassificationType;
|
||||
classes: string[];
|
||||
};
|
||||
|
||||
type Step1NameAndDefineProps = {
|
||||
initialData?: Partial<Step1FormData>;
|
||||
defaultModelType?: "state" | "object";
|
||||
onNext: (data: Step1FormData) => void;
|
||||
onCancel: () => void;
|
||||
};
|
||||
|
||||
export default function Step1NameAndDefine({
|
||||
initialData,
|
||||
defaultModelType,
|
||||
onNext,
|
||||
onCancel,
|
||||
}: Step1NameAndDefineProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
const { data: config } = useSWR<FrigateConfig>("config");
|
||||
const { getLocaleDocUrl } = useDocDomain();
|
||||
|
||||
const objectLabels = useMemo(() => {
|
||||
if (!config) return [];
|
||||
|
||||
const labels = new Set<string>();
|
||||
|
||||
Object.values(config.cameras).forEach((cameraConfig) => {
|
||||
if (!cameraConfig.enabled || !cameraConfig.enabled_in_config) {
|
||||
return;
|
||||
}
|
||||
|
||||
cameraConfig.objects.track.forEach((label) => {
|
||||
if (!config.model.all_attributes.includes(label)) {
|
||||
labels.add(label);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return [...labels].sort();
|
||||
}, [config]);
|
||||
|
||||
const step1FormData = z
|
||||
.object({
|
||||
modelName: z
|
||||
.string()
|
||||
.min(1, t("wizard.step1.errors.nameRequired"))
|
||||
.max(64, t("wizard.step1.errors.nameLength"))
|
||||
.refine((value) => !/^\d+$/.test(value), {
|
||||
message: t("wizard.step1.errors.nameOnlyNumbers"),
|
||||
}),
|
||||
modelType: z.enum(["state", "object"]),
|
||||
objectLabel: z.string().optional(),
|
||||
objectType: z.enum(["sub_label", "attribute"]).optional(),
|
||||
classes: z
|
||||
.array(z.string())
|
||||
.min(1, t("wizard.step1.errors.classRequired"))
|
||||
.refine(
|
||||
(classes) => {
|
||||
const nonEmpty = classes.filter((c) => c.trim().length > 0);
|
||||
return nonEmpty.length >= 1;
|
||||
},
|
||||
{ message: t("wizard.step1.errors.classRequired") },
|
||||
)
|
||||
.refine(
|
||||
(classes) => {
|
||||
const nonEmpty = classes.filter((c) => c.trim().length > 0);
|
||||
const unique = new Set(nonEmpty.map((c) => c.toLowerCase()));
|
||||
return unique.size === nonEmpty.length;
|
||||
},
|
||||
{ message: t("wizard.step1.errors.classesUnique") },
|
||||
),
|
||||
})
|
||||
.refine(
|
||||
(data) => {
|
||||
// State models require at least 2 classes
|
||||
if (data.modelType === "state") {
|
||||
const nonEmpty = data.classes.filter((c) => c.trim().length > 0);
|
||||
return nonEmpty.length >= 2;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{
|
||||
message: t("wizard.step1.errors.stateRequiresTwoClasses"),
|
||||
path: ["classes"],
|
||||
},
|
||||
)
|
||||
.refine(
|
||||
(data) => {
|
||||
if (data.modelType === "object") {
|
||||
return data.objectLabel !== undefined && data.objectLabel !== "";
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{
|
||||
message: t("wizard.step1.errors.objectLabelRequired"),
|
||||
path: ["objectLabel"],
|
||||
},
|
||||
)
|
||||
.refine(
|
||||
(data) => {
|
||||
if (data.modelType === "object") {
|
||||
return data.objectType !== undefined;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{
|
||||
message: t("wizard.step1.errors.objectTypeRequired"),
|
||||
path: ["objectType"],
|
||||
},
|
||||
);
|
||||
|
||||
const form = useForm<z.infer<typeof step1FormData>>({
|
||||
resolver: zodResolver(step1FormData),
|
||||
defaultValues: {
|
||||
modelName: initialData?.modelName || "",
|
||||
modelType: initialData?.modelType || defaultModelType || "state",
|
||||
objectLabel: initialData?.objectLabel,
|
||||
objectType: initialData?.objectType || "sub_label",
|
||||
classes: initialData?.classes?.length ? initialData.classes : [""],
|
||||
},
|
||||
mode: "onChange",
|
||||
});
|
||||
|
||||
const watchedClasses = form.watch("classes");
|
||||
const watchedModelType = form.watch("modelType");
|
||||
const watchedObjectType = form.watch("objectType");
|
||||
|
||||
const handleAddClass = () => {
|
||||
const currentClasses = form.getValues("classes");
|
||||
form.setValue("classes", [...currentClasses, ""], { shouldValidate: true });
|
||||
};
|
||||
|
||||
const handleRemoveClass = (index: number) => {
|
||||
const currentClasses = form.getValues("classes");
|
||||
const newClasses = currentClasses.filter((_, i) => i !== index);
|
||||
|
||||
// Ensure at least one field remains (even if empty)
|
||||
if (newClasses.length === 0) {
|
||||
form.setValue("classes", [""], { shouldValidate: true });
|
||||
} else {
|
||||
form.setValue("classes", newClasses, { shouldValidate: true });
|
||||
}
|
||||
};
|
||||
|
||||
const onSubmit = (data: z.infer<typeof step1FormData>) => {
|
||||
// Filter out empty classes
|
||||
const filteredClasses = data.classes.filter((c) => c.trim().length > 0);
|
||||
onNext({
|
||||
...data,
|
||||
classes: filteredClasses,
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="modelName"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel className="text-primary-variant">
|
||||
{t("wizard.step1.name")}
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
className="h-8"
|
||||
placeholder={t("wizard.step1.namePlaceholder")}
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="modelType"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel className="text-primary-variant">
|
||||
{t("wizard.step1.type")}
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
<RadioGroup
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
className="flex flex-col gap-4 pt-2"
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<RadioGroupItem
|
||||
className={
|
||||
watchedModelType === "state"
|
||||
? "bg-selected from-selected/50 to-selected/90 text-selected"
|
||||
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
|
||||
}
|
||||
id="state"
|
||||
value="state"
|
||||
/>
|
||||
<Label className="cursor-pointer" htmlFor="state">
|
||||
{t("wizard.step1.typeState")}
|
||||
</Label>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<RadioGroupItem
|
||||
className={
|
||||
watchedModelType === "object"
|
||||
? "bg-selected from-selected/50 to-selected/90 text-selected"
|
||||
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
|
||||
}
|
||||
id="object"
|
||||
value="object"
|
||||
/>
|
||||
<Label className="cursor-pointer" htmlFor="object">
|
||||
{t("wizard.step1.typeObject")}
|
||||
</Label>
|
||||
</div>
|
||||
</RadioGroup>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
{watchedModelType === "object" && (
|
||||
<>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="objectLabel"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel className="text-primary-variant">
|
||||
{t("wizard.step1.objectLabel")}
|
||||
</FormLabel>
|
||||
<Select
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
>
|
||||
<FormControl>
|
||||
<SelectTrigger className="h-8">
|
||||
<SelectValue
|
||||
placeholder={t(
|
||||
"wizard.step1.objectLabelPlaceholder",
|
||||
)}
|
||||
/>
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
<SelectContent>
|
||||
{objectLabels.map((label) => (
|
||||
<SelectItem
|
||||
key={label}
|
||||
value={label}
|
||||
className="cursor-pointer hover:bg-secondary-highlight"
|
||||
>
|
||||
{getTranslatedLabel(label)}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="objectType"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<div className="flex items-center gap-1">
|
||||
<FormLabel className="text-primary-variant">
|
||||
{t("wizard.step1.classificationType")}
|
||||
</FormLabel>
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-4 w-4 p-0"
|
||||
>
|
||||
<LuInfo className="size-3" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="pointer-events-auto w-80 text-xs">
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="text-sm">
|
||||
{t("wizard.step1.classificationTypeDesc")}
|
||||
</div>
|
||||
<div className="mt-3 flex items-center text-primary">
|
||||
<a
|
||||
href={getLocaleDocUrl(
|
||||
"configuration/custom_classification/object_classification#classification-type",
|
||||
)}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="inline cursor-pointer"
|
||||
>
|
||||
{t("readTheDocumentation", { ns: "common" })}
|
||||
<LuExternalLink className="ml-2 inline-flex size-3" />
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
<FormControl>
|
||||
<RadioGroup
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
className="flex flex-col gap-4 pt-2"
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<RadioGroupItem
|
||||
className={
|
||||
watchedObjectType === "sub_label"
|
||||
? "bg-selected from-selected/50 to-selected/90 text-selected"
|
||||
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
|
||||
}
|
||||
id="sub_label"
|
||||
value="sub_label"
|
||||
/>
|
||||
<Label className="cursor-pointer" htmlFor="sub_label">
|
||||
{t("wizard.step1.classificationSubLabel")}
|
||||
</Label>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<RadioGroupItem
|
||||
className={
|
||||
watchedObjectType === "attribute"
|
||||
? "bg-selected from-selected/50 to-selected/90 text-selected"
|
||||
: "bg-secondary from-secondary/50 to-secondary/90 text-secondary"
|
||||
}
|
||||
id="attribute"
|
||||
value="attribute"
|
||||
/>
|
||||
<Label className="cursor-pointer" htmlFor="attribute">
|
||||
{t("wizard.step1.classificationAttribute")}
|
||||
</Label>
|
||||
</div>
|
||||
</RadioGroup>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-1">
|
||||
<FormLabel className="text-primary-variant">
|
||||
{t("wizard.step1.classes")}
|
||||
</FormLabel>
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="sm" className="h-4 w-4 p-0">
|
||||
<LuInfo className="size-3" />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="pointer-events-auto w-80 text-xs">
|
||||
<div className="flex flex-col gap-2">
|
||||
<div className="text-sm">
|
||||
{watchedModelType === "state"
|
||||
? t("wizard.step1.classesStateDesc")
|
||||
: t("wizard.step1.classesObjectDesc")}
|
||||
</div>
|
||||
<div className="mt-3 flex items-center text-primary">
|
||||
<a
|
||||
href={getLocaleDocUrl(
|
||||
watchedModelType === "state"
|
||||
? "configuration/custom_classification/state_classification"
|
||||
: "configuration/custom_classification/object_classification",
|
||||
)}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="inline cursor-pointer"
|
||||
>
|
||||
{t("readTheDocumentation", { ns: "common" })}
|
||||
<LuExternalLink className="ml-2 inline-flex size-3" />
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
className="size-6 rounded-md bg-secondary-foreground p-1 text-background"
|
||||
onClick={handleAddClass}
|
||||
>
|
||||
<LuPlus />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
{watchedClasses.map((_, index) => (
|
||||
<FormField
|
||||
key={index}
|
||||
control={form.control}
|
||||
name={`classes.${index}`}
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormControl>
|
||||
<div className="flex items-center gap-2">
|
||||
<Input
|
||||
className="h-8"
|
||||
placeholder={t("wizard.step1.classPlaceholder")}
|
||||
{...field}
|
||||
/>
|
||||
{watchedClasses.length > 1 && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-8 w-8 p-0"
|
||||
onClick={() => handleRemoveClass(index)}
|
||||
>
|
||||
<LuX className="size-4" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
{form.formState.errors.classes && (
|
||||
<p className="text-sm font-medium text-destructive">
|
||||
{form.formState.errors.classes.message}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</form>
|
||||
</Form>
|
||||
|
||||
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
|
||||
<Button type="button" onClick={onCancel} className="sm:flex-1">
|
||||
{t("button.cancel", { ns: "common" })}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={form.handleSubmit(onSubmit)}
|
||||
variant="select"
|
||||
className="flex items-center justify-center gap-2 sm:flex-1"
|
||||
disabled={!form.formState.isValid}
|
||||
>
|
||||
{t("button.continue", { ns: "common" })}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
479
web/src/components/classification/wizard/Step2StateArea.tsx
Normal file
479
web/src/components/classification/wizard/Step2StateArea.tsx
Normal file
@ -0,0 +1,479 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useState, useMemo, useRef, useCallback, useEffect } from "react";
|
||||
import useSWR from "swr";
|
||||
import { FrigateConfig } from "@/types/frigateConfig";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/ui/popover";
|
||||
import { LuX, LuPlus } from "react-icons/lu";
|
||||
import { Stage, Layer, Rect, Transformer } from "react-konva";
|
||||
import Konva from "konva";
|
||||
import { useResizeObserver } from "@/hooks/resize-observer";
|
||||
import { useApiHost } from "@/api";
|
||||
import { resolveCameraName } from "@/hooks/use-camera-friendly-name";
|
||||
import Heading from "@/components/ui/heading";
|
||||
import { isMobile } from "react-device-detect";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export type CameraAreaConfig = {
|
||||
camera: string;
|
||||
crop: [number, number, number, number];
|
||||
};
|
||||
|
||||
export type Step2FormData = {
|
||||
cameraAreas: CameraAreaConfig[];
|
||||
};
|
||||
|
||||
type Step2StateAreaProps = {
|
||||
initialData?: Partial<Step2FormData>;
|
||||
onNext: (data: Step2FormData) => void;
|
||||
onBack: () => void;
|
||||
};
|
||||
|
||||
export default function Step2StateArea({
|
||||
initialData,
|
||||
onNext,
|
||||
onBack,
|
||||
}: Step2StateAreaProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
const { data: config } = useSWR<FrigateConfig>("config");
|
||||
const apiHost = useApiHost();
|
||||
|
||||
const [cameraAreas, setCameraAreas] = useState<CameraAreaConfig[]>(
|
||||
initialData?.cameraAreas || [],
|
||||
);
|
||||
const [selectedCameraIndex, setSelectedCameraIndex] = useState<number>(0);
|
||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
|
||||
const [imageLoaded, setImageLoaded] = useState(false);
|
||||
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const imageRef = useRef<HTMLImageElement>(null);
|
||||
const stageRef = useRef<Konva.Stage>(null);
|
||||
const rectRef = useRef<Konva.Rect>(null);
|
||||
const transformerRef = useRef<Konva.Transformer>(null);
|
||||
|
||||
const [{ width: containerWidth }] = useResizeObserver(containerRef);
|
||||
|
||||
const availableCameras = useMemo(() => {
|
||||
if (!config) return [];
|
||||
|
||||
const selectedCameraNames = cameraAreas.map((ca) => ca.camera);
|
||||
return Object.entries(config.cameras)
|
||||
.sort()
|
||||
.filter(
|
||||
([name, cam]) =>
|
||||
cam.enabled &&
|
||||
cam.enabled_in_config &&
|
||||
!selectedCameraNames.includes(name),
|
||||
)
|
||||
.map(([name]) => ({
|
||||
name,
|
||||
displayName: resolveCameraName(config, name),
|
||||
}));
|
||||
}, [config, cameraAreas]);
|
||||
|
||||
const selectedCamera = useMemo(() => {
|
||||
if (cameraAreas.length === 0) return null;
|
||||
return cameraAreas[selectedCameraIndex];
|
||||
}, [cameraAreas, selectedCameraIndex]);
|
||||
|
||||
const selectedCameraConfig = useMemo(() => {
|
||||
if (!config || !selectedCamera) return null;
|
||||
return config.cameras[selectedCamera.camera];
|
||||
}, [config, selectedCamera]);
|
||||
|
||||
const imageSize = useMemo(() => {
|
||||
if (!containerWidth || !selectedCameraConfig) {
|
||||
return { width: 0, height: 0 };
|
||||
}
|
||||
|
||||
const containerAspectRatio = 16 / 9;
|
||||
const containerHeight = containerWidth / containerAspectRatio;
|
||||
|
||||
const cameraAspectRatio =
|
||||
selectedCameraConfig.detect.width / selectedCameraConfig.detect.height;
|
||||
|
||||
// Fit camera within 16:9 container
|
||||
let imageWidth, imageHeight;
|
||||
if (cameraAspectRatio > containerAspectRatio) {
|
||||
imageWidth = containerWidth;
|
||||
imageHeight = imageWidth / cameraAspectRatio;
|
||||
} else {
|
||||
imageHeight = containerHeight;
|
||||
imageWidth = imageHeight * cameraAspectRatio;
|
||||
}
|
||||
|
||||
return { width: imageWidth, height: imageHeight };
|
||||
}, [containerWidth, selectedCameraConfig]);
|
||||
|
||||
const handleAddCamera = useCallback(
|
||||
(cameraName: string) => {
|
||||
// Calculate a square crop in pixel space
|
||||
const camera = config?.cameras[cameraName];
|
||||
if (!camera) return;
|
||||
|
||||
const cameraAspect = camera.detect.width / camera.detect.height;
|
||||
const cropSize = 0.3;
|
||||
let x1, y1, x2, y2;
|
||||
|
||||
if (cameraAspect >= 1) {
|
||||
const pixelSize = cropSize * camera.detect.height;
|
||||
const normalizedWidth = pixelSize / camera.detect.width;
|
||||
x1 = (1 - normalizedWidth) / 2;
|
||||
y1 = (1 - cropSize) / 2;
|
||||
x2 = x1 + normalizedWidth;
|
||||
y2 = y1 + cropSize;
|
||||
} else {
|
||||
const pixelSize = cropSize * camera.detect.width;
|
||||
const normalizedHeight = pixelSize / camera.detect.height;
|
||||
x1 = (1 - cropSize) / 2;
|
||||
y1 = (1 - normalizedHeight) / 2;
|
||||
x2 = x1 + cropSize;
|
||||
y2 = y1 + normalizedHeight;
|
||||
}
|
||||
|
||||
const newArea: CameraAreaConfig = {
|
||||
camera: cameraName,
|
||||
crop: [x1, y1, x2, y2],
|
||||
};
|
||||
setCameraAreas([...cameraAreas, newArea]);
|
||||
setSelectedCameraIndex(cameraAreas.length);
|
||||
setIsPopoverOpen(false);
|
||||
},
|
||||
[cameraAreas, config],
|
||||
);
|
||||
|
||||
const handleRemoveCamera = useCallback(
|
||||
(index: number) => {
|
||||
const newAreas = cameraAreas.filter((_, i) => i !== index);
|
||||
setCameraAreas(newAreas);
|
||||
if (selectedCameraIndex >= newAreas.length) {
|
||||
setSelectedCameraIndex(Math.max(0, newAreas.length - 1));
|
||||
}
|
||||
},
|
||||
[cameraAreas, selectedCameraIndex],
|
||||
);
|
||||
|
||||
const handleCropChange = useCallback(
|
||||
(crop: [number, number, number, number]) => {
|
||||
const newAreas = [...cameraAreas];
|
||||
newAreas[selectedCameraIndex] = {
|
||||
...newAreas[selectedCameraIndex],
|
||||
crop,
|
||||
};
|
||||
setCameraAreas(newAreas);
|
||||
},
|
||||
[cameraAreas, selectedCameraIndex],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setImageLoaded(false);
|
||||
}, [selectedCamera]);
|
||||
|
||||
useEffect(() => {
|
||||
const rect = rectRef.current;
|
||||
const transformer = transformerRef.current;
|
||||
|
||||
if (
|
||||
rect &&
|
||||
transformer &&
|
||||
selectedCamera &&
|
||||
imageSize.width > 0 &&
|
||||
imageLoaded
|
||||
) {
|
||||
rect.scaleX(1);
|
||||
rect.scaleY(1);
|
||||
transformer.nodes([rect]);
|
||||
transformer.getLayer()?.batchDraw();
|
||||
}
|
||||
}, [selectedCamera, imageSize, imageLoaded]);
|
||||
|
||||
const handleRectChange = useCallback(() => {
|
||||
const rect = rectRef.current;
|
||||
|
||||
if (rect && imageSize.width > 0) {
|
||||
const actualWidth = rect.width() * rect.scaleX();
|
||||
const actualHeight = rect.height() * rect.scaleY();
|
||||
|
||||
// Average dimensions to maintain perfect square
|
||||
const size = (actualWidth + actualHeight) / 2;
|
||||
|
||||
rect.width(size);
|
||||
rect.height(size);
|
||||
rect.scaleX(1);
|
||||
rect.scaleY(1);
|
||||
|
||||
const x1 = rect.x() / imageSize.width;
|
||||
const y1 = rect.y() / imageSize.height;
|
||||
const x2 = (rect.x() + size) / imageSize.width;
|
||||
const y2 = (rect.y() + size) / imageSize.height;
|
||||
|
||||
handleCropChange([x1, y1, x2, y2]);
|
||||
}
|
||||
}, [imageSize, handleCropChange]);
|
||||
|
||||
const handleContinue = useCallback(() => {
|
||||
onNext({ cameraAreas });
|
||||
}, [cameraAreas, onNext]);
|
||||
|
||||
const canContinue = cameraAreas.length > 0;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div
|
||||
className={cn(
|
||||
"flex gap-4 overflow-hidden",
|
||||
isMobile ? "flex-col" : "flex-row",
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-shrink-0 flex-col gap-2 overflow-y-auto rounded-lg bg-secondary p-4",
|
||||
isMobile ? "w-full" : "w-64",
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<h3 className="text-sm font-medium">{t("wizard.step2.cameras")}</h3>
|
||||
{availableCameras.length > 0 ? (
|
||||
<Popover
|
||||
open={isPopoverOpen}
|
||||
onOpenChange={setIsPopoverOpen}
|
||||
modal={true}
|
||||
>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
className="size-6 rounded-md bg-secondary-foreground p-1 text-background"
|
||||
aria-label="Add camera"
|
||||
>
|
||||
<LuPlus />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
className="scrollbar-container w-64 border bg-background p-3 shadow-lg"
|
||||
align="start"
|
||||
sideOffset={5}
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
>
|
||||
<div className="flex flex-col gap-2">
|
||||
<Heading as="h4" className="text-sm text-primary-variant">
|
||||
{t("wizard.step2.selectCamera")}
|
||||
</Heading>
|
||||
<div className="scrollbar-container flex max-h-[30vh] flex-col gap-1 overflow-y-auto">
|
||||
{availableCameras.map((cam) => (
|
||||
<Button
|
||||
key={cam.name}
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-auto justify-start p-2 capitalize text-primary"
|
||||
onClick={() => {
|
||||
handleAddCamera(cam.name);
|
||||
}}
|
||||
>
|
||||
{cam.displayName}
|
||||
</Button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
) : (
|
||||
<Button
|
||||
variant="secondary"
|
||||
className="size-6 cursor-not-allowed rounded-md bg-muted p-1 text-muted-foreground"
|
||||
disabled
|
||||
>
|
||||
<LuPlus />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-1">
|
||||
{cameraAreas.map((area, index) => {
|
||||
const isSelected = index === selectedCameraIndex;
|
||||
const displayName = resolveCameraName(config, area.camera);
|
||||
|
||||
return (
|
||||
<div
|
||||
key={area.camera}
|
||||
className={`flex items-center justify-between rounded-md p-2 ${
|
||||
isSelected
|
||||
? "bg-selected/20 ring-1 ring-selected"
|
||||
: "hover:bg-secondary/50"
|
||||
} cursor-pointer`}
|
||||
onClick={() => setSelectedCameraIndex(index)}
|
||||
>
|
||||
<span className="text-sm capitalize">{displayName}</span>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-6 w-6 p-0"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleRemoveCamera(index);
|
||||
}}
|
||||
>
|
||||
<LuX className="size-4" />
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{cameraAreas.length === 0 && (
|
||||
<div className="flex flex-1 items-center justify-center text-center text-sm text-muted-foreground">
|
||||
{t("wizard.step2.noCameras")}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-1 items-center justify-center overflow-hidden rounded-lg p-4">
|
||||
<div
|
||||
ref={containerRef}
|
||||
className="flex items-center justify-center"
|
||||
style={{
|
||||
width: "100%",
|
||||
aspectRatio: "16 / 9",
|
||||
maxHeight: "100%",
|
||||
}}
|
||||
>
|
||||
{selectedCamera && selectedCameraConfig && imageSize.width > 0 ? (
|
||||
<div
|
||||
style={{
|
||||
width: imageSize.width,
|
||||
height: imageSize.height,
|
||||
position: "relative",
|
||||
}}
|
||||
>
|
||||
<img
|
||||
ref={imageRef}
|
||||
src={`${apiHost}api/${selectedCamera.camera}/latest.jpg?h=500`}
|
||||
alt={resolveCameraName(config, selectedCamera.camera)}
|
||||
className="h-full w-full object-contain"
|
||||
onLoad={() => setImageLoaded(true)}
|
||||
/>
|
||||
<Stage
|
||||
ref={stageRef}
|
||||
width={imageSize.width}
|
||||
height={imageSize.height}
|
||||
className="absolute inset-0"
|
||||
>
|
||||
<Layer>
|
||||
<Rect
|
||||
ref={rectRef}
|
||||
x={selectedCamera.crop[0] * imageSize.width}
|
||||
y={selectedCamera.crop[1] * imageSize.height}
|
||||
width={
|
||||
(selectedCamera.crop[2] - selectedCamera.crop[0]) *
|
||||
imageSize.width
|
||||
}
|
||||
height={
|
||||
(selectedCamera.crop[3] - selectedCamera.crop[1]) *
|
||||
imageSize.height
|
||||
}
|
||||
stroke="#3b82f6"
|
||||
strokeWidth={2}
|
||||
fill="rgba(59, 130, 246, 0.1)"
|
||||
draggable
|
||||
dragBoundFunc={(pos) => {
|
||||
const rect = rectRef.current;
|
||||
if (!rect) return pos;
|
||||
|
||||
const size = rect.width();
|
||||
const x = Math.max(
|
||||
0,
|
||||
Math.min(pos.x, imageSize.width - size),
|
||||
);
|
||||
const y = Math.max(
|
||||
0,
|
||||
Math.min(pos.y, imageSize.height - size),
|
||||
);
|
||||
|
||||
return { x, y };
|
||||
}}
|
||||
onDragEnd={handleRectChange}
|
||||
onTransformEnd={handleRectChange}
|
||||
/>
|
||||
<Transformer
|
||||
ref={transformerRef}
|
||||
rotateEnabled={false}
|
||||
enabledAnchors={[
|
||||
"top-left",
|
||||
"top-right",
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
]}
|
||||
boundBoxFunc={(_oldBox, newBox) => {
|
||||
const minSize = 50;
|
||||
const maxSize = Math.min(
|
||||
imageSize.width,
|
||||
imageSize.height,
|
||||
);
|
||||
|
||||
// Clamp dimensions to stage bounds first
|
||||
const clampedWidth = Math.max(
|
||||
minSize,
|
||||
Math.min(newBox.width, maxSize),
|
||||
);
|
||||
const clampedHeight = Math.max(
|
||||
minSize,
|
||||
Math.min(newBox.height, maxSize),
|
||||
);
|
||||
|
||||
// Enforce square using average
|
||||
const size = (clampedWidth + clampedHeight) / 2;
|
||||
|
||||
// Clamp position to keep square within bounds
|
||||
const x = Math.max(
|
||||
0,
|
||||
Math.min(newBox.x, imageSize.width - size),
|
||||
);
|
||||
const y = Math.max(
|
||||
0,
|
||||
Math.min(newBox.y, imageSize.height - size),
|
||||
);
|
||||
|
||||
return {
|
||||
...newBox,
|
||||
x,
|
||||
y,
|
||||
width: size,
|
||||
height: size,
|
||||
};
|
||||
}}
|
||||
/>
|
||||
</Layer>
|
||||
</Stage>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center justify-center text-muted-foreground">
|
||||
{t("wizard.step2.selectCameraPrompt")}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
|
||||
<Button type="button" onClick={onBack} className="sm:flex-1">
|
||||
{t("button.back", { ns: "common" })}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={handleContinue}
|
||||
variant="select"
|
||||
className="flex items-center justify-center gap-2 sm:flex-1"
|
||||
disabled={!canContinue}
|
||||
>
|
||||
{t("button.continue", { ns: "common" })}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
444
web/src/components/classification/wizard/Step3ChooseExamples.tsx
Normal file
444
web/src/components/classification/wizard/Step3ChooseExamples.tsx
Normal file
@ -0,0 +1,444 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useState, useEffect, useCallback, useMemo } from "react";
|
||||
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
||||
import axios from "axios";
|
||||
import { toast } from "sonner";
|
||||
import { Step1FormData } from "./Step1NameAndDefine";
|
||||
import { Step2FormData } from "./Step2StateArea";
|
||||
import useSWR from "swr";
|
||||
import { baseUrl } from "@/api/baseUrl";
|
||||
import { isMobile } from "react-device-detect";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export type Step3FormData = {
|
||||
examplesGenerated: boolean;
|
||||
imageClassifications?: { [imageName: string]: string };
|
||||
};
|
||||
|
||||
type Step3ChooseExamplesProps = {
|
||||
step1Data: Step1FormData;
|
||||
step2Data?: Step2FormData;
|
||||
initialData?: Partial<Step3FormData>;
|
||||
onClose: () => void;
|
||||
onBack: () => void;
|
||||
};
|
||||
|
||||
export default function Step3ChooseExamples({
|
||||
step1Data,
|
||||
step2Data,
|
||||
initialData,
|
||||
onClose,
|
||||
onBack,
|
||||
}: Step3ChooseExamplesProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
const [isGenerating, setIsGenerating] = useState(false);
|
||||
const [hasGenerated, setHasGenerated] = useState(
|
||||
initialData?.examplesGenerated || false,
|
||||
);
|
||||
const [imageClassifications, setImageClassifications] = useState<{
|
||||
[imageName: string]: string;
|
||||
}>(initialData?.imageClassifications || {});
|
||||
const [isTraining, setIsTraining] = useState(false);
|
||||
const [isProcessing, setIsProcessing] = useState(false);
|
||||
const [currentClassIndex, setCurrentClassIndex] = useState(0);
|
||||
const [selectedImages, setSelectedImages] = useState<Set<string>>(new Set());
|
||||
|
||||
const { data: trainImages, mutate: refreshTrainImages } = useSWR<string[]>(
|
||||
hasGenerated ? `classification/${step1Data.modelName}/train` : null,
|
||||
);
|
||||
|
||||
const unknownImages = useMemo(() => {
|
||||
if (!trainImages) return [];
|
||||
return trainImages;
|
||||
}, [trainImages]);
|
||||
|
||||
const toggleImageSelection = useCallback((imageName: string) => {
|
||||
setSelectedImages((prev) => {
|
||||
const newSet = new Set(prev);
|
||||
if (newSet.has(imageName)) {
|
||||
newSet.delete(imageName);
|
||||
} else {
|
||||
newSet.add(imageName);
|
||||
}
|
||||
return newSet;
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Get all classes (excluding "none" - it will be auto-assigned)
|
||||
const allClasses = useMemo(() => {
|
||||
return [...step1Data.classes];
|
||||
}, [step1Data.classes]);
|
||||
|
||||
const currentClass = allClasses[currentClassIndex];
|
||||
|
||||
const processClassificationsAndTrain = useCallback(
|
||||
async (classifications: { [imageName: string]: string }) => {
|
||||
// Step 1: Create config for the new model
|
||||
const modelConfig: {
|
||||
enabled: boolean;
|
||||
name: string;
|
||||
threshold: number;
|
||||
state_config?: {
|
||||
cameras: Record<string, { crop: number[] }>;
|
||||
motion: boolean;
|
||||
};
|
||||
object_config?: { objects: string[]; classification_type: string };
|
||||
} = {
|
||||
enabled: true,
|
||||
name: step1Data.modelName,
|
||||
threshold: 0.8,
|
||||
};
|
||||
|
||||
if (step1Data.modelType === "state") {
|
||||
// State model config
|
||||
const cameras: Record<string, { crop: number[] }> = {};
|
||||
step2Data?.cameraAreas.forEach((area) => {
|
||||
cameras[area.camera] = {
|
||||
crop: area.crop,
|
||||
};
|
||||
});
|
||||
|
||||
modelConfig.state_config = {
|
||||
cameras,
|
||||
motion: true,
|
||||
};
|
||||
} else {
|
||||
// Object model config
|
||||
modelConfig.object_config = {
|
||||
objects: step1Data.objectLabel ? [step1Data.objectLabel] : [],
|
||||
classification_type: step1Data.objectType || "sub_label",
|
||||
} as { objects: string[]; classification_type: string };
|
||||
}
|
||||
|
||||
// Update config via config API
|
||||
await axios.put("/config/set", {
|
||||
requires_restart: 0,
|
||||
update_topic: `config/classification/custom/${step1Data.modelName}`,
|
||||
config_data: {
|
||||
classification: {
|
||||
custom: {
|
||||
[step1Data.modelName]: modelConfig,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Step 2: Classify each image by moving it to the correct category folder
|
||||
const categorizePromises = Object.entries(classifications).map(
|
||||
([imageName, className]) => {
|
||||
if (!className) return Promise.resolve();
|
||||
return axios.post(
|
||||
`/classification/${step1Data.modelName}/dataset/categorize`,
|
||||
{
|
||||
training_file: imageName,
|
||||
category: className === "none" ? "none" : className,
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
await Promise.all(categorizePromises);
|
||||
|
||||
// Step 3: Kick off training
|
||||
await axios.post(`/classification/${step1Data.modelName}/train`);
|
||||
|
||||
toast.success(t("wizard.step3.trainingStarted"));
|
||||
setIsTraining(true);
|
||||
},
|
||||
[step1Data, step2Data, t],
|
||||
);
|
||||
|
||||
const handleContinueClassification = useCallback(async () => {
|
||||
// Mark selected images with current class
|
||||
const newClassifications = { ...imageClassifications };
|
||||
selectedImages.forEach((imageName) => {
|
||||
newClassifications[imageName] = currentClass;
|
||||
});
|
||||
|
||||
// Check if we're on the last class to select
|
||||
const isLastClass = currentClassIndex === allClasses.length - 1;
|
||||
|
||||
if (isLastClass) {
|
||||
// Assign remaining unclassified images
|
||||
unknownImages.slice(0, 24).forEach((imageName) => {
|
||||
if (!newClassifications[imageName]) {
|
||||
// For state models with 2 classes, assign to the last class
|
||||
// For object models, assign to "none"
|
||||
if (step1Data.modelType === "state" && allClasses.length === 2) {
|
||||
newClassifications[imageName] = allClasses[allClasses.length - 1];
|
||||
} else {
|
||||
newClassifications[imageName] = "none";
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// All done, trigger training immediately
|
||||
setImageClassifications(newClassifications);
|
||||
setIsProcessing(true);
|
||||
|
||||
try {
|
||||
await processClassificationsAndTrain(newClassifications);
|
||||
} catch (error) {
|
||||
const axiosError = error as {
|
||||
response?: { data?: { message?: string; detail?: string } };
|
||||
message?: string;
|
||||
};
|
||||
const errorMessage =
|
||||
axiosError.response?.data?.message ||
|
||||
axiosError.response?.data?.detail ||
|
||||
axiosError.message ||
|
||||
"Failed to classify images";
|
||||
|
||||
toast.error(
|
||||
t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
|
||||
);
|
||||
setIsProcessing(false);
|
||||
}
|
||||
} else {
|
||||
// Move to next class
|
||||
setImageClassifications(newClassifications);
|
||||
setCurrentClassIndex((prev) => prev + 1);
|
||||
setSelectedImages(new Set());
|
||||
}
|
||||
}, [
|
||||
selectedImages,
|
||||
currentClass,
|
||||
currentClassIndex,
|
||||
allClasses,
|
||||
imageClassifications,
|
||||
unknownImages,
|
||||
step1Data,
|
||||
processClassificationsAndTrain,
|
||||
t,
|
||||
]);
|
||||
|
||||
const generateExamples = useCallback(async () => {
|
||||
setIsGenerating(true);
|
||||
|
||||
try {
|
||||
if (step1Data.modelType === "state") {
|
||||
// For state models, use cameras and crop areas
|
||||
if (!step2Data?.cameraAreas || step2Data.cameraAreas.length === 0) {
|
||||
toast.error(t("wizard.step3.errors.noCameras"));
|
||||
setIsGenerating(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const cameras: { [key: string]: [number, number, number, number] } = {};
|
||||
step2Data.cameraAreas.forEach((area) => {
|
||||
cameras[area.camera] = area.crop;
|
||||
});
|
||||
|
||||
await axios.post("/classification/generate_examples/state", {
|
||||
model_name: step1Data.modelName,
|
||||
cameras,
|
||||
});
|
||||
} else {
|
||||
// For object models, use label
|
||||
if (!step1Data.objectLabel) {
|
||||
toast.error(t("wizard.step3.errors.noObjectLabel"));
|
||||
setIsGenerating(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// For now, use all enabled cameras
|
||||
// TODO: In the future, we might want to let users select specific cameras
|
||||
await axios.post("/classification/generate_examples/object", {
|
||||
model_name: step1Data.modelName,
|
||||
label: step1Data.objectLabel,
|
||||
});
|
||||
}
|
||||
|
||||
setHasGenerated(true);
|
||||
toast.success(t("wizard.step3.generateSuccess"));
|
||||
|
||||
await refreshTrainImages();
|
||||
} catch (error) {
|
||||
const axiosError = error as {
|
||||
response?: { data?: { message?: string; detail?: string } };
|
||||
message?: string;
|
||||
};
|
||||
const errorMessage =
|
||||
axiosError.response?.data?.message ||
|
||||
axiosError.response?.data?.detail ||
|
||||
axiosError.message ||
|
||||
"Failed to generate examples";
|
||||
|
||||
toast.error(
|
||||
t("wizard.step3.errors.generateFailed", { error: errorMessage }),
|
||||
);
|
||||
} finally {
|
||||
setIsGenerating(false);
|
||||
}
|
||||
}, [step1Data, step2Data, t, refreshTrainImages]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!hasGenerated && !isGenerating) {
|
||||
generateExamples();
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
const handleContinue = useCallback(async () => {
|
||||
setIsProcessing(true);
|
||||
try {
|
||||
await processClassificationsAndTrain(imageClassifications);
|
||||
} catch (error) {
|
||||
const axiosError = error as {
|
||||
response?: { data?: { message?: string; detail?: string } };
|
||||
message?: string;
|
||||
};
|
||||
const errorMessage =
|
||||
axiosError.response?.data?.message ||
|
||||
axiosError.response?.data?.detail ||
|
||||
axiosError.message ||
|
||||
"Failed to classify images";
|
||||
|
||||
toast.error(
|
||||
t("wizard.step3.errors.classifyFailed", { error: errorMessage }),
|
||||
);
|
||||
setIsProcessing(false);
|
||||
}
|
||||
}, [imageClassifications, processClassificationsAndTrain, t]);
|
||||
|
||||
const unclassifiedImages = useMemo(() => {
|
||||
if (!unknownImages) return [];
|
||||
const images = unknownImages.slice(0, 24);
|
||||
|
||||
// Only filter if we have any classifications
|
||||
if (Object.keys(imageClassifications).length === 0) {
|
||||
return images;
|
||||
}
|
||||
|
||||
return images.filter((img) => !imageClassifications[img]);
|
||||
}, [unknownImages, imageClassifications]);
|
||||
|
||||
const allImagesClassified = useMemo(() => {
|
||||
return unclassifiedImages.length === 0;
|
||||
}, [unclassifiedImages]);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
{isTraining ? (
|
||||
<div className="flex flex-col items-center gap-6 py-12">
|
||||
<ActivityIndicator className="size-12" />
|
||||
<div className="text-center">
|
||||
<h3 className="mb-2 text-lg font-medium">
|
||||
{t("wizard.step3.training.title")}
|
||||
</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{t("wizard.step3.training.description")}
|
||||
</p>
|
||||
</div>
|
||||
<Button onClick={onClose} variant="select" className="mt-4">
|
||||
{t("button.close", { ns: "common" })}
|
||||
</Button>
|
||||
</div>
|
||||
) : isGenerating ? (
|
||||
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
|
||||
<ActivityIndicator className="size-12" />
|
||||
<div className="text-center">
|
||||
<h3 className="mb-2 text-lg font-medium">
|
||||
{t("wizard.step3.generating.title")}
|
||||
</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{t("wizard.step3.generating.description")}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
) : hasGenerated ? (
|
||||
<div className="flex flex-col gap-4">
|
||||
{!allImagesClassified && (
|
||||
<div className="text-center">
|
||||
<h3 className="text-lg font-medium">
|
||||
{t("wizard.step3.selectImagesPrompt", {
|
||||
className: currentClass,
|
||||
})}
|
||||
</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{t("wizard.step3.selectImagesDescription")}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-lg bg-secondary/30 p-4",
|
||||
isMobile && "max-h-[60vh] overflow-y-auto",
|
||||
)}
|
||||
>
|
||||
{!unknownImages || unknownImages.length === 0 ? (
|
||||
<div className="flex h-[40vh] flex-col items-center justify-center gap-4">
|
||||
<p className="text-muted-foreground">
|
||||
{t("wizard.step3.noImages")}
|
||||
</p>
|
||||
<Button onClick={generateExamples} variant="select">
|
||||
{t("wizard.step3.retryGenerate")}
|
||||
</Button>
|
||||
</div>
|
||||
) : allImagesClassified && isProcessing ? (
|
||||
<div className="flex h-[40vh] flex-col items-center justify-center gap-4">
|
||||
<ActivityIndicator className="size-12" />
|
||||
<p className="text-lg font-medium">
|
||||
{t("wizard.step3.classifying")}
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="grid grid-cols-2 gap-4 sm:grid-cols-6">
|
||||
{unclassifiedImages.map((imageName, index) => {
|
||||
const isSelected = selectedImages.has(imageName);
|
||||
return (
|
||||
<div
|
||||
key={imageName}
|
||||
className={cn(
|
||||
"aspect-square cursor-pointer overflow-hidden rounded-lg border-2 bg-background transition-all",
|
||||
isSelected && "border-selected ring-2 ring-selected",
|
||||
)}
|
||||
onClick={() => toggleImageSelection(imageName)}
|
||||
>
|
||||
<img
|
||||
src={`${baseUrl}clips/${step1Data.modelName}/train/${imageName}`}
|
||||
alt={`Example ${index + 1}`}
|
||||
className="h-full w-full object-cover"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex h-[50vh] flex-col items-center justify-center gap-4">
|
||||
<p className="text-sm text-destructive">
|
||||
{t("wizard.step3.errors.generationFailed")}
|
||||
</p>
|
||||
<Button onClick={generateExamples} variant="select">
|
||||
{t("wizard.step3.retryGenerate")}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!isTraining && (
|
||||
<div className="flex flex-col gap-3 pt-3 sm:flex-row sm:justify-end sm:gap-4">
|
||||
<Button type="button" onClick={onBack} className="sm:flex-1">
|
||||
{t("button.back", { ns: "common" })}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={
|
||||
allImagesClassified
|
||||
? handleContinue
|
||||
: handleContinueClassification
|
||||
}
|
||||
variant="select"
|
||||
className="flex items-center justify-center gap-2 sm:flex-1"
|
||||
disabled={!hasGenerated || isGenerating || isProcessing}
|
||||
>
|
||||
{isProcessing && <ActivityIndicator className="size-4" />}
|
||||
{t("button.continue", { ns: "common" })}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -10,11 +10,14 @@ import {
|
||||
CustomClassificationModelConfig,
|
||||
FrigateConfig,
|
||||
} from "@/types/frigateConfig";
|
||||
import { useMemo, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { isMobile } from "react-device-detect";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { FaFolderPlus } from "react-icons/fa";
|
||||
import { MdModelTraining } from "react-icons/md";
|
||||
import useSWR from "swr";
|
||||
import Heading from "@/components/ui/heading";
|
||||
import { useOverlayState } from "@/hooks/use-overlay-state";
|
||||
|
||||
const allModelTypes = ["objects", "states"] as const;
|
||||
type ModelType = (typeof allModelTypes)[number];
|
||||
@ -26,11 +29,24 @@ export default function ModelSelectionView({
|
||||
onClick,
|
||||
}: ModelSelectionViewProps) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
const [page, setPage] = useState<ModelType>("objects");
|
||||
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
|
||||
const { data: config } = useSWR<FrigateConfig>("config", {
|
||||
revalidateOnFocus: false,
|
||||
});
|
||||
const [page, setPage] = useOverlayState<ModelType>("objects", "objects");
|
||||
const [pageToggle, setPageToggle] = useOptimisticState(
|
||||
page || "objects",
|
||||
setPage,
|
||||
100,
|
||||
);
|
||||
const { data: config, mutate: refreshConfig } = useSWR<FrigateConfig>(
|
||||
"config",
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
},
|
||||
);
|
||||
|
||||
// title
|
||||
|
||||
useEffect(() => {
|
||||
document.title = t("documentTitle");
|
||||
}, [t]);
|
||||
|
||||
// data
|
||||
|
||||
@ -64,15 +80,15 @@ export default function ModelSelectionView({
|
||||
return <ActivityIndicator />;
|
||||
}
|
||||
|
||||
if (classificationConfigs.length == 0) {
|
||||
return <div>You need to setup a custom model configuration.</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex size-full flex-col p-2">
|
||||
<ClassificationModelWizardDialog
|
||||
open={newModel}
|
||||
onClose={() => setNewModel(false)}
|
||||
defaultModelType={pageToggle === "objects" ? "object" : "state"}
|
||||
onClose={() => {
|
||||
setNewModel(false);
|
||||
refreshConfig();
|
||||
}}
|
||||
/>
|
||||
|
||||
<div className="flex h-12 w-full items-center justify-between">
|
||||
@ -84,7 +100,6 @@ export default function ModelSelectionView({
|
||||
value={pageToggle}
|
||||
onValueChange={(value: ModelType) => {
|
||||
if (value) {
|
||||
// Restrict viewer navigation
|
||||
setPageToggle(value);
|
||||
}
|
||||
}}
|
||||
@ -117,13 +132,46 @@ export default function ModelSelectionView({
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex size-full gap-2 p-2">
|
||||
{selectedClassificationConfigs.map((config) => (
|
||||
<ModelCard
|
||||
key={config.name}
|
||||
config={config}
|
||||
onClick={() => onClick(config)}
|
||||
{selectedClassificationConfigs.length === 0 ? (
|
||||
<NoModelsView
|
||||
onCreateModel={() => setNewModel(true)}
|
||||
modelType={pageToggle}
|
||||
/>
|
||||
))}
|
||||
) : (
|
||||
selectedClassificationConfigs.map((config) => (
|
||||
<ModelCard
|
||||
key={config.name}
|
||||
config={config}
|
||||
onClick={() => onClick(config)}
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function NoModelsView({
|
||||
onCreateModel,
|
||||
modelType,
|
||||
}: {
|
||||
onCreateModel: () => void;
|
||||
modelType: ModelType;
|
||||
}) {
|
||||
const { t } = useTranslation(["views/classificationModel"]);
|
||||
const typeKey = modelType === "objects" ? "object" : "state";
|
||||
|
||||
return (
|
||||
<div className="flex size-full items-center justify-center">
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<MdModelTraining className="size-8" />
|
||||
<Heading as="h4">{t(`noModels.${typeKey}.title`)}</Heading>
|
||||
<div className="mb-3 text-center text-secondary-foreground">
|
||||
{t(`noModels.${typeKey}.description`)}
|
||||
</div>
|
||||
<Button size="sm" variant="select" onClick={onCreateModel}>
|
||||
{t(`noModels.${typeKey}.buttonText`)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@ -139,13 +187,17 @@ function ModelCard({ config, onClick }: ModelCardProps) {
|
||||
}>(`classification/${config.name}/dataset`, { revalidateOnFocus: false });
|
||||
|
||||
const coverImage = useMemo(() => {
|
||||
if (!dataset?.length) {
|
||||
if (!dataset) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const keys = Object.keys(dataset).filter((key) => key != "none");
|
||||
const selectedKey = keys[0];
|
||||
|
||||
if (!dataset[selectedKey]) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
name: selectedKey,
|
||||
img: dataset[selectedKey][0],
|
||||
|
||||
@ -642,6 +642,7 @@ function DatasetGrid({
|
||||
filepath: `clips/${modelName}/dataset/${categoryName}/${image}`,
|
||||
name: "",
|
||||
}}
|
||||
showArea={false}
|
||||
selected={selectedImages.includes(image)}
|
||||
i18nLibrary="views/classificationModel"
|
||||
onClick={(data, _) => onClickImages([data.filename], true)}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user