Add support for detecting faces during registration

This commit is contained in:
Nicolas Mowen 2024-10-23 11:25:55 -06:00
parent 3a570e21d2
commit 40c83c69b9

View File

@ -103,7 +103,7 @@ class EmbeddingMaintainer(threading.Thread):
def _process_requests(self) -> None:
"""Process embeddings requests"""
def _handle_request(topic: str, data: str) -> str:
def _handle_request(topic: str, data: dict[str, any]) -> str:
try:
if topic == EmbeddingsRequestEnum.embed_description.value:
return serialize(
@ -123,12 +123,34 @@ class EmbeddingMaintainer(threading.Thread):
self.embeddings.text_embedding([data])[0], pack=False
)
elif topic == EmbeddingsRequestEnum.register_face.value:
self.embeddings.embed_face(
data["face_name"],
base64.b64decode(data["image"]),
upsert=True,
)
return None
if data.get("cropped"):
self.embeddings.embed_face(
data["face_name"],
base64.b64decode(data["image"]),
upsert=True,
)
return True
else:
img = cv2.imdecode(
np.frombuffer(
base64.b64decode(data["image"]), dtype=np.uint8
),
cv2.IMREAD_COLOR,
)
face_box = self._detect_face(img)
if not face_box:
return False
face = img[face_box[1] : face_box[3], face_box[0] : face_box[2]]
ret, webp = cv2.imencode(
".webp", face, [int(cv2.IMWRITE_WEBP_QUALITY), 100]
)
self.embeddings.embed_face(
data["face_name"], webp.tobytes(), upsert=True
)
return False
except Exception as e:
logger.error(f"Unable to handle embeddings request {e}")
@ -302,6 +324,29 @@ class EmbeddingMaintainer(threading.Thread):
"""
return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall()
def _detect_face(self, input: np.ndarray) -> tuple[int, int, int, int]:
"""Detect faces in input image."""
self.face_detector.setInputSize((input.shape[1], input.shape[0]))
faces = self.face_detector.detect(input)
if faces[1] is None:
return None
face = None
for _, potential_face in enumerate(faces[1]):
raw_bbox = potential_face[0:4].astype(np.uint16)
x: int = max(raw_bbox[0], 0)
y: int = max(raw_bbox[1], 0)
w: int = raw_bbox[2]
h: int = raw_bbox[3]
bbox = (x, y, x + w, y + h)
if face is None or area(bbox) > area(face):
face = bbox
return face
def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None:
"""Look for faces in image."""
id = obj_data["id"]
@ -331,27 +376,12 @@ class EmbeddingMaintainer(threading.Thread):
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
left, top, right, bottom = person_box
person = rgb[top:bottom, left:right]
face = self._detect_face(person)
self.face_detector.setInputSize((right - left, bottom - top))
faces = self.face_detector.detect(person)
if faces[1] is None:
if not face:
logger.debug("Detected no faces for person object.")
return
face = None
for _, potential_face in enumerate(faces[1]):
raw_bbox = potential_face[0:4].astype(np.int8)
x = max(raw_bbox[0], 0)
y = max(raw_bbox[1], 0)
w = raw_bbox[2]
h = raw_bbox[3]
bbox = (x, y, x + w, y + h)
if face is None or area(bbox) > area(face):
face = bbox
face_frame = person[face[1] : face[3], face[0] : face[2]]
face_frame = cv2.cvtColor(face_frame, cv2.COLOR_RGB2BGR)
else:
@ -384,7 +414,7 @@ class EmbeddingMaintainer(threading.Thread):
face_box[1] : face_box[3], face_box[0] : face_box[2]
]
ret, jpg = cv2.imencode(
ret, webp = cv2.imencode(
".webp", face_frame, [int(cv2.IMWRITE_WEBP_QUALITY), 100]
)
@ -392,7 +422,7 @@ class EmbeddingMaintainer(threading.Thread):
logger.debug("Not processing face due to error creating cropped image.")
return
embedding = self.embeddings.embed_face("unknown", jpg.tobytes(), upsert=False)
embedding = self.embeddings.embed_face("unknown", webp.tobytes(), upsert=False)
query_embedding = serialize(embedding)
best_faces = self._search_face(query_embedding)
logger.debug(f"Detected best faces for person as: {best_faces}")
@ -409,7 +439,7 @@ class EmbeddingMaintainer(threading.Thread):
if face[0].split("-")[0] != sub_label:
logger.debug("Detected multiple faces, result is not valid.")
return None
return
avg_score += score
@ -421,7 +451,7 @@ class EmbeddingMaintainer(threading.Thread):
logger.debug(
f"Recognized face score {avg_score} is less than threshold ({self.config.face_recognition.threshold}) / previous face score ({self.detected_faces.get(id)})."
)
return None
return
resp = requests.post(
f"{FRIGATE_LOCALHOST}/api/events/{id}/sub_label",