Restructure embeddings

This commit is contained in:
Nicolas Mowen 2024-10-10 07:39:33 -06:00
parent a2ca18a714
commit 2142a39b3c
5 changed files with 136 additions and 135 deletions

View File

@ -472,7 +472,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
status_code=404, status_code=404,
) )
thumb_result = context.embeddings.search_thumbnail(search_event) thumb_result = context.search_thumbnail(search_event)
thumb_ids = dict( thumb_ids = dict(
zip( zip(
[result[0] for result in thumb_result], [result[0] for result in thumb_result],
@ -487,7 +487,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
search_types = search_type.split(",") search_types = search_type.split(",")
if "thumbnail" in search_types: if "thumbnail" in search_types:
thumb_result = context.embeddings.search_thumbnail(query) thumb_result = context.search_thumbnail(query)
thumb_ids = dict( thumb_ids = dict(
zip( zip(
[result[0] for result in thumb_result], [result[0] for result in thumb_result],
@ -504,7 +504,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
) )
if "description" in search_types: if "description" in search_types:
desc_result = context.embeddings.search_description(query) desc_result = context.search_description(query)
desc_ids = dict( desc_ids = dict(
zip( zip(
[result[0] for result in desc_result], [result[0] for result in desc_result],
@ -1033,8 +1033,8 @@ def delete_event(request: Request, event_id: str):
# If semantic search is enabled, update the index # If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled: if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings context: EmbeddingsContext = request.app.embeddings
context.embeddings.delete_thumbnail(id=[event_id]) context.db.delete_embeddings_thumbnail(id=[event_id])
context.embeddings.delete_description(id=[event_id]) context.db.delete_embeddings_description(id=[event_id])
return JSONResponse( return JSONResponse(
content=({"success": True, "message": "Event " + event_id + " deleted"}), content=({"success": True, "message": "Event " + event_id + " deleted"}),
status_code=200, status_code=200,

View File

@ -20,3 +20,15 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
conn.enable_load_extension(True) conn.enable_load_extension(True)
conn.load_extension(self.sqlite_vec_path) conn.load_extension(self.sqlite_vec_path)
conn.enable_load_extension(False) conn.enable_load_extension(False)
def delete_embeddings_thumbnail(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
)
def delete_embeddings_description(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
)

View File

@ -1,5 +1,6 @@
"""SQLite-vec embeddings database.""" """SQLite-vec embeddings database."""
import base64
import json import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
@ -7,7 +8,7 @@ import os
import signal import signal
import threading import threading
from types import FrameType from types import FrameType
from typing import Optional from typing import Optional, Union
from setproctitle import setproctitle from setproctitle import setproctitle
@ -15,6 +16,7 @@ from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR from frigate.const import CONFIG_DIR
from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event from frigate.models import Event
from frigate.util.builtin import deserialize, serialize
from frigate.util.services import listen from frigate.util.services import listen
from .embeddings import Embeddings from .embeddings import Embeddings
@ -71,7 +73,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
class EmbeddingsContext: class EmbeddingsContext:
def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase): def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
self.embeddings = Embeddings(config.semantic_search, db) self.db = db
self.thumb_stats = ZScoreNormalization() self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization() self.desc_stats = ZScoreNormalization()
@ -92,3 +94,91 @@ class EmbeddingsContext:
} }
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f: with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
json.dump(contents, f) json.dump(contents, f)
def search_thumbnail(
self, query: Union[Event, str], event_ids: list[str] = None
) -> list[tuple[str, float]]:
if query.__class__ == Event:
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
)
row = cursor.fetchone() if cursor else None
if row:
query_embedding = deserialize(
row[0]
) # Deserialize the thumbnail embedding
else:
# If no embedding found, generate it and return it
thumbnail = base64.b64decode(query.thumbnail)
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
else:
query_embedding = self.text_embedding([query])[0]
sql_query = """
SELECT
id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def search_description(
self, query_text: str, event_ids: list[str] = None
) -> list[tuple[str, float]]:
query_embedding = self.text_embedding([query_text])[0]
# Prepare the base SQL query
sql_query = """
SELECT
id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results

View File

@ -3,11 +3,8 @@
import base64 import base64
import io import io
import logging import logging
import struct
import time import time
from typing import List, Tuple, Union
import numpy as np
from PIL import Image from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
@ -17,6 +14,7 @@ from frigate.const import UPDATE_MODEL_STATE
from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event from frigate.models import Event
from frigate.types import ModelStatusTypesEnum from frigate.types import ModelStatusTypesEnum
from frigate.util.builtin import serialize
from .functions.onnx import GenericONNXEmbedding from .functions.onnx import GenericONNXEmbedding
@ -54,30 +52,6 @@ def get_metadata(event: Event) -> dict:
) )
def serialize(vector: Union[List[float], np.ndarray, float]) -> bytes:
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
if isinstance(vector, np.ndarray):
# Convert numpy array to list of floats
vector = vector.flatten().tolist()
elif isinstance(vector, (float, np.float32, np.float64)):
# Handle single float values
vector = [vector]
elif not isinstance(vector, list):
raise TypeError(
f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
)
try:
return struct.pack("%sf" % len(vector), *vector)
except struct.error as e:
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
def deserialize(bytes_data: bytes) -> List[float]:
"""Deserializes a compact "raw bytes" format into a list of floats"""
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
class Embeddings: class Embeddings:
"""SQLite-vec embeddings database.""" """SQLite-vec embeddings database."""
@ -190,106 +164,6 @@ class Embeddings:
return embedding return embedding
def delete_thumbnail(self, event_ids: List[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
)
def delete_description(self, event_ids: List[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
)
def search_thumbnail(
self, query: Union[Event, str], event_ids: List[str] = None
) -> List[Tuple[str, float]]:
if query.__class__ == Event:
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
)
row = cursor.fetchone() if cursor else None
if row:
query_embedding = deserialize(
row[0]
) # Deserialize the thumbnail embedding
else:
# If no embedding found, generate it and return it
thumbnail = base64.b64decode(query.thumbnail)
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
else:
query_embedding = self.text_embedding([query])[0]
sql_query = """
SELECT
id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def search_description(
self, query_text: str, event_ids: List[str] = None
) -> List[Tuple[str, float]]:
query_embedding = self.text_embedding([query_text])[0]
# Prepare the base SQL query
sql_query = """
SELECT
id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def reindex(self) -> None: def reindex(self) -> None:
logger.info("Indexing event embeddings...") logger.info("Indexing event embeddings...")

View File

@ -8,10 +8,11 @@ import multiprocessing as mp
import queue import queue
import re import re
import shlex import shlex
import struct
import urllib.parse import urllib.parse
from collections.abc import Mapping from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple, Union
import numpy as np import numpy as np
import pytz import pytz
@ -342,3 +343,27 @@ def generate_color_palette(n):
colors.append(interpolate(color1, color2, factor)) colors.append(interpolate(color1, color2, factor))
return colors return colors
def serialize(vector: Union[list[float], np.ndarray, float]) -> bytes:
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
if isinstance(vector, np.ndarray):
# Convert numpy array to list of floats
vector = vector.flatten().tolist()
elif isinstance(vector, (float, np.float32, np.float64)):
# Handle single float values
vector = [vector]
elif not isinstance(vector, list):
raise TypeError(
f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
)
try:
return struct.pack("%sf" % len(vector), *vector)
except struct.error as e:
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
def deserialize(bytes_data: bytes) -> list[float]:
"""Deserializes a compact "raw bytes" format into a list of floats"""
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))