mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
Init models immediately
This commit is contained in:
parent
ff36b8b88c
commit
395cf4e33f
@ -96,6 +96,7 @@ class Embeddings:
|
|||||||
},
|
},
|
||||||
embedding_function=jina_text_embedding_function,
|
embedding_function=jina_text_embedding_function,
|
||||||
model_type="text",
|
model_type="text",
|
||||||
|
requestor=self.requestor,
|
||||||
device="CPU",
|
device="CPU",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -108,6 +109,7 @@ class Embeddings:
|
|||||||
},
|
},
|
||||||
embedding_function=jina_vision_embedding_function,
|
embedding_function=jina_vision_embedding_function,
|
||||||
model_type="vision",
|
model_type="vision",
|
||||||
|
requestor=self.requestor,
|
||||||
device=self.config.device,
|
device=self.config.device,
|
||||||
)
|
)
|
||||||
print("completed embeddings init")
|
print("completed embeddings init")
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from PIL import Image
|
|||||||
from transformers import AutoFeatureExtractor, AutoTokenizer
|
from transformers import AutoFeatureExtractor, AutoTokenizer
|
||||||
from transformers.utils.logging import disable_progress_bar
|
from transformers.utils.logging import disable_progress_bar
|
||||||
|
|
||||||
|
from frigate.comms.inter_process import InterProcessRequestor
|
||||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||||
from frigate.types import ModelStatusTypesEnum
|
from frigate.types import ModelStatusTypesEnum
|
||||||
from frigate.util.downloader import ModelDownloader
|
from frigate.util.downloader import ModelDownloader
|
||||||
@ -41,12 +42,14 @@ class GenericONNXEmbedding:
|
|||||||
download_urls: Dict[str, str],
|
download_urls: Dict[str, str],
|
||||||
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
|
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
|
||||||
model_type: str,
|
model_type: str,
|
||||||
|
requestor: InterProcessRequestor,
|
||||||
tokenizer_file: Optional[str] = None,
|
tokenizer_file: Optional[str] = None,
|
||||||
device: str = "AUTO",
|
device: str = "AUTO",
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.model_file = model_file
|
self.model_file = model_file
|
||||||
self.tokenizer_file = tokenizer_file
|
self.tokenizer_file = tokenizer_file
|
||||||
|
self.requestor = requestor
|
||||||
self.download_urls = download_urls
|
self.download_urls = download_urls
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.model_type = model_type # 'text' or 'vision'
|
self.model_type = model_type # 'text' or 'vision'
|
||||||
@ -69,11 +72,21 @@ class GenericONNXEmbedding:
|
|||||||
download_path=self.download_path,
|
download_path=self.download_path,
|
||||||
file_names=list(self.download_urls.keys())
|
file_names=list(self.download_urls.keys())
|
||||||
+ ([self.tokenizer_file] if self.tokenizer_file else []),
|
+ ([self.tokenizer_file] if self.tokenizer_file else []),
|
||||||
|
requestor=self.requestor,
|
||||||
download_func=self._download_model,
|
download_func=self._download_model,
|
||||||
)
|
)
|
||||||
self.downloader.ensure_model_files()
|
self.downloader.ensure_model_files()
|
||||||
else:
|
else:
|
||||||
self.downloader = None
|
self.downloader = None
|
||||||
|
for file_name in self.download_urls.keys():
|
||||||
|
self.requestor.send_data(
|
||||||
|
UPDATE_MODEL_STATE,
|
||||||
|
{
|
||||||
|
"model": f"{self.model_name}-{file_name}",
|
||||||
|
"state": ModelStatusTypesEnum.downloaded,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self._load_model_and_tokenizer()
|
||||||
print("models are already downloaded")
|
print("models are already downloaded")
|
||||||
|
|
||||||
def _download_model(self, path: str):
|
def _download_model(self, path: str):
|
||||||
|
|||||||
@ -44,6 +44,7 @@ class ModelDownloader:
|
|||||||
download_path: str,
|
download_path: str,
|
||||||
file_names: List[str],
|
file_names: List[str],
|
||||||
download_func: Callable[[str], None],
|
download_func: Callable[[str], None],
|
||||||
|
requestor: InterProcessRequestor,
|
||||||
silent: bool = False,
|
silent: bool = False,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -51,7 +52,7 @@ class ModelDownloader:
|
|||||||
self.file_names = file_names
|
self.file_names = file_names
|
||||||
self.download_func = download_func
|
self.download_func = download_func
|
||||||
self.silent = silent
|
self.silent = silent
|
||||||
self.requestor = InterProcessRequestor()
|
self.requestor = requestor
|
||||||
self.download_thread = None
|
self.download_thread = None
|
||||||
self.download_complete = threading.Event()
|
self.download_complete = threading.Event()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user