diff --git a/frigate/api/classification.py b/frigate/api/classification.py index f5acc437c..1fc17a08f 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -7,7 +7,7 @@ import shutil from typing import Any import cv2 -from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile +from fastapi import APIRouter, Depends, Request, UploadFile from fastapi.responses import JSONResponse from pathvalidate import sanitize_filename from peewee import DoesNotExist @@ -24,7 +24,6 @@ from frigate.config.camera import DetectConfig from frigate.const import CLIPS_DIR, FACE_DIR from frigate.embeddings import EmbeddingsContext from frigate.models import Event -from frigate.util.classification import train_classification_model from frigate.util.path import get_event_snapshot logger = logging.getLogger(__name__) @@ -476,9 +475,7 @@ def get_classification_images(name: str): @router.post("/classification/{name}/train") -async def train_configured_model( - request: Request, name: str, background_tasks: BackgroundTasks -): +async def train_configured_model(request: Request, name: str): config: FrigateConfig = request.app.frigate_config if name not in config.classification.custom: @@ -492,7 +489,8 @@ async def train_configured_model( status_code=404, ) - background_tasks.add_task(train_classification_model, name) + context: EmbeddingsContext = request.app.embeddings + context.start_classification_training(name) return JSONResponse( content={"success": True, "message": "Started classification model training."}, status_code=200, diff --git a/frigate/comms/embeddings_updater.py b/frigate/comms/embeddings_updater.py index 00bc88b3d..5edb9e77d 100644 --- a/frigate/comms/embeddings_updater.py +++ b/frigate/comms/embeddings_updater.py @@ -9,16 +9,22 @@ SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings" class EmbeddingsRequestEnum(Enum): + # audio + transcribe_audio = "transcribe_audio" + # custom classification + train_classification = "train_classification" + # face clear_face_classifier = "clear_face_classifier" - embed_description = "embed_description" - embed_thumbnail = "embed_thumbnail" - generate_search = "generate_search" recognize_face = "recognize_face" register_face = "register_face" reprocess_face = "reprocess_face" - reprocess_plate = "reprocess_plate" + # semantic search + embed_description = "embed_description" + embed_thumbnail = "embed_thumbnail" + generate_search = "generate_search" reindex = "reindex" - transcribe_audio = "transcribe_audio" + # LPR + reprocess_plate = "reprocess_plate" class EmbeddingsResponder: diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 0e254ab0d..4f3cec71e 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -3,11 +3,13 @@ import datetime import logging import os +import threading from typing import Any import cv2 import numpy as np +from frigate.comms.embeddings_updater import EmbeddingsRequestEnum from frigate.comms.event_metadata_updater import ( EventMetadataPublisher, EventMetadataTypeEnum, @@ -15,8 +17,10 @@ from frigate.comms.event_metadata_updater import ( from frigate.comms.inter_process import InterProcessRequestor from frigate.config import FrigateConfig from frigate.config.classification import CustomClassificationConfig -from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR +from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE +from frigate.types import ModelStatusTypesEnum from frigate.util.builtin import load_labels +from frigate.util.classification import train_classification_model from frigate.util.object import box_overlaps, calculate_region from ..types import DataProcessorMetrics @@ -63,6 +67,17 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): prefill=0, ) + def __retrain_model(self) -> None: + train_classification_model(self.model_config.name) + self.__build_detector() + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": self.model_config.name, + "state": ModelStatusTypesEnum.training, + }, + ) + def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): camera = frame_data.get("camera") @@ -143,7 +158,24 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): ) def handle_request(self, topic, request_data): - return None + if topic == EmbeddingsRequestEnum.train_classification.value: + if request_data.get("model_name") == self.model_config.name: + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": self.model_config.name, + "state": ModelStatusTypesEnum.training, + }, + ) + threading.Thread(target=self.__retrain_model).start() + return { + "success": True, + "message": f"Began training {self.model_config.name} model.", + } + else: + return None + else: + return None def expire_object(self, object_id, camera): pass @@ -182,6 +214,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): prefill=0, ) + def __retrain_model(self) -> None: + train_classification_model(self.model_config.name) + self.__build_detector() + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": self.model_config.name, + "state": ModelStatusTypesEnum.training, + }, + ) + def process_frame(self, obj_data, frame): if obj_data["label"] not in self.model_config.object_config.objects: return @@ -236,7 +279,24 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): self.detected_objects[obj_data["id"]] = score def handle_request(self, topic, request_data): - return None + if topic == EmbeddingsRequestEnum.train_classification.value: + if request_data.get("model_name") == self.model_config.name: + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": self.model_config.name, + "state": ModelStatusTypesEnum.training, + }, + ) + threading.Thread(target=self.__retrain_model).start() + return { + "success": True, + "message": f"Began training {self.model_config.name} model.", + } + else: + return None + else: + return None def expire_object(self, object_id, camera): if object_id in self.detected_objects: diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index bc1887e2c..ad4a58825 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -292,6 +292,11 @@ class EmbeddingsContext: def reindex_embeddings(self) -> dict[str, Any]: return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {}) + def start_classification_training(self, model_name: str) -> dict[str, Any]: + return self.requestor.send_data( + EmbeddingsRequestEnum.train_classification, {"model_name": model_name} + ) + def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]: return self.requestor.send_data( EmbeddingsRequestEnum.transcribe_audio.value, {"event": event} diff --git a/frigate/types.py b/frigate/types.py index ee48cc02b..a9e27ba90 100644 --- a/frigate/types.py +++ b/frigate/types.py @@ -21,6 +21,8 @@ class ModelStatusTypesEnum(str, Enum): downloading = "downloading" downloaded = "downloaded" error = "error" + training = "training" + complete = "complete" class TrackedObjectUpdateTypesEnum(str, Enum):