From 6287253563172faea204d1569908e9c8ca468b39 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Thu, 10 Oct 2024 11:43:47 -0600 Subject: [PATCH] Handle zmq error and empty data --- frigate/comms/embeddings_updater.py | 12 +++++++-- frigate/comms/event_metadata_updater.py | 2 +- frigate/embeddings/__init__.py | 35 +++++++++++++++--------- frigate/embeddings/functions/onnx.py | 11 +++++--- frigate/embeddings/maintainer.py | 36 ++++++++++++------------- frigate/util/downloader.py | 1 + 6 files changed, 60 insertions(+), 37 deletions(-) diff --git a/frigate/comms/embeddings_updater.py b/frigate/comms/embeddings_updater.py index acb11924f..6fcaf3903 100644 --- a/frigate/comms/embeddings_updater.py +++ b/frigate/comms/embeddings_updater.py @@ -1,5 +1,6 @@ """Facilitates communication between processes.""" +import logging from enum import Enum from typing import Callable @@ -7,6 +8,8 @@ import zmq SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings" +logger = logging.getLogger(__name__) + class EmbeddingsRequestEnum(Enum): embed_description = "embed_description" @@ -22,7 +25,9 @@ class EmbeddingsResponder: def check_for_request(self, process: Callable) -> None: while True: # load all messages that are queued + logger.debug("Checking for embeddings requests") has_message, _, _ = zmq.select([self.socket], [], [], 0.1) + logger.debug(f"has a request? {has_message}") if not has_message: break @@ -54,8 +59,11 @@ class EmbeddingsRequestor: def send_data(self, topic: str, data: any) -> str: """Sends data and then waits for reply.""" - self.socket.send_json((topic, data)) - return self.socket.recv_json() + try: + self.socket.send_json((topic, data)) + return self.socket.recv_json() + except zmq.ZMQError: + return "" def stop(self) -> None: self.socket.close() diff --git a/frigate/comms/event_metadata_updater.py b/frigate/comms/event_metadata_updater.py index aeede6d8e..87e1889ce 100644 --- a/frigate/comms/event_metadata_updater.py +++ b/frigate/comms/event_metadata_updater.py @@ -39,7 +39,7 @@ class EventMetadataSubscriber(Subscriber): super().__init__(topic) def check_for_update( - self, timeout: float = None + self, timeout: float = 1 ) -> Optional[tuple[EventMetadataTypeEnum, str, RegenerateDescriptionEnum]]: return super().check_for_update(timeout) diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index 8540bdf8e..2bc10f130 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -114,19 +114,25 @@ class EmbeddingsContext: query_embedding = row[0] else: # If no embedding found, generate it and return it - query_embedding = serialize( - self.requestor.send_data( - EmbeddingsRequestEnum.embed_thumbnail.value, - {"id": str(query.id), "thumbnail": str(query.thumbnail)}, - ) + data = self.requestor.send_data( + EmbeddingsRequestEnum.embed_thumbnail.value, + {"id": str(query.id), "thumbnail": str(query.thumbnail)}, ) + + if not data: + return [] + + query_embedding = serialize(data) else: - query_embedding = serialize( - self.requestor.send_data( - EmbeddingsRequestEnum.generate_search.value, query - ) + data = self.requestor.send_data( + EmbeddingsRequestEnum.generate_search.value, query ) + if not data: + return [] + + query_embedding = serialize(data) + sql_query = """ SELECT id, @@ -155,12 +161,15 @@ class EmbeddingsContext: def search_description( self, query_text: str, event_ids: list[str] = None ) -> list[tuple[str, float]]: - query_embedding = serialize( - self.requestor.send_data( - EmbeddingsRequestEnum.generate_search.value, query_text - ) + data = self.requestor.send_data( + EmbeddingsRequestEnum.generate_search.value, query_text ) + if not data: + return [] + + query_embedding = serialize(data) + # Prepare the base SQL query sql_query = """ SELECT diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index 448f9f75b..46d557a07 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -59,6 +59,7 @@ class GenericONNXEmbedding: self.feature_extractor = None self.session = None + print("starting model download") self.downloader = ModelDownloader( model_name=self.model_name, download_path=self.download_path, @@ -70,6 +71,7 @@ class GenericONNXEmbedding: def _download_model(self, path: str): try: + print("beginning model download process") file_name = os.path.basename(path) if file_name in self.download_urls: ModelDownloader.download_from_url(self.download_urls[file_name], path) @@ -107,10 +109,11 @@ class GenericONNXEmbedding: self.tokenizer = self._load_tokenizer() else: self.feature_extractor = self._load_feature_extractor() + print("creating onnx session") self.session = self._load_model( os.path.join(self.download_path, self.model_file) ) - logger.debug("successfully loaded model.") + print("successfully loaded model.") def _load_tokenizer(self): tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer") @@ -127,15 +130,16 @@ class GenericONNXEmbedding: ) def _load_model(self, path: str) -> Optional[ort.InferenceSession]: + print(f"checking if path exists {path}") if os.path.exists(path): - logger.debug( + print( f"loading ORT session with providers {self.providers} and options {self.provider_options}" ) return ort.InferenceSession( path, providers=self.providers, provider_options=self.provider_options ) else: - logger.warning(f"{self.model_name} model file {path} not found.") + print(f"{self.model_name} model file {path} not found.") return None def _process_image(self, image): @@ -149,6 +153,7 @@ class GenericONNXEmbedding: def __call__( self, inputs: Union[List[str], List[Image.Image], List[str]] ) -> List[np.ndarray]: + print("beginning call for onnx embedding") self._load_model_and_tokenizer() if self.session is None or ( diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index c9d539815..d6195f5b6 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -41,8 +41,7 @@ class EmbeddingMaintainer(threading.Thread): config: FrigateConfig, stop_event: MpEvent, ) -> None: - threading.Thread.__init__(self) - self.name = "embeddings_maintainer" + super().__init__(name="embeddings_maintainer") self.config = config self.embeddings = Embeddings(config.semantic_search, db) self.event_subscriber = EventUpdateSubscriber() @@ -61,6 +60,7 @@ class EmbeddingMaintainer(threading.Thread): def run(self) -> None: """Maintain a SQLite-vec database for semantic search.""" while not self.stop_event.is_set(): + print("Doing another embeddings loop.") self._process_requests() self._process_updates() self._process_finalized() @@ -77,9 +77,7 @@ class EmbeddingMaintainer(threading.Thread): """Process embeddings requests""" def handle_request(topic: str, data: str) -> str: - logger.debug( - f"Handling embeddings request of type {topic} with data {data}" - ) + print(f"Handling embeddings request of type {topic} with data {data}") try: if topic == EmbeddingsRequestEnum.embed_description.value: @@ -106,7 +104,7 @@ class EmbeddingMaintainer(threading.Thread): def _process_updates(self) -> None: """Process event updates""" - update = self.event_subscriber.check_for_update(timeout=0.1) + update = self.event_subscriber.check_for_update() if update is None: return @@ -116,7 +114,7 @@ class EmbeddingMaintainer(threading.Thread): if not camera or source_type != EventTypeEnum.tracked_object: return - logger.debug(f"Processing object update of type {source_type} on {camera}") + print(f"Processing object update of type {source_type} on {camera}") camera_config = self.config.cameras[camera] if data["id"] not in self.tracked_events: self.tracked_events[data["id"]] = [] @@ -124,29 +122,33 @@ class EmbeddingMaintainer(threading.Thread): # Create our own thumbnail based on the bounding box and the frame time try: frame_id = f"{camera}{data['frame_time']}" + print("trying to get frame from manager") yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv) + print(f"got frame from manager and it is valid {yuv_frame is not None}") if yuv_frame is not None: data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"]) self.tracked_events[data["id"]].append(data) self.frame_manager.close(frame_id) else: - logger.debug( + print( f"Unable to create embedding for thumbnail from {camera} because frame is missing." ) except FileNotFoundError: pass + print("Finished processing object update") + def _process_finalized(self) -> None: """Process the end of an event.""" while True: - ended = self.event_end_subscriber.check_for_update(timeout=0.1) + ended = self.event_end_subscriber.check_for_update() if ended == None: break event_id, camera, updated_db = ended - logger.debug( + print( f"Processing finalized event for {camera} which updated the db: {updated_db}" ) camera_config = self.config.cameras[camera] @@ -180,7 +182,7 @@ class EmbeddingMaintainer(threading.Thread): or set(event.zones) & set(camera_config.genai.required_zones) ) ): - logger.debug( + print( f"Description generation for {event}, has_snapshot: {event.has_snapshot}" ) if event.has_snapshot and camera_config.genai.use_snapshot: @@ -235,14 +237,12 @@ class EmbeddingMaintainer(threading.Thread): def _process_event_metadata(self): # Check for regenerate description requests - (topic, event_id, source) = self.event_metadata_subscriber.check_for_update( - timeout=0.1 - ) + (topic, event_id, source) = self.event_metadata_subscriber.check_for_update() if topic is None: return - logger.debug(f"Handling event metadata for id {event_id} and source {source}") + print(f"Handling event metadata for id {event_id} and source {source}") if event_id: self.handle_regenerate_description(event_id, source) @@ -276,7 +276,7 @@ class EmbeddingMaintainer(threading.Thread): ) if not description: - logger.debug("Failed to generate description for %s", event.id) + print("Failed to generate description for %s", event.id) return # fire and forget description update @@ -288,7 +288,7 @@ class EmbeddingMaintainer(threading.Thread): # Encode the description self.embeddings.upsert_description(event.id, description) - logger.debug( + print( "Generated description for %s (%d images): %s", event.id, len(thumbnails), @@ -309,7 +309,7 @@ class EmbeddingMaintainer(threading.Thread): thumbnail = base64.b64decode(event.thumbnail) - logger.debug( + print( f"Trying {source} regeneration for {event}, has_snapshot: {event.has_snapshot}" ) diff --git a/frigate/util/downloader.py b/frigate/util/downloader.py index 642dc7c8f..cab714fa0 100644 --- a/frigate/util/downloader.py +++ b/frigate/util/downloader.py @@ -120,4 +120,5 @@ class ModelDownloader: logger.info(f"Downloading complete: {url}") def wait_for_download(self): + print("waiting for model download") self.download_complete.wait()