This commit is contained in:
Nicolas Mowen 2024-10-10 15:21:04 -06:00
parent 72126e515b
commit 4eb6cd3042
4 changed files with 18 additions and 20 deletions

View File

@ -1,6 +1,5 @@
"""Facilitates communication between processes."""
import logging
from enum import Enum
from typing import Callable
@ -8,8 +7,6 @@ import zmq
SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
logger = logging.getLogger(__name__)
class EmbeddingsRequestEnum(Enum):
embed_description = "embed_description"
@ -25,9 +22,7 @@ class EmbeddingsResponder:
def check_for_request(self, process: Callable) -> None:
while True: # load all messages that are queued
print("Checking for embeddings requests")
has_message, _, _ = zmq.select([self.socket], [], [], 0.1)
print(f"has a request? {has_message}")
if not has_message:
break

View File

@ -112,7 +112,6 @@ class Embeddings:
requestor=self.requestor,
device=self.config.device,
)
print("completed embeddings init")
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
# Convert thumbnail bytes to PIL Image

View File

@ -79,11 +79,14 @@ class GenericONNXEmbedding:
self.downloader.ensure_model_files()
else:
self.downloader = None
ModelDownloader.mark_files_downloaded(
self.requestor, self.model_name, files_names
ModelDownloader.mark_files_state(
self.requestor,
self.model_name,
files_names,
ModelStatusTypesEnum.downloaded,
)
self._load_model_and_tokenizer()
print(f"models are already downloaded for {self.model_name}")
logger.debug(f"models are already downloaded for {self.model_name}")
def _download_model(self, path: str):
try:

View File

@ -57,14 +57,12 @@ class ModelDownloader:
self.download_complete = threading.Event()
def ensure_model_files(self):
for file in self.file_names:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file}",
"state": ModelStatusTypesEnum.downloading,
},
)
self.mark_files_state(
self.requestor,
self.model_name,
self.file_names,
ModelStatusTypesEnum.downloading,
)
self.download_thread = threading.Thread(
target=self._download_models,
name=f"_download_model_{self.model_name}",
@ -121,15 +119,18 @@ class ModelDownloader:
logger.info(f"Downloading complete: {url}")
@staticmethod
def mark_files_downloaded(
requestor: InterProcessRequestor, model_name: str, files: list[str]
def mark_files_state(
requestor: InterProcessRequestor,
model_name: str,
files: list[str],
state: ModelStatusTypesEnum,
) -> None:
for file_name in files:
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{model_name}-{file_name}",
"state": ModelStatusTypesEnum.downloaded,
"state": state,
},
)