This commit is contained in:
Nicolas Mowen 2025-05-29 11:38:11 -06:00
parent a0a422df36
commit 4a5e6d3f97
3 changed files with 45 additions and 6 deletions

View File

@ -7,7 +7,7 @@ import shutil
from typing import Any
import cv2
from fastapi import APIRouter, Depends, Request, UploadFile
from fastapi import APIRouter, BackgroundTasks, Depends, Request, UploadFile
from fastapi.responses import JSONResponse
from pathvalidate import sanitize_filename
from peewee import DoesNotExist
@ -19,10 +19,12 @@ from frigate.api.defs.request.classification_body import (
RenameFaceBody,
)
from frigate.api.defs.tags import Tags
from frigate.config import FrigateConfig
from frigate.config.camera import DetectConfig
from frigate.const import FACE_DIR
from frigate.const import FACE_DIR, MODEL_CACHE_DIR
from frigate.embeddings import EmbeddingsContext
from frigate.models import Event
from frigate.util.classification import train_classification_model
from frigate.util.path import get_event_snapshot
logger = logging.getLogger(__name__)
@ -424,3 +426,32 @@ def transcribe_audio(request: Request, body: AudioTranscriptionBody):
},
status_code=500,
)
# custom classification training
@router.post("/classification/{name}/train")
async def train_configured_model(
request: Request, name: str, background_tasks: BackgroundTasks
):
config: FrigateConfig = request.app.frigate_config
if name not in config.classification.custom:
return JSONResponse(
content=(
{
"success": False,
"message": f"{name} is not a known classification model.",
}
),
status_code=404,
)
background_tasks.add_task(
train_classification_model, os.path.join(MODEL_CACHE_DIR, name)
)
return JSONResponse(
content={"success": True, "message": "Started classification model training."},
status_code=200,
)

View File

@ -129,7 +129,11 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
score = round(probs[best_id], 2)
write_classification_attempt(
self.train_dir, frame, now, self.labelmap[best_id], score
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
now,
self.labelmap[best_id],
score,
)
if score >= camera_config.threshold:
@ -214,7 +218,11 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
previous_score = self.detected_objects.get(obj_data["id"], 0.0)
write_classification_attempt(
self.train_dir, frame, now, self.labelmap[best_id], score
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
now,
self.labelmap[best_id],
score,
)
if score <= previous_score:

View File

@ -150,10 +150,10 @@ class EmbeddingMaintainer(threading.Thread):
)
)
for name, model_config in self.config.classification.custom.items():
for model_config in self.config.classification.custom.values():
self.realtime_processors.append(
CustomStateClassificationProcessor(
self.config, model_config, name, self.requestor, self.metrics
self.config, model_config, self.requestor, self.metrics
)
if model_config.state_config != None
else CustomObjectClassificationProcessor(