mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-05-01 11:07:41 +03:00
Implement model training via ZMQ and add model states to represent training
This commit is contained in:
parent
85d721eb6b
commit
74a09ed489
@ -7,7 +7,7 @@ import shutil
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
from pathvalidate import sanitize_filename
|
||||
from peewee import DoesNotExist
|
||||
@ -24,7 +24,6 @@ from frigate.config.camera import DetectConfig
|
||||
from frigate.const import CLIPS_DIR, FACE_DIR
|
||||
from frigate.embeddings import EmbeddingsContext
|
||||
from frigate.models import Event
|
||||
from frigate.util.classification import train_classification_model
|
||||
from frigate.util.path import get_event_snapshot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -476,9 +475,7 @@ def get_classification_images(name: str):
|
||||
|
||||
|
||||
@router.post("/classification/{name}/train")
|
||||
async def train_configured_model(
|
||||
request: Request, name: str, background_tasks: BackgroundTasks
|
||||
):
|
||||
async def train_configured_model(request: Request, name: str):
|
||||
config: FrigateConfig = request.app.frigate_config
|
||||
|
||||
if name not in config.classification.custom:
|
||||
@ -492,7 +489,8 @@ async def train_configured_model(
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
background_tasks.add_task(train_classification_model, name)
|
||||
context: EmbeddingsContext = request.app.embeddings
|
||||
context.start_classification_training(name)
|
||||
return JSONResponse(
|
||||
content={"success": True, "message": "Started classification model training."},
|
||||
status_code=200,
|
||||
|
||||
@ -9,16 +9,22 @@ SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
|
||||
|
||||
|
||||
class EmbeddingsRequestEnum(Enum):
|
||||
# audio
|
||||
transcribe_audio = "transcribe_audio"
|
||||
# custom classification
|
||||
train_classification = "train_classification"
|
||||
# face
|
||||
clear_face_classifier = "clear_face_classifier"
|
||||
embed_description = "embed_description"
|
||||
embed_thumbnail = "embed_thumbnail"
|
||||
generate_search = "generate_search"
|
||||
recognize_face = "recognize_face"
|
||||
register_face = "register_face"
|
||||
reprocess_face = "reprocess_face"
|
||||
reprocess_plate = "reprocess_plate"
|
||||
# semantic search
|
||||
embed_description = "embed_description"
|
||||
embed_thumbnail = "embed_thumbnail"
|
||||
generate_search = "generate_search"
|
||||
reindex = "reindex"
|
||||
transcribe_audio = "transcribe_audio"
|
||||
# LPR
|
||||
reprocess_plate = "reprocess_plate"
|
||||
|
||||
|
||||
class EmbeddingsResponder:
|
||||
|
||||
@ -3,11 +3,13 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum
|
||||
from frigate.comms.event_metadata_updater import (
|
||||
EventMetadataPublisher,
|
||||
EventMetadataTypeEnum,
|
||||
@ -15,8 +17,10 @@ from frigate.comms.event_metadata_updater import (
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.config.classification import CustomClassificationConfig
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.builtin import load_labels
|
||||
from frigate.util.classification import train_classification_model
|
||||
from frigate.util.object import box_overlaps, calculate_region
|
||||
|
||||
from ..types import DataProcessorMetrics
|
||||
@ -63,6 +67,17 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
prefill=0,
|
||||
)
|
||||
|
||||
def __retrain_model(self) -> None:
|
||||
train_classification_model(self.model_config.name)
|
||||
self.__build_detector()
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.training,
|
||||
},
|
||||
)
|
||||
|
||||
def process_frame(self, frame_data: dict[str, Any], frame: np.ndarray):
|
||||
camera = frame_data.get("camera")
|
||||
|
||||
@ -143,7 +158,24 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
return None
|
||||
if topic == EmbeddingsRequestEnum.train_classification.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.training,
|
||||
},
|
||||
)
|
||||
threading.Thread(target=self.__retrain_model).start()
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Began training {self.model_config.name} model.",
|
||||
}
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def expire_object(self, object_id, camera):
|
||||
pass
|
||||
@ -182,6 +214,17 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
prefill=0,
|
||||
)
|
||||
|
||||
def __retrain_model(self) -> None:
|
||||
train_classification_model(self.model_config.name)
|
||||
self.__build_detector()
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.training,
|
||||
},
|
||||
)
|
||||
|
||||
def process_frame(self, obj_data, frame):
|
||||
if obj_data["label"] not in self.model_config.object_config.objects:
|
||||
return
|
||||
@ -236,7 +279,24 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
self.detected_objects[obj_data["id"]] = score
|
||||
|
||||
def handle_request(self, topic, request_data):
|
||||
return None
|
||||
if topic == EmbeddingsRequestEnum.train_classification.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
self.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": self.model_config.name,
|
||||
"state": ModelStatusTypesEnum.training,
|
||||
},
|
||||
)
|
||||
threading.Thread(target=self.__retrain_model).start()
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Began training {self.model_config.name} model.",
|
||||
}
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def expire_object(self, object_id, camera):
|
||||
if object_id in self.detected_objects:
|
||||
|
||||
@ -292,6 +292,11 @@ class EmbeddingsContext:
|
||||
def reindex_embeddings(self) -> dict[str, Any]:
|
||||
return self.requestor.send_data(EmbeddingsRequestEnum.reindex.value, {})
|
||||
|
||||
def start_classification_training(self, model_name: str) -> dict[str, Any]:
|
||||
return self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.train_classification, {"model_name": model_name}
|
||||
)
|
||||
|
||||
def transcribe_audio(self, event: dict[str, any]) -> dict[str, any]:
|
||||
return self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.transcribe_audio.value, {"event": event}
|
||||
|
||||
@ -21,6 +21,8 @@ class ModelStatusTypesEnum(str, Enum):
|
||||
downloading = "downloading"
|
||||
downloaded = "downloaded"
|
||||
error = "error"
|
||||
training = "training"
|
||||
complete = "complete"
|
||||
|
||||
|
||||
class TrackedObjectUpdateTypesEnum(str, Enum):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user