mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 07:35:27 +03:00
Restructure embeddings
This commit is contained in:
parent
a2ca18a714
commit
2142a39b3c
@ -472,7 +472,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
thumb_result = context.embeddings.search_thumbnail(search_event)
|
||||
thumb_result = context.search_thumbnail(search_event)
|
||||
thumb_ids = dict(
|
||||
zip(
|
||||
[result[0] for result in thumb_result],
|
||||
@ -487,7 +487,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
search_types = search_type.split(",")
|
||||
|
||||
if "thumbnail" in search_types:
|
||||
thumb_result = context.embeddings.search_thumbnail(query)
|
||||
thumb_result = context.search_thumbnail(query)
|
||||
thumb_ids = dict(
|
||||
zip(
|
||||
[result[0] for result in thumb_result],
|
||||
@ -504,7 +504,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
)
|
||||
|
||||
if "description" in search_types:
|
||||
desc_result = context.embeddings.search_description(query)
|
||||
desc_result = context.search_description(query)
|
||||
desc_ids = dict(
|
||||
zip(
|
||||
[result[0] for result in desc_result],
|
||||
@ -1033,8 +1033,8 @@ def delete_event(request: Request, event_id: str):
|
||||
# If semantic search is enabled, update the index
|
||||
if request.app.frigate_config.semantic_search.enabled:
|
||||
context: EmbeddingsContext = request.app.embeddings
|
||||
context.embeddings.delete_thumbnail(id=[event_id])
|
||||
context.embeddings.delete_description(id=[event_id])
|
||||
context.db.delete_embeddings_thumbnail(id=[event_id])
|
||||
context.db.delete_embeddings_description(id=[event_id])
|
||||
return JSONResponse(
|
||||
content=({"success": True, "message": "Event " + event_id + " deleted"}),
|
||||
status_code=200,
|
||||
|
||||
@ -20,3 +20,15 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
|
||||
conn.enable_load_extension(True)
|
||||
conn.load_extension(self.sqlite_vec_path)
|
||||
conn.enable_load_extension(False)
|
||||
|
||||
def delete_embeddings_thumbnail(self, event_ids: list[str]) -> None:
|
||||
ids = ",".join(["?" for _ in event_ids])
|
||||
self.execute_sql(
|
||||
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
|
||||
)
|
||||
|
||||
def delete_embeddings_description(self, event_ids: list[str]) -> None:
|
||||
ids = ",".join(["?" for _ in event_ids])
|
||||
self.execute_sql(
|
||||
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""SQLite-vec embeddings database."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
@ -7,7 +8,7 @@ import os
|
||||
import signal
|
||||
import threading
|
||||
from types import FrameType
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from setproctitle import setproctitle
|
||||
|
||||
@ -15,6 +16,7 @@ from frigate.config import FrigateConfig
|
||||
from frigate.const import CONFIG_DIR
|
||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.models import Event
|
||||
from frigate.util.builtin import deserialize, serialize
|
||||
from frigate.util.services import listen
|
||||
|
||||
from .embeddings import Embeddings
|
||||
@ -71,7 +73,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
||||
|
||||
class EmbeddingsContext:
|
||||
def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
|
||||
self.embeddings = Embeddings(config.semantic_search, db)
|
||||
self.db = db
|
||||
self.thumb_stats = ZScoreNormalization()
|
||||
self.desc_stats = ZScoreNormalization()
|
||||
|
||||
@ -92,3 +94,91 @@ class EmbeddingsContext:
|
||||
}
|
||||
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
|
||||
json.dump(contents, f)
|
||||
|
||||
def search_thumbnail(
|
||||
self, query: Union[Event, str], event_ids: list[str] = None
|
||||
) -> list[tuple[str, float]]:
|
||||
if query.__class__ == Event:
|
||||
cursor = self.db.execute_sql(
|
||||
"""
|
||||
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
|
||||
""",
|
||||
[query.id],
|
||||
)
|
||||
|
||||
row = cursor.fetchone() if cursor else None
|
||||
|
||||
if row:
|
||||
query_embedding = deserialize(
|
||||
row[0]
|
||||
) # Deserialize the thumbnail embedding
|
||||
else:
|
||||
# If no embedding found, generate it and return it
|
||||
thumbnail = base64.b64decode(query.thumbnail)
|
||||
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
|
||||
else:
|
||||
query_embedding = self.text_embedding([query])[0]
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
id,
|
||||
distance
|
||||
FROM vec_thumbnails
|
||||
WHERE thumbnail_embedding MATCH ?
|
||||
AND k = 100
|
||||
"""
|
||||
|
||||
# Add the IN clause if event_ids is provided and not empty
|
||||
# this is the only filter supported by sqlite-vec as of 0.1.3
|
||||
# but it seems to be broken in this version
|
||||
if event_ids:
|
||||
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
|
||||
|
||||
# order by distance DESC is not implemented in this version of sqlite-vec
|
||||
# when it's implemented, we can use cosine similarity
|
||||
sql_query += " ORDER BY distance"
|
||||
|
||||
parameters = (
|
||||
[serialize(query_embedding)] + event_ids
|
||||
if event_ids
|
||||
else [serialize(query_embedding)]
|
||||
)
|
||||
|
||||
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
||||
|
||||
return results
|
||||
|
||||
def search_description(
|
||||
self, query_text: str, event_ids: list[str] = None
|
||||
) -> list[tuple[str, float]]:
|
||||
query_embedding = self.text_embedding([query_text])[0]
|
||||
|
||||
# Prepare the base SQL query
|
||||
sql_query = """
|
||||
SELECT
|
||||
id,
|
||||
distance
|
||||
FROM vec_descriptions
|
||||
WHERE description_embedding MATCH ?
|
||||
AND k = 100
|
||||
"""
|
||||
|
||||
# Add the IN clause if event_ids is provided and not empty
|
||||
# this is the only filter supported by sqlite-vec as of 0.1.3
|
||||
# but it seems to be broken in this version
|
||||
if event_ids:
|
||||
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
|
||||
|
||||
# order by distance DESC is not implemented in this version of sqlite-vec
|
||||
# when it's implemented, we can use cosine similarity
|
||||
sql_query += " ORDER BY distance"
|
||||
|
||||
parameters = (
|
||||
[serialize(query_embedding)] + event_ids
|
||||
if event_ids
|
||||
else [serialize(query_embedding)]
|
||||
)
|
||||
|
||||
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
||||
|
||||
return results
|
||||
|
||||
@ -3,11 +3,8 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import struct
|
||||
import time
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
@ -17,6 +14,7 @@ from frigate.const import UPDATE_MODEL_STATE
|
||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.models import Event
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.builtin import serialize
|
||||
|
||||
from .functions.onnx import GenericONNXEmbedding
|
||||
|
||||
@ -54,30 +52,6 @@ def get_metadata(event: Event) -> dict:
|
||||
)
|
||||
|
||||
|
||||
def serialize(vector: Union[List[float], np.ndarray, float]) -> bytes:
|
||||
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
|
||||
if isinstance(vector, np.ndarray):
|
||||
# Convert numpy array to list of floats
|
||||
vector = vector.flatten().tolist()
|
||||
elif isinstance(vector, (float, np.float32, np.float64)):
|
||||
# Handle single float values
|
||||
vector = [vector]
|
||||
elif not isinstance(vector, list):
|
||||
raise TypeError(
|
||||
f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
|
||||
)
|
||||
|
||||
try:
|
||||
return struct.pack("%sf" % len(vector), *vector)
|
||||
except struct.error as e:
|
||||
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
|
||||
|
||||
|
||||
def deserialize(bytes_data: bytes) -> List[float]:
|
||||
"""Deserializes a compact "raw bytes" format into a list of floats"""
|
||||
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
|
||||
|
||||
|
||||
class Embeddings:
|
||||
"""SQLite-vec embeddings database."""
|
||||
|
||||
@ -190,106 +164,6 @@ class Embeddings:
|
||||
|
||||
return embedding
|
||||
|
||||
def delete_thumbnail(self, event_ids: List[str]) -> None:
|
||||
ids = ",".join(["?" for _ in event_ids])
|
||||
self.db.execute_sql(
|
||||
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
|
||||
)
|
||||
|
||||
def delete_description(self, event_ids: List[str]) -> None:
|
||||
ids = ",".join(["?" for _ in event_ids])
|
||||
self.db.execute_sql(
|
||||
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
|
||||
)
|
||||
|
||||
def search_thumbnail(
|
||||
self, query: Union[Event, str], event_ids: List[str] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
if query.__class__ == Event:
|
||||
cursor = self.db.execute_sql(
|
||||
"""
|
||||
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
|
||||
""",
|
||||
[query.id],
|
||||
)
|
||||
|
||||
row = cursor.fetchone() if cursor else None
|
||||
|
||||
if row:
|
||||
query_embedding = deserialize(
|
||||
row[0]
|
||||
) # Deserialize the thumbnail embedding
|
||||
else:
|
||||
# If no embedding found, generate it and return it
|
||||
thumbnail = base64.b64decode(query.thumbnail)
|
||||
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
|
||||
else:
|
||||
query_embedding = self.text_embedding([query])[0]
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
id,
|
||||
distance
|
||||
FROM vec_thumbnails
|
||||
WHERE thumbnail_embedding MATCH ?
|
||||
AND k = 100
|
||||
"""
|
||||
|
||||
# Add the IN clause if event_ids is provided and not empty
|
||||
# this is the only filter supported by sqlite-vec as of 0.1.3
|
||||
# but it seems to be broken in this version
|
||||
if event_ids:
|
||||
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
|
||||
|
||||
# order by distance DESC is not implemented in this version of sqlite-vec
|
||||
# when it's implemented, we can use cosine similarity
|
||||
sql_query += " ORDER BY distance"
|
||||
|
||||
parameters = (
|
||||
[serialize(query_embedding)] + event_ids
|
||||
if event_ids
|
||||
else [serialize(query_embedding)]
|
||||
)
|
||||
|
||||
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
||||
|
||||
return results
|
||||
|
||||
def search_description(
|
||||
self, query_text: str, event_ids: List[str] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
query_embedding = self.text_embedding([query_text])[0]
|
||||
|
||||
# Prepare the base SQL query
|
||||
sql_query = """
|
||||
SELECT
|
||||
id,
|
||||
distance
|
||||
FROM vec_descriptions
|
||||
WHERE description_embedding MATCH ?
|
||||
AND k = 100
|
||||
"""
|
||||
|
||||
# Add the IN clause if event_ids is provided and not empty
|
||||
# this is the only filter supported by sqlite-vec as of 0.1.3
|
||||
# but it seems to be broken in this version
|
||||
if event_ids:
|
||||
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
|
||||
|
||||
# order by distance DESC is not implemented in this version of sqlite-vec
|
||||
# when it's implemented, we can use cosine similarity
|
||||
sql_query += " ORDER BY distance"
|
||||
|
||||
parameters = (
|
||||
[serialize(query_embedding)] + event_ids
|
||||
if event_ids
|
||||
else [serialize(query_embedding)]
|
||||
)
|
||||
|
||||
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
||||
|
||||
return results
|
||||
|
||||
def reindex(self) -> None:
|
||||
logger.info("Indexing event embeddings...")
|
||||
|
||||
|
||||
@ -8,10 +8,11 @@ import multiprocessing as mp
|
||||
import queue
|
||||
import re
|
||||
import shlex
|
||||
import struct
|
||||
import urllib.parse
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pytz
|
||||
@ -342,3 +343,27 @@ def generate_color_palette(n):
|
||||
colors.append(interpolate(color1, color2, factor))
|
||||
|
||||
return colors
|
||||
|
||||
|
||||
def serialize(vector: Union[list[float], np.ndarray, float]) -> bytes:
|
||||
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
|
||||
if isinstance(vector, np.ndarray):
|
||||
# Convert numpy array to list of floats
|
||||
vector = vector.flatten().tolist()
|
||||
elif isinstance(vector, (float, np.float32, np.float64)):
|
||||
# Handle single float values
|
||||
vector = [vector]
|
||||
elif not isinstance(vector, list):
|
||||
raise TypeError(
|
||||
f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
|
||||
)
|
||||
|
||||
try:
|
||||
return struct.pack("%sf" % len(vector), *vector)
|
||||
except struct.error as e:
|
||||
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
|
||||
|
||||
|
||||
def deserialize(bytes_data: bytes) -> list[float]:
|
||||
"""Deserializes a compact "raw bytes" format into a list of floats"""
|
||||
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user