face recognition: use configured device

This commit is contained in:
baudneo 2025-07-31 22:34:49 -06:00
parent b8db85ec94
commit 14bf6dd7bb
No known key found for this signature in database
GPG Key ID: 51445F2ED08EBC7F
2 changed files with 7 additions and 5 deletions

View File

@ -269,7 +269,7 @@ class ArcFaceRecognizer(FaceRecognizer):
def __init__(self, config: FrigateConfig):
super().__init__(config)
self.mean_embs: dict[int, np.ndarray] = {}
self.face_embedder: ArcfaceEmbedding = ArcfaceEmbedding()
self.face_embedder: ArcfaceEmbedding = ArcfaceEmbedding(config)
self.model_builder_queue: queue.Queue | None = None
def clear(self) -> None:
@ -370,4 +370,4 @@ class ArcFaceRecognizer(FaceRecognizer):
score = confidence
label = name
return label, round(score - blur_reduction, 2)
return label, round(score - blur_reduction, 2)

View File

@ -11,6 +11,7 @@ from frigate.util.downloader import ModelDownloader
from .base_embedding import BaseEmbedding
from .runner import ONNXModelRunner
from ...config import FrigateConfig
try:
from tflite_runtime.interpreter import Interpreter
@ -111,7 +112,7 @@ class FaceNetEmbedding(BaseEmbedding):
class ArcfaceEmbedding(BaseEmbedding):
def __init__(self):
def __init__(self, config: FrigateConfig):
super().__init__(
model_name="facedet",
model_file="arcface.onnx",
@ -119,6 +120,7 @@ class ArcfaceEmbedding(BaseEmbedding):
"arcface.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/arcface.onnx",
},
)
self.config = config
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.tokenizer = None
self.feature_extractor = None
@ -148,7 +150,7 @@ class ArcfaceEmbedding(BaseEmbedding):
self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file),
"GPU",
device=self.config.face_recognition.device,
)
def _preprocess_inputs(self, raw_inputs):
@ -184,4 +186,4 @@ class ArcfaceEmbedding(BaseEmbedding):
frame = np.transpose(frame, (2, 0, 1))
frame = np.expand_dims(frame, axis=0)
return [{"data": frame}]
return [{"data": frame}]