mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-27 17:17:40 +03:00
face recognition: use configured device
This commit is contained in:
parent
b8db85ec94
commit
14bf6dd7bb
@ -269,7 +269,7 @@ class ArcFaceRecognizer(FaceRecognizer):
|
|||||||
def __init__(self, config: FrigateConfig):
|
def __init__(self, config: FrigateConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.mean_embs: dict[int, np.ndarray] = {}
|
self.mean_embs: dict[int, np.ndarray] = {}
|
||||||
self.face_embedder: ArcfaceEmbedding = ArcfaceEmbedding()
|
self.face_embedder: ArcfaceEmbedding = ArcfaceEmbedding(config)
|
||||||
self.model_builder_queue: queue.Queue | None = None
|
self.model_builder_queue: queue.Queue | None = None
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from frigate.util.downloader import ModelDownloader
|
|||||||
|
|
||||||
from .base_embedding import BaseEmbedding
|
from .base_embedding import BaseEmbedding
|
||||||
from .runner import ONNXModelRunner
|
from .runner import ONNXModelRunner
|
||||||
|
from ...config import FrigateConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tflite_runtime.interpreter import Interpreter
|
from tflite_runtime.interpreter import Interpreter
|
||||||
@ -111,7 +112,7 @@ class FaceNetEmbedding(BaseEmbedding):
|
|||||||
|
|
||||||
|
|
||||||
class ArcfaceEmbedding(BaseEmbedding):
|
class ArcfaceEmbedding(BaseEmbedding):
|
||||||
def __init__(self):
|
def __init__(self, config: FrigateConfig):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_name="facedet",
|
model_name="facedet",
|
||||||
model_file="arcface.onnx",
|
model_file="arcface.onnx",
|
||||||
@ -119,6 +120,7 @@ class ArcfaceEmbedding(BaseEmbedding):
|
|||||||
"arcface.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/arcface.onnx",
|
"arcface.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/arcface.onnx",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.config = config
|
||||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.feature_extractor = None
|
self.feature_extractor = None
|
||||||
@ -148,7 +150,7 @@ class ArcfaceEmbedding(BaseEmbedding):
|
|||||||
|
|
||||||
self.runner = ONNXModelRunner(
|
self.runner = ONNXModelRunner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
"GPU",
|
device=self.config.face_recognition.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs):
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user