From f8f1852b0cdf40b43884b3e42f92a26b215cdd69 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Thu, 10 Oct 2024 08:05:28 -0600 Subject: [PATCH] Use ZMQ to proxy embeddings requests --- frigate/api/event.py | 6 +-- frigate/app.py | 4 +- frigate/comms/embeddings_updater.py | 62 +++++++++++++++++++++++++++++ frigate/embeddings/__init__.py | 27 ++++++++++--- frigate/embeddings/maintainer.py | 20 ++++++++++ 5 files changed, 108 insertions(+), 11 deletions(-) create mode 100644 frigate/comms/embeddings_updater.py diff --git a/frigate/api/event.py b/frigate/api/event.py index f7ebf16b9..3be37539d 100644 --- a/frigate/api/event.py +++ b/frigate/api/event.py @@ -944,9 +944,9 @@ def set_description( # If semantic search is enabled, update the index if request.app.frigate_config.semantic_search.enabled: context: EmbeddingsContext = request.app.embeddings - context.embeddings.upsert_description( - event_id=event_id, - description=new_description, + context.update_description( + event_id, + new_description, ) response_message = ( diff --git a/frigate/app.py b/frigate/app.py index 253bebf89..1fcf91551 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -276,7 +276,7 @@ class FrigateApp: def init_embeddings_client(self) -> None: if self.config.semantic_search.enabled: # Create a client for other processes to use - self.embeddings = EmbeddingsContext(self.config, self.db) + self.embeddings = EmbeddingsContext(self.db) def init_external_event_processor(self) -> None: self.external_event_processor = ExternalEventProcessor(self.config) @@ -699,7 +699,7 @@ class FrigateApp: # Save embeddings stats to disk if self.embeddings: - self.embeddings.save_stats() + self.embeddings.stop() # Stop Communicators self.inter_process_communicator.stop() diff --git a/frigate/comms/embeddings_updater.py b/frigate/comms/embeddings_updater.py new file mode 100644 index 000000000..d7b30232b --- /dev/null +++ b/frigate/comms/embeddings_updater.py @@ -0,0 +1,62 @@ +"""Facilitates communication between processes.""" + +from enum import Enum +from typing import Callable + +import zmq + +SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings" + + +class EmbeddingsRequestEnum(Enum): + embed_description = "embed_description" + embed_thumbnail = "embed_thumbnail" + generate_search = "generate_search" + + +class EmbeddingsResponder: + def __init__(self) -> None: + self.context = zmq.Context() + self.socket = self.context.socket(zmq.REP) + self.socket.bind(SOCKET_REP_REQ) + + def check_for_request(self, process: Callable) -> None: + while True: # load all messages that are queued + has_message, _, _ = zmq.select([self.socket], [], [], 1) + + if not has_message: + break + + try: + (topic, value) = self.socket.recv_json(flags=zmq.NOBLOCK) + + response = process(topic, value) + + if response is not None: + self.socket.send_json(response) + else: + self.socket.send_json([]) + except zmq.ZMQError: + break + + def stop(self) -> None: + self.socket.close() + self.context.destroy() + + +class EmbeddingsRequestor: + """Simplifies sending data to EmbeddingsResponder and getting a reply.""" + + def __init__(self) -> None: + self.context = zmq.Context() + self.socket = self.context.socket(zmq.REQ) + self.socket.connect(SOCKET_REP_REQ) + + def send_data(self, topic: str, data: any) -> any: + """Sends data and then waits for reply.""" + self.socket.send_json((topic, data)) + return self.socket.recv_json() + + def stop(self) -> None: + self.socket.close() + self.context.destroy() diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index 216614857..688ebcbd1 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -12,6 +12,7 @@ from typing import Optional, Union from setproctitle import setproctitle +from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor from frigate.config import FrigateConfig from frigate.const import CONFIG_DIR from frigate.db.sqlitevecq import SqliteVecQueueDatabase @@ -72,10 +73,11 @@ def manage_embeddings(config: FrigateConfig) -> None: class EmbeddingsContext: - def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase): + def __init__(self, db: SqliteVecQueueDatabase): self.db = db self.thumb_stats = ZScoreNormalization() self.desc_stats = ZScoreNormalization() + self.requestor = EmbeddingsRequestor() # load stats from disk try: @@ -86,7 +88,7 @@ class EmbeddingsContext: except FileNotFoundError: pass - def save_stats(self): + def stop(self): """Write the stats to disk as JSON on exit.""" contents = { "thumb_stats": self.thumb_stats.to_dict(), @@ -94,6 +96,7 @@ class EmbeddingsContext: } with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f: json.dump(contents, f) + self.requestor.stop() def search_thumbnail( self, query: Union[Event, str], event_ids: list[str] = None @@ -114,10 +117,14 @@ class EmbeddingsContext: ) # Deserialize the thumbnail embedding else: # If no embedding found, generate it and return it - thumbnail = base64.b64decode(query.thumbnail) - query_embedding = self.upsert_thumbnail(query.id, thumbnail) + query_embedding = self.requestor.send_data( + EmbeddingsRequestEnum.embed_thumbnail, + {"id": query.id, "thumbnail": query.thumbnail}, + ) else: - query_embedding = self.text_embedding([query])[0] + query_embedding = self.requestor.send_data( + EmbeddingsRequestEnum.generate_search, query + ) sql_query = """ SELECT @@ -151,7 +158,9 @@ class EmbeddingsContext: def search_description( self, query_text: str, event_ids: list[str] = None ) -> list[tuple[str, float]]: - query_embedding = self.text_embedding([query_text])[0] + query_embedding = self.requestor.send_data( + EmbeddingsRequestEnum.generate_search, query_text + ) # Prepare the base SQL query sql_query = """ @@ -182,3 +191,9 @@ class EmbeddingsContext: results = self.db.execute_sql(sql_query, parameters).fetchall() return results + + def update_description(self, event_id: str, description: str) -> None: + self.requestor.send_data( + EmbeddingsRequestEnum.embed_description, + {"id": event_id, "description": description}, + ) diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index c95cb5050..c8b2375e6 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -12,6 +12,7 @@ import numpy as np from peewee import DoesNotExist from playhouse.sqliteq import SqliteQueueDatabase +from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsResponder from frigate.comms.event_metadata_updater import ( EventMetadataSubscriber, EventMetadataTypeEnum, @@ -48,6 +49,7 @@ class EmbeddingMaintainer(threading.Thread): self.event_metadata_subscriber = EventMetadataSubscriber( EventMetadataTypeEnum.regenerate_description ) + self.embeddings_responder = EmbeddingsResponder() self.frame_manager = SharedMemoryFrameManager() # create communication for updating event descriptions self.requestor = InterProcessRequestor() @@ -58,6 +60,7 @@ class EmbeddingMaintainer(threading.Thread): def run(self) -> None: """Maintain a SQLite-vec database for semantic search.""" while not self.stop_event.is_set(): + self._process_requests() self._process_updates() self._process_finalized() self._process_event_metadata() @@ -65,9 +68,26 @@ class EmbeddingMaintainer(threading.Thread): self.event_subscriber.stop() self.event_end_subscriber.stop() self.event_metadata_subscriber.stop() + self.embeddings_responder.stop() self.requestor.stop() logger.info("Exiting embeddings maintenance...") + def _process_requests(self) -> None: + """Process embeddings requests""" + + def handle_request(topic: str, data: str) -> any: + if topic == EmbeddingsRequestEnum.embed_description: + return self.embeddings.upsert_description( + data["id"], data["description"] + ) + elif topic == EmbeddingsRequestEnum.embed_thumbnail: + thumbnail = base64.b64decode(data["thumbnail"]) + return self.embeddings.upsert_thumbnail(data["id"], thumbnail) + elif topic == EmbeddingsRequestEnum.generate_search: + return self.embeddings.text_embedding([data])[0] + + self.embeddings_responder.check_for_request(handle_request) + def _process_updates(self) -> None: """Process event updates""" update = self.event_subscriber.check_for_update()