diff --git a/frigate/config/logger.py b/frigate/config/logger.py index e6e1c06d3..0e49f7422 100644 --- a/frigate/config/logger.py +++ b/frigate/config/logger.py @@ -29,6 +29,7 @@ class LoggerConfig(FrigateBaseModel): logging.getLogger().setLevel(self.default.value.upper()) log_levels = { + "absl": LogLevel.error, "httpx": LogLevel.error, "werkzeug": LogLevel.error, "ws4py": LogLevel.error, diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 4f3cec71e..df4baf70b 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -74,9 +74,10 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): UPDATE_MODEL_STATE, { "model": self.model_config.name, - "state": ModelStatusTypesEnum.training, + "state": ModelStatusTypesEnum.complete, }, ) + logger.info(f"Successfully loaded updated model for {self.model_config.name}") def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): camera = frame_data.get("camera") @@ -221,9 +222,10 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): UPDATE_MODEL_STATE, { "model": self.model_config.name, - "state": ModelStatusTypesEnum.training, + "state": ModelStatusTypesEnum.complete, }, ) + logger.info(f"Successfully loaded updated model for {self.model_config.name}") def process_frame(self, obj_data, frame): if obj_data["label"] not in self.model_config.object_config.objects: diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index ad4a58825..650eefc81 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -294,7 +294,7 @@ class EmbeddingsContext: def start_classification_training(self, model_name: str) -> dict[str, Any]: return self.requestor.send_data( - EmbeddingsRequestEnum.train_classification, {"model_name": model_name} + EmbeddingsRequestEnum.train_classification.value, {"model_name": model_name} ) def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]: diff --git a/web/src/types/ws.ts b/web/src/types/ws.ts index d1e810494..06ec9ae1d 100644 --- a/web/src/types/ws.ts +++ b/web/src/types/ws.ts @@ -73,7 +73,9 @@ export type ModelState = | "not_downloaded" | "downloading" | "downloaded" - | "error"; + | "error" + | "training" + | "complete"; export type EmbeddingsReindexProgressType = { thumbnails: number; diff --git a/web/src/views/classification/ModelTrainingView.tsx b/web/src/views/classification/ModelTrainingView.tsx index 53ef7fa66..197768a95 100644 --- a/web/src/views/classification/ModelTrainingView.tsx +++ b/web/src/views/classification/ModelTrainingView.tsx @@ -45,6 +45,9 @@ import { toast } from "sonner"; import useSWR from "swr"; import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog"; import { TbCategoryPlus } from "react-icons/tb"; +import { useModelState } from "@/api/ws"; +import { ModelState } from "@/types/ws"; +import ActivityIndicator from "@/components/indicators/activity-indicator"; type ModelTrainingViewProps = { model: CustomClassificationModelConfig; @@ -54,6 +57,17 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { const [page, setPage] = useState("train"); const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); + // model state + + const { payload: lastModelState } = useModelState(model.name, true); + const modelState = useMemo(() => { + if (!lastModelState || lastModelState == "downloaded") { + return "complete"; + } + + return lastModelState; + }, [lastModelState]); + // dataset const { data: trainImages, mutate: refreshTrain } = useSWR( @@ -274,7 +288,14 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) { ) : ( - + )} {pageToggle == "train" ? (