mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
Get matching face embeddings
This commit is contained in:
parent
a10a49a85c
commit
6b2ffc4c06
@ -263,8 +263,26 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
if event_id:
|
if event_id:
|
||||||
self.handle_regenerate_description(event_id, source)
|
self.handle_regenerate_description(event_id, source)
|
||||||
|
|
||||||
|
def _search_face(self, query_embedding: bytes) -> list:
|
||||||
|
"""Search for the face most closely matching the embedding."""
|
||||||
|
sql_query = """
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
distance
|
||||||
|
FROM vec_faces
|
||||||
|
WHERE face_embedding MATCH ?
|
||||||
|
AND k = 10 ORDER BY distance
|
||||||
|
"""
|
||||||
|
logger.info("doing a search")
|
||||||
|
results = self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall()
|
||||||
|
logger.info(f"the search results are {results}")
|
||||||
|
|
||||||
def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None:
|
def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None:
|
||||||
"""Look for faces in image."""
|
"""Look for faces in image."""
|
||||||
|
# don't run for non person objects
|
||||||
|
if obj_data.get("label") != "person":
|
||||||
|
return
|
||||||
|
|
||||||
# don't overwrite sub label for objects that have one
|
# don't overwrite sub label for objects that have one
|
||||||
if obj_data.get("sub_label"):
|
if obj_data.get("sub_label"):
|
||||||
return
|
return
|
||||||
@ -275,7 +293,12 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
# TODO run cv2 face detection
|
# TODO run cv2 face detection
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
# don't run for object without attributes
|
||||||
|
if not obj_data.get("current_attributes"):
|
||||||
|
return
|
||||||
|
|
||||||
for attr in obj_data.get("current_attributes", []):
|
for attr in obj_data.get("current_attributes", []):
|
||||||
|
logger.info(f"attribute is {attr}")
|
||||||
if attr.get("label") != "face":
|
if attr.get("label") != "face":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -308,6 +331,8 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
embedding = self.embeddings.embed_face("nick", jpg.tobytes(), upsert=False)
|
embedding = self.embeddings.embed_face("nick", jpg.tobytes(), upsert=False)
|
||||||
|
query_embedding = serialize(embedding)
|
||||||
|
best_faces = self._search_face(query_embedding)
|
||||||
|
|
||||||
# TODO compare embedding to faces in embeddings DB to fine cosine similarity
|
# TODO compare embedding to faces in embeddings DB to fine cosine similarity
|
||||||
# TODO check against threshold and min score to see if best face qualifies
|
# TODO check against threshold and min score to see if best face qualifies
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user