diff --git a/frigate/api/classification.py b/frigate/api/classification.py index 87de52884..a2aec6898 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -595,9 +595,13 @@ def get_classification_dataset(name: str): "last_training_image_count": 0, "current_image_count": current_image_count, "new_images_count": current_image_count, + "dataset_changed": current_image_count > 0, } else: last_training_count = metadata.get("last_training_image_count", 0) + # Dataset has changed if count is different (either added or deleted images) + dataset_changed = current_image_count != last_training_count + # Only show positive count for new images (ignore deletions in the count display) new_images_count = max(0, current_image_count - last_training_count) training_metadata = { "has_trained": True, @@ -605,6 +609,7 @@ def get_classification_dataset(name: str): "last_training_image_count": last_training_count, "current_image_count": current_image_count, "new_images_count": new_images_count, + "dataset_changed": dataset_changed, } return JSONResponse( @@ -948,31 +953,29 @@ async def generate_object_examples(request: Request, body: GenerateObjectExample dependencies=[Depends(require_role(["admin"]))], summary="Delete a classification model", description="""Deletes a specific classification model and all its associated data. - The name must exist in the classification models. Returns a success message or an error if the name is invalid.""", + Works even if the model is not in the config (e.g., partially created during wizard). + Returns a success message.""", ) def delete_classification_model(request: Request, name: str): - config: FrigateConfig = request.app.frigate_config - - if name not in config.classification.custom: - return JSONResponse( - content=( - { - "success": False, - "message": f"{name} is not a known classification model.", - } - ), - status_code=404, - ) + sanitized_name = sanitize_filename(name) # Delete the classification model's data directory in clips - data_dir = os.path.join(CLIPS_DIR, sanitize_filename(name)) + data_dir = os.path.join(CLIPS_DIR, sanitized_name) if os.path.exists(data_dir): - shutil.rmtree(data_dir) + try: + shutil.rmtree(data_dir) + logger.info(f"Deleted classification data directory for {name}") + except Exception as e: + logger.debug(f"Failed to delete data directory for {name}: {e}") # Delete the classification model's files in model_cache - model_dir = os.path.join(MODEL_CACHE_DIR, sanitize_filename(name)) + model_dir = os.path.join(MODEL_CACHE_DIR, sanitized_name) if os.path.exists(model_dir): - shutil.rmtree(model_dir) + try: + shutil.rmtree(model_dir) + logger.info(f"Deleted classification model directory for {name}") + except Exception as e: + logger.debug(f"Failed to delete model directory for {name}: {e}") return JSONResponse( content=( diff --git a/frigate/config/camera/camera.py b/frigate/config/camera/camera.py index 967a69427..0f2b1c8be 100644 --- a/frigate/config/camera/camera.py +++ b/frigate/config/camera/camera.py @@ -177,6 +177,12 @@ class CameraConfig(FrigateBaseModel): def ffmpeg_cmds(self) -> list[dict[str, list[str]]]: return self._ffmpeg_cmds + def get_formatted_name(self) -> str: + """Return the friendly name if set, otherwise return a formatted version of the camera name.""" + if self.friendly_name: + return self.friendly_name + return self.name.replace("_", " ").title() if self.name else "" + def create_ffmpeg_cmds(self): if "_ffmpeg_cmds" in self: return diff --git a/frigate/config/camera/zone.py b/frigate/config/camera/zone.py index 530ba1cf9..7df1a1f25 100644 --- a/frigate/config/camera/zone.py +++ b/frigate/config/camera/zone.py @@ -56,6 +56,12 @@ class ZoneConfig(BaseModel): def contour(self) -> np.ndarray: return self._contour + def get_formatted_name(self, zone_name: str) -> str: + """Return the friendly name if set, otherwise return a formatted version of the zone name.""" + if self.friendly_name: + return self.friendly_name + return zone_name.replace("_", " ").title() + @field_validator("objects", mode="before") @classmethod def validate_objects(cls, v): diff --git a/frigate/data_processing/common/audio_transcription/model.py b/frigate/data_processing/common/audio_transcription/model.py index 0fe5ddb5c..82472ad62 100644 --- a/frigate/data_processing/common/audio_transcription/model.py +++ b/frigate/data_processing/common/audio_transcription/model.py @@ -4,7 +4,6 @@ import logging import os import sherpa_onnx -from faster_whisper.utils import download_model from frigate.comms.inter_process import InterProcessRequestor from frigate.const import MODEL_CACHE_DIR @@ -25,6 +24,9 @@ class AudioTranscriptionModelRunner: if model_size == "large": # use the Whisper download function instead of our own + # Import dynamically to avoid crashes on systems without AVX support + from faster_whisper.utils import download_model + logger.debug("Downloading Whisper audio transcription model") download_model( size_or_id="small" if device == "cuda" else "tiny", diff --git a/frigate/data_processing/post/audio_transcription.py b/frigate/data_processing/post/audio_transcription.py index 066287707..870c34068 100644 --- a/frigate/data_processing/post/audio_transcription.py +++ b/frigate/data_processing/post/audio_transcription.py @@ -6,7 +6,6 @@ import threading import time from typing import Optional -from faster_whisper import WhisperModel from peewee import DoesNotExist from frigate.comms.inter_process import InterProcessRequestor @@ -51,6 +50,9 @@ class AudioTranscriptionPostProcessor(PostProcessorApi): def __build_recognizer(self) -> None: try: + # Import dynamically to avoid crashes on systems without AVX support + from faster_whisper import WhisperModel + self.recognizer = WhisperModel( model_size_or_path="small", device="cuda" diff --git a/frigate/data_processing/post/review_descriptions.py b/frigate/data_processing/post/review_descriptions.py index ffb7b7a25..9691ac8fd 100644 --- a/frigate/data_processing/post/review_descriptions.py +++ b/frigate/data_processing/post/review_descriptions.py @@ -16,6 +16,7 @@ from peewee import DoesNotExist from frigate.comms.embeddings_updater import EmbeddingsRequestEnum from frigate.comms.inter_process import InterProcessRequestor from frigate.config import FrigateConfig +from frigate.config.camera import CameraConfig from frigate.config.camera.review import GenAIReviewConfig, ImageSourceEnum from frigate.const import CACHE_DIR, CLIPS_DIR, UPDATE_REVIEW_DESCRIPTION from frigate.data_processing.types import PostProcessDataEnum @@ -30,6 +31,7 @@ from ..types import DataProcessorMetrics logger = logging.getLogger(__name__) RECORDING_BUFFER_EXTENSION_PERCENT = 0.10 +MIN_RECORDING_DURATION = 10 class ReviewDescriptionProcessor(PostProcessorApi): @@ -130,7 +132,17 @@ class ReviewDescriptionProcessor(PostProcessorApi): if image_source == ImageSourceEnum.recordings: duration = final_data["end_time"] - final_data["start_time"] - buffer_extension = duration * RECORDING_BUFFER_EXTENSION_PERCENT + buffer_extension = min( + 10, max(2, duration * RECORDING_BUFFER_EXTENSION_PERCENT) + ) + + # Ensure minimum total duration for short review items + # This provides better context for brief events + total_duration = duration + (2 * buffer_extension) + if total_duration < MIN_RECORDING_DURATION: + # Expand buffer to reach minimum duration, still respecting max of 10s per side + additional_buffer_per_side = (MIN_RECORDING_DURATION - duration) / 2 + buffer_extension = min(10, additional_buffer_per_side) thumbs = self.get_recording_frames( camera, @@ -182,7 +194,7 @@ class ReviewDescriptionProcessor(PostProcessorApi): self.requestor, self.genai_client, self.review_desc_speed, - camera, + camera_config, final_data, thumbs, camera_config.review.genai, @@ -411,7 +423,7 @@ def run_analysis( requestor: InterProcessRequestor, genai_client: GenAIClient, review_inference_speed: InferenceSpeed, - camera: str, + camera_config: CameraConfig, final_data: dict[str, str], thumbs: list[bytes], genai_config: GenAIReviewConfig, @@ -419,10 +431,19 @@ def run_analysis( attribute_labels: list[str], ) -> None: start = datetime.datetime.now().timestamp() + + # Format zone names using zone config friendly names if available + formatted_zones = [] + for zone_name in final_data["data"]["zones"]: + if zone_name in camera_config.zones: + formatted_zones.append( + camera_config.zones[zone_name].get_formatted_name(zone_name) + ) + analytics_data = { "id": final_data["id"], - "camera": camera, - "zones": final_data["data"]["zones"], + "camera": camera_config.get_formatted_name(), + "zones": formatted_zones, "start": datetime.datetime.fromtimestamp(final_data["start_time"]).strftime( "%A, %I:%M %p" ), diff --git a/frigate/detectors/detection_runners.py b/frigate/detectors/detection_runners.py index 6eb3a32fc..80d4e0487 100644 --- a/frigate/detectors/detection_runners.py +++ b/frigate/detectors/detection_runners.py @@ -394,7 +394,11 @@ class OpenVINOModelRunner(BaseModelRunner): self.infer_request.set_input_tensor(input_index, input_tensor) # Run inference - self.infer_request.infer() + try: + self.infer_request.infer() + except Exception as e: + logger.error(f"Error during OpenVINO inference: {e}") + return [] # Get all output tensors outputs = [] diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 5689511a8..01d011ae2 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -472,7 +472,7 @@ class Embeddings: ) thumbnail_missing = True except DoesNotExist: - logger.warning( + logger.debug( f"Event ID {trigger.data} for trigger {trigger_name} does not exist." ) continue diff --git a/frigate/genai/__init__.py b/frigate/genai/__init__.py index 881e63b97..dd42fc6dd 100644 --- a/frigate/genai/__init__.py +++ b/frigate/genai/__init__.py @@ -51,8 +51,7 @@ class GenAIClient: def get_concern_prompt() -> str: if concerns: concern_list = "\n - ".join(concerns) - return f""" -- `other_concerns` (list of strings): Include a list of any of the following concerns that are occurring: + return f"""- `other_concerns` (list of strings): Include a list of any of the following concerns that are occurring: - {concern_list}""" else: return "" @@ -70,7 +69,7 @@ class GenAIClient: return "\n- (No objects detected)" context_prompt = f""" -Your task is to analyze the sequence of images ({len(thumbnails)} total) taken in chronological order from the perspective of the {review_data["camera"].replace("_", " ")} security camera. +Your task is to analyze the sequence of images ({len(thumbnails)} total) taken in chronological order from the perspective of the {review_data["camera"]} security camera. ## Normal Activity Patterns for This Property @@ -110,7 +109,7 @@ Your response MUST be a flat JSON object with: - Frame 1 = earliest, Frame {len(thumbnails)} = latest - Activity started at {review_data["start"]} and lasted {review_data["duration"]} seconds -- Zones involved: {", ".join(z.replace("_", " ").title() for z in review_data["zones"]) or "None"} +- Zones involved: {", ".join(review_data["zones"]) if review_data["zones"] else "None"} ## Objects in Scene diff --git a/web/public/locales/en/views/classificationModel.json b/web/public/locales/en/views/classificationModel.json index 2bae0c0ce..f8aef1b8f 100644 --- a/web/public/locales/en/views/classificationModel.json +++ b/web/public/locales/en/views/classificationModel.json @@ -16,6 +16,7 @@ "tooltip": { "trainingInProgress": "Model is currently training", "noNewImages": "No new images to train. Classify more images in the dataset first.", + "noChanges": "No changes to the dataset since last training.", "modelNotReady": "Model is not ready for training" }, "toast": { @@ -43,7 +44,9 @@ }, "deleteCategory": { "title": "Delete Class", - "desc": "Are you sure you want to delete the class {{name}}? This will permanently delete all associated images and require re-training the model." + "desc": "Are you sure you want to delete the class {{name}}? This will permanently delete all associated images and require re-training the model.", + "minClassesTitle": "Cannot Delete Class", + "minClassesDesc": "A classification model must have at least 2 classes. Add another class before deleting this one." }, "deleteModel": { "title": "Delete Classification Model", diff --git a/web/src/components/classification/ClassificationModelEditDialog.tsx b/web/src/components/classification/ClassificationModelEditDialog.tsx index ff80a1a29..c47765d76 100644 --- a/web/src/components/classification/ClassificationModelEditDialog.tsx +++ b/web/src/components/classification/ClassificationModelEditDialog.tsx @@ -28,6 +28,7 @@ import { CustomClassificationModelConfig, FrigateConfig, } from "@/types/frigateConfig"; +import { ClassificationDatasetResponse } from "@/types/classification"; import { getTranslatedLabel } from "@/utils/i18n"; import { zodResolver } from "@hookform/resolvers/zod"; import axios from "axios"; @@ -140,16 +141,19 @@ export default function ClassificationModelEditDialog({ }); // Fetch dataset to get current classes for state models - const { data: dataset } = useSWR<{ - [id: string]: string[]; - }>(isStateModel ? `classification/${model.name}/dataset` : null, { - revalidateOnFocus: false, - }); + const { data: dataset } = useSWR( + isStateModel ? `classification/${model.name}/dataset` : null, + { + revalidateOnFocus: false, + }, + ); // Update form with classes from dataset when loaded useEffect(() => { - if (isStateModel && dataset) { - const classes = Object.keys(dataset).filter((key) => key !== "none"); + if (isStateModel && dataset?.categories) { + const classes = Object.keys(dataset.categories).filter( + (key) => key !== "none", + ); if (classes.length > 0) { (form as ReturnType>).setValue( "classes", diff --git a/web/src/components/classification/ClassificationModelWizardDialog.tsx b/web/src/components/classification/ClassificationModelWizardDialog.tsx index e67a95f89..06bf1f850 100644 --- a/web/src/components/classification/ClassificationModelWizardDialog.tsx +++ b/web/src/components/classification/ClassificationModelWizardDialog.tsx @@ -15,6 +15,7 @@ import Step3ChooseExamples, { } from "./wizard/Step3ChooseExamples"; import { cn } from "@/lib/utils"; import { isDesktop } from "react-device-detect"; +import axios from "axios"; const OBJECT_STEPS = [ "wizard.steps.nameAndDefine", @@ -120,7 +121,18 @@ export default function ClassificationModelWizardDialog({ dispatch({ type: "PREVIOUS_STEP" }); }; - const handleCancel = () => { + const handleCancel = async () => { + // Clean up any generated training images if we're cancelling from Step 3 + if (wizardState.step1Data && wizardState.step3Data?.examplesGenerated) { + try { + await axios.delete( + `/classification/${wizardState.step1Data.modelName}`, + ); + } catch (error) { + // Silently fail - user is already cancelling + } + } + dispatch({ type: "RESET" }); onClose(); }; diff --git a/web/src/components/classification/wizard/Step3ChooseExamples.tsx b/web/src/components/classification/wizard/Step3ChooseExamples.tsx index f638c01e3..e4c157526 100644 --- a/web/src/components/classification/wizard/Step3ChooseExamples.tsx +++ b/web/src/components/classification/wizard/Step3ChooseExamples.tsx @@ -165,18 +165,15 @@ export default function Step3ChooseExamples({ const isLastClass = currentClassIndex === allClasses.length - 1; if (isLastClass) { - // Assign remaining unclassified images - unknownImages.slice(0, 24).forEach((imageName) => { - if (!newClassifications[imageName]) { - // For state models with 2 classes, assign to the last class - // For object models, assign to "none" - if (step1Data.modelType === "state" && allClasses.length === 2) { - newClassifications[imageName] = allClasses[allClasses.length - 1]; - } else { + // For object models, assign remaining unclassified images to "none" + // For state models, this should never happen since we require all images to be classified + if (step1Data.modelType !== "state") { + unknownImages.slice(0, 24).forEach((imageName) => { + if (!newClassifications[imageName]) { newClassifications[imageName] = "none"; } - } - }); + }); + } // All done, trigger training immediately setImageClassifications(newClassifications); @@ -316,8 +313,15 @@ export default function Step3ChooseExamples({ return images; } - return images.filter((img) => !imageClassifications[img]); - }, [unknownImages, imageClassifications]); + // If we're viewing a previous class (going back), show images for that class + // Otherwise show only unclassified images + const currentClassInView = allClasses[currentClassIndex]; + return images.filter((img) => { + const imgClass = imageClassifications[img]; + // Show if: unclassified OR classified with current class we're viewing + return !imgClass || imgClass === currentClassInView; + }); + }, [unknownImages, imageClassifications, allClasses, currentClassIndex]); const allImagesClassified = useMemo(() => { return unclassifiedImages.length === 0; @@ -326,15 +330,26 @@ export default function Step3ChooseExamples({ // For state models on the last class, require all images to be classified const isLastClass = currentClassIndex === allClasses.length - 1; const canProceed = useMemo(() => { - if ( - step1Data.modelType === "state" && - isLastClass && - !allImagesClassified - ) { - return false; + if (step1Data.modelType === "state" && isLastClass) { + // Check if all 24 images will be classified after current selections are applied + const totalImages = unknownImages.slice(0, 24).length; + + // Count images that will be classified (either already classified or currently selected) + const allImages = unknownImages.slice(0, 24); + const willBeClassified = allImages.filter((img) => { + return imageClassifications[img] || selectedImages.has(img); + }).length; + + return willBeClassified >= totalImages; } return true; - }, [step1Data.modelType, isLastClass, allImagesClassified]); + }, [ + step1Data.modelType, + isLastClass, + unknownImages, + imageClassifications, + selectedImages, + ]); const handleBack = useCallback(() => { if (currentClassIndex > 0) { diff --git a/web/src/components/overlay/ImageShadowOverlay.tsx b/web/src/components/overlay/ImageShadowOverlay.tsx index 85791eec1..4f822572d 100644 --- a/web/src/components/overlay/ImageShadowOverlay.tsx +++ b/web/src/components/overlay/ImageShadowOverlay.tsx @@ -12,13 +12,13 @@ export function ImageShadowOverlay({ <>
diff --git a/web/src/components/player/BirdseyeLivePlayer.tsx b/web/src/components/player/BirdseyeLivePlayer.tsx index f94e9aca2..3dcd6afe7 100644 --- a/web/src/components/player/BirdseyeLivePlayer.tsx +++ b/web/src/components/player/BirdseyeLivePlayer.tsx @@ -77,7 +77,10 @@ export default function BirdseyeLivePlayer({ )} onClick={onClick} > - +
{player}
diff --git a/web/src/components/player/LivePlayer.tsx b/web/src/components/player/LivePlayer.tsx index 3e7dcde00..9500688f5 100644 --- a/web/src/components/player/LivePlayer.tsx +++ b/web/src/components/player/LivePlayer.tsx @@ -331,7 +331,10 @@ export default function LivePlayer({ > {cameraEnabled && ((showStillWithoutActivity && !liveReady) || liveReady) && ( - + )} {player} {cameraEnabled && diff --git a/web/src/context/detail-stream-context.tsx b/web/src/context/detail-stream-context.tsx index 57971f7ac..67c06f981 100644 --- a/web/src/context/detail-stream-context.tsx +++ b/web/src/context/detail-stream-context.tsx @@ -1,4 +1,10 @@ -import React, { createContext, useContext, useState, useEffect } from "react"; +import React, { + createContext, + useContext, + useState, + useEffect, + useRef, +} from "react"; import { FrigateConfig } from "@/types/frigateConfig"; import useSWR from "swr"; @@ -36,6 +42,23 @@ export function DetailStreamProvider({ () => initialSelectedObjectIds ?? [], ); + // When the parent provides a new initialSelectedObjectIds (for example + // when navigating between search results) update the selection so children + // like `ObjectTrackOverlay` receive the new ids immediately. We only + // perform this update when the incoming value actually changes. + useEffect(() => { + if ( + initialSelectedObjectIds && + (initialSelectedObjectIds.length !== selectedObjectIds.length || + initialSelectedObjectIds.some((v, i) => selectedObjectIds[i] !== v)) + ) { + setSelectedObjectIds(initialSelectedObjectIds); + } + // Intentionally include selectedObjectIds to compare previous value and + // avoid overwriting user interactions unless the incoming prop changed. + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [initialSelectedObjectIds]); + const toggleObjectSelection = (id: string | undefined) => { if (id === undefined) { setSelectedObjectIds([]); @@ -63,10 +86,33 @@ export function DetailStreamProvider({ setAnnotationOffset(cfgOffset); }, [config, camera]); - // Clear selected objects when exiting detail mode or changing cameras + // Clear selected objects when exiting detail mode or when the camera + // changes for providers that are not initialized with an explicit + // `initialSelectedObjectIds` (e.g., the RecordingView). For providers + // that receive `initialSelectedObjectIds` (like SearchDetailDialog) we + // avoid clearing on camera change to prevent a race with children that + // immediately set selection when mounting. + const prevCameraRef = useRef(undefined); useEffect(() => { - setSelectedObjectIds([]); - }, [isDetailMode, camera]); + // Always clear when leaving detail mode + if (!isDetailMode) { + setSelectedObjectIds([]); + prevCameraRef.current = camera; + return; + } + + // If camera changed and the parent did not provide initialSelectedObjectIds, + // clear selection to preserve previous behavior. + if ( + prevCameraRef.current !== undefined && + prevCameraRef.current !== camera && + initialSelectedObjectIds === undefined + ) { + setSelectedObjectIds([]); + } + + prevCameraRef.current = camera; + }, [isDetailMode, camera, initialSelectedObjectIds]); const value: DetailStreamContextType = { selectedObjectIds, diff --git a/web/src/types/classification.ts b/web/src/types/classification.ts index 092021342..10c130459 100644 --- a/web/src/types/classification.ts +++ b/web/src/types/classification.ts @@ -20,3 +20,17 @@ export type ClassificationThreshold = { recognition: number; unknown: number; }; + +export type ClassificationDatasetResponse = { + categories: { + [id: string]: string[]; + }; + training_metadata: { + has_trained: boolean; + last_training_date: string | null; + last_training_image_count: number; + current_image_count: number; + new_images_count: number; + dataset_changed: boolean; + } | null; +}; diff --git a/web/src/views/classification/ModelSelectionView.tsx b/web/src/views/classification/ModelSelectionView.tsx index c5e65e0e5..e72d2b6c1 100644 --- a/web/src/views/classification/ModelSelectionView.tsx +++ b/web/src/views/classification/ModelSelectionView.tsx @@ -11,6 +11,7 @@ import { CustomClassificationModelConfig, FrigateConfig, } from "@/types/frigateConfig"; +import { ClassificationDatasetResponse } from "@/types/classification"; import { useCallback, useEffect, useMemo, useState } from "react"; import { useTranslation } from "react-i18next"; import { FaFolderPlus } from "react-icons/fa"; @@ -209,9 +210,10 @@ type ModelCardProps = { function ModelCard({ config, onClick, onUpdate, onDelete }: ModelCardProps) { const { t } = useTranslation(["views/classificationModel"]); - const { data: dataset } = useSWR<{ - [id: string]: string[]; - }>(`classification/${config.name}/dataset`, { revalidateOnFocus: false }); + const { data: dataset } = useSWR( + `classification/${config.name}/dataset`, + { revalidateOnFocus: false }, + ); const [deleteDialogOpen, setDeleteDialogOpen] = useState(false); const [editDialogOpen, setEditDialogOpen] = useState(false); @@ -260,20 +262,25 @@ function ModelCard({ config, onClick, onUpdate, onDelete }: ModelCardProps) { }, []); const coverImage = useMemo(() => { - if (!dataset) { + if (!dataset || !dataset.categories) { return undefined; } - const keys = Object.keys(dataset).filter((key) => key != "none"); - const selectedKey = keys[0]; + const keys = Object.keys(dataset.categories).filter((key) => key != "none"); + if (keys.length === 0) { + return undefined; + } - if (!dataset[selectedKey]) { + const selectedKey = keys[0]; + const images = dataset.categories[selectedKey]; + + if (!images || images.length === 0) { return undefined; } return { name: selectedKey, - img: dataset[selectedKey][0], + img: images[0], }; }, [dataset]); @@ -317,11 +324,19 @@ function ModelCard({ config, onClick, onUpdate, onDelete }: ModelCardProps) { )} onClick={onClick} > - - + {coverImage ? ( + <> + + + + ) : ( +
+ +
+ )}
{config.name}
diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 6a3e680f9..53328e0e2 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -59,7 +59,11 @@ import { useNavigate } from "react-router-dom"; import { IoMdArrowRoundBack } from "react-icons/io"; import TrainFilterDialog from "@/components/overlay/dialog/TrainFilterDialog"; import useApiFilter from "@/hooks/use-api-filter"; -import { ClassificationItemData, TrainFilter } from "@/types/classification"; +import { + ClassificationDatasetResponse, + ClassificationItemData, + TrainFilter, +} from "@/types/classification"; import { ClassificationCard, GroupedClassificationCard, @@ -118,16 +122,10 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { const { data: trainImages, mutate: refreshTrain } = useSWR( `classification/${model.name}/train`, ); - const { data: datasetResponse, mutate: refreshDataset } = useSWR<{ - categories: { [id: string]: string[] }; - training_metadata: { - has_trained: boolean; - last_training_date: string | null; - last_training_image_count: number; - current_image_count: number; - new_images_count: number; - } | null; - }>(`classification/${model.name}/dataset`); + const { data: datasetResponse, mutate: refreshDataset } = + useSWR( + `classification/${model.name}/dataset`, + ); const dataset = datasetResponse?.categories || {}; const trainingMetadata = datasetResponse?.training_metadata; @@ -264,10 +262,11 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { ); } + // Always refresh dataset to update the categories list + refreshDataset(); + if (pageToggle == "train") { refreshTrain(); - } else { - refreshDataset(); } } }) @@ -445,7 +444,7 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { variant={modelState == "failed" ? "destructive" : "select"} disabled={ (modelState != "complete" && modelState != "failed") || - (trainingMetadata?.new_images_count ?? 0) === 0 + !trainingMetadata?.dataset_changed } > {modelState == "training" ? ( @@ -466,14 +465,14 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { )} - {((trainingMetadata?.new_images_count ?? 0) === 0 || + {(!trainingMetadata?.dataset_changed || (modelState != "complete" && modelState != "failed")) && ( {modelState == "training" ? t("tooltip.trainingInProgress") - : trainingMetadata?.new_images_count === 0 - ? t("tooltip.noNewImages") + : !trainingMetadata?.dataset_changed + ? t("tooltip.noChanges") : t("tooltip.modelNotReady")} @@ -571,27 +570,44 @@ function LibrarySelector({ > - {t("deleteCategory.title")} + + {Object.keys(dataset).length <= 2 + ? t("deleteCategory.minClassesTitle") + : t("deleteCategory.title")} + - {t("deleteCategory.desc", { name: confirmDelete })} + {Object.keys(dataset).length <= 2 + ? t("deleteCategory.minClassesDesc") + : t("deleteCategory.desc", { name: confirmDelete })}
- - + {Object.keys(dataset).length <= 2 ? ( + + ) : ( + <> + + + + )}