From e4210a7eb8b09b562264e605927536cfd8aebc5e Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Tue, 22 Oct 2024 07:46:24 -0600 Subject: [PATCH] Implement face uploading API --- docker/main/requirements-wheels.txt | 2 ++ frigate/api/classification.py | 27 ++++++++++++++++++++++++--- frigate/comms/embeddings_updater.py | 1 + frigate/embeddings/__init__.py | 7 +++++++ frigate/embeddings/maintainer.py | 7 +++++++ 5 files changed, 41 insertions(+), 3 deletions(-) diff --git a/docker/main/requirements-wheels.txt b/docker/main/requirements-wheels.txt index 02dd62795..d0d53608a 100644 --- a/docker/main/requirements-wheels.txt +++ b/docker/main/requirements-wheels.txt @@ -8,6 +8,8 @@ imutils == 0.5.* joserfc == 1.0.* pathvalidate == 3.2.* markupsafe == 2.1.* +python-multipart == 0.0.12 +# General mypy == 1.6.1 numpy == 1.26.* onvif_zeep == 0.2.12 diff --git a/frigate/api/classification.py b/frigate/api/classification.py index f2fff8983..5c4b0e6c2 100644 --- a/frigate/api/classification.py +++ b/frigate/api/classification.py @@ -2,9 +2,11 @@ import logging -from fastapi import APIRouter +from fastapi import APIRouter, Request, UploadFile +from fastapi.responses import JSONResponse from frigate.api.defs.tags import Tags +from frigate.embeddings import EmbeddingsContext logger = logging.getLogger(__name__) @@ -12,5 +14,24 @@ router = APIRouter(tags=[Tags.events]) @router.get("/faces") -def get_faces() -> None: - return None +def get_faces(): + return JSONResponse(content={"message": "there are faces"}) + + +@router.post("/faces/{name}") +async def register_face(request: Request, name: str, file: UploadFile): + #if not file.content_type.startswith("image"): + # return JSONResponse( + # status_code=400, + # content={ + # "success": False, + # "message": "Only an image can be used to register a face.", + # }, + # ) + + context: EmbeddingsContext = request.app.embeddings + context.register_face(name, await file.read()) + return JSONResponse( + status_code=200, + content={"success": True, "message": "Successfully registered face."}, + ) diff --git a/frigate/comms/embeddings_updater.py b/frigate/comms/embeddings_updater.py index 9a13525f8..728c58211 100644 --- a/frigate/comms/embeddings_updater.py +++ b/frigate/comms/embeddings_updater.py @@ -12,6 +12,7 @@ class EmbeddingsRequestEnum(Enum): embed_description = "embed_description" embed_thumbnail = "embed_thumbnail" generate_search = "generate_search" + register_face = "register_face" class EmbeddingsResponder: diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index 7f2e1a10c..0c4628272 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -1,5 +1,6 @@ """SQLite-vec embeddings database.""" +import base64 import json import logging import multiprocessing as mp @@ -189,6 +190,12 @@ class EmbeddingsContext: return results + def register_face(self, face_name: str, image_data: bytes) -> None: + self.requestor.send_data( + EmbeddingsRequestEnum.register_face.value, + {"face_name": face_name, "image": base64.b64encode(image_data).decode("ASCII")}, + ) + def update_description(self, event_id: str, description: str) -> None: self.requestor.send_data( EmbeddingsRequestEnum.embed_description.value, diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index a198be79c..ac6d14d63 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -109,6 +109,13 @@ class EmbeddingMaintainer(threading.Thread): return serialize( self.embeddings.text_embedding([data])[0], pack=False ) + elif topic == EmbeddingsRequestEnum.register_face.value: + self.embeddings.embed_face( + data["face_name"], + base64.b64decode(data["image"]), + upsert=True, + ) + return None except Exception as e: logger.error(f"Unable to handle embeddings request {e}")