diff --git a/frigate/api/classification.py b/frigate/api/classification.py
index 5e5320813..ec53a47af 100644
--- a/frigate/api/classification.py
+++ b/frigate/api/classification.py
@@ -435,7 +435,7 @@ def transcribe_audio(request: Request, body: AudioTranscriptionBody):
def get_classification_dataset(name: str):
dataset_dict: dict[str, list[str]] = {}
- dataset_dir = os.path.join(MODEL_CACHE_DIR, f"{sanitize_filename(name)}/dataset")
+ dataset_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "dataset")
if not os.path.exists(dataset_dir):
return JSONResponse(status_code=200, content={})
@@ -459,7 +459,7 @@ def get_classification_dataset(name: str):
@router.get("/classification/{name}/train")
def get_classification_images(name: str):
- train_dir = os.path.join(CLIPS_DIR, sanitize_filename(name))
+ train_dir = os.path.join(CLIPS_DIR, sanitize_filename(name), "train")
if not os.path.exists(train_dir):
return JSONResponse(status_code=200, content=[])
@@ -492,9 +492,7 @@ async def train_configured_model(
status_code=404,
)
- background_tasks.add_task(
- train_classification_model, os.path.join(MODEL_CACHE_DIR, name)
- )
+ background_tasks.add_task(train_classification_model, name)
return JSONResponse(
content={"success": True, "message": "Started classification model training."},
status_code=200,
diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx
index 502a53251..fdc83dc82 100644
--- a/web/src/views/classification/ModelSelectionView.tsx
+++ b/web/src/views/classification/ModelSelectionView.tsx
@@ -38,12 +38,10 @@ export default function ModelSelectionView({
{classificationConfigs.map((config) => (
onClick(config)}
onContextMenu={() => {
diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx
index c7bd0e18e..1a76b5c5e 100644
--- a/web/src/views/classification/ModelTrainingView.tsx
+++ b/web/src/views/classification/ModelTrainingView.tsx
@@ -40,8 +40,12 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
// dataset
- const { data: trainImages } = useSWR(`classification/${model.name}/train`);
- const { data: dataset } = useSWR(`classification/${model.name}/dataset`);
+ const { data: trainImages } = useSWR
(
+ `classification/${model.name}/train`,
+ );
+ const { data: dataset } = useSWR<{ [id: string]: string[] }>(
+ `classification/${model.name}/dataset`,
+ );
// actions
@@ -54,8 +58,8 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
{}}
onRename={() => {}}
@@ -65,13 +69,18 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
{pageToggle == "train" ? (
{}}
onDelete={() => {}}
/>
) : (
-
+ {}}
+ />
)}
);
@@ -259,51 +268,65 @@ function LibrarySelector({
}
type DatasetGridProps = {
- name: string;
+ modelName: string;
+ categoryName: string;
images: string[];
- onDelete: (name: string, ids: string[]) => void;
+ onDelete: (modelName: string, categoryName: string, ids: string[]) => void;
};
-function DatasetGrid({ name, images, onDelete }: DatasetGridProps) {
+function DatasetGrid({
+ modelName,
+ categoryName,
+ images,
+ onDelete,
+}: DatasetGridProps) {
const { t } = useTranslation(["views/classificationModel"]);
return (
-
-
-

-
-
-
-
-
{name}
+
+ {images.map((image) => (
+
{
+ //e.stopPropagation();
+ //onClickImages([data.raw], e.ctrlKey || e.metaKey);
+ }}
+ >
+
+
-
-
-
- {
- e.stopPropagation();
- onDelete(name, images);
- }}
- />
-
- {t("button.deleteFaceAttempts")}
-
+
+
+
+
+
+ {
+ e.stopPropagation();
+ onDelete(modelName, categoryName, [image]);
+ }}
+ />
+
+
+ {t("button.deleteClassificationAttempts")}
+
+
+
+
-
+ ))}
);
}
@@ -339,9 +362,10 @@ function TrainGrid({
);
return (
-
- {trainData.map((data) => (
+
+ {trainData?.map((data) => (