From dca1776334cdbfab203027547fc1e52f03afa644 Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Thu, 10 Oct 2024 09:07:03 -0600 Subject: [PATCH] Handle serialization --- frigate/comms/embeddings_updater.py | 2 +- frigate/embeddings/__init__.py | 39 +++++++++++++---------------- frigate/embeddings/maintainer.py | 21 ++++++++++------ frigate/util/builtin.py | 7 ++++-- 4 files changed, 36 insertions(+), 33 deletions(-) diff --git a/frigate/comms/embeddings_updater.py b/frigate/comms/embeddings_updater.py index d7b30232b..8a7617630 100644 --- a/frigate/comms/embeddings_updater.py +++ b/frigate/comms/embeddings_updater.py @@ -52,7 +52,7 @@ class EmbeddingsRequestor: self.socket = self.context.socket(zmq.REQ) self.socket.connect(SOCKET_REP_REQ) - def send_data(self, topic: str, data: any) -> any: + def send_data(self, topic: str, data: any) -> str: """Sends data and then waits for reply.""" self.socket.send_json((topic, data)) return self.socket.recv_json() diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index 688ebcbd1..f8a1232d2 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -1,6 +1,5 @@ """SQLite-vec embeddings database.""" -import base64 import json import logging import multiprocessing as mp @@ -112,18 +111,20 @@ class EmbeddingsContext: row = cursor.fetchone() if cursor else None if row: - query_embedding = deserialize( - row[0] - ) # Deserialize the thumbnail embedding + query_embedding = row[0] else: # If no embedding found, generate it and return it - query_embedding = self.requestor.send_data( - EmbeddingsRequestEnum.embed_thumbnail, - {"id": query.id, "thumbnail": query.thumbnail}, + query_embedding = serialize( + self.requestor.send_data( + EmbeddingsRequestEnum.embed_thumbnail.value, + {"id": query.id, "thumbnail": query.thumbnail}, + ) ) else: - query_embedding = self.requestor.send_data( - EmbeddingsRequestEnum.generate_search, query + query_embedding = serialize( + self.requestor.send_data( + EmbeddingsRequestEnum.generate_search.value, query + ) ) sql_query = """ @@ -145,11 +146,7 @@ class EmbeddingsContext: # when it's implemented, we can use cosine similarity sql_query += " ORDER BY distance" - parameters = ( - [serialize(query_embedding)] + event_ids - if event_ids - else [serialize(query_embedding)] - ) + parameters = [query_embedding] + event_ids if event_ids else [query_embedding] results = self.db.execute_sql(sql_query, parameters).fetchall() @@ -158,8 +155,10 @@ class EmbeddingsContext: def search_description( self, query_text: str, event_ids: list[str] = None ) -> list[tuple[str, float]]: - query_embedding = self.requestor.send_data( - EmbeddingsRequestEnum.generate_search, query_text + query_embedding = serialize( + self.requestor.send_data( + EmbeddingsRequestEnum.generate_search.value, query_text + ) ) # Prepare the base SQL query @@ -182,11 +181,7 @@ class EmbeddingsContext: # when it's implemented, we can use cosine similarity sql_query += " ORDER BY distance" - parameters = ( - [serialize(query_embedding)] + event_ids - if event_ids - else [serialize(query_embedding)] - ) + parameters = [query_embedding] + event_ids if event_ids else [query_embedding] results = self.db.execute_sql(sql_query, parameters).fetchall() @@ -194,6 +189,6 @@ class EmbeddingsContext: def update_description(self, event_id: str, description: str) -> None: self.requestor.send_data( - EmbeddingsRequestEnum.embed_description, + EmbeddingsRequestEnum.embed_description.value, {"id": event_id, "description": description}, ) diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index c8b2375e6..68c3e3686 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -24,6 +24,7 @@ from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION from frigate.events.types import EventTypeEnum from frigate.genai import get_genai_client from frigate.models import Event +from frigate.util.builtin import serialize from frigate.util.image import SharedMemoryFrameManager, calculate_region from .embeddings import Embeddings @@ -75,16 +76,20 @@ class EmbeddingMaintainer(threading.Thread): def _process_requests(self) -> None: """Process embeddings requests""" - def handle_request(topic: str, data: str) -> any: - if topic == EmbeddingsRequestEnum.embed_description: - return self.embeddings.upsert_description( - data["id"], data["description"] + def handle_request(topic: str, data: str) -> str: + if topic == EmbeddingsRequestEnum.embed_description.value: + return serialize( + self.embeddings.upsert_description(data["id"], data["description"]), + pack=False, ) - elif topic == EmbeddingsRequestEnum.embed_thumbnail: + elif topic == EmbeddingsRequestEnum.embed_thumbnail.value: thumbnail = base64.b64decode(data["thumbnail"]) - return self.embeddings.upsert_thumbnail(data["id"], thumbnail) - elif topic == EmbeddingsRequestEnum.generate_search: - return self.embeddings.text_embedding([data])[0] + return serialize( + self.embeddings.upsert_thumbnail(data["id"], thumbnail), + pack=False, + ) + elif topic == EmbeddingsRequestEnum.generate_search.value: + return serialize(self.embeddings.text_embedding([data])[0], pack=False) self.embeddings_responder.check_for_request(handle_request) diff --git a/frigate/util/builtin.py b/frigate/util/builtin.py index e10f83647..a28eab013 100644 --- a/frigate/util/builtin.py +++ b/frigate/util/builtin.py @@ -345,7 +345,7 @@ def generate_color_palette(n): return colors -def serialize(vector: Union[list[float], np.ndarray, float]) -> bytes: +def serialize(vector: Union[list[float], np.ndarray, float], pack: bool = True) -> bytes: """Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format""" if isinstance(vector, np.ndarray): # Convert numpy array to list of floats @@ -359,7 +359,10 @@ def serialize(vector: Union[list[float], np.ndarray, float]) -> bytes: ) try: - return struct.pack("%sf" % len(vector), *vector) + if pack: + return struct.pack("%sf" % len(vector), *vector) + else: + return vector except struct.error as e: raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")