mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 07:35:27 +03:00
Handle serialization
This commit is contained in:
parent
f8f1852b0c
commit
dca1776334
@ -52,7 +52,7 @@ class EmbeddingsRequestor:
|
|||||||
self.socket = self.context.socket(zmq.REQ)
|
self.socket = self.context.socket(zmq.REQ)
|
||||||
self.socket.connect(SOCKET_REP_REQ)
|
self.socket.connect(SOCKET_REP_REQ)
|
||||||
|
|
||||||
def send_data(self, topic: str, data: any) -> any:
|
def send_data(self, topic: str, data: any) -> str:
|
||||||
"""Sends data and then waits for reply."""
|
"""Sends data and then waits for reply."""
|
||||||
self.socket.send_json((topic, data))
|
self.socket.send_json((topic, data))
|
||||||
return self.socket.recv_json()
|
return self.socket.recv_json()
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
"""SQLite-vec embeddings database."""
|
"""SQLite-vec embeddings database."""
|
||||||
|
|
||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
@ -112,18 +111,20 @@ class EmbeddingsContext:
|
|||||||
row = cursor.fetchone() if cursor else None
|
row = cursor.fetchone() if cursor else None
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
query_embedding = deserialize(
|
query_embedding = row[0]
|
||||||
row[0]
|
|
||||||
) # Deserialize the thumbnail embedding
|
|
||||||
else:
|
else:
|
||||||
# If no embedding found, generate it and return it
|
# If no embedding found, generate it and return it
|
||||||
query_embedding = self.requestor.send_data(
|
query_embedding = serialize(
|
||||||
EmbeddingsRequestEnum.embed_thumbnail,
|
self.requestor.send_data(
|
||||||
{"id": query.id, "thumbnail": query.thumbnail},
|
EmbeddingsRequestEnum.embed_thumbnail.value,
|
||||||
|
{"id": query.id, "thumbnail": query.thumbnail},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
query_embedding = self.requestor.send_data(
|
query_embedding = serialize(
|
||||||
EmbeddingsRequestEnum.generate_search, query
|
self.requestor.send_data(
|
||||||
|
EmbeddingsRequestEnum.generate_search.value, query
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
@ -145,11 +146,7 @@ class EmbeddingsContext:
|
|||||||
# when it's implemented, we can use cosine similarity
|
# when it's implemented, we can use cosine similarity
|
||||||
sql_query += " ORDER BY distance"
|
sql_query += " ORDER BY distance"
|
||||||
|
|
||||||
parameters = (
|
parameters = [query_embedding] + event_ids if event_ids else [query_embedding]
|
||||||
[serialize(query_embedding)] + event_ids
|
|
||||||
if event_ids
|
|
||||||
else [serialize(query_embedding)]
|
|
||||||
)
|
|
||||||
|
|
||||||
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
||||||
|
|
||||||
@ -158,8 +155,10 @@ class EmbeddingsContext:
|
|||||||
def search_description(
|
def search_description(
|
||||||
self, query_text: str, event_ids: list[str] = None
|
self, query_text: str, event_ids: list[str] = None
|
||||||
) -> list[tuple[str, float]]:
|
) -> list[tuple[str, float]]:
|
||||||
query_embedding = self.requestor.send_data(
|
query_embedding = serialize(
|
||||||
EmbeddingsRequestEnum.generate_search, query_text
|
self.requestor.send_data(
|
||||||
|
EmbeddingsRequestEnum.generate_search.value, query_text
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare the base SQL query
|
# Prepare the base SQL query
|
||||||
@ -182,11 +181,7 @@ class EmbeddingsContext:
|
|||||||
# when it's implemented, we can use cosine similarity
|
# when it's implemented, we can use cosine similarity
|
||||||
sql_query += " ORDER BY distance"
|
sql_query += " ORDER BY distance"
|
||||||
|
|
||||||
parameters = (
|
parameters = [query_embedding] + event_ids if event_ids else [query_embedding]
|
||||||
[serialize(query_embedding)] + event_ids
|
|
||||||
if event_ids
|
|
||||||
else [serialize(query_embedding)]
|
|
||||||
)
|
|
||||||
|
|
||||||
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
||||||
|
|
||||||
@ -194,6 +189,6 @@ class EmbeddingsContext:
|
|||||||
|
|
||||||
def update_description(self, event_id: str, description: str) -> None:
|
def update_description(self, event_id: str, description: str) -> None:
|
||||||
self.requestor.send_data(
|
self.requestor.send_data(
|
||||||
EmbeddingsRequestEnum.embed_description,
|
EmbeddingsRequestEnum.embed_description.value,
|
||||||
{"id": event_id, "description": description},
|
{"id": event_id, "description": description},
|
||||||
)
|
)
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION
|
|||||||
from frigate.events.types import EventTypeEnum
|
from frigate.events.types import EventTypeEnum
|
||||||
from frigate.genai import get_genai_client
|
from frigate.genai import get_genai_client
|
||||||
from frigate.models import Event
|
from frigate.models import Event
|
||||||
|
from frigate.util.builtin import serialize
|
||||||
from frigate.util.image import SharedMemoryFrameManager, calculate_region
|
from frigate.util.image import SharedMemoryFrameManager, calculate_region
|
||||||
|
|
||||||
from .embeddings import Embeddings
|
from .embeddings import Embeddings
|
||||||
@ -75,16 +76,20 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
def _process_requests(self) -> None:
|
def _process_requests(self) -> None:
|
||||||
"""Process embeddings requests"""
|
"""Process embeddings requests"""
|
||||||
|
|
||||||
def handle_request(topic: str, data: str) -> any:
|
def handle_request(topic: str, data: str) -> str:
|
||||||
if topic == EmbeddingsRequestEnum.embed_description:
|
if topic == EmbeddingsRequestEnum.embed_description.value:
|
||||||
return self.embeddings.upsert_description(
|
return serialize(
|
||||||
data["id"], data["description"]
|
self.embeddings.upsert_description(data["id"], data["description"]),
|
||||||
|
pack=False,
|
||||||
)
|
)
|
||||||
elif topic == EmbeddingsRequestEnum.embed_thumbnail:
|
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
|
||||||
thumbnail = base64.b64decode(data["thumbnail"])
|
thumbnail = base64.b64decode(data["thumbnail"])
|
||||||
return self.embeddings.upsert_thumbnail(data["id"], thumbnail)
|
return serialize(
|
||||||
elif topic == EmbeddingsRequestEnum.generate_search:
|
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
|
||||||
return self.embeddings.text_embedding([data])[0]
|
pack=False,
|
||||||
|
)
|
||||||
|
elif topic == EmbeddingsRequestEnum.generate_search.value:
|
||||||
|
return serialize(self.embeddings.text_embedding([data])[0], pack=False)
|
||||||
|
|
||||||
self.embeddings_responder.check_for_request(handle_request)
|
self.embeddings_responder.check_for_request(handle_request)
|
||||||
|
|
||||||
|
|||||||
@ -345,7 +345,7 @@ def generate_color_palette(n):
|
|||||||
return colors
|
return colors
|
||||||
|
|
||||||
|
|
||||||
def serialize(vector: Union[list[float], np.ndarray, float]) -> bytes:
|
def serialize(vector: Union[list[float], np.ndarray, float], pack: bool = True) -> bytes:
|
||||||
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
|
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
|
||||||
if isinstance(vector, np.ndarray):
|
if isinstance(vector, np.ndarray):
|
||||||
# Convert numpy array to list of floats
|
# Convert numpy array to list of floats
|
||||||
@ -359,7 +359,10 @@ def serialize(vector: Union[list[float], np.ndarray, float]) -> bytes:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return struct.pack("%sf" % len(vector), *vector)
|
if pack:
|
||||||
|
return struct.pack("%sf" % len(vector), *vector)
|
||||||
|
else:
|
||||||
|
return vector
|
||||||
except struct.error as e:
|
except struct.error as e:
|
||||||
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
|
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user