Implement model training via ZMQ and add model states to represent training

This commit is contained in:
Nicolas Mowen 2025-06-05 07:31:00 -06:00
parent 85d721eb6b
commit 74a09ed489
5 changed files with 85 additions and 14 deletions

View File

@ -7,7 +7,7 @@ import shutil
from typing import Any from typing import Any
import cv2 import cv2
from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile from fastapi import APIRouter, Depends, Request, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pathvalidate import sanitize_filename from pathvalidate import sanitize_filename
from peewee import DoesNotExist from peewee import DoesNotExist
@ -24,7 +24,6 @@ from frigate.config.camera import DetectConfig
from frigate.const import CLIPS_DIR, FACE_DIR from frigate.const import CLIPS_DIR, FACE_DIR
from frigate.embeddings import EmbeddingsContext from frigate.embeddings import EmbeddingsContext
from frigate.models import Event from frigate.models import Event
from frigate.util.classification import train_classification_model
from frigate.util.path import get_event_snapshot from frigate.util.path import get_event_snapshot
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -476,9 +475,7 @@ def get_classification_images(name: str):
@router.post("/classification/{name}/train") @router.post("/classification/{name}/train")
async def train_configured_model( async def train_configured_model(request: Request, name: str):
request: Request, name: str, background_tasks: BackgroundTasks
):
config: FrigateConfig = request.app.frigate_config config: FrigateConfig = request.app.frigate_config
if name not in config.classification.custom: if name not in config.classification.custom:
@ -492,7 +489,8 @@ async def train_configured_model(
status_code=404, status_code=404,
) )
background_tasks.add_task(train_classification_model, name) context: EmbeddingsContext = request.app.embeddings
context.start_classification_training(name)
return JSONResponse( return JSONResponse(
content={"success": True, "message": "Started classification model training."}, content={"success": True, "message": "Started classification model training."},
status_code=200, status_code=200,

View File

@ -9,16 +9,22 @@ SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
class EmbeddingsRequestEnum(Enum): class EmbeddingsRequestEnum(Enum):
# audio
transcribe_audio = "transcribe_audio"
# custom classification
train_classification = "train_classification"
# face
clear_face_classifier = "clear_face_classifier" clear_face_classifier = "clear_face_classifier"
embed_description = "embed_description"
embed_thumbnail = "embed_thumbnail"
generate_search = "generate_search"
recognize_face = "recognize_face" recognize_face = "recognize_face"
register_face = "register_face" register_face = "register_face"
reprocess_face = "reprocess_face" reprocess_face = "reprocess_face"
reprocess_plate = "reprocess_plate" # semantic search
embed_description = "embed_description"
embed_thumbnail = "embed_thumbnail"
generate_search = "generate_search"
reindex = "reindex" reindex = "reindex"
transcribe_audio = "transcribe_audio" # LPR
reprocess_plate = "reprocess_plate"
class EmbeddingsResponder: class EmbeddingsResponder:

View File

@ -3,11 +3,13 @@
import datetime import datetime
import logging import logging
import os import os
import threading
from typing import Any from typing import Any
import cv2 import cv2
import numpy as np import numpy as np
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum
from frigate.comms.event_metadata_updater import ( from frigate.comms.event_metadata_updater import (
EventMetadataPublisher, EventMetadataPublisher,
EventMetadataTypeEnum, EventMetadataTypeEnum,
@ -15,8 +17,10 @@ from frigate.comms.event_metadata_updater import (
from frigate.comms.inter_process import InterProcessRequestor from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.config.classification import CustomClassificationConfig from frigate.config.classification import CustomClassificationConfig
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.builtin import load_labels from frigate.util.builtin import load_labels
from frigate.util.classification import train_classification_model
from frigate.util.object import box_overlaps, calculate_region from frigate.util.object import box_overlaps, calculate_region
from ..types import DataProcessorMetrics from ..types import DataProcessorMetrics
@ -63,6 +67,17 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
prefill=0, prefill=0,
) )
def __retrain_model(self) -> None:
train_classification_model(self.model_config.name)
self.__build_detector()
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.training,
},
)
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray): def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
camera = frame_data.get("camera") camera = frame_data.get("camera")
@ -143,7 +158,24 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
) )
def handle_request(self, topic, request_data): def handle_request(self, topic, request_data):
return None if topic == EmbeddingsRequestEnum.train_classification.value:
if request_data.get("model_name") == self.model_config.name:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.training,
},
)
threading.Thread(target=self.__retrain_model).start()
return {
"success": True,
"message": f"Began training {self.model_config.name} model.",
}
else:
return None
else:
return None
def expire_object(self, object_id, camera): def expire_object(self, object_id, camera):
pass pass
@ -182,6 +214,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
prefill=0, prefill=0,
) )
def __retrain_model(self) -> None:
train_classification_model(self.model_config.name)
self.__build_detector()
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.training,
},
)
def process_frame(self, obj_data, frame): def process_frame(self, obj_data, frame):
if obj_data["label"] not in self.model_config.object_config.objects: if obj_data["label"] not in self.model_config.object_config.objects:
return return
@ -236,7 +279,24 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
self.detected_objects[obj_data["id"]] = score self.detected_objects[obj_data["id"]] = score
def handle_request(self, topic, request_data): def handle_request(self, topic, request_data):
return None if topic == EmbeddingsRequestEnum.train_classification.value:
if request_data.get("model_name") == self.model_config.name:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": self.model_config.name,
"state": ModelStatusTypesEnum.training,
},
)
threading.Thread(target=self.__retrain_model).start()
return {
"success": True,
"message": f"Began training {self.model_config.name} model.",
}
else:
return None
else:
return None
def expire_object(self, object_id, camera): def expire_object(self, object_id, camera):
if object_id in self.detected_objects: if object_id in self.detected_objects:

View File

@ -292,6 +292,11 @@ class EmbeddingsContext:
def reindex_embeddings(self) -> dict[str, Any]: def reindex_embeddings(self) -> dict[str, Any]:
return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {}) return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {})
def start_classification_training(self, model_name: str) -> dict[str, Any]:
return self.requestor.send_data(
EmbeddingsRequestEnum.train_classification, {"model_name": model_name}
)
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]: def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:
return self.requestor.send_data( return self.requestor.send_data(
EmbeddingsRequestEnum.transcribe_audio.value, {"event": event} EmbeddingsRequestEnum.transcribe_audio.value, {"event": event}

View File

@ -21,6 +21,8 @@ class ModelStatusTypesEnum(str, Enum):
downloading = "downloading" downloading = "downloading"
downloaded = "downloaded" downloaded = "downloaded"
error = "error" error = "error"
training = "training"
complete = "complete"
class TrackedObjectUpdateTypesEnum(str, Enum): class TrackedObjectUpdateTypesEnum(str, Enum):