mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-17 16:44:29 +03:00
Use SVC to classify faces
This commit is contained in:
parent
d353450f1c
commit
509a9d7863
@ -29,12 +29,12 @@ from frigate.genai import get_genai_client
|
||||
from frigate.models import Event
|
||||
from frigate.util.builtin import serialize
|
||||
from frigate.util.image import SharedMemoryFrameManager, area, calculate_region
|
||||
from frigate.util.model import FaceClassificationModel
|
||||
|
||||
from .embeddings import Embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REQUIRED_FACES = 2
|
||||
MAX_THUMBNAILS = 10
|
||||
|
||||
|
||||
@ -67,6 +67,9 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
self.face_recognition_enabled = self.config.face_recognition.enabled
|
||||
self.requires_face_detection = "face" not in self.config.objects.all_objects
|
||||
self.detected_faces: dict[str, float] = {}
|
||||
self.face_classifier = (
|
||||
FaceClassificationModel(db) if self.face_recognition_enabled else None
|
||||
)
|
||||
|
||||
# create communication for updating event descriptions
|
||||
self.requestor = InterProcessRequestor()
|
||||
@ -336,18 +339,6 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
if event_id:
|
||||
self.handle_regenerate_description(event_id, source)
|
||||
|
||||
def _search_face(self, query_embedding: bytes) -> list[tuple[str, float]]:
|
||||
"""Search for the face most closely matching the embedding."""
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
id,
|
||||
distance
|
||||
FROM vec_faces
|
||||
WHERE face_embedding MATCH ?
|
||||
AND k = {REQUIRED_FACES} ORDER BY distance
|
||||
"""
|
||||
return self.embeddings.db.execute_sql(sql_query, [query_embedding]).fetchall()
|
||||
|
||||
def _detect_face(self, input: np.ndarray) -> tuple[int, int, int, int]:
|
||||
"""Detect faces in input image."""
|
||||
self.face_detector.setInputSize((input.shape[1], input.shape[0]))
|
||||
@ -462,33 +453,22 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
return
|
||||
|
||||
embedding = self.embeddings.embed_face("nick", webp.tobytes(), upsert=True)
|
||||
query_embedding = serialize(embedding)
|
||||
best_faces = self._search_face(query_embedding)
|
||||
logger.debug(f"Detected best faces for person as: {best_faces}")
|
||||
res = self.face_classifier.classify_face(embedding)
|
||||
|
||||
if not best_faces or len(best_faces) < REQUIRED_FACES:
|
||||
logger.debug(f"{len(best_faces)} < {REQUIRED_FACES} min required faces.")
|
||||
if not res:
|
||||
return
|
||||
|
||||
sub_label = str(best_faces[0][0]).split("-")[0]
|
||||
avg_score = 0
|
||||
sub_label, score = res
|
||||
|
||||
for face in best_faces:
|
||||
score = 1.0 - face[1]
|
||||
logger.debug(
|
||||
f"Detected best face for person as: {sub_label} with score {score}"
|
||||
)
|
||||
|
||||
if face[0].split("-")[0] != sub_label:
|
||||
logger.debug("Detected multiple faces, result is not valid.")
|
||||
return
|
||||
|
||||
avg_score += score
|
||||
|
||||
avg_score = round(avg_score / REQUIRED_FACES, 2)
|
||||
|
||||
if avg_score < self.config.face_recognition.threshold or (
|
||||
id in self.detected_faces and avg_score <= self.detected_faces[id]
|
||||
if score < self.config.face_recognition.threshold or (
|
||||
id in self.detected_faces and score <= self.detected_faces[id]
|
||||
):
|
||||
logger.debug(
|
||||
f"Recognized face score {avg_score} is less than threshold ({self.config.face_recognition.threshold}) / previous face score ({self.detected_faces.get(id)})."
|
||||
f"Recognized face score {score} is less than threshold ({self.config.face_recognition.threshold}) / previous face score ({self.detected_faces.get(id)})."
|
||||
)
|
||||
return
|
||||
|
||||
@ -497,12 +477,12 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
json={
|
||||
"camera": obj_data.get("camera"),
|
||||
"subLabel": sub_label,
|
||||
"subLabelScore": avg_score,
|
||||
"subLabelScore": score,
|
||||
},
|
||||
)
|
||||
|
||||
if resp.status_code == 200:
|
||||
self.detected_faces[id] = avg_score
|
||||
self.detected_faces[id] = score
|
||||
|
||||
def _detect_license_plate(self, input: np.ndarray) -> tuple[int, int, int, int]:
|
||||
"""Return the dimensions of the input image as [x, y, width, height]."""
|
||||
|
||||
@ -2,9 +2,15 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
from sklearn.preprocessing import LabelEncoder, Normalizer
|
||||
from sklearn.svm import SVC
|
||||
|
||||
from frigate.util.builtin import deserialize
|
||||
|
||||
try:
|
||||
import openvino as ov
|
||||
@ -145,3 +151,39 @@ class ONNXModelRunner:
|
||||
return [infer_request.get_output_tensor().data]
|
||||
elif self.type == "ort":
|
||||
return self.ort.run(None, input)
|
||||
|
||||
|
||||
class FaceClassificationModel:
|
||||
def __init__(self, db: SqliteQueueDatabase):
|
||||
self.db = db
|
||||
self.labeler: Optional[LabelEncoder] = None
|
||||
self.classifier: Optional[SVC] = None
|
||||
|
||||
def __build_classifier(self) -> None:
|
||||
faces: list[tuple[str, bytes]] = self.db.execute_sql(
|
||||
"SELECT id, face_embedding FROM vec_faces"
|
||||
).fetchall()
|
||||
embeddings = np.array([deserialize(f[1]) for f in faces])
|
||||
self.labeler = LabelEncoder()
|
||||
norms = Normalizer(norm="l2").transform(embeddings)
|
||||
labels = self.labeler.fit_transform([f[0].split("-")[0] for f in faces])
|
||||
self.classifier = SVC(kernel="linear", probability=True)
|
||||
self.classifier.fit(norms, labels)
|
||||
|
||||
def classify_face(
|
||||
self, embedding: np.ndarray, rebuild_classifier: bool = False
|
||||
) -> Optional[tuple[str, float]]:
|
||||
if not self.classifier or rebuild_classifier:
|
||||
self.__build_classifier()
|
||||
|
||||
res = self.classifier.predict([embedding])
|
||||
|
||||
if not res:
|
||||
return None
|
||||
|
||||
label = res[0]
|
||||
probabilities = self.classifier.predict_proba([embedding])[0]
|
||||
return (
|
||||
self.labeler.inverse_transform([label])[0],
|
||||
round(probabilities[label], 2),
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user