From 654efe6be19908e3efabf8d18d1bc5ed2bca9f78 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Fri, 4 Oct 2024 13:20:24 -0500 Subject: [PATCH] remove chroma and revamp Embeddings class for sqlite_vec --- frigate/embeddings/embeddings.py | 218 +++++++++++++++++++++---------- 1 file changed, 151 insertions(+), 67 deletions(-) diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 540764c1b..52c429025 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -1,37 +1,20 @@ -"""ChromaDB embeddings database.""" +"""SQLite-vec embeddings database.""" import base64 import io import logging -import sys +import struct import time +from typing import List, Tuple -import numpy as np from PIL import Image from playhouse.shortcuts import model_to_dict +from playhouse.sqliteq import SqliteQueueDatabase from frigate.models import Event -# Squelch posthog logging -logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.CRITICAL) - -# Hot-swap the sqlite3 module for Chroma compatibility -try: - from chromadb import Collection - from chromadb import HttpClient as ChromaClient - from chromadb.config import Settings - - from .functions.clip import ClipEmbedding - from .functions.minilm_l6_v2 import MiniLMEmbedding -except RuntimeError: - __import__("pysqlite3") - sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") - from chromadb import Collection - from chromadb import HttpClient as ChromaClient - from chromadb.config import Settings - - from .functions.clip import ClipEmbedding - from .functions.minilm_l6_v2 import MiniLMEmbedding +from .functions.clip import ClipEmbedding +from .functions.minilm_l6_v2 import MiniLMEmbedding logger = logging.getLogger(__name__) @@ -67,34 +50,158 @@ def get_metadata(event: Event) -> dict: ) +def serialize(vector: List[float]) -> bytes: + """Serializes a list of floats into a compact "raw bytes" format""" + return struct.pack("%sf" % len(vector), *vector) + + +def deserialize(bytes_data: bytes) -> List[float]: + """Deserializes a compact "raw bytes" format into a list of floats""" + return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data)) + + class Embeddings: - """ChromaDB embeddings database.""" + """SQLite-vec embeddings database.""" - def __init__(self) -> None: - self.client: ChromaClient = ChromaClient( - host="127.0.0.1", - settings=Settings(anonymized_telemetry=False), + def __init__(self, db: SqliteQueueDatabase) -> None: + self.conn = db.connection() # Store the database connection instance + + # create tables if they don't exist + self._create_tables() + + self.clip_embedding = ClipEmbedding(model="ViT-B/32") + self.minilm_embedding = MiniLMEmbedding( + preferred_providers=["CPUExecutionProvider"], ) - @property - def thumbnail(self) -> Collection: - return self.client.get_or_create_collection( - name="event_thumbnail", embedding_function=ClipEmbedding() + def _create_tables(self): + # Create vec0 virtual table for thumbnail embeddings + self.conn.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( + id TEXT PRIMARY KEY, + thumbnail_embedding FLOAT[512] + ); + """) + + # Create vec0 virtual table for description embeddings + self.conn.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0( + id TEXT PRIMARY KEY, + description_embedding FLOAT[384] + ); + """) + + def upsert_thumbnail(self, event_id: str, thumbnail: bytes): + # Convert thumbnail bytes to PIL Image + image = Image.open(io.BytesIO(thumbnail)).convert("RGB") + # Generate embedding using CLIP + embedding = self.clip_embedding([image])[0] + + # sqlite_vec virtual tables don't support upsert, check if event_id exists + cursor = self.conn.execute( + "SELECT 1 FROM vec_thumbnails WHERE id = ?", (event_id,) + ) + row = cursor.fetchone() + + if row is None: + # Insert if the event_id does not exist + self.conn.execute( + "INSERT INTO vec_thumbnails(id, thumbnail_embedding) VALUES(?, ?)", + [event_id, serialize(embedding)], + ) + else: + # Update if the event_id already exists + self.conn.execute( + "UPDATE vec_thumbnails SET thumbnail_embedding = ? WHERE id = ?", + [serialize(embedding), event_id], + ) + + def upsert_description(self, event_id: str, description: str): + # Generate embedding using MiniLM + embedding = self.minilm_embedding([description])[0] + + # sqlite_vec virtual tables don't support upsert, check if event_id exists + cursor = self.conn.execute( + "SELECT 1 FROM vec_descriptions WHERE id = ?", (event_id,) + ) + row = cursor.fetchone() + + if row is None: + # Insert if the event_id does not exist + self.conn.execute( + "INSERT INTO vec_descriptions(id, description_embedding) VALUES(?, ?)", + [event_id, serialize(embedding)], + ) + else: + # Update if the event_id already exists + self.conn.execute( + "UPDATE vec_descriptions SET description_embedding = ? WHERE id = ?", + [serialize(embedding), event_id], + ) + + def delete_thumbnail(self, event_ids: List[str]) -> None: + ids = ", ".join("?" for _ in event_ids) + + self.conn.execute( + f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", tuple(event_ids) ) - @property - def description(self) -> Collection: - return self.client.get_or_create_collection( - name="event_description", - embedding_function=MiniLMEmbedding( - preferred_providers=["CPUExecutionProvider"] - ), + def delete_description(self, event_ids: List[str]) -> None: + ids = ", ".join("?" for _ in event_ids) + + self.conn.execute( + f"DELETE FROM vec_descriptions WHERE id IN ({ids})", tuple(event_ids) ) + def search_thumbnail(self, event_id: str, limit=10) -> List[Tuple[str, float]]: + # check if it's already embedded + cursor = self.conn.execute( + "SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?", (event_id,) + ) + row = cursor.fetchone() + if row: + query_embedding = deserialize(row[0]) + else: + # If not embedded, fetch the thumbnail from the Event table and embed it + event = Event.get_by_id(event_id) + thumbnail = base64.b64decode(event.thumbnail) + image = Image.open(io.BytesIO(thumbnail)).convert("RGB") + query_embedding = self.clip_embedding([image])[0] + self.upsert_thumbnail(event_id, thumbnail) + + cursor = self.conn.execute( + """ + SELECT + vec_thumbnails.id, + distance + FROM vec_thumbnails + WHERE thumbnail_embedding MATCH ? + AND k = ? + ORDER BY distance + """, + [serialize(query_embedding), limit], + ) + return cursor.fetchall() + + def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]: + query_embedding = self.minilm_embedding([query_text])[0] + cursor = self.conn.execute( + """ + SELECT + vec_descriptions.id, + distance + FROM vec_descriptions + WHERE description_embedding MATCH ? + AND k = ? + ORDER BY distance + """, + [serialize(query_embedding), limit], + ) + return cursor.fetchall() + def reindex(self) -> None: """Reindex all event embeddings.""" logger.info("Indexing event embeddings...") - self.client.reset() st = time.time() totals = { @@ -115,37 +222,14 @@ class Embeddings: ) while len(events) > 0: - thumbnails = {"ids": [], "images": [], "metadatas": []} - descriptions = {"ids": [], "documents": [], "metadatas": []} - event: Event for event in events: - metadata = get_metadata(event) thumbnail = base64.b64decode(event.thumbnail) - img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB")) - thumbnails["ids"].append(event.id) - thumbnails["images"].append(img) - thumbnails["metadatas"].append(metadata) + self.upsert_thumbnail(event.id, thumbnail) + totals["thumb"] += 1 if description := event.data.get("description", "").strip(): - descriptions["ids"].append(event.id) - descriptions["documents"].append(description) - descriptions["metadatas"].append(metadata) - - if len(thumbnails["ids"]) > 0: - totals["thumb"] += len(thumbnails["ids"]) - self.thumbnail.upsert( - images=thumbnails["images"], - metadatas=thumbnails["metadatas"], - ids=thumbnails["ids"], - ) - - if len(descriptions["ids"]) > 0: - totals["desc"] += len(descriptions["ids"]) - self.description.upsert( - documents=descriptions["documents"], - metadatas=descriptions["metadatas"], - ids=descriptions["ids"], - ) + totals["desc"] += 1 + self.upsert_description(event.id, description) current_page += 1 events = (