mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-12-06 13:34:13 +03:00
Improve classification (#20863)
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions
This commit is contained in:
parent
a374a60756
commit
99a363c047
@ -37,6 +37,8 @@ from frigate.models import Event
|
|||||||
from frigate.util.classification import (
|
from frigate.util.classification import (
|
||||||
collect_object_classification_examples,
|
collect_object_classification_examples,
|
||||||
collect_state_classification_examples,
|
collect_state_classification_examples,
|
||||||
|
get_dataset_image_count,
|
||||||
|
read_training_metadata,
|
||||||
)
|
)
|
||||||
from frigate.util.file import get_event_snapshot
|
from frigate.util.file import get_event_snapshot
|
||||||
|
|
||||||
@ -564,23 +566,54 @@ def get_classification_dataset(name: str):
|
|||||||
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
|
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
|
||||||
|
|
||||||
if not os.path.exists(dataset_dir):
|
if not os.path.exists(dataset_dir):
|
||||||
return JSONResponse(status_code=200, content={})
|
return JSONResponse(
|
||||||
|
status_code=200, content={"categories": {}, "training_metadata": None}
|
||||||
|
)
|
||||||
|
|
||||||
for name in os.listdir(dataset_dir):
|
for category_name in os.listdir(dataset_dir):
|
||||||
category_dir = os.path.join(dataset_dir, name)
|
category_dir = os.path.join(dataset_dir, category_name)
|
||||||
|
|
||||||
if not os.path.isdir(category_dir):
|
if not os.path.isdir(category_dir):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
dataset_dict[name] = []
|
dataset_dict[category_name] = []
|
||||||
|
|
||||||
for file in filter(
|
for file in filter(
|
||||||
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
|
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
|
||||||
os.listdir(category_dir),
|
os.listdir(category_dir),
|
||||||
):
|
):
|
||||||
dataset_dict[name].append(file)
|
dataset_dict[category_name].append(file)
|
||||||
|
|
||||||
return JSONResponse(status_code=200, content=dataset_dict)
|
# Get training metadata
|
||||||
|
metadata = read_training_metadata(sanitize_filename(name))
|
||||||
|
current_image_count = get_dataset_image_count(sanitize_filename(name))
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
training_metadata = {
|
||||||
|
"has_trained": False,
|
||||||
|
"last_training_date": None,
|
||||||
|
"last_training_image_count": 0,
|
||||||
|
"current_image_count": current_image_count,
|
||||||
|
"new_images_count": current_image_count,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
last_training_count = metadata.get("last_training_image_count", 0)
|
||||||
|
new_images_count = max(0, current_image_count - last_training_count)
|
||||||
|
training_metadata = {
|
||||||
|
"has_trained": True,
|
||||||
|
"last_training_date": metadata.get("last_training_date"),
|
||||||
|
"last_training_image_count": last_training_count,
|
||||||
|
"current_image_count": current_image_count,
|
||||||
|
"new_images_count": new_images_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"categories": dataset_dict,
|
||||||
|
"training_metadata": training_metadata,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|||||||
@ -23,6 +23,7 @@ class ModelStatusTypesEnum(str, Enum):
|
|||||||
error = "error"
|
error = "error"
|
||||||
training = "training"
|
training = "training"
|
||||||
complete = "complete"
|
complete = "complete"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
class TrackedObjectUpdateTypesEnum(str, Enum):
|
class TrackedObjectUpdateTypesEnum(str, Enum):
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
"""Util for classification models."""
|
"""Util for classification models."""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@ -27,10 +29,96 @@ from frigate.util.process import FrigateProcess
|
|||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
EPOCHS = 50
|
EPOCHS = 50
|
||||||
LEARNING_RATE = 0.001
|
LEARNING_RATE = 0.001
|
||||||
|
TRAINING_METADATA_FILE = ".training_metadata.json"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def write_training_metadata(model_name: str, image_count: int) -> None:
|
||||||
|
"""
|
||||||
|
Write training metadata to a hidden file in the model's clips directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the classification model
|
||||||
|
image_count: Number of images used in training
|
||||||
|
"""
|
||||||
|
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
|
||||||
|
os.makedirs(clips_model_dir, exist_ok=True)
|
||||||
|
|
||||||
|
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
|
||||||
|
metadata = {
|
||||||
|
"last_training_date": datetime.datetime.now().isoformat(),
|
||||||
|
"last_training_image_count": image_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(metadata_path, "w") as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
logger.info(f"Wrote training metadata for {model_name}: {image_count} images")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to write training metadata for {model_name}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def read_training_metadata(model_name: str) -> dict[str, any] | None:
|
||||||
|
"""
|
||||||
|
Read training metadata from the hidden file in the model's clips directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the classification model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with last_training_date and last_training_image_count, or None if not found
|
||||||
|
"""
|
||||||
|
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
|
||||||
|
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
|
||||||
|
|
||||||
|
if not os.path.exists(metadata_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
return metadata
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read training metadata for {model_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_image_count(model_name: str) -> int:
|
||||||
|
"""
|
||||||
|
Count the total number of images in the model's dataset directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the classification model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total count of images across all categories
|
||||||
|
"""
|
||||||
|
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||||
|
|
||||||
|
if not os.path.exists(dataset_dir):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total_count = 0
|
||||||
|
try:
|
||||||
|
for category in os.listdir(dataset_dir):
|
||||||
|
category_dir = os.path.join(dataset_dir, category)
|
||||||
|
if not os.path.isdir(category_dir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_files = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(category_dir)
|
||||||
|
if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))
|
||||||
|
]
|
||||||
|
total_count += len(image_files)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to count dataset images for {model_name}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return total_count
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTrainingProcess(FrigateProcess):
|
class ClassificationTrainingProcess(FrigateProcess):
|
||||||
def __init__(self, model_name: str) -> None:
|
def __init__(self, model_name: str) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -42,7 +130,8 @@ class ClassificationTrainingProcess(FrigateProcess):
|
|||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.pre_run_setup()
|
self.pre_run_setup()
|
||||||
self.__train_classification_model()
|
success = self.__train_classification_model()
|
||||||
|
exit(0 if success else 1)
|
||||||
|
|
||||||
def __generate_representative_dataset_factory(self, dataset_dir: str):
|
def __generate_representative_dataset_factory(self, dataset_dir: str):
|
||||||
def generate_representative_dataset():
|
def generate_representative_dataset():
|
||||||
@ -65,85 +154,117 @@ class ClassificationTrainingProcess(FrigateProcess):
|
|||||||
@redirect_output_to_logger(logger, logging.DEBUG)
|
@redirect_output_to_logger(logger, logging.DEBUG)
|
||||||
def __train_classification_model(self) -> bool:
|
def __train_classification_model(self) -> bool:
|
||||||
"""Train a classification model."""
|
"""Train a classification model."""
|
||||||
|
try:
|
||||||
|
# import in the function so that tensorflow is not initialized multiple times
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.keras import layers, models, optimizers
|
||||||
|
from tensorflow.keras.applications import MobileNetV2
|
||||||
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||||
|
|
||||||
# import in the function so that tensorflow is not initialized multiple times
|
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
|
||||||
import tensorflow as tf
|
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
from tensorflow.keras import layers, models, optimizers
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
from tensorflow.keras.applications import MobileNetV2
|
|
||||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
||||||
|
|
||||||
logger.info(f"Kicking off classification training for {self.model_name}.")
|
num_classes = len(
|
||||||
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
|
[
|
||||||
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
d
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
for d in os.listdir(dataset_dir)
|
||||||
num_classes = len(
|
if os.path.isdir(os.path.join(dataset_dir, d))
|
||||||
[
|
]
|
||||||
d
|
)
|
||||||
for d in os.listdir(dataset_dir)
|
|
||||||
if os.path.isdir(os.path.join(dataset_dir, d))
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start with imagenet base model with 35% of channels in each layer
|
if num_classes < 2:
|
||||||
base_model = MobileNetV2(
|
logger.error(
|
||||||
input_shape=(224, 224, 3),
|
f"Training failed for {self.model_name}: Need at least 2 classes, found {num_classes}"
|
||||||
include_top=False,
|
)
|
||||||
weights="imagenet",
|
return False
|
||||||
alpha=0.35,
|
|
||||||
)
|
|
||||||
base_model.trainable = False # Freeze pre-trained layers
|
|
||||||
|
|
||||||
model = models.Sequential(
|
# Start with imagenet base model with 35% of channels in each layer
|
||||||
[
|
base_model = MobileNetV2(
|
||||||
base_model,
|
input_shape=(224, 224, 3),
|
||||||
layers.GlobalAveragePooling2D(),
|
include_top=False,
|
||||||
layers.Dense(128, activation="relu"),
|
weights="imagenet",
|
||||||
layers.Dropout(0.3),
|
alpha=0.35,
|
||||||
layers.Dense(num_classes, activation="softmax"),
|
)
|
||||||
]
|
base_model.trainable = False # Freeze pre-trained layers
|
||||||
)
|
|
||||||
|
|
||||||
model.compile(
|
model = models.Sequential(
|
||||||
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE),
|
[
|
||||||
loss="categorical_crossentropy",
|
base_model,
|
||||||
metrics=["accuracy"],
|
layers.GlobalAveragePooling2D(),
|
||||||
)
|
layers.Dense(128, activation="relu"),
|
||||||
|
layers.Dropout(0.3),
|
||||||
|
layers.Dense(num_classes, activation="softmax"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# create training set
|
model.compile(
|
||||||
datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2)
|
optimizer=optimizers.Adam(learning_rate=LEARNING_RATE),
|
||||||
train_gen = datagen.flow_from_directory(
|
loss="categorical_crossentropy",
|
||||||
dataset_dir,
|
metrics=["accuracy"],
|
||||||
target_size=(224, 224),
|
)
|
||||||
batch_size=BATCH_SIZE,
|
|
||||||
class_mode="categorical",
|
|
||||||
subset="training",
|
|
||||||
)
|
|
||||||
|
|
||||||
# write labelmap
|
# create training set
|
||||||
class_indices = train_gen.class_indices
|
datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2)
|
||||||
index_to_class = {v: k for k, v in class_indices.items()}
|
train_gen = datagen.flow_from_directory(
|
||||||
sorted_classes = [index_to_class[i] for i in range(len(index_to_class))]
|
dataset_dir,
|
||||||
with open(os.path.join(model_dir, "labelmap.txt"), "w") as f:
|
target_size=(224, 224),
|
||||||
for class_name in sorted_classes:
|
batch_size=BATCH_SIZE,
|
||||||
f.write(f"{class_name}\n")
|
class_mode="categorical",
|
||||||
|
subset="training",
|
||||||
|
)
|
||||||
|
|
||||||
# train the model
|
total_images = train_gen.samples
|
||||||
model.fit(train_gen, epochs=EPOCHS, verbose=0)
|
logger.debug(
|
||||||
|
f"Training {self.model_name}: {total_images} images across {num_classes} classes"
|
||||||
|
)
|
||||||
|
|
||||||
# convert model to tflite
|
# write labelmap
|
||||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
class_indices = train_gen.class_indices
|
||||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
index_to_class = {v: k for k, v in class_indices.items()}
|
||||||
converter.representative_dataset = (
|
sorted_classes = [index_to_class[i] for i in range(len(index_to_class))]
|
||||||
self.__generate_representative_dataset_factory(dataset_dir)
|
with open(os.path.join(model_dir, "labelmap.txt"), "w") as f:
|
||||||
)
|
for class_name in sorted_classes:
|
||||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
f.write(f"{class_name}\n")
|
||||||
converter.inference_input_type = tf.uint8
|
|
||||||
converter.inference_output_type = tf.uint8
|
|
||||||
tflite_model = converter.convert()
|
|
||||||
|
|
||||||
# write model
|
# train the model
|
||||||
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
|
logger.debug(f"Training {self.model_name} for {EPOCHS} epochs...")
|
||||||
f.write(tflite_model)
|
model.fit(train_gen, epochs=EPOCHS, verbose=0)
|
||||||
|
logger.debug(f"Converting {self.model_name} to TFLite...")
|
||||||
|
|
||||||
|
# convert model to tflite
|
||||||
|
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||||
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
|
converter.representative_dataset = (
|
||||||
|
self.__generate_representative_dataset_factory(dataset_dir)
|
||||||
|
)
|
||||||
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||||
|
converter.inference_input_type = tf.uint8
|
||||||
|
converter.inference_output_type = tf.uint8
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
# write model
|
||||||
|
model_path = os.path.join(model_dir, "model.tflite")
|
||||||
|
with open(model_path, "wb") as f:
|
||||||
|
f.write(tflite_model)
|
||||||
|
|
||||||
|
# verify model file was written successfully
|
||||||
|
if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
|
||||||
|
logger.error(
|
||||||
|
f"Training failed for {self.model_name}: Model file was not created or is empty"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# write training metadata with image count
|
||||||
|
dataset_image_count = get_dataset_image_count(self.model_name)
|
||||||
|
write_training_metadata(self.model_name, dataset_image_count)
|
||||||
|
|
||||||
|
logger.info(f"Finished training {self.model_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training failed for {self.model_name}: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def kickoff_model_training(
|
def kickoff_model_training(
|
||||||
@ -165,18 +286,36 @@ def kickoff_model_training(
|
|||||||
training_process.start()
|
training_process.start()
|
||||||
training_process.join()
|
training_process.join()
|
||||||
|
|
||||||
# reload model and mark training as complete
|
# check if training succeeded by examining the exit code
|
||||||
embeddingRequestor.send_data(
|
training_success = training_process.exitcode == 0
|
||||||
EmbeddingsRequestEnum.reload_classification_model.value,
|
|
||||||
{"model_name": model_name},
|
if training_success:
|
||||||
)
|
# reload model and mark training as complete
|
||||||
requestor.send_data(
|
embeddingRequestor.send_data(
|
||||||
UPDATE_MODEL_STATE,
|
EmbeddingsRequestEnum.reload_classification_model.value,
|
||||||
{
|
{"model_name": model_name},
|
||||||
"model": model_name,
|
)
|
||||||
"state": ModelStatusTypesEnum.complete,
|
requestor.send_data(
|
||||||
},
|
UPDATE_MODEL_STATE,
|
||||||
)
|
{
|
||||||
|
"model": model_name,
|
||||||
|
"state": ModelStatusTypesEnum.complete,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Training subprocess failed for {model_name} (exit code: {training_process.exitcode})"
|
||||||
|
)
|
||||||
|
# mark training as failed so UI shows error state
|
||||||
|
# don't reload the model since it failed
|
||||||
|
requestor.send_data(
|
||||||
|
UPDATE_MODEL_STATE,
|
||||||
|
{
|
||||||
|
"model": model_name,
|
||||||
|
"state": ModelStatusTypesEnum.failed,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
requestor.stop()
|
requestor.stop()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,11 @@
|
|||||||
"deleteModels": "Delete Models",
|
"deleteModels": "Delete Models",
|
||||||
"editModel": "Edit Model"
|
"editModel": "Edit Model"
|
||||||
},
|
},
|
||||||
|
"tooltip": {
|
||||||
|
"trainingInProgress": "Model is currently training",
|
||||||
|
"noNewImages": "No new images to train. Classify more images in the dataset first.",
|
||||||
|
"modelNotReady": "Model is not ready for training"
|
||||||
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
"success": {
|
"success": {
|
||||||
"deletedCategory": "Deleted Class",
|
"deletedCategory": "Deleted Class",
|
||||||
@ -30,7 +35,8 @@
|
|||||||
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
|
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
|
||||||
"deleteModelFailed": "Failed to delete model: {{errorMessage}}",
|
"deleteModelFailed": "Failed to delete model: {{errorMessage}}",
|
||||||
"categorizeFailed": "Failed to categorize image: {{errorMessage}}",
|
"categorizeFailed": "Failed to categorize image: {{errorMessage}}",
|
||||||
"trainingFailed": "Failed to start model training: {{errorMessage}}",
|
"trainingFailed": "Model training failed. Check Frigate logs for details.",
|
||||||
|
"trainingFailedToStart": "Failed to start model training: {{errorMessage}}",
|
||||||
"updateModelFailed": "Failed to update model: {{errorMessage}}",
|
"updateModelFailed": "Failed to update model: {{errorMessage}}",
|
||||||
"renameCategoryFailed": "Failed to rename class: {{errorMessage}}"
|
"renameCategoryFailed": "Failed to rename class: {{errorMessage}}"
|
||||||
}
|
}
|
||||||
@ -143,6 +149,8 @@
|
|||||||
"step3": {
|
"step3": {
|
||||||
"selectImagesPrompt": "Select all images with: {{className}}",
|
"selectImagesPrompt": "Select all images with: {{className}}",
|
||||||
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
|
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
|
||||||
|
"allImagesRequired_one": "Please classify all images. {{count}} image remaining.",
|
||||||
|
"allImagesRequired_other": "Please classify all images. {{count}} images remaining.",
|
||||||
"generating": {
|
"generating": {
|
||||||
"title": "Generating Sample Images",
|
"title": "Generating Sample Images",
|
||||||
"description": "Frigate is pulling representative images from your recordings. This may take a moment..."
|
"description": "Frigate is pulling representative images from your recordings. This may take a moment..."
|
||||||
|
|||||||
@ -10,6 +10,12 @@ import useSWR from "swr";
|
|||||||
import { baseUrl } from "@/api/baseUrl";
|
import { baseUrl } from "@/api/baseUrl";
|
||||||
import { isMobile } from "react-device-detect";
|
import { isMobile } from "react-device-detect";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
TooltipContent,
|
||||||
|
TooltipTrigger,
|
||||||
|
} from "@/components/ui/tooltip";
|
||||||
|
import { TooltipPortal } from "@radix-ui/react-tooltip";
|
||||||
|
|
||||||
export type Step3FormData = {
|
export type Step3FormData = {
|
||||||
examplesGenerated: boolean;
|
examplesGenerated: boolean;
|
||||||
@ -317,6 +323,19 @@ export default function Step3ChooseExamples({
|
|||||||
return unclassifiedImages.length === 0;
|
return unclassifiedImages.length === 0;
|
||||||
}, [unclassifiedImages]);
|
}, [unclassifiedImages]);
|
||||||
|
|
||||||
|
// For state models on the last class, require all images to be classified
|
||||||
|
const isLastClass = currentClassIndex === allClasses.length - 1;
|
||||||
|
const canProceed = useMemo(() => {
|
||||||
|
if (
|
||||||
|
step1Data.modelType === "state" &&
|
||||||
|
isLastClass &&
|
||||||
|
!allImagesClassified
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}, [step1Data.modelType, isLastClass, allImagesClassified]);
|
||||||
|
|
||||||
const handleBack = useCallback(() => {
|
const handleBack = useCallback(() => {
|
||||||
if (currentClassIndex > 0) {
|
if (currentClassIndex > 0) {
|
||||||
const previousClass = allClasses[currentClassIndex - 1];
|
const previousClass = allClasses[currentClassIndex - 1];
|
||||||
@ -438,20 +457,35 @@ export default function Step3ChooseExamples({
|
|||||||
<Button type="button" onClick={handleBack} className="sm:flex-1">
|
<Button type="button" onClick={handleBack} className="sm:flex-1">
|
||||||
{t("button.back", { ns: "common" })}
|
{t("button.back", { ns: "common" })}
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Tooltip>
|
||||||
type="button"
|
<TooltipTrigger asChild>
|
||||||
onClick={
|
<Button
|
||||||
allImagesClassified
|
type="button"
|
||||||
? handleContinue
|
onClick={
|
||||||
: handleContinueClassification
|
allImagesClassified
|
||||||
}
|
? handleContinue
|
||||||
variant="select"
|
: handleContinueClassification
|
||||||
className="flex items-center justify-center gap-2 sm:flex-1"
|
}
|
||||||
disabled={!hasGenerated || isGenerating || isProcessing}
|
variant="select"
|
||||||
>
|
className="flex items-center justify-center gap-2 sm:flex-1"
|
||||||
{isProcessing && <ActivityIndicator className="size-4" />}
|
disabled={
|
||||||
{t("button.continue", { ns: "common" })}
|
!hasGenerated || isGenerating || isProcessing || !canProceed
|
||||||
</Button>
|
}
|
||||||
|
>
|
||||||
|
{isProcessing && <ActivityIndicator className="size-4" />}
|
||||||
|
{t("button.continue", { ns: "common" })}
|
||||||
|
</Button>
|
||||||
|
</TooltipTrigger>
|
||||||
|
{!canProceed && (
|
||||||
|
<TooltipPortal>
|
||||||
|
<TooltipContent>
|
||||||
|
{t("wizard.step3.allImagesRequired", {
|
||||||
|
count: unclassifiedImages.length,
|
||||||
|
})}
|
||||||
|
</TooltipContent>
|
||||||
|
</TooltipPortal>
|
||||||
|
)}
|
||||||
|
</Tooltip>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -87,7 +87,8 @@ export type ModelState =
|
|||||||
| "downloaded"
|
| "downloaded"
|
||||||
| "error"
|
| "error"
|
||||||
| "training"
|
| "training"
|
||||||
| "complete";
|
| "complete"
|
||||||
|
| "failed";
|
||||||
|
|
||||||
export type EmbeddingsReindexProgressType = {
|
export type EmbeddingsReindexProgressType = {
|
||||||
thumbnails: number;
|
thumbnails: number;
|
||||||
|
|||||||
@ -102,6 +102,12 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
position: "top-center",
|
position: "top-center",
|
||||||
});
|
});
|
||||||
setWasTraining(false);
|
setWasTraining(false);
|
||||||
|
refreshDataset();
|
||||||
|
} else if (modelState == "failed") {
|
||||||
|
toast.error(t("toast.error.trainingFailed"), {
|
||||||
|
position: "top-center",
|
||||||
|
});
|
||||||
|
setWasTraining(false);
|
||||||
}
|
}
|
||||||
// only refresh when modelState changes
|
// only refresh when modelState changes
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
@ -112,10 +118,20 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
||||||
`classification/${model.name}/train`,
|
`classification/${model.name}/train`,
|
||||||
);
|
);
|
||||||
const { data: dataset, mutate: refreshDataset } = useSWR<{
|
const { data: datasetResponse, mutate: refreshDataset } = useSWR<{
|
||||||
[id: string]: string[];
|
categories: { [id: string]: string[] };
|
||||||
|
training_metadata: {
|
||||||
|
has_trained: boolean;
|
||||||
|
last_training_date: string | null;
|
||||||
|
last_training_image_count: number;
|
||||||
|
current_image_count: number;
|
||||||
|
new_images_count: number;
|
||||||
|
} | null;
|
||||||
}>(`classification/${model.name}/dataset`);
|
}>(`classification/${model.name}/dataset`);
|
||||||
|
|
||||||
|
const dataset = datasetResponse?.categories || {};
|
||||||
|
const trainingMetadata = datasetResponse?.training_metadata;
|
||||||
|
|
||||||
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
|
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
|
||||||
|
|
||||||
const refreshAll = useCallback(() => {
|
const refreshAll = useCallback(() => {
|
||||||
@ -177,7 +193,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
error.response?.data?.detail ||
|
error.response?.data?.detail ||
|
||||||
"Unknown error";
|
"Unknown error";
|
||||||
|
|
||||||
toast.error(t("toast.error.trainingFailed", { errorMessage }), {
|
toast.error(t("toast.error.trainingFailedToStart", { errorMessage }), {
|
||||||
position: "top-center",
|
position: "top-center",
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@ -421,19 +437,48 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
filterValues={{ classes: Object.keys(dataset || {}) }}
|
filterValues={{ classes: Object.keys(dataset || {}) }}
|
||||||
onUpdateFilter={setTrainFilter}
|
onUpdateFilter={setTrainFilter}
|
||||||
/>
|
/>
|
||||||
<Button
|
<Tooltip>
|
||||||
className="flex justify-center gap-2"
|
<TooltipTrigger asChild>
|
||||||
onClick={trainModel}
|
<Button
|
||||||
variant="select"
|
className="flex justify-center gap-2"
|
||||||
disabled={modelState != "complete"}
|
onClick={trainModel}
|
||||||
>
|
variant={modelState == "failed" ? "destructive" : "select"}
|
||||||
{modelState == "training" ? (
|
disabled={
|
||||||
<ActivityIndicator size={20} />
|
(modelState != "complete" && modelState != "failed") ||
|
||||||
) : (
|
(trainingMetadata?.new_images_count ?? 0) === 0
|
||||||
<HiSparkles className="text-white" />
|
}
|
||||||
|
>
|
||||||
|
{modelState == "training" ? (
|
||||||
|
<ActivityIndicator size={20} />
|
||||||
|
) : (
|
||||||
|
<HiSparkles className="text-white" />
|
||||||
|
)}
|
||||||
|
{isDesktop && (
|
||||||
|
<>
|
||||||
|
{t("button.trainModel")}
|
||||||
|
{trainingMetadata?.new_images_count !== undefined &&
|
||||||
|
trainingMetadata.new_images_count > 0 && (
|
||||||
|
<span className="text-sm text-selected-foreground">
|
||||||
|
({trainingMetadata.new_images_count})
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</TooltipTrigger>
|
||||||
|
{((trainingMetadata?.new_images_count ?? 0) === 0 ||
|
||||||
|
(modelState != "complete" && modelState != "failed")) && (
|
||||||
|
<TooltipPortal>
|
||||||
|
<TooltipContent>
|
||||||
|
{modelState == "training"
|
||||||
|
? t("tooltip.trainingInProgress")
|
||||||
|
: trainingMetadata?.new_images_count === 0
|
||||||
|
? t("tooltip.noNewImages")
|
||||||
|
: t("tooltip.modelNotReady")}
|
||||||
|
</TooltipContent>
|
||||||
|
</TooltipPortal>
|
||||||
)}
|
)}
|
||||||
{isDesktop && t("button.trainModel")}
|
</Tooltip>
|
||||||
</Button>
|
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user