mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 07:35:27 +03:00
remove chroma and revamp Embeddings class for sqlite_vec
This commit is contained in:
parent
139c8652a9
commit
654efe6be1
@ -1,37 +1,20 @@
|
|||||||
"""ChromaDB embeddings database."""
|
"""SQLite-vec embeddings database."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import struct
|
||||||
import time
|
import time
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from playhouse.shortcuts import model_to_dict
|
from playhouse.shortcuts import model_to_dict
|
||||||
|
from playhouse.sqliteq import SqliteQueueDatabase
|
||||||
|
|
||||||
from frigate.models import Event
|
from frigate.models import Event
|
||||||
|
|
||||||
# Squelch posthog logging
|
from .functions.clip import ClipEmbedding
|
||||||
logging.getLogger("chromadb.telemetry.product.posthog").setLevel(logging.CRITICAL)
|
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
||||||
|
|
||||||
# Hot-swap the sqlite3 module for Chroma compatibility
|
|
||||||
try:
|
|
||||||
from chromadb import Collection
|
|
||||||
from chromadb import HttpClient as ChromaClient
|
|
||||||
from chromadb.config import Settings
|
|
||||||
|
|
||||||
from .functions.clip import ClipEmbedding
|
|
||||||
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
|
||||||
except RuntimeError:
|
|
||||||
__import__("pysqlite3")
|
|
||||||
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
|
||||||
from chromadb import Collection
|
|
||||||
from chromadb import HttpClient as ChromaClient
|
|
||||||
from chromadb.config import Settings
|
|
||||||
|
|
||||||
from .functions.clip import ClipEmbedding
|
|
||||||
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -67,34 +50,158 @@ def get_metadata(event: Event) -> dict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize(vector: List[float]) -> bytes:
|
||||||
|
"""Serializes a list of floats into a compact "raw bytes" format"""
|
||||||
|
return struct.pack("%sf" % len(vector), *vector)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize(bytes_data: bytes) -> List[float]:
|
||||||
|
"""Deserializes a compact "raw bytes" format into a list of floats"""
|
||||||
|
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
|
||||||
|
|
||||||
|
|
||||||
class Embeddings:
|
class Embeddings:
|
||||||
"""ChromaDB embeddings database."""
|
"""SQLite-vec embeddings database."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, db: SqliteQueueDatabase) -> None:
|
||||||
self.client: ChromaClient = ChromaClient(
|
self.conn = db.connection() # Store the database connection instance
|
||||||
host="127.0.0.1",
|
|
||||||
settings=Settings(anonymized_telemetry=False),
|
# create tables if they don't exist
|
||||||
|
self._create_tables()
|
||||||
|
|
||||||
|
self.clip_embedding = ClipEmbedding(model="ViT-B/32")
|
||||||
|
self.minilm_embedding = MiniLMEmbedding(
|
||||||
|
preferred_providers=["CPUExecutionProvider"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
def _create_tables(self):
|
||||||
def thumbnail(self) -> Collection:
|
# Create vec0 virtual table for thumbnail embeddings
|
||||||
return self.client.get_or_create_collection(
|
self.conn.execute("""
|
||||||
name="event_thumbnail", embedding_function=ClipEmbedding()
|
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
thumbnail_embedding FLOAT[512]
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create vec0 virtual table for description embeddings
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
description_embedding FLOAT[384]
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
|
||||||
|
# Convert thumbnail bytes to PIL Image
|
||||||
|
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
||||||
|
# Generate embedding using CLIP
|
||||||
|
embedding = self.clip_embedding([image])[0]
|
||||||
|
|
||||||
|
# sqlite_vec virtual tables don't support upsert, check if event_id exists
|
||||||
|
cursor = self.conn.execute(
|
||||||
|
"SELECT 1 FROM vec_thumbnails WHERE id = ?", (event_id,)
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
# Insert if the event_id does not exist
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO vec_thumbnails(id, thumbnail_embedding) VALUES(?, ?)",
|
||||||
|
[event_id, serialize(embedding)],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Update if the event_id already exists
|
||||||
|
self.conn.execute(
|
||||||
|
"UPDATE vec_thumbnails SET thumbnail_embedding = ? WHERE id = ?",
|
||||||
|
[serialize(embedding), event_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
def upsert_description(self, event_id: str, description: str):
|
||||||
|
# Generate embedding using MiniLM
|
||||||
|
embedding = self.minilm_embedding([description])[0]
|
||||||
|
|
||||||
|
# sqlite_vec virtual tables don't support upsert, check if event_id exists
|
||||||
|
cursor = self.conn.execute(
|
||||||
|
"SELECT 1 FROM vec_descriptions WHERE id = ?", (event_id,)
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
# Insert if the event_id does not exist
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO vec_descriptions(id, description_embedding) VALUES(?, ?)",
|
||||||
|
[event_id, serialize(embedding)],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Update if the event_id already exists
|
||||||
|
self.conn.execute(
|
||||||
|
"UPDATE vec_descriptions SET description_embedding = ? WHERE id = ?",
|
||||||
|
[serialize(embedding), event_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_thumbnail(self, event_ids: List[str]) -> None:
|
||||||
|
ids = ", ".join("?" for _ in event_ids)
|
||||||
|
|
||||||
|
self.conn.execute(
|
||||||
|
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", tuple(event_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
def delete_description(self, event_ids: List[str]) -> None:
|
||||||
def description(self) -> Collection:
|
ids = ", ".join("?" for _ in event_ids)
|
||||||
return self.client.get_or_create_collection(
|
|
||||||
name="event_description",
|
self.conn.execute(
|
||||||
embedding_function=MiniLMEmbedding(
|
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", tuple(event_ids)
|
||||||
preferred_providers=["CPUExecutionProvider"]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def search_thumbnail(self, event_id: str, limit=10) -> List[Tuple[str, float]]:
|
||||||
|
# check if it's already embedded
|
||||||
|
cursor = self.conn.execute(
|
||||||
|
"SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?", (event_id,)
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
query_embedding = deserialize(row[0])
|
||||||
|
else:
|
||||||
|
# If not embedded, fetch the thumbnail from the Event table and embed it
|
||||||
|
event = Event.get_by_id(event_id)
|
||||||
|
thumbnail = base64.b64decode(event.thumbnail)
|
||||||
|
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
||||||
|
query_embedding = self.clip_embedding([image])[0]
|
||||||
|
self.upsert_thumbnail(event_id, thumbnail)
|
||||||
|
|
||||||
|
cursor = self.conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
vec_thumbnails.id,
|
||||||
|
distance
|
||||||
|
FROM vec_thumbnails
|
||||||
|
WHERE thumbnail_embedding MATCH ?
|
||||||
|
AND k = ?
|
||||||
|
ORDER BY distance
|
||||||
|
""",
|
||||||
|
[serialize(query_embedding), limit],
|
||||||
|
)
|
||||||
|
return cursor.fetchall()
|
||||||
|
|
||||||
|
def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]:
|
||||||
|
query_embedding = self.minilm_embedding([query_text])[0]
|
||||||
|
cursor = self.conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT
|
||||||
|
vec_descriptions.id,
|
||||||
|
distance
|
||||||
|
FROM vec_descriptions
|
||||||
|
WHERE description_embedding MATCH ?
|
||||||
|
AND k = ?
|
||||||
|
ORDER BY distance
|
||||||
|
""",
|
||||||
|
[serialize(query_embedding), limit],
|
||||||
|
)
|
||||||
|
return cursor.fetchall()
|
||||||
|
|
||||||
def reindex(self) -> None:
|
def reindex(self) -> None:
|
||||||
"""Reindex all event embeddings."""
|
"""Reindex all event embeddings."""
|
||||||
logger.info("Indexing event embeddings...")
|
logger.info("Indexing event embeddings...")
|
||||||
self.client.reset()
|
|
||||||
|
|
||||||
st = time.time()
|
st = time.time()
|
||||||
totals = {
|
totals = {
|
||||||
@ -115,37 +222,14 @@ class Embeddings:
|
|||||||
)
|
)
|
||||||
|
|
||||||
while len(events) > 0:
|
while len(events) > 0:
|
||||||
thumbnails = {"ids": [], "images": [], "metadatas": []}
|
|
||||||
descriptions = {"ids": [], "documents": [], "metadatas": []}
|
|
||||||
|
|
||||||
event: Event
|
event: Event
|
||||||
for event in events:
|
for event in events:
|
||||||
metadata = get_metadata(event)
|
|
||||||
thumbnail = base64.b64decode(event.thumbnail)
|
thumbnail = base64.b64decode(event.thumbnail)
|
||||||
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
|
self.upsert_thumbnail(event.id, thumbnail)
|
||||||
thumbnails["ids"].append(event.id)
|
totals["thumb"] += 1
|
||||||
thumbnails["images"].append(img)
|
|
||||||
thumbnails["metadatas"].append(metadata)
|
|
||||||
if description := event.data.get("description", "").strip():
|
if description := event.data.get("description", "").strip():
|
||||||
descriptions["ids"].append(event.id)
|
totals["desc"] += 1
|
||||||
descriptions["documents"].append(description)
|
self.upsert_description(event.id, description)
|
||||||
descriptions["metadatas"].append(metadata)
|
|
||||||
|
|
||||||
if len(thumbnails["ids"]) > 0:
|
|
||||||
totals["thumb"] += len(thumbnails["ids"])
|
|
||||||
self.thumbnail.upsert(
|
|
||||||
images=thumbnails["images"],
|
|
||||||
metadatas=thumbnails["metadatas"],
|
|
||||||
ids=thumbnails["ids"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(descriptions["ids"]) > 0:
|
|
||||||
totals["desc"] += len(descriptions["ids"])
|
|
||||||
self.description.upsert(
|
|
||||||
documents=descriptions["documents"],
|
|
||||||
metadatas=descriptions["metadatas"],
|
|
||||||
ids=descriptions["ids"],
|
|
||||||
)
|
|
||||||
|
|
||||||
current_page += 1
|
current_page += 1
|
||||||
events = (
|
events = (
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user