Restructure embeddings

This commit is contained in:
Nicolas Mowen 2024-10-10 07:39:33 -06:00
parent a2ca18a714
commit 2142a39b3c
5 changed files with 136 additions and 135 deletions

View File

@ -472,7 +472,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
status_code=404,
)
thumb_result = context.embeddings.search_thumbnail(search_event)
thumb_result = context.search_thumbnail(search_event)
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
@ -487,7 +487,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
search_types = search_type.split(",")
if "thumbnail" in search_types:
thumb_result = context.embeddings.search_thumbnail(query)
thumb_result = context.search_thumbnail(query)
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
@ -504,7 +504,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
)
if "description" in search_types:
desc_result = context.embeddings.search_description(query)
desc_result = context.search_description(query)
desc_ids = dict(
zip(
[result[0] for result in desc_result],
@ -1033,8 +1033,8 @@ def delete_event(request: Request, event_id: str):
# If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings
context.embeddings.delete_thumbnail(id=[event_id])
context.embeddings.delete_description(id=[event_id])
context.db.delete_embeddings_thumbnail(id=[event_id])
context.db.delete_embeddings_description(id=[event_id])
return JSONResponse(
content=({"success": True, "message": "Event " + event_id + " deleted"}),
status_code=200,

View File

@ -20,3 +20,15 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
conn.enable_load_extension(True)
conn.load_extension(self.sqlite_vec_path)
conn.enable_load_extension(False)
def delete_embeddings_thumbnail(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
)
def delete_embeddings_description(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
)

View File

@ -1,5 +1,6 @@
"""SQLite-vec embeddings database."""
import base64
import json
import logging
import multiprocessing as mp
@ -7,7 +8,7 @@ import os
import signal
import threading
from types import FrameType
from typing import Optional
from typing import Optional, Union
from setproctitle import setproctitle
@ -15,6 +16,7 @@ from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event
from frigate.util.builtin import deserialize, serialize
from frigate.util.services import listen
from .embeddings import Embeddings
@ -71,7 +73,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
class EmbeddingsContext:
def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
self.embeddings = Embeddings(config.semantic_search, db)
self.db = db
self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization()
@ -92,3 +94,91 @@ class EmbeddingsContext:
}
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
json.dump(contents, f)
def search_thumbnail(
self, query: Union[Event, str], event_ids: list[str] = None
) -> list[tuple[str, float]]:
if query.__class__ == Event:
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
)
row = cursor.fetchone() if cursor else None
if row:
query_embedding = deserialize(
row[0]
) # Deserialize the thumbnail embedding
else:
# If no embedding found, generate it and return it
thumbnail = base64.b64decode(query.thumbnail)
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
else:
query_embedding = self.text_embedding([query])[0]
sql_query = """
SELECT
id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def search_description(
self, query_text: str, event_ids: list[str] = None
) -> list[tuple[str, float]]:
query_embedding = self.text_embedding([query_text])[0]
# Prepare the base SQL query
sql_query = """
SELECT
id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results

View File

@ -3,11 +3,8 @@
import base64
import io
import logging
import struct
import time
from typing import List, Tuple, Union
import numpy as np
from PIL import Image
from playhouse.shortcuts import model_to_dict
@ -17,6 +14,7 @@ from frigate.const import UPDATE_MODEL_STATE
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event
from frigate.types import ModelStatusTypesEnum
from frigate.util.builtin import serialize
from .functions.onnx import GenericONNXEmbedding
@ -54,30 +52,6 @@ def get_metadata(event: Event) -> dict:
)
def serialize(vector: Union[List[float], np.ndarray, float]) -> bytes:
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
if isinstance(vector, np.ndarray):
# Convert numpy array to list of floats
vector = vector.flatten().tolist()
elif isinstance(vector, (float, np.float32, np.float64)):
# Handle single float values
vector = [vector]
elif not isinstance(vector, list):
raise TypeError(
f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
)
try:
return struct.pack("%sf" % len(vector), *vector)
except struct.error as e:
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
def deserialize(bytes_data: bytes) -> List[float]:
"""Deserializes a compact "raw bytes" format into a list of floats"""
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
class Embeddings:
"""SQLite-vec embeddings database."""
@ -190,106 +164,6 @@ class Embeddings:
return embedding
def delete_thumbnail(self, event_ids: List[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
)
def delete_description(self, event_ids: List[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
)
def search_thumbnail(
self, query: Union[Event, str], event_ids: List[str] = None
) -> List[Tuple[str, float]]:
if query.__class__ == Event:
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
)
row = cursor.fetchone() if cursor else None
if row:
query_embedding = deserialize(
row[0]
) # Deserialize the thumbnail embedding
else:
# If no embedding found, generate it and return it
thumbnail = base64.b64decode(query.thumbnail)
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
else:
query_embedding = self.text_embedding([query])[0]
sql_query = """
SELECT
id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def search_description(
self, query_text: str, event_ids: List[str] = None
) -> List[Tuple[str, float]]:
query_embedding = self.text_embedding([query_text])[0]
# Prepare the base SQL query
sql_query = """
SELECT
id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def reindex(self) -> None:
logger.info("Indexing event embeddings...")

View File

@ -8,10 +8,11 @@ import multiprocessing as mp
import queue
import re
import shlex
import struct
import urllib.parse
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional, Tuple
from typing import Any, Optional, Tuple, Union
import numpy as np
import pytz
@ -342,3 +343,27 @@ def generate_color_palette(n):
colors.append(interpolate(color1, color2, factor))
return colors
def serialize(vector: Union[list[float], np.ndarray, float]) -> bytes:
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
if isinstance(vector, np.ndarray):
# Convert numpy array to list of floats
vector = vector.flatten().tolist()
elif isinstance(vector, (float, np.float32, np.float64)):
# Handle single float values
vector = [vector]
elif not isinstance(vector, list):
raise TypeError(
f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
)
try:
return struct.pack("%sf" % len(vector), *vector)
except struct.error as e:
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
def deserialize(bytes_data: bytes) -> list[float]:
"""Deserializes a compact "raw bytes" format into a list of floats"""
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))