From b34a51903f43094c28e185234cea885888c415a1 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Wed, 23 Oct 2024 08:08:43 -0600 Subject: [PATCH] Increase requirements for face to be set --- frigate/embeddings/maintainer.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index d0f351233..5556c8811 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -32,6 +32,7 @@ from .embeddings import Embeddings logger = logging.getLogger(__name__) +REQUIRED_FACES = 2 MAX_THUMBNAILS = 10 @@ -65,6 +66,7 @@ class EmbeddingMaintainer(threading.Thread): self.config.semantic_search.face_recognition.enabled ) self.requires_face_detection = "face" not in self.config.model.all_attributes + self.detected_faces: dict[str, float] = {} # create communication for updating event descriptions self.requestor = InterProcessRequestor() @@ -276,25 +278,28 @@ class EmbeddingMaintainer(threading.Thread): def _search_face(self, query_embedding: bytes) -> list: """Search for the face most closely matching the embedding.""" - sql_query = """ + sql_query = f""" SELECT id, distance FROM vec_faces WHERE face_embedding MATCH ? - AND k = 10 ORDER BY distance + AND k = {REQUIRED_FACES} ORDER BY distance """ return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall() def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None: """Look for faces in image.""" + id = obj_data["id"] + # don't run for non person objects if obj_data.get("label") != "person": logger.debug("Not a processing face for non person object.") return - # don't overwrite sub label for objects that have one - if obj_data.get("sub_label"): + # don't overwrite sub label for objects that have a sub label + # that is not a face + if obj_data.get("sub_label") and id not in self.detected_faces: logger.debug( f"Not processing face due to existing sub label: {obj_data.get('sub_label')}." ) @@ -348,18 +353,25 @@ class EmbeddingMaintainer(threading.Thread): best_faces = self._search_face(query_embedding) logger.debug(f"Detected best faces for person as: {best_faces}") - if not best_faces: + if not best_faces or len(best_faces) < REQUIRED_FACES: return sub_label = str(best_faces[0][0]).split("-")[0] - score = 1.0 - best_faces[0][1] + avg_score = 0 - if score < self.config.semantic_search.face_recognition.threshold: - return None + for face in best_faces: + score = 1.0 - face[1] + if score < self.config.semantic_search.face_recognition.threshold: + return None + + avg_score += score + + avg_score = avg_score / REQUIRED_FACES + self.detected_faces[id] = avg_score requests.post( - f"{FRIGATE_LOCALHOST}/api/events/{obj_data['id']}/sub_label", - json={"subLabel": sub_label, "subLabelScore": score}, + f"{FRIGATE_LOCALHOST}/api/events/{id}/sub_label", + json={"subLabel": sub_label, "subLabelScore": avg_score}, ) def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]: