diff --git a/frigate/api/event.py b/frigate/api/event.py index 9457d0148..7d4802355 100644 --- a/frigate/api/event.py +++ b/frigate/api/event.py @@ -384,8 +384,6 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) context: EmbeddingsContext = request.app.embeddings - logger.info(f"context: {context.embeddings}, conn: {context.embeddings.conn}") - selected_columns = [ Event.id, Event.camera, @@ -503,6 +501,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends()) if "thumbnail" in search_types: thumb_results = context.embeddings.search_thumbnail(query, limit) + logger.info(f"thumb results: {thumb_results}") if "description" in search_types: desc_results = context.embeddings.search_description(query, limit) diff --git a/frigate/app.py b/frigate/app.py index daebae8bb..255e0c1a9 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -9,11 +9,9 @@ from multiprocessing.synchronize import Event as MpEvent from typing import Any, Optional import psutil -import sqlite_vec import uvicorn from peewee_migrate import Router from playhouse.sqlite_ext import SqliteExtDatabase -from playhouse.sqliteq import SqliteQueueDatabase import frigate.util as util from frigate.api.auth import hash_password @@ -41,6 +39,7 @@ from frigate.const import ( RECORD_DIR, ) from frigate.embeddings import EmbeddingsContext, manage_embeddings +from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase from frigate.events.audio import AudioProcessor from frigate.events.cleanup import EventCleanup from frigate.events.external import ExternalEventProcessor @@ -240,7 +239,7 @@ class FrigateApp: def bind_database(self) -> None: """Bind db to the main process.""" # NOTE: all db accessing processes need to be created before the db can be bound to the main process - self.db = SqliteQueueDatabase( + self.db = SqliteVecQueueDatabase( self.config.database.path, pragmas={ "auto_vacuum": "FULL", # Does not defragment database @@ -250,6 +249,7 @@ class FrigateApp: timeout=max( 60, 10 * len([c for c in self.config.cameras.values() if c.enabled]) ), + load_vec_extension=self.config.semantic_search.enabled, ) models = [ Event, @@ -264,14 +264,6 @@ class FrigateApp: ] self.db.bind(models) - if self.config.semantic_search.enabled: - # use existing db connection to load sqlite_vec extension - conn = self.db.connection() - conn.enable_load_extension(True) - sqlite_vec.load(conn) - conn.enable_load_extension(False) - logger.info(f"main connection: {self.db}") - def check_db_data_migrations(self) -> None: # check if vacuum needs to be run if not os.path.exists(f"{CONFIG_DIR}/.exports"): @@ -284,12 +276,9 @@ class FrigateApp: migrate_exports(self.config.ffmpeg, list(self.config.cameras.keys())) def init_embeddings_client(self) -> None: - if not self.config.semantic_search.enabled: - self.embeddings = None - return - - # Create a client for other processes to use - self.embeddings = EmbeddingsContext(self.db) + if self.config.semantic_search.enabled: + # Create a client for other processes to use + self.embeddings = EmbeddingsContext(self.db) def init_external_event_processor(self) -> None: self.external_event_processor = ExternalEventProcessor(self.config) diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index 00b02b1ac..970060eb4 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -9,12 +9,11 @@ import threading from types import FrameType from typing import Optional -import sqlite_vec -from playhouse.sqliteq import SqliteQueueDatabase from setproctitle import setproctitle from frigate.config import FrigateConfig from frigate.const import CONFIG_DIR +from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase from frigate.models import Event from frigate.util.services import listen @@ -43,7 +42,7 @@ def manage_embeddings(config: FrigateConfig) -> None: listen() # Configure Frigate DB - db = SqliteQueueDatabase( + db = SqliteVecQueueDatabase( config.database.path, pragmas={ "auto_vacuum": "FULL", # Does not defragment database @@ -51,15 +50,11 @@ def manage_embeddings(config: FrigateConfig) -> None: "synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous }, timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])), + load_vec_extension=True, ) models = [Event] db.bind(models) - conn = db.connection() - conn.enable_load_extension(True) - sqlite_vec.load(conn) - conn.enable_load_extension(False) - embeddings = Embeddings(db) # Check if we need to re-index events @@ -75,9 +70,9 @@ def manage_embeddings(config: FrigateConfig) -> None: class EmbeddingsContext: - def __init__(self, db: SqliteQueueDatabase): + def __init__(self, db: SqliteVecQueueDatabase): self.db = db - self.embeddings = Embeddings(db) + self.embeddings = Embeddings(self.db) self.thumb_stats = ZScoreNormalization() self.desc_stats = ZScoreNormalization() diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 52c429025..a1f0b9686 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -5,12 +5,12 @@ import io import logging import struct import time -from typing import List, Tuple +from typing import List, Tuple, Union from PIL import Image from playhouse.shortcuts import model_to_dict -from playhouse.sqliteq import SqliteQueueDatabase +from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase from frigate.models import Event from .functions.clip import ClipEmbedding @@ -63,10 +63,10 @@ def deserialize(bytes_data: bytes) -> List[float]: class Embeddings: """SQLite-vec embeddings database.""" - def __init__(self, db: SqliteQueueDatabase) -> None: - self.conn = db.connection() # Store the database connection instance + def __init__(self, db: SqliteVecQueueDatabase) -> None: + self.db = db - # create tables if they don't exist + # Create tables if they don't exist self._create_tables() self.clip_embedding = ClipEmbedding(model="ViT-B/32") @@ -76,7 +76,7 @@ class Embeddings: def _create_tables(self): # Create vec0 virtual table for thumbnail embeddings - self.conn.execute(""" + self.db.execute_sql(""" CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( id TEXT PRIMARY KEY, thumbnail_embedding FLOAT[512] @@ -84,7 +84,7 @@ class Embeddings: """) # Create vec0 virtual table for description embeddings - self.conn.execute(""" + self.db.execute_sql(""" CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0( id TEXT PRIMARY KEY, description_embedding FLOAT[384] @@ -97,79 +97,65 @@ class Embeddings: # Generate embedding using CLIP embedding = self.clip_embedding([image])[0] - # sqlite_vec virtual tables don't support upsert, check if event_id exists - cursor = self.conn.execute( - "SELECT 1 FROM vec_thumbnails WHERE id = ?", (event_id,) + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) + VALUES(?, ?) + """, + (event_id, serialize(embedding)), ) - row = cursor.fetchone() - - if row is None: - # Insert if the event_id does not exist - self.conn.execute( - "INSERT INTO vec_thumbnails(id, thumbnail_embedding) VALUES(?, ?)", - [event_id, serialize(embedding)], - ) - else: - # Update if the event_id already exists - self.conn.execute( - "UPDATE vec_thumbnails SET thumbnail_embedding = ? WHERE id = ?", - [serialize(embedding), event_id], - ) def upsert_description(self, event_id: str, description: str): # Generate embedding using MiniLM embedding = self.minilm_embedding([description])[0] - # sqlite_vec virtual tables don't support upsert, check if event_id exists - cursor = self.conn.execute( - "SELECT 1 FROM vec_descriptions WHERE id = ?", (event_id,) + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) + VALUES(?, ?) + """, + (event_id, serialize(embedding)), ) - row = cursor.fetchone() - - if row is None: - # Insert if the event_id does not exist - self.conn.execute( - "INSERT INTO vec_descriptions(id, description_embedding) VALUES(?, ?)", - [event_id, serialize(embedding)], - ) - else: - # Update if the event_id already exists - self.conn.execute( - "UPDATE vec_descriptions SET description_embedding = ? WHERE id = ?", - [serialize(embedding), event_id], - ) def delete_thumbnail(self, event_ids: List[str]) -> None: - ids = ", ".join("?" for _ in event_ids) - - self.conn.execute( - f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", tuple(event_ids) + ids = ",".join(["?" for _ in event_ids]) + self.db.execute_sql( + f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids ) def delete_description(self, event_ids: List[str]) -> None: - ids = ", ".join("?" for _ in event_ids) - - self.conn.execute( - f"DELETE FROM vec_descriptions WHERE id IN ({ids})", tuple(event_ids) + ids = ",".join(["?" for _ in event_ids]) + self.db.execute_sql( + f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids ) - def search_thumbnail(self, event_id: str, limit=10) -> List[Tuple[str, float]]: - # check if it's already embedded - cursor = self.conn.execute( - "SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?", (event_id,) - ) - row = cursor.fetchone() - if row: - query_embedding = deserialize(row[0]) - else: - # If not embedded, fetch the thumbnail from the Event table and embed it - event = Event.get_by_id(event_id) - thumbnail = base64.b64decode(event.thumbnail) - image = Image.open(io.BytesIO(thumbnail)).convert("RGB") - query_embedding = self.clip_embedding([image])[0] - self.upsert_thumbnail(event_id, thumbnail) + def search_thumbnail( + self, query: Union[Event, str], limit=10 + ) -> List[Tuple[str, float]]: + if isinstance(query, Event): + cursor = self.db.execute_sql( + """ + SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ? + """, + [query.id], + ) - cursor = self.conn.execute( + row = cursor.fetchone() if cursor else None + + if row: + query_embedding = deserialize( + row[0] + ) # Deserialize the thumbnail embedding + else: + # If no embedding found, generate it + thumbnail = base64.b64decode(query.thumbnail) + self.upsert_thumbnail(query.id, thumbnail) + image = Image.open(io.BytesIO(thumbnail)).convert("RGB") + query = self.clip_embedding([image])[0] + + query_embedding = self.clip_embedding([query])[0] + + results = self.db.execute_sql( """ SELECT vec_thumbnails.id, @@ -178,14 +164,15 @@ class Embeddings: WHERE thumbnail_embedding MATCH ? AND k = ? ORDER BY distance - """, - [serialize(query_embedding), limit], - ) - return cursor.fetchall() + """, + (serialize(query_embedding), limit), + ).fetchall() + + return results def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]: query_embedding = self.minilm_embedding([query_text])[0] - cursor = self.conn.execute( + results = self.db.execute_sql( """ SELECT vec_descriptions.id, @@ -194,13 +181,13 @@ class Embeddings: WHERE description_embedding MATCH ? AND k = ? ORDER BY distance - """, - [serialize(query_embedding), limit], - ) - return cursor.fetchall() + """, + (serialize(query_embedding), limit), + ).fetchall() + + return results def reindex(self) -> None: - """Reindex all event embeddings.""" logger.info("Indexing event embeddings...") st = time.time() diff --git a/frigate/embeddings/sqlitevecq.py b/frigate/embeddings/sqlitevecq.py new file mode 100644 index 000000000..205364010 --- /dev/null +++ b/frigate/embeddings/sqlitevecq.py @@ -0,0 +1,19 @@ +import sqlite_vec +from playhouse.sqliteq import SqliteQueueDatabase + + +class SqliteVecQueueDatabase(SqliteQueueDatabase): + def __init__(self, *args, load_vec_extension=False, **kwargs): + super().__init__(*args, **kwargs) + self.load_vec_extension = load_vec_extension + + def _connect(self, *args, **kwargs): + conn = super()._connect(*args, **kwargs) + if self.load_vec_extension: + self._load_vec_extension(conn) + return conn + + def _load_vec_extension(self, conn): + conn.enable_load_extension(True) + sqlite_vec.load(conn) + conn.enable_load_extension(False)