diff --git a/frigate/api/classification.py b/frigate/api/classification.py index 85b604379..df804f34a 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -6,6 +6,7 @@ import random import shutil import string +import cv2 from fastapi import APIRouter, Depends, Request, UploadFile from fastapi.responses import JSONResponse from pathvalidate import sanitize_filename @@ -14,9 +15,11 @@ from playhouse.shortcuts import model_to_dict from frigate.api.auth import require_role from frigate.api.defs.tags import Tags +from frigate.config.camera import DetectConfig from frigate.const import FACE_DIR from frigate.embeddings import EmbeddingsContext from frigate.models import Event +from frigate.util.path import get_event_snapshot logger = logging.getLogger(__name__) @@ -87,16 +90,27 @@ def train_face(request: Request, name: str, body: dict = None): ) json: dict[str, any] = body or {} - training_file = os.path.join( - FACE_DIR, f"train/{sanitize_filename(json.get('training_file', ''))}" - ) + training_file_name = sanitize_filename(json.get("training_file", "")) + training_file = os.path.join(FACE_DIR, f"train/{training_file_name}") + event_id = json.get("event_id") - if not training_file or not os.path.isfile(training_file): + if not training_file_name and not event_id: return JSONResponse( content=( { "success": False, - "message": f"Invalid filename or no file exists: {training_file}", + "message": "A training file or event_id must be passed.", + } + ), + status_code=400, + ) + + if training_file_name and not os.path.isfile(training_file): + return JSONResponse( + content=( + { + "success": False, + "message": f"Invalid filename or no file exists: {training_file_name}", } ), status_code=404, @@ -106,7 +120,36 @@ def train_face(request: Request, name: str, body: dict = None): rand_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=6)) new_name = f"{sanitized_name}-{rand_id}.webp" new_file = os.path.join(FACE_DIR, f"{sanitized_name}/{new_name}") - shutil.move(training_file, new_file) + + if training_file_name: + shutil.move(training_file, new_file) + else: + try: + event: Event = Event.get(Event.id == event_id) + except DoesNotExist: + return JSONResponse( + content=( + { + "success": False, + "message": f"Invalid event_id or no event exists: {event_id}", + } + ), + status_code=404, + ) + + snapshot = get_event_snapshot(event) + face_box = event.data["attributes"][0]["box"] + detect_config: DetectConfig = request.app.frigate_config.cameras[ + event.camera + ].detect + + # crop onto the face box minus the bounding box itself + x1 = int(face_box[0] * detect_config.width) + 2 + y1 = int(face_box[1] * detect_config.height) + 2 + x2 = x1 + int(face_box[2] * detect_config.width) - 4 + y2 = y1 + int(face_box[3] * detect_config.height) - 4 + face = snapshot[y1:y2, x1:x2] + cv2.imwrite(new_file, face) context: EmbeddingsContext = request.app.embeddings context.clear_face_classifier() @@ -115,7 +158,7 @@ def train_face(request: Request, name: str, body: dict = None): content=( { "success": True, - "message": f"Successfully saved {training_file} as {new_name}.", + "message": f"Successfully saved {training_file_name} as {new_name}.", } ), status_code=200, diff --git a/frigate/util/path.py b/frigate/util/path.py index dbe51abe5..565f5a357 100644 --- a/frigate/util/path.py +++ b/frigate/util/path.py @@ -4,6 +4,9 @@ import base64 import os from pathlib import Path +import cv2 +from numpy import ndarray + from frigate.const import CLIPS_DIR, THUMB_DIR from frigate.models import Event @@ -21,6 +24,11 @@ def get_event_thumbnail_bytes(event: Event) -> bytes | None: return None +def get_event_snapshot(event: Event) -> ndarray: + media_name = f"{event.camera}-{event.id}" + return cv2.imread(f"{os.path.join(CLIPS_DIR, media_name)}.jpg") + + ### Deletion diff --git a/web/src/components/overlay/detail/SearchDetailDialog.tsx b/web/src/components/overlay/detail/SearchDetailDialog.tsx index b8c230178..b22eb9a4c 100644 --- a/web/src/components/overlay/detail/SearchDetailDialog.tsx +++ b/web/src/components/overlay/detail/SearchDetailDialog.tsx @@ -564,7 +564,7 @@ function ObjectDetailsTab({ return false; } - return search.data.attributes.find((attr) => attr.label == "face"); + return search.data.attributes?.find((attr) => attr.label == "face"); }, [config, search]); const { data: faceData } = useSWR(hasFace ? "faces" : null); diff --git a/web/src/types/search.ts b/web/src/types/search.ts index af8e61674..2a57385f7 100644 --- a/web/src/types/search.ts +++ b/web/src/types/search.ts @@ -50,7 +50,7 @@ export type SearchResult = { score: number; sub_label_score?: number; region: number[]; - attributes: [{ box: number[]; label: string; score: number }]; + attributes?: [{ box: number[]; label: string; score: number }]; box: number[]; area: number; ratio: number;