extend the SqliteQueueDatabase class and use peewee db.execute_sql

This commit is contained in:
Josh Hawkins 2024-10-04 15:40:53 -05:00
parent 5181ea7b3d
commit 1b7f469daf
5 changed files with 93 additions and 104 deletions

View File

@ -384,8 +384,6 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
context: EmbeddingsContext = request.app.embeddings context: EmbeddingsContext = request.app.embeddings
logger.info(f"context: {context.embeddings}, conn: {context.embeddings.conn}")
selected_columns = [ selected_columns = [
Event.id, Event.id,
Event.camera, Event.camera,
@ -503,6 +501,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
if "thumbnail" in search_types: if "thumbnail" in search_types:
thumb_results = context.embeddings.search_thumbnail(query, limit) thumb_results = context.embeddings.search_thumbnail(query, limit)
logger.info(f"thumb results: {thumb_results}")
if "description" in search_types: if "description" in search_types:
desc_results = context.embeddings.search_description(query, limit) desc_results = context.embeddings.search_description(query, limit)

View File

@ -9,11 +9,9 @@ from multiprocessing.synchronize import Event as MpEvent
from typing import Any, Optional from typing import Any, Optional
import psutil import psutil
import sqlite_vec
import uvicorn import uvicorn
from peewee_migrate import Router from peewee_migrate import Router
from playhouse.sqlite_ext import SqliteExtDatabase from playhouse.sqlite_ext import SqliteExtDatabase
from playhouse.sqliteq import SqliteQueueDatabase
import frigate.util as util import frigate.util as util
from frigate.api.auth import hash_password from frigate.api.auth import hash_password
@ -41,6 +39,7 @@ from frigate.const import (
RECORD_DIR, RECORD_DIR,
) )
from frigate.embeddings import EmbeddingsContext, manage_embeddings from frigate.embeddings import EmbeddingsContext, manage_embeddings
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
from frigate.events.audio import AudioProcessor from frigate.events.audio import AudioProcessor
from frigate.events.cleanup import EventCleanup from frigate.events.cleanup import EventCleanup
from frigate.events.external import ExternalEventProcessor from frigate.events.external import ExternalEventProcessor
@ -240,7 +239,7 @@ class FrigateApp:
def bind_database(self) -> None: def bind_database(self) -> None:
"""Bind db to the main process.""" """Bind db to the main process."""
# NOTE: all db accessing processes need to be created before the db can be bound to the main process # NOTE: all db accessing processes need to be created before the db can be bound to the main process
self.db = SqliteQueueDatabase( self.db = SqliteVecQueueDatabase(
self.config.database.path, self.config.database.path,
pragmas={ pragmas={
"auto_vacuum": "FULL", # Does not defragment database "auto_vacuum": "FULL", # Does not defragment database
@ -250,6 +249,7 @@ class FrigateApp:
timeout=max( timeout=max(
60, 10 * len([c for c in self.config.cameras.values() if c.enabled]) 60, 10 * len([c for c in self.config.cameras.values() if c.enabled])
), ),
load_vec_extension=self.config.semantic_search.enabled,
) )
models = [ models = [
Event, Event,
@ -264,14 +264,6 @@ class FrigateApp:
] ]
self.db.bind(models) self.db.bind(models)
if self.config.semantic_search.enabled:
# use existing db connection to load sqlite_vec extension
conn = self.db.connection()
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
logger.info(f"main connection: {self.db}")
def check_db_data_migrations(self) -> None: def check_db_data_migrations(self) -> None:
# check if vacuum needs to be run # check if vacuum needs to be run
if not os.path.exists(f"{CONFIG_DIR}/.exports"): if not os.path.exists(f"{CONFIG_DIR}/.exports"):
@ -284,10 +276,7 @@ class FrigateApp:
migrate_exports(self.config.ffmpeg, list(self.config.cameras.keys())) migrate_exports(self.config.ffmpeg, list(self.config.cameras.keys()))
def init_embeddings_client(self) -> None: def init_embeddings_client(self) -> None:
if not self.config.semantic_search.enabled: if self.config.semantic_search.enabled:
self.embeddings = None
return
# Create a client for other processes to use # Create a client for other processes to use
self.embeddings = EmbeddingsContext(self.db) self.embeddings = EmbeddingsContext(self.db)

View File

@ -9,12 +9,11 @@ import threading
from types import FrameType from types import FrameType
from typing import Optional from typing import Optional
import sqlite_vec
from playhouse.sqliteq import SqliteQueueDatabase
from setproctitle import setproctitle from setproctitle import setproctitle
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR from frigate.const import CONFIG_DIR
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event from frigate.models import Event
from frigate.util.services import listen from frigate.util.services import listen
@ -43,7 +42,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
listen() listen()
# Configure Frigate DB # Configure Frigate DB
db = SqliteQueueDatabase( db = SqliteVecQueueDatabase(
config.database.path, config.database.path,
pragmas={ pragmas={
"auto_vacuum": "FULL", # Does not defragment database "auto_vacuum": "FULL", # Does not defragment database
@ -51,15 +50,11 @@ def manage_embeddings(config: FrigateConfig) -> None:
"synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous "synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous
}, },
timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])), timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])),
load_vec_extension=True,
) )
models = [Event] models = [Event]
db.bind(models) db.bind(models)
conn = db.connection()
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
embeddings = Embeddings(db) embeddings = Embeddings(db)
# Check if we need to re-index events # Check if we need to re-index events
@ -75,9 +70,9 @@ def manage_embeddings(config: FrigateConfig) -> None:
class EmbeddingsContext: class EmbeddingsContext:
def __init__(self, db: SqliteQueueDatabase): def __init__(self, db: SqliteVecQueueDatabase):
self.db = db self.db = db
self.embeddings = Embeddings(db) self.embeddings = Embeddings(self.db)
self.thumb_stats = ZScoreNormalization() self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization() self.desc_stats = ZScoreNormalization()

View File

@ -5,12 +5,12 @@ import io
import logging import logging
import struct import struct
import time import time
from typing import List, Tuple from typing import List, Tuple, Union
from PIL import Image from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from playhouse.sqliteq import SqliteQueueDatabase
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event from frigate.models import Event
from .functions.clip import ClipEmbedding from .functions.clip import ClipEmbedding
@ -63,10 +63,10 @@ def deserialize(bytes_data: bytes) -> List[float]:
class Embeddings: class Embeddings:
"""SQLite-vec embeddings database.""" """SQLite-vec embeddings database."""
def __init__(self, db: SqliteQueueDatabase) -> None: def __init__(self, db: SqliteVecQueueDatabase) -> None:
self.conn = db.connection() # Store the database connection instance self.db = db
# create tables if they don't exist # Create tables if they don't exist
self._create_tables() self._create_tables()
self.clip_embedding = ClipEmbedding(model="ViT-B/32") self.clip_embedding = ClipEmbedding(model="ViT-B/32")
@ -76,7 +76,7 @@ class Embeddings:
def _create_tables(self): def _create_tables(self):
# Create vec0 virtual table for thumbnail embeddings # Create vec0 virtual table for thumbnail embeddings
self.conn.execute(""" self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0( CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[512] thumbnail_embedding FLOAT[512]
@ -84,7 +84,7 @@ class Embeddings:
""") """)
# Create vec0 virtual table for description embeddings # Create vec0 virtual table for description embeddings
self.conn.execute(""" self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0( CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
description_embedding FLOAT[384] description_embedding FLOAT[384]
@ -97,79 +97,65 @@ class Embeddings:
# Generate embedding using CLIP # Generate embedding using CLIP
embedding = self.clip_embedding([image])[0] embedding = self.clip_embedding([image])[0]
# sqlite_vec virtual tables don't support upsert, check if event_id exists self.db.execute_sql(
cursor = self.conn.execute( """
"SELECT 1 FROM vec_thumbnails WHERE id = ?", (event_id,) INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
) VALUES(?, ?)
row = cursor.fetchone() """,
(event_id, serialize(embedding)),
if row is None:
# Insert if the event_id does not exist
self.conn.execute(
"INSERT INTO vec_thumbnails(id, thumbnail_embedding) VALUES(?, ?)",
[event_id, serialize(embedding)],
)
else:
# Update if the event_id already exists
self.conn.execute(
"UPDATE vec_thumbnails SET thumbnail_embedding = ? WHERE id = ?",
[serialize(embedding), event_id],
) )
def upsert_description(self, event_id: str, description: str): def upsert_description(self, event_id: str, description: str):
# Generate embedding using MiniLM # Generate embedding using MiniLM
embedding = self.minilm_embedding([description])[0] embedding = self.minilm_embedding([description])[0]
# sqlite_vec virtual tables don't support upsert, check if event_id exists self.db.execute_sql(
cursor = self.conn.execute( """
"SELECT 1 FROM vec_descriptions WHERE id = ?", (event_id,) INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
) VALUES(?, ?)
row = cursor.fetchone() """,
(event_id, serialize(embedding)),
if row is None:
# Insert if the event_id does not exist
self.conn.execute(
"INSERT INTO vec_descriptions(id, description_embedding) VALUES(?, ?)",
[event_id, serialize(embedding)],
)
else:
# Update if the event_id already exists
self.conn.execute(
"UPDATE vec_descriptions SET description_embedding = ? WHERE id = ?",
[serialize(embedding), event_id],
) )
def delete_thumbnail(self, event_ids: List[str]) -> None: def delete_thumbnail(self, event_ids: List[str]) -> None:
ids = ", ".join("?" for _ in event_ids) ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
self.conn.execute( f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", tuple(event_ids)
) )
def delete_description(self, event_ids: List[str]) -> None: def delete_description(self, event_ids: List[str]) -> None:
ids = ", ".join("?" for _ in event_ids) ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
self.conn.execute( f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", tuple(event_ids)
) )
def search_thumbnail(self, event_id: str, limit=10) -> List[Tuple[str, float]]: def search_thumbnail(
# check if it's already embedded self, query: Union[Event, str], limit=10
cursor = self.conn.execute( ) -> List[Tuple[str, float]]:
"SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?", (event_id,) if isinstance(query, Event):
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
) )
row = cursor.fetchone()
row = cursor.fetchone() if cursor else None
if row: if row:
query_embedding = deserialize(row[0]) query_embedding = deserialize(
row[0]
) # Deserialize the thumbnail embedding
else: else:
# If not embedded, fetch the thumbnail from the Event table and embed it # If no embedding found, generate it
event = Event.get_by_id(event_id) thumbnail = base64.b64decode(query.thumbnail)
thumbnail = base64.b64decode(event.thumbnail) self.upsert_thumbnail(query.id, thumbnail)
image = Image.open(io.BytesIO(thumbnail)).convert("RGB") image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
query_embedding = self.clip_embedding([image])[0] query = self.clip_embedding([image])[0]
self.upsert_thumbnail(event_id, thumbnail)
cursor = self.conn.execute( query_embedding = self.clip_embedding([query])[0]
results = self.db.execute_sql(
""" """
SELECT SELECT
vec_thumbnails.id, vec_thumbnails.id,
@ -179,13 +165,14 @@ class Embeddings:
AND k = ? AND k = ?
ORDER BY distance ORDER BY distance
""", """,
[serialize(query_embedding), limit], (serialize(query_embedding), limit),
) ).fetchall()
return cursor.fetchall()
return results
def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]: def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]:
query_embedding = self.minilm_embedding([query_text])[0] query_embedding = self.minilm_embedding([query_text])[0]
cursor = self.conn.execute( results = self.db.execute_sql(
""" """
SELECT SELECT
vec_descriptions.id, vec_descriptions.id,
@ -195,12 +182,12 @@ class Embeddings:
AND k = ? AND k = ?
ORDER BY distance ORDER BY distance
""", """,
[serialize(query_embedding), limit], (serialize(query_embedding), limit),
) ).fetchall()
return cursor.fetchall()
return results
def reindex(self) -> None: def reindex(self) -> None:
"""Reindex all event embeddings."""
logger.info("Indexing event embeddings...") logger.info("Indexing event embeddings...")
st = time.time() st = time.time()

View File

@ -0,0 +1,19 @@
import sqlite_vec
from playhouse.sqliteq import SqliteQueueDatabase
class SqliteVecQueueDatabase(SqliteQueueDatabase):
def __init__(self, *args, load_vec_extension=False, **kwargs):
super().__init__(*args, **kwargs)
self.load_vec_extension = load_vec_extension
def _connect(self, *args, **kwargs):
conn = super()._connect(*args, **kwargs)
if self.load_vec_extension:
self._load_vec_extension(conn)
return conn
def _load_vec_extension(self, conn):
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)