This commit is contained in:
Nicolas Mowen 2025-05-29 11:38:11 -06:00
parent a0a422df36
commit 4a5e6d3f97
3 changed files with 45 additions and 6 deletions

View File

@ -7,7 +7,7 @@ import shutil
from typing import Any from typing import Any
import cv2 import cv2
from fastapi import APIRouter, Depends, Request, UploadFile from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pathvalidate import sanitize_filename from pathvalidate import sanitize_filename
from peewee import DoesNotExist from peewee import DoesNotExist
@ -19,10 +19,12 @@ from frigate.api.defs.request.classification_body import (
RenameFaceBody, RenameFaceBody,
) )
from frigate.api.defs.tags import Tags from frigate.api.defs.tags import Tags
from frigate.config import FrigateConfig
from frigate.config.camera import DetectConfig from frigate.config.camera import DetectConfig
from frigate.const import FACE_DIR from frigate.const import FACE_DIR, MODEL_CACHE_DIR
from frigate.embeddings import EmbeddingsContext from frigate.embeddings import EmbeddingsContext
from frigate.models import Event from frigate.models import Event
from frigate.util.classification import train_classification_model
from frigate.util.path import get_event_snapshot from frigate.util.path import get_event_snapshot
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -424,3 +426,32 @@ def transcribe_audio(request: Request, body: AudioTranscriptionBody):
}, },
status_code=500, status_code=500,
) )
# custom classification training
@router.post("/classification/{name}/train")
async def train_configured_model(
request: Request, name: str, background_tasks: BackgroundTasks
):
config: FrigateConfig = request.app.frigate_config
if name not in config.classification.custom:
return JSONResponse(
content=(
{
"success": False,
"message": f"{name} is not a known classification model.",
}
),
status_code=404,
)
background_tasks.add_task(
train_classification_model, os.path.join(MODEL_CACHE_DIR, name)
)
return JSONResponse(
content={"success": True, "message": "Started classification model training."},
status_code=200,
)

View File

@ -129,7 +129,11 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
score = round(probs[best_id], 2) score = round(probs[best_id], 2)
write_classification_attempt( write_classification_attempt(
self.train_dir, frame, now, self.labelmap[best_id], score self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
now,
self.labelmap[best_id],
score,
) )
if score >= camera_config.threshold: if score >= camera_config.threshold:
@ -214,7 +218,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
previous_score = self.detected_objects.get(obj_data["id"], 0.0) previous_score = self.detected_objects.get(obj_data["id"], 0.0)
write_classification_attempt( write_classification_attempt(
self.train_dir, frame, now, self.labelmap[best_id], score self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
now,
self.labelmap[best_id],
score,
) )
if score <= previous_score: if score <= previous_score:

View File

@ -150,10 +150,10 @@ class EmbeddingMaintainer(threading.Thread):
) )
) )
for name, model_config in self.config.classification.custom.items(): for model_config in self.config.classification.custom.values():
self.realtime_processors.append( self.realtime_processors.append(
CustomStateClassificationProcessor( CustomStateClassificationProcessor(
self.config, model_config, name, self.requestor, self.metrics self.config, model_config, self.requestor, self.metrics
) )
if model_config.state_config != None if model_config.state_config != None
else CustomObjectClassificationProcessor( else CustomObjectClassificationProcessor(