Get model updates working

This commit is contained in:
Nicolas Mowen 2025-06-05 07:44:49 -06:00
parent 74a09ed489
commit afa8889de7
5 changed files with 31 additions and 5 deletions

View File

@ -29,6 +29,7 @@ class LoggerConfig(FrigateBaseModel):
logging.getLogger().setLevel(self.default.value.upper()) logging.getLogger().setLevel(self.default.value.upper())
log_levels = { log_levels = {
"absl": LogLevel.error,
"httpx": LogLevel.error, "httpx": LogLevel.error,
"werkzeug": LogLevel.error, "werkzeug": LogLevel.error,
"ws4py": LogLevel.error, "ws4py": LogLevel.error,

View File

@ -74,9 +74,10 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
UPDATE_MODEL_STATE, UPDATE_MODEL_STATE,
{ {
"model": self.model_config.name, "model": self.model_config.name,
"state": ModelStatusTypesEnum.training, "state": ModelStatusTypesEnum.complete,
}, },
) )
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
camera = frame_data.get("camera") camera = frame_data.get("camera")
@ -221,9 +222,10 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
UPDATE_MODEL_STATE, UPDATE_MODEL_STATE,
{ {
"model": self.model_config.name, "model": self.model_config.name,
"state": ModelStatusTypesEnum.training, "state": ModelStatusTypesEnum.complete,
}, },
) )
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
def process_frame(self, obj_data, frame): def process_frame(self, obj_data, frame):
if obj_data["label"] not in self.model_config.object_config.objects: if obj_data["label"] not in self.model_config.object_config.objects:

View File

@ -294,7 +294,7 @@ class EmbeddingsContext:
def start_classification_training(self, model_name: str) -> dict[str, Any]: def start_classification_training(self, model_name: str) -> dict[str, Any]:
return self.requestor.send_data( return self.requestor.send_data(
EmbeddingsRequestEnum.train_classification, {"model_name": model_name} EmbeddingsRequestEnum.train_classification.value, {"model_name": model_name}
) )
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]: def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:

View File

@ -73,7 +73,9 @@ export type ModelState =
| "not_downloaded" | "not_downloaded"
| "downloading" | "downloading"
| "downloaded" | "downloaded"
| "error"; | "error"
| "training"
| "complete";
export type EmbeddingsReindexProgressType = { export type EmbeddingsReindexProgressType = {
thumbnails: number; thumbnails: number;

View File

@ -45,6 +45,9 @@ import { toast } from "sonner";
import useSWR from "swr"; import useSWR from "swr";
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog"; import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
import { TbCategoryPlus } from "react-icons/tb"; import { TbCategoryPlus } from "react-icons/tb";
import { useModelState } from "@/api/ws";
import { ModelState } from "@/types/ws";
import ActivityIndicator from "@/components/indicators/activity-indicator";
type ModelTrainingViewProps = { type ModelTrainingViewProps = {
model: CustomClassificationModelConfig; model: CustomClassificationModelConfig;
@ -54,6 +57,17 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const [page, setPage] = useState<string>("train"); const [page, setPage] = useState<string>("train");
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100); const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
// model state
const { payload: lastModelState } = useModelState(model.name, true);
const modelState = useMemo<ModelState>(() => {
if (!lastModelState || lastModelState == "downloaded") {
return "complete";
}
return lastModelState;
}, [lastModelState]);
// dataset // dataset
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>( const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
@ -274,7 +288,14 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</Button> </Button>
</div> </div>
) : ( ) : (
<Button onClick={trainModel}>Train Model</Button> <Button
className="flex justify-center gap-2"
onClick={trainModel}
disabled={modelState != "complete"}
>
Train Model
{modelState == "training" && <ActivityIndicator size={20} />}
</Button>
)} )}
</div> </div>
{pageToggle == "train" ? ( {pageToggle == "train" ? (