diff --git a/frigate/api/classification.py b/frigate/api/classification.py index 81112933c..f2c6ac06b 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -7,7 +7,7 @@ import shutil from typing import Any import cv2 -from fastapi import APIRouter, Depends, Request, UploadFile +from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile from fastapi.responses import JSONResponse from pathvalidate import sanitize_filename from peewee import DoesNotExist @@ -19,10 +19,12 @@ from frigate.api.defs.request.classification_body import ( RenameFaceBody, ) from frigate.api.defs.tags import Tags +from frigate.config import FrigateConfig from frigate.config.camera import DetectConfig -from frigate.const import FACE_DIR +from frigate.const import FACE_DIR, MODEL_CACHE_DIR from frigate.embeddings import EmbeddingsContext from frigate.models import Event +from frigate.util.classification import train_classification_model from frigate.util.path import get_event_snapshot logger = logging.getLogger(__name__) @@ -424,3 +426,32 @@ def transcribe_audio(request: Request, body: AudioTranscriptionBody): }, status_code=500, ) + + +# custom classification training + + +@router.post("/classification/{name}/train") +async def train_configured_model( + request: Request, name: str, background_tasks: BackgroundTasks +): + config: FrigateConfig = request.app.frigate_config + + if name not in config.classification.custom: + return JSONResponse( + content=( + { + "success": False, + "message": f"{name} is not a known classification model.", + } + ), + status_code=404, + ) + + background_tasks.add_task( + train_classification_model, os.path.join(MODEL_CACHE_DIR, name) + ) + return JSONResponse( + content={"success": True, "message": "Started classification model training."}, + status_code=200, + ) diff --git a/frigate/data_processing/real_time/custom_classification.py b/frigate/data_processing/real_time/custom_classification.py index 8bdf64033..4348afce6 100644 --- a/frigate/data_processing/real_time/custom_classification.py +++ b/frigate/data_processing/real_time/custom_classification.py @@ -129,7 +129,11 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi): score = round(probs[best_id], 2) write_classification_attempt( - self.train_dir, frame, now, self.labelmap[best_id], score + self.train_dir, + cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), + now, + self.labelmap[best_id], + score, ) if score >= camera_config.threshold: @@ -214,7 +218,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi): previous_score = self.detected_objects.get(obj_data["id"], 0.0) write_classification_attempt( - self.train_dir, frame, now, self.labelmap[best_id], score + self.train_dir, + cv2.cvtColor(frame, cv2.COLOR_RGB2BGR), + now, + self.labelmap[best_id], + score, ) if score <= previous_score: diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 25601f014..9a2378221 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -150,10 +150,10 @@ class EmbeddingMaintainer(threading.Thread): ) ) - for name, model_config in self.config.classification.custom.items(): + for model_config in self.config.classification.custom.values(): self.realtime_processors.append( CustomStateClassificationProcessor( - self.config, model_config, name, self.requestor, self.metrics + self.config, model_config, self.requestor, self.metrics ) if model_config.state_config != None else CustomObjectClassificationProcessor(