From cb5b982b61398a5707b7904f3bcde37ed3447ff7 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Fri, 4 Oct 2024 13:23:06 -0500 Subject: [PATCH] migrate embedding maintainer from chroma to sqlite_vec --- frigate/embeddings/maintainer.py | 44 ++++++++++---------------------- 1 file changed, 14 insertions(+), 30 deletions(-) diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 3c3d956c8..4cb6a3bca 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -1,7 +1,6 @@ -"""Maintain embeddings in Chroma.""" +"""Maintain embeddings in SQLite-vec.""" import base64 -import io import logging import os import threading @@ -11,7 +10,7 @@ from typing import Optional import cv2 import numpy as np from peewee import DoesNotExist -from PIL import Image +from playhouse.sqliteq import SqliteQueueDatabase from frigate.comms.event_metadata_updater import ( EventMetadataSubscriber, @@ -26,7 +25,7 @@ from frigate.genai import get_genai_client from frigate.models import Event from frigate.util.image import SharedMemoryFrameManager, calculate_region -from .embeddings import Embeddings, get_metadata +from .embeddings import Embeddings logger = logging.getLogger(__name__) @@ -36,13 +35,14 @@ class EmbeddingMaintainer(threading.Thread): def __init__( self, + db: SqliteQueueDatabase, config: FrigateConfig, stop_event: MpEvent, ) -> None: threading.Thread.__init__(self) self.name = "embeddings_maintainer" self.config = config - self.embeddings = Embeddings() + self.embeddings = Embeddings(db) self.event_subscriber = EventUpdateSubscriber() self.event_end_subscriber = EventEndSubscriber() self.event_metadata_subscriber = EventMetadataSubscriber( @@ -56,7 +56,7 @@ class EmbeddingMaintainer(threading.Thread): self.genai_client = get_genai_client(config.genai) def run(self) -> None: - """Maintain a Chroma vector database for semantic search.""" + """Maintain a SQLite-vec database for semantic search.""" while not self.stop_event.is_set(): self._process_updates() self._process_finalized() @@ -117,12 +117,11 @@ class EmbeddingMaintainer(threading.Thread): if event.data.get("type") != "object": continue - # Extract valid event metadata - metadata = get_metadata(event) + # Extract valid thumbnail thumbnail = base64.b64decode(event.thumbnail) # Embed the thumbnail - self._embed_thumbnail(event_id, thumbnail, metadata) + self._embed_thumbnail(event_id, thumbnail) if ( camera_config.genai.enabled @@ -183,7 +182,6 @@ class EmbeddingMaintainer(threading.Thread): args=( event, embed_image, - metadata, ), ).start() @@ -219,25 +217,16 @@ class EmbeddingMaintainer(threading.Thread): return None - def _embed_thumbnail(self, event_id: str, thumbnail: bytes, metadata: dict) -> None: + def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None: """Embed the thumbnail for an event.""" + self.embeddings.upsert_thumbnail(event_id, thumbnail) - # Encode the thumbnail - img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB")) - self.embeddings.thumbnail.upsert( - images=[img], - metadatas=[metadata], - ids=[event_id], - ) - - def _embed_description( - self, event: Event, thumbnails: list[bytes], metadata: dict - ) -> None: + def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None: """Embed the description for an event.""" camera_config = self.config.cameras[event.camera] description = self.genai_client.generate_description( - camera_config, thumbnails, metadata + camera_config, thumbnails, event.label ) if not description: @@ -251,11 +240,7 @@ class EmbeddingMaintainer(threading.Thread): ) # Encode the description - self.embeddings.description.upsert( - documents=[description], - metadatas=[metadata], - ids=[event.id], - ) + self.embeddings.upsert_description(event.id, description) logger.debug( "Generated description for %s (%d images): %s", @@ -276,7 +261,6 @@ class EmbeddingMaintainer(threading.Thread): logger.error(f"GenAI not enabled for camera {event.camera}") return - metadata = get_metadata(event) thumbnail = base64.b64decode(event.thumbnail) logger.debug( @@ -315,4 +299,4 @@ class EmbeddingMaintainer(threading.Thread): ) ) - self._embed_description(event, embed_image, metadata) + self._embed_description(event, embed_image)