Use ZMQ to proxy embeddings requests

This commit is contained in:
Nicolas Mowen 2024-10-10 08:05:28 -06:00
parent 2142a39b3c
commit f8f1852b0c
5 changed files with 108 additions and 11 deletions

View File

@ -944,9 +944,9 @@ def set_description(
# If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings
context.embeddings.upsert_description(
event_id=event_id,
description=new_description,
context.update_description(
event_id,
new_description,
)
response_message = (

View File

@ -276,7 +276,7 @@ class FrigateApp:
def init_embeddings_client(self) -> None:
if self.config.semantic_search.enabled:
# Create a client for other processes to use
self.embeddings = EmbeddingsContext(self.config, self.db)
self.embeddings = EmbeddingsContext(self.db)
def init_external_event_processor(self) -> None:
self.external_event_processor = ExternalEventProcessor(self.config)
@ -699,7 +699,7 @@ class FrigateApp:
# Save embeddings stats to disk
if self.embeddings:
self.embeddings.save_stats()
self.embeddings.stop()
# Stop Communicators
self.inter_process_communicator.stop()

View File

@ -0,0 +1,62 @@
"""Facilitates communication between processes."""
from enum import Enum
from typing import Callable
import zmq
SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
class EmbeddingsRequestEnum(Enum):
embed_description = "embed_description"
embed_thumbnail = "embed_thumbnail"
generate_search = "generate_search"
class EmbeddingsResponder:
def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind(SOCKET_REP_REQ)
def check_for_request(self, process: Callable) -> None:
while True: # load all messages that are queued
has_message, _, _ = zmq.select([self.socket], [], [], 1)
if not has_message:
break
try:
(topic, value) = self.socket.recv_json(flags=zmq.NOBLOCK)
response = process(topic, value)
if response is not None:
self.socket.send_json(response)
else:
self.socket.send_json([])
except zmq.ZMQError:
break
def stop(self) -> None:
self.socket.close()
self.context.destroy()
class EmbeddingsRequestor:
"""Simplifies sending data to EmbeddingsResponder and getting a reply."""
def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(SOCKET_REP_REQ)
def send_data(self, topic: str, data: any) -> any:
"""Sends data and then waits for reply."""
self.socket.send_json((topic, data))
return self.socket.recv_json()
def stop(self) -> None:
self.socket.close()
self.context.destroy()

View File

@ -12,6 +12,7 @@ from typing import Optional, Union
from setproctitle import setproctitle
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
@ -72,10 +73,11 @@ def manage_embeddings(config: FrigateConfig) -> None:
class EmbeddingsContext:
def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
def __init__(self, db: SqliteVecQueueDatabase):
self.db = db
self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization()
self.requestor = EmbeddingsRequestor()
# load stats from disk
try:
@ -86,7 +88,7 @@ class EmbeddingsContext:
except FileNotFoundError:
pass
def save_stats(self):
def stop(self):
"""Write the stats to disk as JSON on exit."""
contents = {
"thumb_stats": self.thumb_stats.to_dict(),
@ -94,6 +96,7 @@ class EmbeddingsContext:
}
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
json.dump(contents, f)
self.requestor.stop()
def search_thumbnail(
self, query: Union[Event, str], event_ids: list[str] = None
@ -114,10 +117,14 @@ class EmbeddingsContext:
) # Deserialize the thumbnail embedding
else:
# If no embedding found, generate it and return it
thumbnail = base64.b64decode(query.thumbnail)
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
query_embedding = self.requestor.send_data(
EmbeddingsRequestEnum.embed_thumbnail,
{"id": query.id, "thumbnail": query.thumbnail},
)
else:
query_embedding = self.text_embedding([query])[0]
query_embedding = self.requestor.send_data(
EmbeddingsRequestEnum.generate_search, query
)
sql_query = """
SELECT
@ -151,7 +158,9 @@ class EmbeddingsContext:
def search_description(
self, query_text: str, event_ids: list[str] = None
) -> list[tuple[str, float]]:
query_embedding = self.text_embedding([query_text])[0]
query_embedding = self.requestor.send_data(
EmbeddingsRequestEnum.generate_search, query_text
)
# Prepare the base SQL query
sql_query = """
@ -182,3 +191,9 @@ class EmbeddingsContext:
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def update_description(self, event_id: str, description: str) -> None:
self.requestor.send_data(
EmbeddingsRequestEnum.embed_description,
{"id": event_id, "description": description},
)

View File

@ -12,6 +12,7 @@ import numpy as np
from peewee import DoesNotExist
from playhouse.sqliteq import SqliteQueueDatabase
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsResponder
from frigate.comms.event_metadata_updater import (
EventMetadataSubscriber,
EventMetadataTypeEnum,
@ -48,6 +49,7 @@ class EmbeddingMaintainer(threading.Thread):
self.event_metadata_subscriber = EventMetadataSubscriber(
EventMetadataTypeEnum.regenerate_description
)
self.embeddings_responder = EmbeddingsResponder()
self.frame_manager = SharedMemoryFrameManager()
# create communication for updating event descriptions
self.requestor = InterProcessRequestor()
@ -58,6 +60,7 @@ class EmbeddingMaintainer(threading.Thread):
def run(self) -> None:
"""Maintain a SQLite-vec database for semantic search."""
while not self.stop_event.is_set():
self._process_requests()
self._process_updates()
self._process_finalized()
self._process_event_metadata()
@ -65,9 +68,26 @@ class EmbeddingMaintainer(threading.Thread):
self.event_subscriber.stop()
self.event_end_subscriber.stop()
self.event_metadata_subscriber.stop()
self.embeddings_responder.stop()
self.requestor.stop()
logger.info("Exiting embeddings maintenance...")
def _process_requests(self) -> None:
"""Process embeddings requests"""
def handle_request(topic: str, data: str) -> any:
if topic == EmbeddingsRequestEnum.embed_description:
return self.embeddings.upsert_description(
data["id"], data["description"]
)
elif topic == EmbeddingsRequestEnum.embed_thumbnail:
thumbnail = base64.b64decode(data["thumbnail"])
return self.embeddings.upsert_thumbnail(data["id"], thumbnail)
elif topic == EmbeddingsRequestEnum.generate_search:
return self.embeddings.text_embedding([data])[0]
self.embeddings_responder.check_for_request(handle_request)
def _process_updates(self) -> None:
"""Process event updates"""
update = self.event_subscriber.check_for_update()