diff --git a/frigate/api/classification.py b/frigate/api/classification.py index 5e5320813..ec53a47af 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -435,7 +435,7 @@ def transcribe_audio(request: Request, body: AudioTranscriptionBody): def get_classification_dataset(name: str): dataset_dict: dict[str, list[str]] = {} - dataset_dir = os.path.join(MODEL_CACHE_DIR, f"{sanitize_filename(name)}/dataset") + dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset") if not os.path.exists(dataset_dir): return JSONResponse(status_code=200, content={}) @@ -459,7 +459,7 @@ def get_classification_dataset(name: str): @router.get("/classification/{name}/train") def get_classification_images(name: str): - train_dir = os.path.join(CLIPS_DIR, sanitize_filename(name)) + train_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "train") if not os.path.exists(train_dir): return JSONResponse(status_code=200, content=[]) @@ -492,9 +492,7 @@ async def train_configured_model( status_code=404, ) - background_tasks.add_task( - train_classification_model, os.path.join(MODEL_CACHE_DIR, name) - ) + background_tasks.add_task(train_classification_model, name) return JSONResponse( content={"success": True, "message": "Started classification model training."}, status_code=200, diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx index 502a53251..fdc83dc82 100644 --- a/web/src/views/classification/ModelSelectionView.tsx +++ b/web/src/views/classification/ModelSelectionView.tsx @@ -38,12 +38,10 @@ export default function ModelSelectionView({
{classificationConfigs.map((config) => (
onClick(config)} onContextMenu={() => { diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index c7bd0e18e..1a76b5c5e 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -40,8 +40,12 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { // dataset - const { data: trainImages } = useSWR(`classification/${model.name}/train`); - const { data: dataset } = useSWR(`classification/${model.name}/dataset`); + const { data: trainImages } = useSWR( + `classification/${model.name}/train`, + ); + const { data: dataset } = useSWR<{ [id: string]: string[] }>( + `classification/${model.name}/dataset`, + ); // actions @@ -54,8 +58,8 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
{}} onRename={() => {}} @@ -65,13 +69,18 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { {pageToggle == "train" ? ( {}} onDelete={() => {}} /> ) : ( - + {}} + /> )}
); @@ -259,51 +268,65 @@ function LibrarySelector({ } type DatasetGridProps = { - name: string; + modelName: string; + categoryName: string; images: string[]; - onDelete: (name: string, ids: string[]) => void; + onDelete: (modelName: string, categoryName: string, ids: string[]) => void; }; -function DatasetGrid({ name, images, onDelete }: DatasetGridProps) { +function DatasetGrid({ + modelName, + categoryName, + images, + onDelete, +}: DatasetGridProps) { const { t } = useTranslation(["views/classificationModel"]); return ( -
-
- -
-
-
-
-
{name}
+
+ {images.map((image) => ( +
{ + //e.stopPropagation(); + //onClickImages([data.raw], e.ctrlKey || e.metaKey); + }} + > +
+
-
- - - { - e.stopPropagation(); - onDelete(name, images); - }} - /> - - {t("button.deleteFaceAttempts")} - +
+
+
+ + + { + e.stopPropagation(); + onDelete(modelName, categoryName, [image]); + }} + /> + + + {t("button.deleteClassificationAttempts")} + + +
+
-
+ ))}
); } @@ -339,9 +362,10 @@ function TrainGrid({ ); return ( -
- {trainData.map((data) => ( +
+ {trainData?.map((data) => (