remove chroma and revamp Embeddings class for sqlite_vec

This commit is contained in:
Josh Hawkins 2024-10-04 13:20:24 -05:00
parent 139c8652a9
commit 654efe6be1

View File

@ -1,35 +1,18 @@
"""ChromaDB embeddings database.""" """SQLite-vec embeddings database."""
import base64 import base64
import io import io
import logging import logging
import sys import struct
import time import time
from typing import List, Tuple
import numpy as np
from PIL import Image from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from playhouse.sqliteq import SqliteQueueDatabase
from frigate.models import Event from frigate.models import Event
# Squelch posthog logging
logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.CRITICAL)
# Hot-swap the sqlite3 module for Chroma compatibility
try:
from chromadb import Collection
from chromadb import HttpClient as ChromaClient
from chromadb.config import Settings
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
except RuntimeError:
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from chromadb import Collection
from chromadb import HttpClient as ChromaClient
from chromadb.config import Settings
from .functions.clip import ClipEmbedding from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding from .functions.minilm_l6_v2 import MiniLMEmbedding
@ -67,34 +50,158 @@ def get_metadata(event: Event) -> dict:
) )
def serialize(vector: List[float]) -> bytes:
"""Serializes a list of floats into a compact "raw bytes" format"""
return struct.pack("%sf" % len(vector), *vector)
def deserialize(bytes_data: bytes) -> List[float]:
"""Deserializes a compact "raw bytes" format into a list of floats"""
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
class Embeddings: class Embeddings:
"""ChromaDB embeddings database.""" """SQLite-vec embeddings database."""
def __init__(self) -> None: def __init__(self, db: SqliteQueueDatabase) -> None:
self.client: ChromaClient = ChromaClient( self.conn = db.connection() # Store the database connection instance
host="127.0.0.1",
settings=Settings(anonymized_telemetry=False), # create tables if they don't exist
self._create_tables()
self.clip_embedding = ClipEmbedding(model="ViT-B/32")
self.minilm_embedding = MiniLMEmbedding(
preferred_providers=["CPUExecutionProvider"],
) )
@property def _create_tables(self):
def thumbnail(self) -> Collection: # Create vec0 virtual table for thumbnail embeddings
return self.client.get_or_create_collection( self.conn.execute("""
name="event_thumbnail", embedding_function=ClipEmbedding() CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[512]
);
""")
# Create vec0 virtual table for description embeddings
self.conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
description_embedding FLOAT[384]
);
""")
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
# Convert thumbnail bytes to PIL Image
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
# Generate embedding using CLIP
embedding = self.clip_embedding([image])[0]
# sqlite_vec virtual tables don't support upsert, check if event_id exists
cursor = self.conn.execute(
"SELECT 1 FROM vec_thumbnails WHERE id = ?", (event_id,)
)
row = cursor.fetchone()
if row is None:
# Insert if the event_id does not exist
self.conn.execute(
"INSERT INTO vec_thumbnails(id, thumbnail_embedding) VALUES(?, ?)",
[event_id, serialize(embedding)],
)
else:
# Update if the event_id already exists
self.conn.execute(
"UPDATE vec_thumbnails SET thumbnail_embedding = ? WHERE id = ?",
[serialize(embedding), event_id],
) )
@property def upsert_description(self, event_id: str, description: str):
def description(self) -> Collection: # Generate embedding using MiniLM
return self.client.get_or_create_collection( embedding = self.minilm_embedding([description])[0]
name="event_description",
embedding_function=MiniLMEmbedding( # sqlite_vec virtual tables don't support upsert, check if event_id exists
preferred_providers=["CPUExecutionProvider"] cursor = self.conn.execute(
), "SELECT 1 FROM vec_descriptions WHERE id = ?", (event_id,)
) )
row = cursor.fetchone()
if row is None:
# Insert if the event_id does not exist
self.conn.execute(
"INSERT INTO vec_descriptions(id, description_embedding) VALUES(?, ?)",
[event_id, serialize(embedding)],
)
else:
# Update if the event_id already exists
self.conn.execute(
"UPDATE vec_descriptions SET description_embedding = ? WHERE id = ?",
[serialize(embedding), event_id],
)
def delete_thumbnail(self, event_ids: List[str]) -> None:
ids = ", ".join("?" for _ in event_ids)
self.conn.execute(
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", tuple(event_ids)
)
def delete_description(self, event_ids: List[str]) -> None:
ids = ", ".join("?" for _ in event_ids)
self.conn.execute(
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", tuple(event_ids)
)
def search_thumbnail(self, event_id: str, limit=10) -> List[Tuple[str, float]]:
# check if it's already embedded
cursor = self.conn.execute(
"SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?", (event_id,)
)
row = cursor.fetchone()
if row:
query_embedding = deserialize(row[0])
else:
# If not embedded, fetch the thumbnail from the Event table and embed it
event = Event.get_by_id(event_id)
thumbnail = base64.b64decode(event.thumbnail)
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
query_embedding = self.clip_embedding([image])[0]
self.upsert_thumbnail(event_id, thumbnail)
cursor = self.conn.execute(
"""
SELECT
vec_thumbnails.id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = ?
ORDER BY distance
""",
[serialize(query_embedding), limit],
)
return cursor.fetchall()
def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]:
query_embedding = self.minilm_embedding([query_text])[0]
cursor = self.conn.execute(
"""
SELECT
vec_descriptions.id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = ?
ORDER BY distance
""",
[serialize(query_embedding), limit],
)
return cursor.fetchall()
def reindex(self) -> None: def reindex(self) -> None:
"""Reindex all event embeddings.""" """Reindex all event embeddings."""
logger.info("Indexing event embeddings...") logger.info("Indexing event embeddings...")
self.client.reset()
st = time.time() st = time.time()
totals = { totals = {
@ -115,37 +222,14 @@ class Embeddings:
) )
while len(events) > 0: while len(events) > 0:
thumbnails = {"ids": [], "images": [], "metadatas": []}
descriptions = {"ids": [], "documents": [], "metadatas": []}
event: Event event: Event
for event in events: for event in events:
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail) thumbnail = base64.b64decode(event.thumbnail)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB")) self.upsert_thumbnail(event.id, thumbnail)
thumbnails["ids"].append(event.id) totals["thumb"] += 1
thumbnails["images"].append(img)
thumbnails["metadatas"].append(metadata)
if description := event.data.get("description", "").strip(): if description := event.data.get("description", "").strip():
descriptions["ids"].append(event.id) totals["desc"] += 1
descriptions["documents"].append(description) self.upsert_description(event.id, description)
descriptions["metadatas"].append(metadata)
if len(thumbnails["ids"]) > 0:
totals["thumb"] += len(thumbnails["ids"])
self.thumbnail.upsert(
images=thumbnails["images"],
metadatas=thumbnails["metadatas"],
ids=thumbnails["ids"],
)
if len(descriptions["ids"]) > 0:
totals["desc"] += len(descriptions["ids"])
self.description.upsert(
documents=descriptions["documents"],
metadatas=descriptions["metadatas"],
ids=descriptions["ids"],
)
current_page += 1 current_page += 1
events = ( events = (