diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index e7fa2c4f2..ca7d09238 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -103,7 +103,7 @@ class EmbeddingMaintainer(threading.Thread): def _process_requests(self) -> None: """Process embeddings requests""" - def _handle_request(topic: str, data: str) -> str: + def _handle_request(topic: str, data: dict[str, any]) -> str: try: if topic == EmbeddingsRequestEnum.embed_description.value: return serialize( @@ -123,12 +123,34 @@ class EmbeddingMaintainer(threading.Thread): self.embeddings.text_embedding([data])[0], pack=False ) elif topic == EmbeddingsRequestEnum.register_face.value: - self.embeddings.embed_face( - data["face_name"], - base64.b64decode(data["image"]), - upsert=True, - ) - return None + if data.get("cropped"): + self.embeddings.embed_face( + data["face_name"], + base64.b64decode(data["image"]), + upsert=True, + ) + return True + else: + img = cv2.imdecode( + np.frombuffer( + base64.b64decode(data["image"]), dtype=np.uint8 + ), + cv2.IMREAD_COLOR, + ) + face_box = self._detect_face(img) + + if not face_box: + return False + + face = img[face_box[1] : face_box[3], face_box[0] : face_box[2]] + ret, webp = cv2.imencode( + ".webp", face, [int(cv2.IMWRITE_WEBP_QUALITY), 100] + ) + self.embeddings.embed_face( + data["face_name"], webp.tobytes(), upsert=True + ) + + return False except Exception as e: logger.error(f"Unable to handle embeddings request {e}") @@ -302,6 +324,29 @@ class EmbeddingMaintainer(threading.Thread): """ return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall() + def _detect_face(self, input: np.ndarray) -> tuple[int, int, int, int]: + """Detect faces in input image.""" + self.face_detector.setInputSize((input.shape[1], input.shape[0])) + faces = self.face_detector.detect(input) + + if faces[1] is None: + return None + + face = None + + for _, potential_face in enumerate(faces[1]): + raw_bbox = potential_face[0:4].astype(np.uint16) + x: int = max(raw_bbox[0], 0) + y: int = max(raw_bbox[1], 0) + w: int = raw_bbox[2] + h: int = raw_bbox[3] + bbox = (x, y, x + w, y + h) + + if face is None or area(bbox) > area(face): + face = bbox + + return face + def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None: """Look for faces in image.""" id = obj_data["id"] @@ -331,27 +376,12 @@ class EmbeddingMaintainer(threading.Thread): rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420) left, top, right, bottom = person_box person = rgb[top:bottom, left:right] + face = self._detect_face(person) - self.face_detector.setInputSize((right - left, bottom - top)) - faces = self.face_detector.detect(person) - - if faces[1] is None: + if not face: logger.debug("Detected no faces for person object.") return - face = None - - for _, potential_face in enumerate(faces[1]): - raw_bbox = potential_face[0:4].astype(np.int8) - x = max(raw_bbox[0], 0) - y = max(raw_bbox[1], 0) - w = raw_bbox[2] - h = raw_bbox[3] - bbox = (x, y, x + w, y + h) - - if face is None or area(bbox) > area(face): - face = bbox - face_frame = person[face[1] : face[3], face[0] : face[2]] face_frame = cv2.cvtColor(face_frame, cv2.COLOR_RGB2BGR) else: @@ -384,7 +414,7 @@ class EmbeddingMaintainer(threading.Thread): face_box[1] : face_box[3], face_box[0] : face_box[2] ] - ret, jpg = cv2.imencode( + ret, webp = cv2.imencode( ".webp", face_frame, [int(cv2.IMWRITE_WEBP_QUALITY), 100] ) @@ -392,7 +422,7 @@ class EmbeddingMaintainer(threading.Thread): logger.debug("Not processing face due to error creating cropped image.") return - embedding = self.embeddings.embed_face("unknown", jpg.tobytes(), upsert=False) + embedding = self.embeddings.embed_face("unknown", webp.tobytes(), upsert=False) query_embedding = serialize(embedding) best_faces = self._search_face(query_embedding) logger.debug(f"Detected best faces for person as: {best_faces}") @@ -409,7 +439,7 @@ class EmbeddingMaintainer(threading.Thread): if face[0].split("-")[0] != sub_label: logger.debug("Detected multiple faces, result is not valid.") - return None + return avg_score += score @@ -421,7 +451,7 @@ class EmbeddingMaintainer(threading.Thread): logger.debug( f"Recognized face score {avg_score} is less than threshold ({self.config.face_recognition.threshold}) / previous face score ({self.detected_faces.get(id)})." ) - return None + return resp = requests.post( f"{FRIGATE_LOCALHOST}/api/events/{id}/sub_label",