mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-12-06 13:34:13 +03:00
Improve train state by showing number of images to classify and adding tooltip
This commit is contained in:
parent
a374a60756
commit
b62de79c39
@ -37,6 +37,8 @@ from frigate.models import Event
|
||||
from frigate.util.classification import (
|
||||
collect_object_classification_examples,
|
||||
collect_state_classification_examples,
|
||||
get_dataset_image_count,
|
||||
read_training_metadata,
|
||||
)
|
||||
from frigate.util.file import get_event_snapshot
|
||||
|
||||
@ -564,23 +566,54 @@ def get_classification_dataset(name: str):
|
||||
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
|
||||
|
||||
if not os.path.exists(dataset_dir):
|
||||
return JSONResponse(status_code=200, content={})
|
||||
return JSONResponse(
|
||||
status_code=200, content={"categories": {}, "training_metadata": None}
|
||||
)
|
||||
|
||||
for name in os.listdir(dataset_dir):
|
||||
category_dir = os.path.join(dataset_dir, name)
|
||||
for category_name in os.listdir(dataset_dir):
|
||||
category_dir = os.path.join(dataset_dir, category_name)
|
||||
|
||||
if not os.path.isdir(category_dir):
|
||||
continue
|
||||
|
||||
dataset_dict[name] = []
|
||||
dataset_dict[category_name] = []
|
||||
|
||||
for file in filter(
|
||||
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
|
||||
os.listdir(category_dir),
|
||||
):
|
||||
dataset_dict[name].append(file)
|
||||
dataset_dict[category_name].append(file)
|
||||
|
||||
return JSONResponse(status_code=200, content=dataset_dict)
|
||||
# Get training metadata
|
||||
metadata = read_training_metadata(sanitize_filename(name))
|
||||
current_image_count = get_dataset_image_count(sanitize_filename(name))
|
||||
|
||||
if metadata is None:
|
||||
training_metadata = {
|
||||
"has_trained": False,
|
||||
"last_training_date": None,
|
||||
"last_training_image_count": 0,
|
||||
"current_image_count": current_image_count,
|
||||
"new_images_count": current_image_count,
|
||||
}
|
||||
else:
|
||||
last_training_count = metadata.get("last_training_image_count", 0)
|
||||
new_images_count = max(0, current_image_count - last_training_count)
|
||||
training_metadata = {
|
||||
"has_trained": True,
|
||||
"last_training_date": metadata.get("last_training_date"),
|
||||
"last_training_image_count": last_training_count,
|
||||
"current_image_count": current_image_count,
|
||||
"new_images_count": new_images_count,
|
||||
}
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"categories": dataset_dict,
|
||||
"training_metadata": training_metadata,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Util for classification models."""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@ -27,10 +29,96 @@ from frigate.util.process import FrigateProcess
|
||||
BATCH_SIZE = 16
|
||||
EPOCHS = 50
|
||||
LEARNING_RATE = 0.001
|
||||
TRAINING_METADATA_FILE = ".training_metadata.json"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def write_training_metadata(model_name: str, image_count: int) -> None:
|
||||
"""
|
||||
Write training metadata to a hidden file in the model's clips directory.
|
||||
|
||||
Args:
|
||||
model_name: Name of the classification model
|
||||
image_count: Number of images used in training
|
||||
"""
|
||||
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
|
||||
os.makedirs(clips_model_dir, exist_ok=True)
|
||||
|
||||
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
|
||||
metadata = {
|
||||
"last_training_date": datetime.datetime.now().isoformat(),
|
||||
"last_training_image_count": image_count,
|
||||
}
|
||||
|
||||
try:
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
logger.info(f"Wrote training metadata for {model_name}: {image_count} images")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write training metadata for {model_name}: {e}")
|
||||
|
||||
|
||||
def read_training_metadata(model_name: str) -> dict[str, any] | None:
|
||||
"""
|
||||
Read training metadata from the hidden file in the model's clips directory.
|
||||
|
||||
Args:
|
||||
model_name: Name of the classification model
|
||||
|
||||
Returns:
|
||||
Dictionary with last_training_date and last_training_image_count, or None if not found
|
||||
"""
|
||||
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
|
||||
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
|
||||
|
||||
if not os.path.exists(metadata_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
return metadata
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read training metadata for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_dataset_image_count(model_name: str) -> int:
|
||||
"""
|
||||
Count the total number of images in the model's dataset directory.
|
||||
|
||||
Args:
|
||||
model_name: Name of the classification model
|
||||
|
||||
Returns:
|
||||
Total count of images across all categories
|
||||
"""
|
||||
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||
|
||||
if not os.path.exists(dataset_dir):
|
||||
return 0
|
||||
|
||||
total_count = 0
|
||||
try:
|
||||
for category in os.listdir(dataset_dir):
|
||||
category_dir = os.path.join(dataset_dir, category)
|
||||
if not os.path.isdir(category_dir):
|
||||
continue
|
||||
|
||||
image_files = [
|
||||
f
|
||||
for f in os.listdir(category_dir)
|
||||
if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))
|
||||
]
|
||||
total_count += len(image_files)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to count dataset images for {model_name}: {e}")
|
||||
return 0
|
||||
|
||||
return total_count
|
||||
|
||||
|
||||
class ClassificationTrainingProcess(FrigateProcess):
|
||||
def __init__(self, model_name: str) -> None:
|
||||
super().__init__(
|
||||
@ -145,6 +233,10 @@ class ClassificationTrainingProcess(FrigateProcess):
|
||||
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
|
||||
f.write(tflite_model)
|
||||
|
||||
# write training metadata with image count
|
||||
dataset_image_count = get_dataset_image_count(self.model_name)
|
||||
write_training_metadata(self.model_name, dataset_image_count)
|
||||
|
||||
|
||||
def kickoff_model_training(
|
||||
embeddingRequestor: EmbeddingsRequestor, model_name: str
|
||||
|
||||
@ -13,6 +13,11 @@
|
||||
"deleteModels": "Delete Models",
|
||||
"editModel": "Edit Model"
|
||||
},
|
||||
"tooltip": {
|
||||
"trainingInProgress": "Model is currently training",
|
||||
"noNewImages": "No new images to train. Classify more images in the dataset first.",
|
||||
"modelNotReady": "Model is not ready for training"
|
||||
},
|
||||
"toast": {
|
||||
"success": {
|
||||
"deletedCategory": "Deleted Class",
|
||||
|
||||
@ -102,6 +102,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
position: "top-center",
|
||||
});
|
||||
setWasTraining(false);
|
||||
refreshDataset();
|
||||
}
|
||||
// only refresh when modelState changes
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
@ -112,10 +113,20 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
||||
`classification/${model.name}/train`,
|
||||
);
|
||||
const { data: dataset, mutate: refreshDataset } = useSWR<{
|
||||
[id: string]: string[];
|
||||
const { data: datasetResponse, mutate: refreshDataset } = useSWR<{
|
||||
categories: { [id: string]: string[] };
|
||||
training_metadata: {
|
||||
has_trained: boolean;
|
||||
last_training_date: string | null;
|
||||
last_training_image_count: number;
|
||||
current_image_count: number;
|
||||
new_images_count: number;
|
||||
} | null;
|
||||
}>(`classification/${model.name}/dataset`);
|
||||
|
||||
const dataset = datasetResponse?.categories || {};
|
||||
const trainingMetadata = datasetResponse?.training_metadata;
|
||||
|
||||
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
|
||||
|
||||
const refreshAll = useCallback(() => {
|
||||
@ -421,19 +432,48 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
filterValues={{ classes: Object.keys(dataset || {}) }}
|
||||
onUpdateFilter={setTrainFilter}
|
||||
/>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
className="flex justify-center gap-2"
|
||||
onClick={trainModel}
|
||||
variant="select"
|
||||
disabled={modelState != "complete"}
|
||||
disabled={
|
||||
modelState != "complete" ||
|
||||
(trainingMetadata?.new_images_count ?? 0) === 0
|
||||
}
|
||||
>
|
||||
{modelState == "training" ? (
|
||||
<ActivityIndicator size={20} />
|
||||
) : (
|
||||
<HiSparkles className="text-white" />
|
||||
)}
|
||||
{isDesktop && t("button.trainModel")}
|
||||
{isDesktop && (
|
||||
<>
|
||||
{t("button.trainModel")}
|
||||
{trainingMetadata?.new_images_count !== undefined &&
|
||||
trainingMetadata.new_images_count > 0 && (
|
||||
<span className="text-sm text-selected-foreground">
|
||||
({trainingMetadata.new_images_count})
|
||||
</span>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
{((trainingMetadata?.new_images_count ?? 0) === 0 ||
|
||||
modelState != "complete") && (
|
||||
<TooltipPortal>
|
||||
<TooltipContent>
|
||||
{modelState == "training"
|
||||
? t("tooltip.trainingInProgress")
|
||||
: trainingMetadata?.new_images_count === 0
|
||||
? t("tooltip.noNewImages")
|
||||
: t("tooltip.modelNotReady")}
|
||||
</TooltipContent>
|
||||
</TooltipPortal>
|
||||
)}
|
||||
</Tooltip>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
Loading…
Reference in New Issue
Block a user