diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 47379cacb..0b8488763 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -96,6 +96,7 @@ class Embeddings: }, embedding_function=jina_text_embedding_function, model_type="text", + requestor=self.requestor, device="CPU", ) @@ -108,6 +109,7 @@ class Embeddings: }, embedding_function=jina_vision_embedding_function, model_type="vision", + requestor=self.requestor, device=self.config.device, ) print("completed embeddings init") diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index ebfd36959..25486282d 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -15,6 +15,7 @@ from PIL import Image from transformers import AutoFeatureExtractor, AutoTokenizer from transformers.utils.logging import disable_progress_bar +from frigate.comms.inter_process import InterProcessRequestor from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE from frigate.types import ModelStatusTypesEnum from frigate.util.downloader import ModelDownloader @@ -41,12 +42,14 @@ class GenericONNXEmbedding: download_urls: Dict[str, str], embedding_function: Callable[[List[np.ndarray]], np.ndarray], model_type: str, + requestor: InterProcessRequestor, tokenizer_file: Optional[str] = None, device: str = "AUTO", ): self.model_name = model_name self.model_file = model_file self.tokenizer_file = tokenizer_file + self.requestor = requestor self.download_urls = download_urls self.embedding_function = embedding_function self.model_type = model_type # 'text' or 'vision' @@ -69,11 +72,21 @@ class GenericONNXEmbedding: download_path=self.download_path, file_names=list(self.download_urls.keys()) + ([self.tokenizer_file] if self.tokenizer_file else []), + requestor=self.requestor, download_func=self._download_model, ) self.downloader.ensure_model_files() else: self.downloader = None + for file_name in self.download_urls.keys(): + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{self.model_name}-{file_name}", + "state": ModelStatusTypesEnum.downloaded, + }, + ) + self._load_model_and_tokenizer() print("models are already downloaded") def _download_model(self, path: str): diff --git a/frigate/util/downloader.py b/frigate/util/downloader.py index cab714fa0..d2604c8b3 100644 --- a/frigate/util/downloader.py +++ b/frigate/util/downloader.py @@ -44,6 +44,7 @@ class ModelDownloader: download_path: str, file_names: List[str], download_func: Callable[[str], None], + requestor: InterProcessRequestor, silent: bool = False, ): self.model_name = model_name @@ -51,7 +52,7 @@ class ModelDownloader: self.file_names = file_names self.download_func = download_func self.silent = silent - self.requestor = InterProcessRequestor() + self.requestor = requestor self.download_thread = None self.download_complete = threading.Event()