Force GPU for large embedding model

This commit is contained in:
Nicolas Mowen 2025-05-08 13:34:47 -06:00
parent cac02b5b56
commit 7d775def31

View File

@ -23,10 +23,7 @@ FACENET_INPUT_SIZE = 160
class FaceNetEmbedding(BaseEmbedding):
def __init__(
self,
device: str = "AUTO",
):
def __init__(self):
super().__init__(
model_name="facedet",
model_file="facenet.tflite",
@ -34,7 +31,6 @@ class FaceNetEmbedding(BaseEmbedding):
"facenet.tflite": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/facenet.tflite",
},
)
self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.tokenizer = None
self.feature_extractor = None
@ -113,10 +109,7 @@ class FaceNetEmbedding(BaseEmbedding):
class ArcfaceEmbedding(BaseEmbedding):
def __init__(
self,
device: str = "AUTO",
):
def __init__(self):
super().__init__(
model_name="facedet",
model_file="arcface.onnx",
@ -124,7 +117,6 @@ class ArcfaceEmbedding(BaseEmbedding):
"arcface.onnx": "https://github.com/NickM-27/facenet-onnx/releases/download/v1.0/arcface.onnx",
},
)
self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.tokenizer = None
self.feature_extractor = None
@ -154,7 +146,7 @@ class ArcfaceEmbedding(BaseEmbedding):
self.runner = ONNXModelRunner(
os.path.join(self.download_path, self.model_file),
self.device,
"GPU",
)
def _preprocess_inputs(self, raw_inputs):