Use GPU for embeddings

This commit is contained in:
Nicolas Mowen 2024-10-09 15:05:43 -06:00
parent 4f33f7283c
commit feb7be41cb

View File

@ -18,6 +18,7 @@ from transformers.utils.logging import disable_progress_bar
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader from frigate.util.downloader import ModelDownloader
from frigate.util.model import get_ort_providers
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
@ -40,8 +41,8 @@ class GenericONNXEmbedding:
download_urls: Dict[str, str], download_urls: Dict[str, str],
embedding_function: Callable[[List[np.ndarray]], np.ndarray], embedding_function: Callable[[List[np.ndarray]], np.ndarray],
model_type: str, model_type: str,
preferred_providers: List[str] = ["CPUExecutionProvider"],
tokenizer_file: Optional[str] = None, tokenizer_file: Optional[str] = None,
force_cpu: bool = False,
): ):
self.model_name = model_name self.model_name = model_name
self.model_file = model_file self.model_file = model_file
@ -49,7 +50,7 @@ class GenericONNXEmbedding:
self.download_urls = download_urls self.download_urls = download_urls
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.model_type = model_type # 'text' or 'vision' self.model_type = model_type # 'text' or 'vision'
self.preferred_providers = preferred_providers self.providers, self.provider_options = get_ort_providers(force_cpu=force_cpu)
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.tokenizer = None self.tokenizer = None
@ -105,8 +106,7 @@ class GenericONNXEmbedding:
else: else:
self.feature_extractor = self._load_feature_extractor() self.feature_extractor = self._load_feature_extractor()
self.session = self._load_model( self.session = self._load_model(
os.path.join(self.download_path, self.model_file), os.path.join(self.download_path, self.model_file)
self.preferred_providers,
) )
def _load_tokenizer(self): def _load_tokenizer(self):
@ -123,7 +123,7 @@ class GenericONNXEmbedding:
f"{MODEL_CACHE_DIR}/{self.model_name}", f"{MODEL_CACHE_DIR}/{self.model_name}",
) )
def _load_model(self, path: str, providers: List[str]): def _load_model(self, path: str):
if os.path.exists(path): if os.path.exists(path):
return ort.InferenceSession(path, providers=providers) return ort.InferenceSession(path, providers=providers)
else: else: