Remove unneeded embeddings creations

This commit is contained in:
Nicolas Mowen 2024-10-10 13:26:49 -06:00
parent 8f5e05fc71
commit 157b1771ce
6 changed files with 45 additions and 46 deletions

View File

@ -28,3 +28,26 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
def delete_embeddings_description(self, event_ids: list[str]) -> None: def delete_embeddings_description(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids]) ids = ",".join(["?" for _ in event_ids])
self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids) self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
def drop_embeddings_tables(self) -> None:
self.execute_sql("""
DROP TABLE vec_descriptions;
""")
self.execute_sql("""
DROP TABLE vec_thumbnails;
""")
def create_embeddings_tables(self) -> None:
"""Create vec0 virtual table for embeddings"""
self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[768]
);
""")
self.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
description_embedding FLOAT[768]
);
""")

View File

@ -19,7 +19,6 @@ from frigate.models import Event
from frigate.util.builtin import serialize from frigate.util.builtin import serialize
from frigate.util.services import listen from frigate.util.services import listen
from .embeddings import Embeddings
from .maintainer import EmbeddingMaintainer from .maintainer import EmbeddingMaintainer
from .util import ZScoreNormalization from .util import ZScoreNormalization
@ -57,12 +56,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
models = [Event] models = [Event]
db.bind(models) db.bind(models)
embeddings = Embeddings(config.semantic_search, db) print("creating embedding maintainer")
# Check if we need to re-index events
if config.semantic_search.reindex:
embeddings.reindex()
maintainer = EmbeddingMaintainer( maintainer = EmbeddingMaintainer(
db, db,
config, config,

View File

@ -63,7 +63,7 @@ class Embeddings:
self.requestor = InterProcessRequestor() self.requestor = InterProcessRequestor()
# Create tables if they don't exist # Create tables if they don't exist
self._create_tables() self.db.create_embeddings_tables()
models = [ models = [
"jinaai/jina-clip-v1-text_model_fp16.onnx", "jinaai/jina-clip-v1-text_model_fp16.onnx",
@ -110,31 +110,7 @@ class Embeddings:
model_type="vision", model_type="vision",
device=self.config.device, device=self.config.device,
) )
print("completed embeddings init")
def _create_tables(self):
# Create vec0 virtual table for thumbnail embeddings
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[768]
);
""")
# Create vec0 virtual table for description embeddings
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
description_embedding FLOAT[768]
);
""")
def _drop_tables(self):
self.db.execute_sql("""
DROP TABLE vec_descriptions;
""")
self.db.execute_sql("""
DROP TABLE vec_thumbnails;
""")
def upsert_thumbnail(self, event_id: str, thumbnail: bytes): def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
# Convert thumbnail bytes to PIL Image # Convert thumbnail bytes to PIL Image
@ -153,7 +129,6 @@ class Embeddings:
def upsert_description(self, event_id: str, description: str): def upsert_description(self, event_id: str, description: str):
embedding = self.text_embedding([description])[0] embedding = self.text_embedding([description])[0]
self.db.execute_sql( self.db.execute_sql(
""" """
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
@ -167,9 +142,9 @@ class Embeddings:
def reindex(self) -> None: def reindex(self) -> None:
logger.info("Indexing tracked object embeddings...") logger.info("Indexing tracked object embeddings...")
self._drop_tables() self.db.drop_embeddings_tables()
logger.debug("Dropped embeddings tables.") logger.debug("Dropped embeddings tables.")
self._create_tables() self.db.create_embeddings_tables()
logger.debug("Created embeddings tables.") logger.debug("Created embeddings tables.")
st = time.time() st = time.time()

View File

@ -59,7 +59,10 @@ class GenericONNXEmbedding:
self.feature_extractor = None self.feature_extractor = None
self.session = None self.session = None
if not all(os.path.exists(os.path.join(self.download_path, n)) for n in self.download_urls.keys()): if not all(
os.path.exists(os.path.join(self.download_path, n))
for n in self.download_urls.keys()
):
print("starting model download") print("starting model download")
self.downloader = ModelDownloader( self.downloader = ModelDownloader(
model_name=self.model_name, model_name=self.model_name,

View File

@ -43,7 +43,14 @@ class EmbeddingMaintainer(threading.Thread):
) -> None: ) -> None:
super().__init__(name="embeddings_maintainer") super().__init__(name="embeddings_maintainer")
self.config = config self.config = config
print("creating embeddings")
self.embeddings = Embeddings(config.semantic_search, db) self.embeddings = Embeddings(config.semantic_search, db)
print("finished creating embeddings")
# Check if we need to re-index events
if config.semantic_search.reindex:
self.embeddings.reindex()
self.event_subscriber = EventUpdateSubscriber() self.event_subscriber = EventUpdateSubscriber()
self.event_end_subscriber = EventEndSubscriber() self.event_end_subscriber = EventEndSubscriber()
self.event_metadata_subscriber = EventMetadataSubscriber( self.event_metadata_subscriber = EventMetadataSubscriber(
@ -56,6 +63,7 @@ class EmbeddingMaintainer(threading.Thread):
self.stop_event = stop_event self.stop_event = stop_event
self.tracked_events = {} self.tracked_events = {}
self.genai_client = get_genai_client(config.genai) self.genai_client = get_genai_client(config.genai)
print("finished embed maintainer setup")
def run(self) -> None: def run(self) -> None:
"""Maintain a SQLite-vec database for semantic search.""" """Maintain a SQLite-vec database for semantic search."""

View File

@ -11,7 +11,6 @@ from pathlib import Path
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import CLIPS_DIR from frigate.const import CLIPS_DIR
from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.embeddings.embeddings import Embeddings
from frigate.models import Event, Timeline from frigate.models import Event, Timeline
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,9 +33,6 @@ class EventCleanup(threading.Thread):
self.removed_camera_labels: list[str] = None self.removed_camera_labels: list[str] = None
self.camera_labels: dict[str, dict[str, any]] = {} self.camera_labels: dict[str, dict[str, any]] = {}
if self.config.semantic_search.enabled:
self.embeddings = Embeddings(self.config.semantic_search, self.db)
def get_removed_camera_labels(self) -> list[Event]: def get_removed_camera_labels(self) -> list[Event]:
"""Get a list of distinct labels for removed cameras.""" """Get a list of distinct labels for removed cameras."""
if self.removed_camera_labels is None: if self.removed_camera_labels is None: