Handle serialization

This commit is contained in:
Nicolas Mowen 2024-10-10 09:07:03 -06:00
parent f8f1852b0c
commit dca1776334
4 changed files with 36 additions and 33 deletions

View File

@ -52,7 +52,7 @@ class EmbeddingsRequestor:
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(SOCKET_REP_REQ)
def send_data(self, topic: str, data: any) -> any:
def send_data(self, topic: str, data: any) -> str:
"""Sends data and then waits for reply."""
self.socket.send_json((topic, data))
return self.socket.recv_json()

View File

@ -1,6 +1,5 @@
"""SQLite-vec embeddings database."""
import base64
import json
import logging
import multiprocessing as mp
@ -112,18 +111,20 @@ class EmbeddingsContext:
row = cursor.fetchone() if cursor else None
if row:
query_embedding = deserialize(
row[0]
) # Deserialize the thumbnail embedding
query_embedding = row[0]
else:
# If no embedding found, generate it and return it
query_embedding = self.requestor.send_data(
EmbeddingsRequestEnum.embed_thumbnail,
{"id": query.id, "thumbnail": query.thumbnail},
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.embed_thumbnail.value,
{"id": query.id, "thumbnail": query.thumbnail},
)
)
else:
query_embedding = self.requestor.send_data(
EmbeddingsRequestEnum.generate_search, query
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query
)
)
sql_query = """
@ -145,11 +146,7 @@ class EmbeddingsContext:
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
parameters = [query_embedding] + event_ids if event_ids else [query_embedding]
results = self.db.execute_sql(sql_query, parameters).fetchall()
@ -158,8 +155,10 @@ class EmbeddingsContext:
def search_description(
self, query_text: str, event_ids: list[str] = None
) -> list[tuple[str, float]]:
query_embedding = self.requestor.send_data(
EmbeddingsRequestEnum.generate_search, query_text
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query_text
)
)
# Prepare the base SQL query
@ -182,11 +181,7 @@ class EmbeddingsContext:
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
parameters = [query_embedding] + event_ids if event_ids else [query_embedding]
results = self.db.execute_sql(sql_query, parameters).fetchall()
@ -194,6 +189,6 @@ class EmbeddingsContext:
def update_description(self, event_id: str, description: str) -> None:
self.requestor.send_data(
EmbeddingsRequestEnum.embed_description,
EmbeddingsRequestEnum.embed_description.value,
{"id": event_id, "description": description},
)

View File

@ -24,6 +24,7 @@ from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION
from frigate.events.types import EventTypeEnum
from frigate.genai import get_genai_client
from frigate.models import Event
from frigate.util.builtin import serialize
from frigate.util.image import SharedMemoryFrameManager, calculate_region
from .embeddings import Embeddings
@ -75,16 +76,20 @@ class EmbeddingMaintainer(threading.Thread):
def _process_requests(self) -> None:
"""Process embeddings requests"""
def handle_request(topic: str, data: str) -> any:
if topic == EmbeddingsRequestEnum.embed_description:
return self.embeddings.upsert_description(
data["id"], data["description"]
def handle_request(topic: str, data: str) -> str:
if topic == EmbeddingsRequestEnum.embed_description.value:
return serialize(
self.embeddings.upsert_description(data["id"], data["description"]),
pack=False,
)
elif topic == EmbeddingsRequestEnum.embed_thumbnail:
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
thumbnail = base64.b64decode(data["thumbnail"])
return self.embeddings.upsert_thumbnail(data["id"], thumbnail)
elif topic == EmbeddingsRequestEnum.generate_search:
return self.embeddings.text_embedding([data])[0]
return serialize(
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
pack=False,
)
elif topic == EmbeddingsRequestEnum.generate_search.value:
return serialize(self.embeddings.text_embedding([data])[0], pack=False)
self.embeddings_responder.check_for_request(handle_request)

View File

@ -345,7 +345,7 @@ def generate_color_palette(n):
return colors
def serialize(vector: Union[list[float], np.ndarray, float]) -> bytes:
def serialize(vector: Union[list[float], np.ndarray, float], pack: bool = True) -> bytes:
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
if isinstance(vector, np.ndarray):
# Convert numpy array to list of floats
@ -359,7 +359,10 @@ def serialize(vector: Union[list[float], np.ndarray, float]) -> bytes:
)
try:
return struct.pack("%sf" % len(vector), *vector)
if pack:
return struct.pack("%sf" % len(vector), *vector)
else:
return vector
except struct.error as e:
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")