mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-12-06 13:34:13 +03:00
Improve classification (#20863)
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions
Some checks are pending
CI / AMD64 Build (push) Waiting to run
CI / ARM Build (push) Waiting to run
CI / Jetson Jetpack 6 (push) Waiting to run
CI / AMD64 Extra Build (push) Blocked by required conditions
CI / ARM Extra Build (push) Blocked by required conditions
CI / Synaptics Build (push) Blocked by required conditions
CI / Assemble and push default build (push) Blocked by required conditions
This commit is contained in:
parent
a374a60756
commit
99a363c047
@ -37,6 +37,8 @@ from frigate.models import Event
|
|||||||
from frigate.util.classification import (
|
from frigate.util.classification import (
|
||||||
collect_object_classification_examples,
|
collect_object_classification_examples,
|
||||||
collect_state_classification_examples,
|
collect_state_classification_examples,
|
||||||
|
get_dataset_image_count,
|
||||||
|
read_training_metadata,
|
||||||
)
|
)
|
||||||
from frigate.util.file import get_event_snapshot
|
from frigate.util.file import get_event_snapshot
|
||||||
|
|
||||||
@ -564,23 +566,54 @@ def get_classification_dataset(name: str):
|
|||||||
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
|
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
|
||||||
|
|
||||||
if not os.path.exists(dataset_dir):
|
if not os.path.exists(dataset_dir):
|
||||||
return JSONResponse(status_code=200, content={})
|
return JSONResponse(
|
||||||
|
status_code=200, content={"categories": {}, "training_metadata": None}
|
||||||
|
)
|
||||||
|
|
||||||
for name in os.listdir(dataset_dir):
|
for category_name in os.listdir(dataset_dir):
|
||||||
category_dir = os.path.join(dataset_dir, name)
|
category_dir = os.path.join(dataset_dir, category_name)
|
||||||
|
|
||||||
if not os.path.isdir(category_dir):
|
if not os.path.isdir(category_dir):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
dataset_dict[name] = []
|
dataset_dict[category_name] = []
|
||||||
|
|
||||||
for file in filter(
|
for file in filter(
|
||||||
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
|
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
|
||||||
os.listdir(category_dir),
|
os.listdir(category_dir),
|
||||||
):
|
):
|
||||||
dataset_dict[name].append(file)
|
dataset_dict[category_name].append(file)
|
||||||
|
|
||||||
return JSONResponse(status_code=200, content=dataset_dict)
|
# Get training metadata
|
||||||
|
metadata = read_training_metadata(sanitize_filename(name))
|
||||||
|
current_image_count = get_dataset_image_count(sanitize_filename(name))
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
training_metadata = {
|
||||||
|
"has_trained": False,
|
||||||
|
"last_training_date": None,
|
||||||
|
"last_training_image_count": 0,
|
||||||
|
"current_image_count": current_image_count,
|
||||||
|
"new_images_count": current_image_count,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
last_training_count = metadata.get("last_training_image_count", 0)
|
||||||
|
new_images_count = max(0, current_image_count - last_training_count)
|
||||||
|
training_metadata = {
|
||||||
|
"has_trained": True,
|
||||||
|
"last_training_date": metadata.get("last_training_date"),
|
||||||
|
"last_training_image_count": last_training_count,
|
||||||
|
"current_image_count": current_image_count,
|
||||||
|
"new_images_count": new_images_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"categories": dataset_dict,
|
||||||
|
"training_metadata": training_metadata,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|||||||
@ -23,6 +23,7 @@ class ModelStatusTypesEnum(str, Enum):
|
|||||||
error = "error"
|
error = "error"
|
||||||
training = "training"
|
training = "training"
|
||||||
complete = "complete"
|
complete = "complete"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
class TrackedObjectUpdateTypesEnum(str, Enum):
|
class TrackedObjectUpdateTypesEnum(str, Enum):
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
"""Util for classification models."""
|
"""Util for classification models."""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@ -27,10 +29,96 @@ from frigate.util.process import FrigateProcess
|
|||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
EPOCHS = 50
|
EPOCHS = 50
|
||||||
LEARNING_RATE = 0.001
|
LEARNING_RATE = 0.001
|
||||||
|
TRAINING_METADATA_FILE = ".training_metadata.json"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def write_training_metadata(model_name: str, image_count: int) -> None:
|
||||||
|
"""
|
||||||
|
Write training metadata to a hidden file in the model's clips directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the classification model
|
||||||
|
image_count: Number of images used in training
|
||||||
|
"""
|
||||||
|
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
|
||||||
|
os.makedirs(clips_model_dir, exist_ok=True)
|
||||||
|
|
||||||
|
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
|
||||||
|
metadata = {
|
||||||
|
"last_training_date": datetime.datetime.now().isoformat(),
|
||||||
|
"last_training_image_count": image_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(metadata_path, "w") as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
logger.info(f"Wrote training metadata for {model_name}: {image_count} images")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to write training metadata for {model_name}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def read_training_metadata(model_name: str) -> dict[str, any] | None:
|
||||||
|
"""
|
||||||
|
Read training metadata from the hidden file in the model's clips directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the classification model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with last_training_date and last_training_image_count, or None if not found
|
||||||
|
"""
|
||||||
|
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
|
||||||
|
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
|
||||||
|
|
||||||
|
if not os.path.exists(metadata_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
return metadata
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read training metadata for {model_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_image_count(model_name: str) -> int:
|
||||||
|
"""
|
||||||
|
Count the total number of images in the model's dataset directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the classification model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total count of images across all categories
|
||||||
|
"""
|
||||||
|
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||||
|
|
||||||
|
if not os.path.exists(dataset_dir):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total_count = 0
|
||||||
|
try:
|
||||||
|
for category in os.listdir(dataset_dir):
|
||||||
|
category_dir = os.path.join(dataset_dir, category)
|
||||||
|
if not os.path.isdir(category_dir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_files = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(category_dir)
|
||||||
|
if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))
|
||||||
|
]
|
||||||
|
total_count += len(image_files)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to count dataset images for {model_name}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return total_count
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTrainingProcess(FrigateProcess):
|
class ClassificationTrainingProcess(FrigateProcess):
|
||||||
def __init__(self, model_name: str) -> None:
|
def __init__(self, model_name: str) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -42,7 +130,8 @@ class ClassificationTrainingProcess(FrigateProcess):
|
|||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
self.pre_run_setup()
|
self.pre_run_setup()
|
||||||
self.__train_classification_model()
|
success = self.__train_classification_model()
|
||||||
|
exit(0 if success else 1)
|
||||||
|
|
||||||
def __generate_representative_dataset_factory(self, dataset_dir: str):
|
def __generate_representative_dataset_factory(self, dataset_dir: str):
|
||||||
def generate_representative_dataset():
|
def generate_representative_dataset():
|
||||||
@ -65,17 +154,17 @@ class ClassificationTrainingProcess(FrigateProcess):
|
|||||||
@redirect_output_to_logger(logger, logging.DEBUG)
|
@redirect_output_to_logger(logger, logging.DEBUG)
|
||||||
def __train_classification_model(self) -> bool:
|
def __train_classification_model(self) -> bool:
|
||||||
"""Train a classification model."""
|
"""Train a classification model."""
|
||||||
|
try:
|
||||||
# import in the function so that tensorflow is not initialized multiple times
|
# import in the function so that tensorflow is not initialized multiple times
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.keras import layers, models, optimizers
|
from tensorflow.keras import layers, models, optimizers
|
||||||
from tensorflow.keras.applications import MobileNetV2
|
from tensorflow.keras.applications import MobileNetV2
|
||||||
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||||
|
|
||||||
logger.info(f"Kicking off classification training for {self.model_name}.")
|
|
||||||
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
|
dataset_dir = os.path.join(CLIPS_DIR, self.model_name, "dataset")
|
||||||
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
model_dir = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
num_classes = len(
|
num_classes = len(
|
||||||
[
|
[
|
||||||
d
|
d
|
||||||
@ -84,6 +173,12 @@ class ClassificationTrainingProcess(FrigateProcess):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if num_classes < 2:
|
||||||
|
logger.error(
|
||||||
|
f"Training failed for {self.model_name}: Need at least 2 classes, found {num_classes}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
# Start with imagenet base model with 35% of channels in each layer
|
# Start with imagenet base model with 35% of channels in each layer
|
||||||
base_model = MobileNetV2(
|
base_model = MobileNetV2(
|
||||||
input_shape=(224, 224, 3),
|
input_shape=(224, 224, 3),
|
||||||
@ -119,6 +214,11 @@ class ClassificationTrainingProcess(FrigateProcess):
|
|||||||
subset="training",
|
subset="training",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
total_images = train_gen.samples
|
||||||
|
logger.debug(
|
||||||
|
f"Training {self.model_name}: {total_images} images across {num_classes} classes"
|
||||||
|
)
|
||||||
|
|
||||||
# write labelmap
|
# write labelmap
|
||||||
class_indices = train_gen.class_indices
|
class_indices = train_gen.class_indices
|
||||||
index_to_class = {v: k for k, v in class_indices.items()}
|
index_to_class = {v: k for k, v in class_indices.items()}
|
||||||
@ -128,7 +228,9 @@ class ClassificationTrainingProcess(FrigateProcess):
|
|||||||
f.write(f"{class_name}\n")
|
f.write(f"{class_name}\n")
|
||||||
|
|
||||||
# train the model
|
# train the model
|
||||||
|
logger.debug(f"Training {self.model_name} for {EPOCHS} epochs...")
|
||||||
model.fit(train_gen, epochs=EPOCHS, verbose=0)
|
model.fit(train_gen, epochs=EPOCHS, verbose=0)
|
||||||
|
logger.debug(f"Converting {self.model_name} to TFLite...")
|
||||||
|
|
||||||
# convert model to tflite
|
# convert model to tflite
|
||||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||||
@ -142,9 +244,28 @@ class ClassificationTrainingProcess(FrigateProcess):
|
|||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
# write model
|
# write model
|
||||||
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
|
model_path = os.path.join(model_dir, "model.tflite")
|
||||||
|
with open(model_path, "wb") as f:
|
||||||
f.write(tflite_model)
|
f.write(tflite_model)
|
||||||
|
|
||||||
|
# verify model file was written successfully
|
||||||
|
if not os.path.exists(model_path) or os.path.getsize(model_path) == 0:
|
||||||
|
logger.error(
|
||||||
|
f"Training failed for {self.model_name}: Model file was not created or is empty"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# write training metadata with image count
|
||||||
|
dataset_image_count = get_dataset_image_count(self.model_name)
|
||||||
|
write_training_metadata(self.model_name, dataset_image_count)
|
||||||
|
|
||||||
|
logger.info(f"Finished training {self.model_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training failed for {self.model_name}: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def kickoff_model_training(
|
def kickoff_model_training(
|
||||||
embeddingRequestor: EmbeddingsRequestor, model_name: str
|
embeddingRequestor: EmbeddingsRequestor, model_name: str
|
||||||
@ -165,6 +286,10 @@ def kickoff_model_training(
|
|||||||
training_process.start()
|
training_process.start()
|
||||||
training_process.join()
|
training_process.join()
|
||||||
|
|
||||||
|
# check if training succeeded by examining the exit code
|
||||||
|
training_success = training_process.exitcode == 0
|
||||||
|
|
||||||
|
if training_success:
|
||||||
# reload model and mark training as complete
|
# reload model and mark training as complete
|
||||||
embeddingRequestor.send_data(
|
embeddingRequestor.send_data(
|
||||||
EmbeddingsRequestEnum.reload_classification_model.value,
|
EmbeddingsRequestEnum.reload_classification_model.value,
|
||||||
@ -177,6 +302,20 @@ def kickoff_model_training(
|
|||||||
"state": ModelStatusTypesEnum.complete,
|
"state": ModelStatusTypesEnum.complete,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Training subprocess failed for {model_name} (exit code: {training_process.exitcode})"
|
||||||
|
)
|
||||||
|
# mark training as failed so UI shows error state
|
||||||
|
# don't reload the model since it failed
|
||||||
|
requestor.send_data(
|
||||||
|
UPDATE_MODEL_STATE,
|
||||||
|
{
|
||||||
|
"model": model_name,
|
||||||
|
"state": ModelStatusTypesEnum.failed,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
requestor.stop()
|
requestor.stop()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,11 @@
|
|||||||
"deleteModels": "Delete Models",
|
"deleteModels": "Delete Models",
|
||||||
"editModel": "Edit Model"
|
"editModel": "Edit Model"
|
||||||
},
|
},
|
||||||
|
"tooltip": {
|
||||||
|
"trainingInProgress": "Model is currently training",
|
||||||
|
"noNewImages": "No new images to train. Classify more images in the dataset first.",
|
||||||
|
"modelNotReady": "Model is not ready for training"
|
||||||
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
"success": {
|
"success": {
|
||||||
"deletedCategory": "Deleted Class",
|
"deletedCategory": "Deleted Class",
|
||||||
@ -30,7 +35,8 @@
|
|||||||
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
|
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
|
||||||
"deleteModelFailed": "Failed to delete model: {{errorMessage}}",
|
"deleteModelFailed": "Failed to delete model: {{errorMessage}}",
|
||||||
"categorizeFailed": "Failed to categorize image: {{errorMessage}}",
|
"categorizeFailed": "Failed to categorize image: {{errorMessage}}",
|
||||||
"trainingFailed": "Failed to start model training: {{errorMessage}}",
|
"trainingFailed": "Model training failed. Check Frigate logs for details.",
|
||||||
|
"trainingFailedToStart": "Failed to start model training: {{errorMessage}}",
|
||||||
"updateModelFailed": "Failed to update model: {{errorMessage}}",
|
"updateModelFailed": "Failed to update model: {{errorMessage}}",
|
||||||
"renameCategoryFailed": "Failed to rename class: {{errorMessage}}"
|
"renameCategoryFailed": "Failed to rename class: {{errorMessage}}"
|
||||||
}
|
}
|
||||||
@ -143,6 +149,8 @@
|
|||||||
"step3": {
|
"step3": {
|
||||||
"selectImagesPrompt": "Select all images with: {{className}}",
|
"selectImagesPrompt": "Select all images with: {{className}}",
|
||||||
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
|
"selectImagesDescription": "Click on images to select them. Click Continue when you're done with this class.",
|
||||||
|
"allImagesRequired_one": "Please classify all images. {{count}} image remaining.",
|
||||||
|
"allImagesRequired_other": "Please classify all images. {{count}} images remaining.",
|
||||||
"generating": {
|
"generating": {
|
||||||
"title": "Generating Sample Images",
|
"title": "Generating Sample Images",
|
||||||
"description": "Frigate is pulling representative images from your recordings. This may take a moment..."
|
"description": "Frigate is pulling representative images from your recordings. This may take a moment..."
|
||||||
|
|||||||
@ -10,6 +10,12 @@ import useSWR from "swr";
|
|||||||
import { baseUrl } from "@/api/baseUrl";
|
import { baseUrl } from "@/api/baseUrl";
|
||||||
import { isMobile } from "react-device-detect";
|
import { isMobile } from "react-device-detect";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
import {
|
||||||
|
Tooltip,
|
||||||
|
TooltipContent,
|
||||||
|
TooltipTrigger,
|
||||||
|
} from "@/components/ui/tooltip";
|
||||||
|
import { TooltipPortal } from "@radix-ui/react-tooltip";
|
||||||
|
|
||||||
export type Step3FormData = {
|
export type Step3FormData = {
|
||||||
examplesGenerated: boolean;
|
examplesGenerated: boolean;
|
||||||
@ -317,6 +323,19 @@ export default function Step3ChooseExamples({
|
|||||||
return unclassifiedImages.length === 0;
|
return unclassifiedImages.length === 0;
|
||||||
}, [unclassifiedImages]);
|
}, [unclassifiedImages]);
|
||||||
|
|
||||||
|
// For state models on the last class, require all images to be classified
|
||||||
|
const isLastClass = currentClassIndex === allClasses.length - 1;
|
||||||
|
const canProceed = useMemo(() => {
|
||||||
|
if (
|
||||||
|
step1Data.modelType === "state" &&
|
||||||
|
isLastClass &&
|
||||||
|
!allImagesClassified
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}, [step1Data.modelType, isLastClass, allImagesClassified]);
|
||||||
|
|
||||||
const handleBack = useCallback(() => {
|
const handleBack = useCallback(() => {
|
||||||
if (currentClassIndex > 0) {
|
if (currentClassIndex > 0) {
|
||||||
const previousClass = allClasses[currentClassIndex - 1];
|
const previousClass = allClasses[currentClassIndex - 1];
|
||||||
@ -438,6 +457,8 @@ export default function Step3ChooseExamples({
|
|||||||
<Button type="button" onClick={handleBack} className="sm:flex-1">
|
<Button type="button" onClick={handleBack} className="sm:flex-1">
|
||||||
{t("button.back", { ns: "common" })}
|
{t("button.back", { ns: "common" })}
|
||||||
</Button>
|
</Button>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
onClick={
|
onClick={
|
||||||
@ -447,11 +468,24 @@ export default function Step3ChooseExamples({
|
|||||||
}
|
}
|
||||||
variant="select"
|
variant="select"
|
||||||
className="flex items-center justify-center gap-2 sm:flex-1"
|
className="flex items-center justify-center gap-2 sm:flex-1"
|
||||||
disabled={!hasGenerated || isGenerating || isProcessing}
|
disabled={
|
||||||
|
!hasGenerated || isGenerating || isProcessing || !canProceed
|
||||||
|
}
|
||||||
>
|
>
|
||||||
{isProcessing && <ActivityIndicator className="size-4" />}
|
{isProcessing && <ActivityIndicator className="size-4" />}
|
||||||
{t("button.continue", { ns: "common" })}
|
{t("button.continue", { ns: "common" })}
|
||||||
</Button>
|
</Button>
|
||||||
|
</TooltipTrigger>
|
||||||
|
{!canProceed && (
|
||||||
|
<TooltipPortal>
|
||||||
|
<TooltipContent>
|
||||||
|
{t("wizard.step3.allImagesRequired", {
|
||||||
|
count: unclassifiedImages.length,
|
||||||
|
})}
|
||||||
|
</TooltipContent>
|
||||||
|
</TooltipPortal>
|
||||||
|
)}
|
||||||
|
</Tooltip>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -87,7 +87,8 @@ export type ModelState =
|
|||||||
| "downloaded"
|
| "downloaded"
|
||||||
| "error"
|
| "error"
|
||||||
| "training"
|
| "training"
|
||||||
| "complete";
|
| "complete"
|
||||||
|
| "failed";
|
||||||
|
|
||||||
export type EmbeddingsReindexProgressType = {
|
export type EmbeddingsReindexProgressType = {
|
||||||
thumbnails: number;
|
thumbnails: number;
|
||||||
|
|||||||
@ -102,6 +102,12 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
position: "top-center",
|
position: "top-center",
|
||||||
});
|
});
|
||||||
setWasTraining(false);
|
setWasTraining(false);
|
||||||
|
refreshDataset();
|
||||||
|
} else if (modelState == "failed") {
|
||||||
|
toast.error(t("toast.error.trainingFailed"), {
|
||||||
|
position: "top-center",
|
||||||
|
});
|
||||||
|
setWasTraining(false);
|
||||||
}
|
}
|
||||||
// only refresh when modelState changes
|
// only refresh when modelState changes
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
@ -112,10 +118,20 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
||||||
`classification/${model.name}/train`,
|
`classification/${model.name}/train`,
|
||||||
);
|
);
|
||||||
const { data: dataset, mutate: refreshDataset } = useSWR<{
|
const { data: datasetResponse, mutate: refreshDataset } = useSWR<{
|
||||||
[id: string]: string[];
|
categories: { [id: string]: string[] };
|
||||||
|
training_metadata: {
|
||||||
|
has_trained: boolean;
|
||||||
|
last_training_date: string | null;
|
||||||
|
last_training_image_count: number;
|
||||||
|
current_image_count: number;
|
||||||
|
new_images_count: number;
|
||||||
|
} | null;
|
||||||
}>(`classification/${model.name}/dataset`);
|
}>(`classification/${model.name}/dataset`);
|
||||||
|
|
||||||
|
const dataset = datasetResponse?.categories || {};
|
||||||
|
const trainingMetadata = datasetResponse?.training_metadata;
|
||||||
|
|
||||||
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
|
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
|
||||||
|
|
||||||
const refreshAll = useCallback(() => {
|
const refreshAll = useCallback(() => {
|
||||||
@ -177,7 +193,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
error.response?.data?.detail ||
|
error.response?.data?.detail ||
|
||||||
"Unknown error";
|
"Unknown error";
|
||||||
|
|
||||||
toast.error(t("toast.error.trainingFailed", { errorMessage }), {
|
toast.error(t("toast.error.trainingFailedToStart", { errorMessage }), {
|
||||||
position: "top-center",
|
position: "top-center",
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@ -421,19 +437,48 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
filterValues={{ classes: Object.keys(dataset || {}) }}
|
filterValues={{ classes: Object.keys(dataset || {}) }}
|
||||||
onUpdateFilter={setTrainFilter}
|
onUpdateFilter={setTrainFilter}
|
||||||
/>
|
/>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
<Button
|
<Button
|
||||||
className="flex justify-center gap-2"
|
className="flex justify-center gap-2"
|
||||||
onClick={trainModel}
|
onClick={trainModel}
|
||||||
variant="select"
|
variant={modelState == "failed" ? "destructive" : "select"}
|
||||||
disabled={modelState != "complete"}
|
disabled={
|
||||||
|
(modelState != "complete" && modelState != "failed") ||
|
||||||
|
(trainingMetadata?.new_images_count ?? 0) === 0
|
||||||
|
}
|
||||||
>
|
>
|
||||||
{modelState == "training" ? (
|
{modelState == "training" ? (
|
||||||
<ActivityIndicator size={20} />
|
<ActivityIndicator size={20} />
|
||||||
) : (
|
) : (
|
||||||
<HiSparkles className="text-white" />
|
<HiSparkles className="text-white" />
|
||||||
)}
|
)}
|
||||||
{isDesktop && t("button.trainModel")}
|
{isDesktop && (
|
||||||
|
<>
|
||||||
|
{t("button.trainModel")}
|
||||||
|
{trainingMetadata?.new_images_count !== undefined &&
|
||||||
|
trainingMetadata.new_images_count > 0 && (
|
||||||
|
<span className="text-sm text-selected-foreground">
|
||||||
|
({trainingMetadata.new_images_count})
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</Button>
|
</Button>
|
||||||
|
</TooltipTrigger>
|
||||||
|
{((trainingMetadata?.new_images_count ?? 0) === 0 ||
|
||||||
|
(modelState != "complete" && modelState != "failed")) && (
|
||||||
|
<TooltipPortal>
|
||||||
|
<TooltipContent>
|
||||||
|
{modelState == "training"
|
||||||
|
? t("tooltip.trainingInProgress")
|
||||||
|
: trainingMetadata?.new_images_count === 0
|
||||||
|
? t("tooltip.noNewImages")
|
||||||
|
: t("tooltip.modelNotReady")}
|
||||||
|
</TooltipContent>
|
||||||
|
</TooltipPortal>
|
||||||
|
)}
|
||||||
|
</Tooltip>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user