Get model updates working

This commit is contained in:
Nicolas Mowen 2025-06-05 07:44:49 -06:00
parent 74a09ed489
commit afa8889de7
5 changed files with 31 additions and 5 deletions

View File

@ -29,6 +29,7 @@ class LoggerConfig(FrigateBaseModel):
logging.getLogger().setLevel(self.default.value.upper())
log_levels = {
"absl": LogLevel.error,
"httpx": LogLevel.error,
"werkzeug": LogLevel.error,
"ws4py": LogLevel.error,

View File

@ -74,9 +74,10 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.training,
"state": ModelStatusTypesEnum.complete,
},
)
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
camera = frame_data.get("camera")
@ -221,9 +222,10 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.training,
"state": ModelStatusTypesEnum.complete,
},
)
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
def process_frame(self, obj_data, frame):
if obj_data["label"] not in self.model_config.object_config.objects:

View File

@ -294,7 +294,7 @@ class EmbeddingsContext:
def start_classification_training(self, model_name: str) -> dict[str, Any]:
return self.requestor.send_data(
EmbeddingsRequestEnum.train_classification, {"model_name": model_name}
EmbeddingsRequestEnum.train_classification.value, {"model_name": model_name}
)
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:

View File

@ -73,7 +73,9 @@ export type ModelState =
| "not_downloaded"
| "downloading"
| "downloaded"
| "error";
| "error"
| "training"
| "complete";
export type EmbeddingsReindexProgressType = {
thumbnails: number;

View File

@ -45,6 +45,9 @@ import { toast } from "sonner";
import useSWR from "swr";
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
import { TbCategoryPlus } from "react-icons/tb";
import { useModelState } from "@/api/ws";
import { ModelState } from "@/types/ws";
import ActivityIndicator from "@/components/indicators/activity-indicator";
type ModelTrainingViewProps = {
model: CustomClassificationModelConfig;
@ -54,6 +57,17 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const [page, setPage] = useState<string>("train");
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
// model state
const { payload: lastModelState } = useModelState(model.name, true);
const modelState = useMemo<ModelState>(() => {
if (!lastModelState || lastModelState == "downloaded") {
return "complete";
}
return lastModelState;
}, [lastModelState]);
// dataset
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
@ -274,7 +288,14 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</Button>
</div>
) : (
<Button onClick={trainModel}>Train Model</Button>
<Button
className="flex justify-center gap-2"
onClick={trainModel}
disabled={modelState != "complete"}
>
Train Model
{modelState == "training" && <ActivityIndicator size={20} />}
</Button>
)}
</div>
{pageToggle == "train" ? (