migrate embedding maintainer from chroma to sqlite_vec

This commit is contained in:
Josh Hawkins 2024-10-04 13:23:06 -05:00
parent f5aceece73
commit cb5b982b61

View File

@ -1,7 +1,6 @@
"""Maintain embeddings in Chroma."""
"""Maintain embeddings in SQLite-vec."""
import base64
import io
import logging
import os
import threading
@ -11,7 +10,7 @@ from typing import Optional
import cv2
import numpy as np
from peewee import DoesNotExist
from PIL import Image
from playhouse.sqliteq import SqliteQueueDatabase
from frigate.comms.event_metadata_updater import (
EventMetadataSubscriber,
@ -26,7 +25,7 @@ from frigate.genai import get_genai_client
from frigate.models import Event
from frigate.util.image import SharedMemoryFrameManager, calculate_region
from .embeddings import Embeddings, get_metadata
from .embeddings import Embeddings
logger = logging.getLogger(__name__)
@ -36,13 +35,14 @@ class EmbeddingMaintainer(threading.Thread):
def __init__(
self,
db: SqliteQueueDatabase,
config: FrigateConfig,
stop_event: MpEvent,
) -> None:
threading.Thread.__init__(self)
self.name = "embeddings_maintainer"
self.config = config
self.embeddings = Embeddings()
self.embeddings = Embeddings(db)
self.event_subscriber = EventUpdateSubscriber()
self.event_end_subscriber = EventEndSubscriber()
self.event_metadata_subscriber = EventMetadataSubscriber(
@ -56,7 +56,7 @@ class EmbeddingMaintainer(threading.Thread):
self.genai_client = get_genai_client(config.genai)
def run(self) -> None:
"""Maintain a Chroma vector database for semantic search."""
"""Maintain a SQLite-vec database for semantic search."""
while not self.stop_event.is_set():
self._process_updates()
self._process_finalized()
@ -117,12 +117,11 @@ class EmbeddingMaintainer(threading.Thread):
if event.data.get("type") != "object":
continue
# Extract valid event metadata
metadata = get_metadata(event)
# Extract valid thumbnail
thumbnail = base64.b64decode(event.thumbnail)
# Embed the thumbnail
self._embed_thumbnail(event_id, thumbnail, metadata)
self._embed_thumbnail(event_id, thumbnail)
if (
camera_config.genai.enabled
@ -183,7 +182,6 @@ class EmbeddingMaintainer(threading.Thread):
args=(
event,
embed_image,
metadata,
),
).start()
@ -219,25 +217,16 @@ class EmbeddingMaintainer(threading.Thread):
return None
def _embed_thumbnail(self, event_id: str, thumbnail: bytes, metadata: dict) -> None:
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
"""Embed the thumbnail for an event."""
self.embeddings.upsert_thumbnail(event_id, thumbnail)
# Encode the thumbnail
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
self.embeddings.thumbnail.upsert(
images=[img],
metadatas=[metadata],
ids=[event_id],
)
def _embed_description(
self, event: Event, thumbnails: list[bytes], metadata: dict
) -> None:
def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
"""Embed the description for an event."""
camera_config = self.config.cameras[event.camera]
description = self.genai_client.generate_description(
camera_config, thumbnails, metadata
camera_config, thumbnails, event.label
)
if not description:
@ -251,11 +240,7 @@ class EmbeddingMaintainer(threading.Thread):
)
# Encode the description
self.embeddings.description.upsert(
documents=[description],
metadatas=[metadata],
ids=[event.id],
)
self.embeddings.upsert_description(event.id, description)
logger.debug(
"Generated description for %s (%d images): %s",
@ -276,7 +261,6 @@ class EmbeddingMaintainer(threading.Thread):
logger.error(f"GenAI not enabled for camera {event.camera}")
return
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail)
logger.debug(
@ -315,4 +299,4 @@ class EmbeddingMaintainer(threading.Thread):
)
)
self._embed_description(event, embed_image, metadata)
self._embed_description(event, embed_image)