migrate embedding maintainer from chroma to sqlite_vec

This commit is contained in:
Josh Hawkins 2024-10-04 13:23:06 -05:00
parent f5aceece73
commit cb5b982b61

View File

@ -1,7 +1,6 @@
"""Maintain embeddings in Chroma.""" """Maintain embeddings in SQLite-vec."""
import base64 import base64
import io
import logging import logging
import os import os
import threading import threading
@ -11,7 +10,7 @@ from typing import Optional
import cv2 import cv2
import numpy as np import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
from PIL import Image from playhouse.sqliteq import SqliteQueueDatabase
from frigate.comms.event_metadata_updater import ( from frigate.comms.event_metadata_updater import (
EventMetadataSubscriber, EventMetadataSubscriber,
@ -26,7 +25,7 @@ from frigate.genai import get_genai_client
from frigate.models import Event from frigate.models import Event
from frigate.util.image import SharedMemoryFrameManager, calculate_region from frigate.util.image import SharedMemoryFrameManager, calculate_region
from .embeddings import Embeddings, get_metadata from .embeddings import Embeddings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,13 +35,14 @@ class EmbeddingMaintainer(threading.Thread):
def __init__( def __init__(
self, self,
db: SqliteQueueDatabase,
config: FrigateConfig, config: FrigateConfig,
stop_event: MpEvent, stop_event: MpEvent,
) -> None: ) -> None:
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.name = "embeddings_maintainer" self.name = "embeddings_maintainer"
self.config = config self.config = config
self.embeddings = Embeddings() self.embeddings = Embeddings(db)
self.event_subscriber = EventUpdateSubscriber() self.event_subscriber = EventUpdateSubscriber()
self.event_end_subscriber = EventEndSubscriber() self.event_end_subscriber = EventEndSubscriber()
self.event_metadata_subscriber = EventMetadataSubscriber( self.event_metadata_subscriber = EventMetadataSubscriber(
@ -56,7 +56,7 @@ class EmbeddingMaintainer(threading.Thread):
self.genai_client = get_genai_client(config.genai) self.genai_client = get_genai_client(config.genai)
def run(self) -> None: def run(self) -> None:
"""Maintain a Chroma vector database for semantic search.""" """Maintain a SQLite-vec database for semantic search."""
while not self.stop_event.is_set(): while not self.stop_event.is_set():
self._process_updates() self._process_updates()
self._process_finalized() self._process_finalized()
@ -117,12 +117,11 @@ class EmbeddingMaintainer(threading.Thread):
if event.data.get("type") != "object": if event.data.get("type") != "object":
continue continue
# Extract valid event metadata # Extract valid thumbnail
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail) thumbnail = base64.b64decode(event.thumbnail)
# Embed the thumbnail # Embed the thumbnail
self._embed_thumbnail(event_id, thumbnail, metadata) self._embed_thumbnail(event_id, thumbnail)
if ( if (
camera_config.genai.enabled camera_config.genai.enabled
@ -183,7 +182,6 @@ class EmbeddingMaintainer(threading.Thread):
args=( args=(
event, event,
embed_image, embed_image,
metadata,
), ),
).start() ).start()
@ -219,25 +217,16 @@ class EmbeddingMaintainer(threading.Thread):
return None return None
def _embed_thumbnail(self, event_id: str, thumbnail: bytes, metadata: dict) -> None: def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
"""Embed the thumbnail for an event.""" """Embed the thumbnail for an event."""
self.embeddings.upsert_thumbnail(event_id, thumbnail)
# Encode the thumbnail def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
self.embeddings.thumbnail.upsert(
images=[img],
metadatas=[metadata],
ids=[event_id],
)
def _embed_description(
self, event: Event, thumbnails: list[bytes], metadata: dict
) -> None:
"""Embed the description for an event.""" """Embed the description for an event."""
camera_config = self.config.cameras[event.camera] camera_config = self.config.cameras[event.camera]
description = self.genai_client.generate_description( description = self.genai_client.generate_description(
camera_config, thumbnails, metadata camera_config, thumbnails, event.label
) )
if not description: if not description:
@ -251,11 +240,7 @@ class EmbeddingMaintainer(threading.Thread):
) )
# Encode the description # Encode the description
self.embeddings.description.upsert( self.embeddings.upsert_description(event.id, description)
documents=[description],
metadatas=[metadata],
ids=[event.id],
)
logger.debug( logger.debug(
"Generated description for %s (%d images): %s", "Generated description for %s (%d images): %s",
@ -276,7 +261,6 @@ class EmbeddingMaintainer(threading.Thread):
logger.error(f"GenAI not enabled for camera {event.camera}") logger.error(f"GenAI not enabled for camera {event.camera}")
return return
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail) thumbnail = base64.b64decode(event.thumbnail)
logger.debug( logger.debug(
@ -315,4 +299,4 @@ class EmbeddingMaintainer(threading.Thread):
) )
) )
self._embed_description(event, embed_image, metadata) self._embed_description(event, embed_image)