mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-05-03 12:07:40 +03:00
Get model updates working
This commit is contained in:
parent
74a09ed489
commit
afa8889de7
@ -29,6 +29,7 @@ class LoggerConfig(FrigateBaseModel):
|
|||||||
logging.getLogger().setLevel(self.default.value.upper())
|
logging.getLogger().setLevel(self.default.value.upper())
|
||||||
|
|
||||||
log_levels = {
|
log_levels = {
|
||||||
|
"absl": LogLevel.error,
|
||||||
"httpx": LogLevel.error,
|
"httpx": LogLevel.error,
|
||||||
"werkzeug": LogLevel.error,
|
"werkzeug": LogLevel.error,
|
||||||
"ws4py": LogLevel.error,
|
"ws4py": LogLevel.error,
|
||||||
|
|||||||
@ -74,9 +74,10 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
UPDATE_MODEL_STATE,
|
UPDATE_MODEL_STATE,
|
||||||
{
|
{
|
||||||
"model": self.model_config.name,
|
"model": self.model_config.name,
|
||||||
"state": ModelStatusTypesEnum.training,
|
"state": ModelStatusTypesEnum.complete,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
|
||||||
|
|
||||||
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
||||||
camera = frame_data.get("camera")
|
camera = frame_data.get("camera")
|
||||||
@ -221,9 +222,10 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
UPDATE_MODEL_STATE,
|
UPDATE_MODEL_STATE,
|
||||||
{
|
{
|
||||||
"model": self.model_config.name,
|
"model": self.model_config.name,
|
||||||
"state": ModelStatusTypesEnum.training,
|
"state": ModelStatusTypesEnum.complete,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
logger.info(f"Successfully loaded updated model for {self.model_config.name}")
|
||||||
|
|
||||||
def process_frame(self, obj_data, frame):
|
def process_frame(self, obj_data, frame):
|
||||||
if obj_data["label"] not in self.model_config.object_config.objects:
|
if obj_data["label"] not in self.model_config.object_config.objects:
|
||||||
|
|||||||
@ -294,7 +294,7 @@ class EmbeddingsContext:
|
|||||||
|
|
||||||
def start_classification_training(self, model_name: str) -> dict[str, Any]:
|
def start_classification_training(self, model_name: str) -> dict[str, Any]:
|
||||||
return self.requestor.send_data(
|
return self.requestor.send_data(
|
||||||
EmbeddingsRequestEnum.train_classification, {"model_name": model_name}
|
EmbeddingsRequestEnum.train_classification.value, {"model_name": model_name}
|
||||||
)
|
)
|
||||||
|
|
||||||
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:
|
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:
|
||||||
|
|||||||
@ -73,7 +73,9 @@ export type ModelState =
|
|||||||
| "not_downloaded"
|
| "not_downloaded"
|
||||||
| "downloading"
|
| "downloading"
|
||||||
| "downloaded"
|
| "downloaded"
|
||||||
| "error";
|
| "error"
|
||||||
|
| "training"
|
||||||
|
| "complete";
|
||||||
|
|
||||||
export type EmbeddingsReindexProgressType = {
|
export type EmbeddingsReindexProgressType = {
|
||||||
thumbnails: number;
|
thumbnails: number;
|
||||||
|
|||||||
@ -45,6 +45,9 @@ import { toast } from "sonner";
|
|||||||
import useSWR from "swr";
|
import useSWR from "swr";
|
||||||
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
|
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
|
||||||
import { TbCategoryPlus } from "react-icons/tb";
|
import { TbCategoryPlus } from "react-icons/tb";
|
||||||
|
import { useModelState } from "@/api/ws";
|
||||||
|
import { ModelState } from "@/types/ws";
|
||||||
|
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
||||||
|
|
||||||
type ModelTrainingViewProps = {
|
type ModelTrainingViewProps = {
|
||||||
model: CustomClassificationModelConfig;
|
model: CustomClassificationModelConfig;
|
||||||
@ -54,6 +57,17 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
const [page, setPage] = useState<string>("train");
|
const [page, setPage] = useState<string>("train");
|
||||||
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
|
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
|
||||||
|
|
||||||
|
// model state
|
||||||
|
|
||||||
|
const { payload: lastModelState } = useModelState(model.name, true);
|
||||||
|
const modelState = useMemo<ModelState>(() => {
|
||||||
|
if (!lastModelState || lastModelState == "downloaded") {
|
||||||
|
return "complete";
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastModelState;
|
||||||
|
}, [lastModelState]);
|
||||||
|
|
||||||
// dataset
|
// dataset
|
||||||
|
|
||||||
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
||||||
@ -274,7 +288,14 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<Button onClick={trainModel}>Train Model</Button>
|
<Button
|
||||||
|
className="flex justify-center gap-2"
|
||||||
|
onClick={trainModel}
|
||||||
|
disabled={modelState != "complete"}
|
||||||
|
>
|
||||||
|
Train Model
|
||||||
|
{modelState == "training" && <ActivityIndicator size={20} />}
|
||||||
|
</Button>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
{pageToggle == "train" ? (
|
{pageToggle == "train" ? (
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user