remove chroma and revamp Embeddings class for sqlite_vec

This commit is contained in:
Josh Hawkins 2024-10-04 13:20:24 -05:00
parent 139c8652a9
commit 654efe6be1

View File

@ -1,37 +1,20 @@
"""ChromaDB embeddings database."""
"""SQLite-vec embeddings database."""
import base64
import io
import logging
import sys
import struct
import time
from typing import List, Tuple
import numpy as np
from PIL import Image
from playhouse.shortcuts import model_to_dict
from playhouse.sqliteq import SqliteQueueDatabase
from frigate.models import Event
# Squelch posthog logging
logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.CRITICAL)
# Hot-swap the sqlite3 module for Chroma compatibility
try:
from chromadb import Collection
from chromadb import HttpClient as ChromaClient
from chromadb.config import Settings
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
except RuntimeError:
__import__("pysqlite3")
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from chromadb import Collection
from chromadb import HttpClient as ChromaClient
from chromadb.config import Settings
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
logger = logging.getLogger(__name__)
@ -67,34 +50,158 @@ def get_metadata(event: Event) -> dict:
)
def serialize(vector: List[float]) -> bytes:
"""Serializes a list of floats into a compact "raw bytes" format"""
return struct.pack("%sf" % len(vector), *vector)
def deserialize(bytes_data: bytes) -> List[float]:
"""Deserializes a compact "raw bytes" format into a list of floats"""
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
class Embeddings:
"""ChromaDB embeddings database."""
"""SQLite-vec embeddings database."""
def __init__(self) -> None:
self.client: ChromaClient = ChromaClient(
host="127.0.0.1",
settings=Settings(anonymized_telemetry=False),
def __init__(self, db: SqliteQueueDatabase) -> None:
self.conn = db.connection() # Store the database connection instance
# create tables if they don't exist
self._create_tables()
self.clip_embedding = ClipEmbedding(model="ViT-B/32")
self.minilm_embedding = MiniLMEmbedding(
preferred_providers=["CPUExecutionProvider"],
)
@property
def thumbnail(self) -> Collection:
return self.client.get_or_create_collection(
name="event_thumbnail", embedding_function=ClipEmbedding()
def _create_tables(self):
# Create vec0 virtual table for thumbnail embeddings
self.conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[512]
);
""")
# Create vec0 virtual table for description embeddings
self.conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
description_embedding FLOAT[384]
);
""")
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
# Convert thumbnail bytes to PIL Image
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
# Generate embedding using CLIP
embedding = self.clip_embedding([image])[0]
# sqlite_vec virtual tables don't support upsert, check if event_id exists
cursor = self.conn.execute(
"SELECT 1 FROM vec_thumbnails WHERE id = ?", (event_id,)
)
row = cursor.fetchone()
if row is None:
# Insert if the event_id does not exist
self.conn.execute(
"INSERT INTO vec_thumbnails(id, thumbnail_embedding) VALUES(?, ?)",
[event_id, serialize(embedding)],
)
else:
# Update if the event_id already exists
self.conn.execute(
"UPDATE vec_thumbnails SET thumbnail_embedding = ? WHERE id = ?",
[serialize(embedding), event_id],
)
@property
def description(self) -> Collection:
return self.client.get_or_create_collection(
name="event_description",
embedding_function=MiniLMEmbedding(
preferred_providers=["CPUExecutionProvider"]
),
def upsert_description(self, event_id: str, description: str):
# Generate embedding using MiniLM
embedding = self.minilm_embedding([description])[0]
# sqlite_vec virtual tables don't support upsert, check if event_id exists
cursor = self.conn.execute(
"SELECT 1 FROM vec_descriptions WHERE id = ?", (event_id,)
)
row = cursor.fetchone()
if row is None:
# Insert if the event_id does not exist
self.conn.execute(
"INSERT INTO vec_descriptions(id, description_embedding) VALUES(?, ?)",
[event_id, serialize(embedding)],
)
else:
# Update if the event_id already exists
self.conn.execute(
"UPDATE vec_descriptions SET description_embedding = ? WHERE id = ?",
[serialize(embedding), event_id],
)
def delete_thumbnail(self, event_ids: List[str]) -> None:
ids = ", ".join("?" for _ in event_ids)
self.conn.execute(
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", tuple(event_ids)
)
def delete_description(self, event_ids: List[str]) -> None:
ids = ", ".join("?" for _ in event_ids)
self.conn.execute(
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", tuple(event_ids)
)
def search_thumbnail(self, event_id: str, limit=10) -> List[Tuple[str, float]]:
# check if it's already embedded
cursor = self.conn.execute(
"SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?", (event_id,)
)
row = cursor.fetchone()
if row:
query_embedding = deserialize(row[0])
else:
# If not embedded, fetch the thumbnail from the Event table and embed it
event = Event.get_by_id(event_id)
thumbnail = base64.b64decode(event.thumbnail)
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
query_embedding = self.clip_embedding([image])[0]
self.upsert_thumbnail(event_id, thumbnail)
cursor = self.conn.execute(
"""
SELECT
vec_thumbnails.id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = ?
ORDER BY distance
""",
[serialize(query_embedding), limit],
)
return cursor.fetchall()
def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]:
query_embedding = self.minilm_embedding([query_text])[0]
cursor = self.conn.execute(
"""
SELECT
vec_descriptions.id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = ?
ORDER BY distance
""",
[serialize(query_embedding), limit],
)
return cursor.fetchall()
def reindex(self) -> None:
"""Reindex all event embeddings."""
logger.info("Indexing event embeddings...")
self.client.reset()
st = time.time()
totals = {
@ -115,37 +222,14 @@ class Embeddings:
)
while len(events) > 0:
thumbnails = {"ids": [], "images": [], "metadatas": []}
descriptions = {"ids": [], "documents": [], "metadatas": []}
event: Event
for event in events:
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail)
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
thumbnails["ids"].append(event.id)
thumbnails["images"].append(img)
thumbnails["metadatas"].append(metadata)
self.upsert_thumbnail(event.id, thumbnail)
totals["thumb"] += 1
if description := event.data.get("description", "").strip():
descriptions["ids"].append(event.id)
descriptions["documents"].append(description)
descriptions["metadatas"].append(metadata)
if len(thumbnails["ids"]) > 0:
totals["thumb"] += len(thumbnails["ids"])
self.thumbnail.upsert(
images=thumbnails["images"],
metadatas=thumbnails["metadatas"],
ids=thumbnails["ids"],
)
if len(descriptions["ids"]) > 0:
totals["desc"] += len(descriptions["ids"])
self.description.upsert(
documents=descriptions["documents"],
metadatas=descriptions["metadatas"],
ids=descriptions["ids"],
)
totals["desc"] += 1
self.upsert_description(event.id, description)
current_page += 1
events = (