mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
Use GPU for embeddings
This commit is contained in:
parent
4f33f7283c
commit
feb7be41cb
@ -18,6 +18,7 @@ from transformers.utils.logging import disable_progress_bar
|
|||||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||||
from frigate.types import ModelStatusTypesEnum
|
from frigate.types import ModelStatusTypesEnum
|
||||||
from frigate.util.downloader import ModelDownloader
|
from frigate.util.downloader import ModelDownloader
|
||||||
|
from frigate.util.model import get_ort_providers
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore",
|
"ignore",
|
||||||
@ -40,8 +41,8 @@ class GenericONNXEmbedding:
|
|||||||
download_urls: Dict[str, str],
|
download_urls: Dict[str, str],
|
||||||
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
|
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
|
||||||
model_type: str,
|
model_type: str,
|
||||||
preferred_providers: List[str] = ["CPUExecutionProvider"],
|
|
||||||
tokenizer_file: Optional[str] = None,
|
tokenizer_file: Optional[str] = None,
|
||||||
|
force_cpu: bool = False,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.model_file = model_file
|
self.model_file = model_file
|
||||||
@ -49,7 +50,7 @@ class GenericONNXEmbedding:
|
|||||||
self.download_urls = download_urls
|
self.download_urls = download_urls
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.model_type = model_type # 'text' or 'vision'
|
self.model_type = model_type # 'text' or 'vision'
|
||||||
self.preferred_providers = preferred_providers
|
self.providers, self.provider_options = get_ort_providers(force_cpu=force_cpu)
|
||||||
|
|
||||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
@ -105,8 +106,7 @@ class GenericONNXEmbedding:
|
|||||||
else:
|
else:
|
||||||
self.feature_extractor = self._load_feature_extractor()
|
self.feature_extractor = self._load_feature_extractor()
|
||||||
self.session = self._load_model(
|
self.session = self._load_model(
|
||||||
os.path.join(self.download_path, self.model_file),
|
os.path.join(self.download_path, self.model_file)
|
||||||
self.preferred_providers,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_tokenizer(self):
|
def _load_tokenizer(self):
|
||||||
@ -123,7 +123,7 @@ class GenericONNXEmbedding:
|
|||||||
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_model(self, path: str, providers: List[str]):
|
def _load_model(self, path: str):
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
return ort.InferenceSession(path, providers=providers)
|
return ort.InferenceSession(path, providers=providers)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user