mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-12-06 13:34:13 +03:00
Improve train state by showing number of images to classify and adding tooltip
This commit is contained in:
parent
a374a60756
commit
b62de79c39
@ -37,6 +37,8 @@ from frigate.models import Event
|
|||||||
from frigate.util.classification import (
|
from frigate.util.classification import (
|
||||||
collect_object_classification_examples,
|
collect_object_classification_examples,
|
||||||
collect_state_classification_examples,
|
collect_state_classification_examples,
|
||||||
|
get_dataset_image_count,
|
||||||
|
read_training_metadata,
|
||||||
)
|
)
|
||||||
from frigate.util.file import get_event_snapshot
|
from frigate.util.file import get_event_snapshot
|
||||||
|
|
||||||
@ -564,23 +566,54 @@ def get_classification_dataset(name: str):
|
|||||||
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
|
dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
|
||||||
|
|
||||||
if not os.path.exists(dataset_dir):
|
if not os.path.exists(dataset_dir):
|
||||||
return JSONResponse(status_code=200, content={})
|
return JSONResponse(
|
||||||
|
status_code=200, content={"categories": {}, "training_metadata": None}
|
||||||
|
)
|
||||||
|
|
||||||
for name in os.listdir(dataset_dir):
|
for category_name in os.listdir(dataset_dir):
|
||||||
category_dir = os.path.join(dataset_dir, name)
|
category_dir = os.path.join(dataset_dir, category_name)
|
||||||
|
|
||||||
if not os.path.isdir(category_dir):
|
if not os.path.isdir(category_dir):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
dataset_dict[name] = []
|
dataset_dict[category_name] = []
|
||||||
|
|
||||||
for file in filter(
|
for file in filter(
|
||||||
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
|
lambda f: (f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))),
|
||||||
os.listdir(category_dir),
|
os.listdir(category_dir),
|
||||||
):
|
):
|
||||||
dataset_dict[name].append(file)
|
dataset_dict[category_name].append(file)
|
||||||
|
|
||||||
return JSONResponse(status_code=200, content=dataset_dict)
|
# Get training metadata
|
||||||
|
metadata = read_training_metadata(sanitize_filename(name))
|
||||||
|
current_image_count = get_dataset_image_count(sanitize_filename(name))
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
training_metadata = {
|
||||||
|
"has_trained": False,
|
||||||
|
"last_training_date": None,
|
||||||
|
"last_training_image_count": 0,
|
||||||
|
"current_image_count": current_image_count,
|
||||||
|
"new_images_count": current_image_count,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
last_training_count = metadata.get("last_training_image_count", 0)
|
||||||
|
new_images_count = max(0, current_image_count - last_training_count)
|
||||||
|
training_metadata = {
|
||||||
|
"has_trained": True,
|
||||||
|
"last_training_date": metadata.get("last_training_date"),
|
||||||
|
"last_training_image_count": last_training_count,
|
||||||
|
"current_image_count": current_image_count,
|
||||||
|
"new_images_count": new_images_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content={
|
||||||
|
"categories": dataset_dict,
|
||||||
|
"training_metadata": training_metadata,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
"""Util for classification models."""
|
"""Util for classification models."""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@ -27,10 +29,96 @@ from frigate.util.process import FrigateProcess
|
|||||||
BATCH_SIZE = 16
|
BATCH_SIZE = 16
|
||||||
EPOCHS = 50
|
EPOCHS = 50
|
||||||
LEARNING_RATE = 0.001
|
LEARNING_RATE = 0.001
|
||||||
|
TRAINING_METADATA_FILE = ".training_metadata.json"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def write_training_metadata(model_name: str, image_count: int) -> None:
|
||||||
|
"""
|
||||||
|
Write training metadata to a hidden file in the model's clips directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the classification model
|
||||||
|
image_count: Number of images used in training
|
||||||
|
"""
|
||||||
|
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
|
||||||
|
os.makedirs(clips_model_dir, exist_ok=True)
|
||||||
|
|
||||||
|
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
|
||||||
|
metadata = {
|
||||||
|
"last_training_date": datetime.datetime.now().isoformat(),
|
||||||
|
"last_training_image_count": image_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(metadata_path, "w") as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
logger.info(f"Wrote training metadata for {model_name}: {image_count} images")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to write training metadata for {model_name}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def read_training_metadata(model_name: str) -> dict[str, any] | None:
|
||||||
|
"""
|
||||||
|
Read training metadata from the hidden file in the model's clips directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the classification model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with last_training_date and last_training_image_count, or None if not found
|
||||||
|
"""
|
||||||
|
clips_model_dir = os.path.join(CLIPS_DIR, model_name)
|
||||||
|
metadata_path = os.path.join(clips_model_dir, TRAINING_METADATA_FILE)
|
||||||
|
|
||||||
|
if not os.path.exists(metadata_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
return metadata
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read training metadata for {model_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_image_count(model_name: str) -> int:
|
||||||
|
"""
|
||||||
|
Count the total number of images in the model's dataset directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the classification model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total count of images across all categories
|
||||||
|
"""
|
||||||
|
dataset_dir = os.path.join(CLIPS_DIR, model_name, "dataset")
|
||||||
|
|
||||||
|
if not os.path.exists(dataset_dir):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total_count = 0
|
||||||
|
try:
|
||||||
|
for category in os.listdir(dataset_dir):
|
||||||
|
category_dir = os.path.join(dataset_dir, category)
|
||||||
|
if not os.path.isdir(category_dir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_files = [
|
||||||
|
f
|
||||||
|
for f in os.listdir(category_dir)
|
||||||
|
if f.lower().endswith((".webp", ".png", ".jpg", ".jpeg"))
|
||||||
|
]
|
||||||
|
total_count += len(image_files)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to count dataset images for {model_name}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return total_count
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTrainingProcess(FrigateProcess):
|
class ClassificationTrainingProcess(FrigateProcess):
|
||||||
def __init__(self, model_name: str) -> None:
|
def __init__(self, model_name: str) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -145,6 +233,10 @@ class ClassificationTrainingProcess(FrigateProcess):
|
|||||||
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
|
with open(os.path.join(model_dir, "model.tflite"), "wb") as f:
|
||||||
f.write(tflite_model)
|
f.write(tflite_model)
|
||||||
|
|
||||||
|
# write training metadata with image count
|
||||||
|
dataset_image_count = get_dataset_image_count(self.model_name)
|
||||||
|
write_training_metadata(self.model_name, dataset_image_count)
|
||||||
|
|
||||||
|
|
||||||
def kickoff_model_training(
|
def kickoff_model_training(
|
||||||
embeddingRequestor: EmbeddingsRequestor, model_name: str
|
embeddingRequestor: EmbeddingsRequestor, model_name: str
|
||||||
|
|||||||
@ -13,6 +13,11 @@
|
|||||||
"deleteModels": "Delete Models",
|
"deleteModels": "Delete Models",
|
||||||
"editModel": "Edit Model"
|
"editModel": "Edit Model"
|
||||||
},
|
},
|
||||||
|
"tooltip": {
|
||||||
|
"trainingInProgress": "Model is currently training",
|
||||||
|
"noNewImages": "No new images to train. Classify more images in the dataset first.",
|
||||||
|
"modelNotReady": "Model is not ready for training"
|
||||||
|
},
|
||||||
"toast": {
|
"toast": {
|
||||||
"success": {
|
"success": {
|
||||||
"deletedCategory": "Deleted Class",
|
"deletedCategory": "Deleted Class",
|
||||||
|
|||||||
@ -102,6 +102,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
position: "top-center",
|
position: "top-center",
|
||||||
});
|
});
|
||||||
setWasTraining(false);
|
setWasTraining(false);
|
||||||
|
refreshDataset();
|
||||||
}
|
}
|
||||||
// only refresh when modelState changes
|
// only refresh when modelState changes
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
@ -112,10 +113,20 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
||||||
`classification/${model.name}/train`,
|
`classification/${model.name}/train`,
|
||||||
);
|
);
|
||||||
const { data: dataset, mutate: refreshDataset } = useSWR<{
|
const { data: datasetResponse, mutate: refreshDataset } = useSWR<{
|
||||||
[id: string]: string[];
|
categories: { [id: string]: string[] };
|
||||||
|
training_metadata: {
|
||||||
|
has_trained: boolean;
|
||||||
|
last_training_date: string | null;
|
||||||
|
last_training_image_count: number;
|
||||||
|
current_image_count: number;
|
||||||
|
new_images_count: number;
|
||||||
|
} | null;
|
||||||
}>(`classification/${model.name}/dataset`);
|
}>(`classification/${model.name}/dataset`);
|
||||||
|
|
||||||
|
const dataset = datasetResponse?.categories || {};
|
||||||
|
const trainingMetadata = datasetResponse?.training_metadata;
|
||||||
|
|
||||||
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
|
const [trainFilter, setTrainFilter] = useApiFilter<TrainFilter>();
|
||||||
|
|
||||||
const refreshAll = useCallback(() => {
|
const refreshAll = useCallback(() => {
|
||||||
@ -421,19 +432,48 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
filterValues={{ classes: Object.keys(dataset || {}) }}
|
filterValues={{ classes: Object.keys(dataset || {}) }}
|
||||||
onUpdateFilter={setTrainFilter}
|
onUpdateFilter={setTrainFilter}
|
||||||
/>
|
/>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
<Button
|
<Button
|
||||||
className="flex justify-center gap-2"
|
className="flex justify-center gap-2"
|
||||||
onClick={trainModel}
|
onClick={trainModel}
|
||||||
variant="select"
|
variant="select"
|
||||||
disabled={modelState != "complete"}
|
disabled={
|
||||||
|
modelState != "complete" ||
|
||||||
|
(trainingMetadata?.new_images_count ?? 0) === 0
|
||||||
|
}
|
||||||
>
|
>
|
||||||
{modelState == "training" ? (
|
{modelState == "training" ? (
|
||||||
<ActivityIndicator size={20} />
|
<ActivityIndicator size={20} />
|
||||||
) : (
|
) : (
|
||||||
<HiSparkles className="text-white" />
|
<HiSparkles className="text-white" />
|
||||||
)}
|
)}
|
||||||
{isDesktop && t("button.trainModel")}
|
{isDesktop && (
|
||||||
|
<>
|
||||||
|
{t("button.trainModel")}
|
||||||
|
{trainingMetadata?.new_images_count !== undefined &&
|
||||||
|
trainingMetadata.new_images_count > 0 && (
|
||||||
|
<span className="text-sm text-selected-foreground">
|
||||||
|
({trainingMetadata.new_images_count})
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</Button>
|
</Button>
|
||||||
|
</TooltipTrigger>
|
||||||
|
{((trainingMetadata?.new_images_count ?? 0) === 0 ||
|
||||||
|
modelState != "complete") && (
|
||||||
|
<TooltipPortal>
|
||||||
|
<TooltipContent>
|
||||||
|
{modelState == "training"
|
||||||
|
? t("tooltip.trainingInProgress")
|
||||||
|
: trainingMetadata?.new_images_count === 0
|
||||||
|
? t("tooltip.noNewImages")
|
||||||
|
: t("tooltip.modelNotReady")}
|
||||||
|
</TooltipContent>
|
||||||
|
</TooltipPortal>
|
||||||
|
)}
|
||||||
|
</Tooltip>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user