mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-05-04 04:27:42 +03:00
Force GPU for large embedding model
This commit is contained in:
parent
cac02b5b56
commit
7d775def31
@ -23,10 +23,7 @@ FACENET_INPUT_SIZE = 160
|
|||||||
|
|
||||||
|
|
||||||
class FaceNetEmbedding(BaseEmbedding):
|
class FaceNetEmbedding(BaseEmbedding):
|
||||||
def __init__(
|
def __init__(self):
|
||||||
self,
|
|
||||||
device: str = "AUTO",
|
|
||||||
):
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_name="facedet",
|
model_name="facedet",
|
||||||
model_file="facenet.tflite",
|
model_file="facenet.tflite",
|
||||||
@ -34,7 +31,6 @@ class FaceNetEmbedding(BaseEmbedding):
|
|||||||
"facenet.tflite": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facenet.tflite",
|
"facenet.tflite": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facenet.tflite",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.device = device
|
|
||||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.feature_extractor = None
|
self.feature_extractor = None
|
||||||
@ -113,10 +109,7 @@ class FaceNetEmbedding(BaseEmbedding):
|
|||||||
|
|
||||||
|
|
||||||
class ArcfaceEmbedding(BaseEmbedding):
|
class ArcfaceEmbedding(BaseEmbedding):
|
||||||
def __init__(
|
def __init__(self):
|
||||||
self,
|
|
||||||
device: str = "AUTO",
|
|
||||||
):
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_name="facedet",
|
model_name="facedet",
|
||||||
model_file="arcface.onnx",
|
model_file="arcface.onnx",
|
||||||
@ -124,7 +117,6 @@ class ArcfaceEmbedding(BaseEmbedding):
|
|||||||
"arcface.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/arcface.onnx",
|
"arcface.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/arcface.onnx",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.device = device
|
|
||||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.feature_extractor = None
|
self.feature_extractor = None
|
||||||
@ -154,7 +146,7 @@ class ArcfaceEmbedding(BaseEmbedding):
|
|||||||
|
|
||||||
self.runner = ONNXModelRunner(
|
self.runner = ONNXModelRunner(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file),
|
||||||
self.device,
|
"GPU",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_inputs(self, raw_inputs):
|
def _preprocess_inputs(self, raw_inputs):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user