mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-17 16:44:29 +03:00
Add support for detecting faces during registration
This commit is contained in:
parent
3a570e21d2
commit
40c83c69b9
@ -103,7 +103,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
def _process_requests(self) -> None:
|
def _process_requests(self) -> None:
|
||||||
"""Process embeddings requests"""
|
"""Process embeddings requests"""
|
||||||
|
|
||||||
def _handle_request(topic: str, data: str) -> str:
|
def _handle_request(topic: str, data: dict[str, any]) -> str:
|
||||||
try:
|
try:
|
||||||
if topic == EmbeddingsRequestEnum.embed_description.value:
|
if topic == EmbeddingsRequestEnum.embed_description.value:
|
||||||
return serialize(
|
return serialize(
|
||||||
@ -123,12 +123,34 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
self.embeddings.text_embedding([data])[0], pack=False
|
self.embeddings.text_embedding([data])[0], pack=False
|
||||||
)
|
)
|
||||||
elif topic == EmbeddingsRequestEnum.register_face.value:
|
elif topic == EmbeddingsRequestEnum.register_face.value:
|
||||||
self.embeddings.embed_face(
|
if data.get("cropped"):
|
||||||
data["face_name"],
|
self.embeddings.embed_face(
|
||||||
base64.b64decode(data["image"]),
|
data["face_name"],
|
||||||
upsert=True,
|
base64.b64decode(data["image"]),
|
||||||
)
|
upsert=True,
|
||||||
return None
|
)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
img = cv2.imdecode(
|
||||||
|
np.frombuffer(
|
||||||
|
base64.b64decode(data["image"]), dtype=np.uint8
|
||||||
|
),
|
||||||
|
cv2.IMREAD_COLOR,
|
||||||
|
)
|
||||||
|
face_box = self._detect_face(img)
|
||||||
|
|
||||||
|
if not face_box:
|
||||||
|
return False
|
||||||
|
|
||||||
|
face = img[face_box[1] : face_box[3], face_box[0] : face_box[2]]
|
||||||
|
ret, webp = cv2.imencode(
|
||||||
|
".webp", face, [int(cv2.IMWRITE_WEBP_QUALITY), 100]
|
||||||
|
)
|
||||||
|
self.embeddings.embed_face(
|
||||||
|
data["face_name"], webp.tobytes(), upsert=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unable to handle embeddings request {e}")
|
logger.error(f"Unable to handle embeddings request {e}")
|
||||||
|
|
||||||
@ -302,6 +324,29 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
"""
|
"""
|
||||||
return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall()
|
return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall()
|
||||||
|
|
||||||
|
def _detect_face(self, input: np.ndarray) -> tuple[int, int, int, int]:
|
||||||
|
"""Detect faces in input image."""
|
||||||
|
self.face_detector.setInputSize((input.shape[1], input.shape[0]))
|
||||||
|
faces = self.face_detector.detect(input)
|
||||||
|
|
||||||
|
if faces[1] is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
face = None
|
||||||
|
|
||||||
|
for _, potential_face in enumerate(faces[1]):
|
||||||
|
raw_bbox = potential_face[0:4].astype(np.uint16)
|
||||||
|
x: int = max(raw_bbox[0], 0)
|
||||||
|
y: int = max(raw_bbox[1], 0)
|
||||||
|
w: int = raw_bbox[2]
|
||||||
|
h: int = raw_bbox[3]
|
||||||
|
bbox = (x, y, x + w, y + h)
|
||||||
|
|
||||||
|
if face is None or area(bbox) > area(face):
|
||||||
|
face = bbox
|
||||||
|
|
||||||
|
return face
|
||||||
|
|
||||||
def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None:
|
def _process_face(self, obj_data: dict[str, any], frame: np.ndarray) -> None:
|
||||||
"""Look for faces in image."""
|
"""Look for faces in image."""
|
||||||
id = obj_data["id"]
|
id = obj_data["id"]
|
||||||
@ -331,27 +376,12 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
|
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
|
||||||
left, top, right, bottom = person_box
|
left, top, right, bottom = person_box
|
||||||
person = rgb[top:bottom, left:right]
|
person = rgb[top:bottom, left:right]
|
||||||
|
face = self._detect_face(person)
|
||||||
|
|
||||||
self.face_detector.setInputSize((right - left, bottom - top))
|
if not face:
|
||||||
faces = self.face_detector.detect(person)
|
|
||||||
|
|
||||||
if faces[1] is None:
|
|
||||||
logger.debug("Detected no faces for person object.")
|
logger.debug("Detected no faces for person object.")
|
||||||
return
|
return
|
||||||
|
|
||||||
face = None
|
|
||||||
|
|
||||||
for _, potential_face in enumerate(faces[1]):
|
|
||||||
raw_bbox = potential_face[0:4].astype(np.int8)
|
|
||||||
x = max(raw_bbox[0], 0)
|
|
||||||
y = max(raw_bbox[1], 0)
|
|
||||||
w = raw_bbox[2]
|
|
||||||
h = raw_bbox[3]
|
|
||||||
bbox = (x, y, x + w, y + h)
|
|
||||||
|
|
||||||
if face is None or area(bbox) > area(face):
|
|
||||||
face = bbox
|
|
||||||
|
|
||||||
face_frame = person[face[1] : face[3], face[0] : face[2]]
|
face_frame = person[face[1] : face[3], face[0] : face[2]]
|
||||||
face_frame = cv2.cvtColor(face_frame, cv2.COLOR_RGB2BGR)
|
face_frame = cv2.cvtColor(face_frame, cv2.COLOR_RGB2BGR)
|
||||||
else:
|
else:
|
||||||
@ -384,7 +414,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
face_box[1] : face_box[3], face_box[0] : face_box[2]
|
face_box[1] : face_box[3], face_box[0] : face_box[2]
|
||||||
]
|
]
|
||||||
|
|
||||||
ret, jpg = cv2.imencode(
|
ret, webp = cv2.imencode(
|
||||||
".webp", face_frame, [int(cv2.IMWRITE_WEBP_QUALITY), 100]
|
".webp", face_frame, [int(cv2.IMWRITE_WEBP_QUALITY), 100]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -392,7 +422,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
logger.debug("Not processing face due to error creating cropped image.")
|
logger.debug("Not processing face due to error creating cropped image.")
|
||||||
return
|
return
|
||||||
|
|
||||||
embedding = self.embeddings.embed_face("unknown", jpg.tobytes(), upsert=False)
|
embedding = self.embeddings.embed_face("unknown", webp.tobytes(), upsert=False)
|
||||||
query_embedding = serialize(embedding)
|
query_embedding = serialize(embedding)
|
||||||
best_faces = self._search_face(query_embedding)
|
best_faces = self._search_face(query_embedding)
|
||||||
logger.debug(f"Detected best faces for person as: {best_faces}")
|
logger.debug(f"Detected best faces for person as: {best_faces}")
|
||||||
@ -409,7 +439,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
|
|
||||||
if face[0].split("-")[0] != sub_label:
|
if face[0].split("-")[0] != sub_label:
|
||||||
logger.debug("Detected multiple faces, result is not valid.")
|
logger.debug("Detected multiple faces, result is not valid.")
|
||||||
return None
|
return
|
||||||
|
|
||||||
avg_score += score
|
avg_score += score
|
||||||
|
|
||||||
@ -421,7 +451,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Recognized face score {avg_score} is less than threshold ({self.config.face_recognition.threshold}) / previous face score ({self.detected_faces.get(id)})."
|
f"Recognized face score {avg_score} is less than threshold ({self.config.face_recognition.threshold}) / previous face score ({self.detected_faces.get(id)})."
|
||||||
)
|
)
|
||||||
return None
|
return
|
||||||
|
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
f"{FRIGATE_LOCALHOST}/api/events/{id}/sub_label",
|
f"{FRIGATE_LOCALHOST}/api/events/{id}/sub_label",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user