extend the SqliteQueueDatabase class and use peewee db.execute_sql

This commit is contained in:
Josh Hawkins 2024-10-04 15:40:53 -05:00
parent 5181ea7b3d
commit 1b7f469daf
5 changed files with 93 additions and 104 deletions

View File

@ -384,8 +384,6 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
context: EmbeddingsContext = request.app.embeddings
logger.info(f"context: {context.embeddings}, conn: {context.embeddings.conn}")
selected_columns = [
Event.id,
Event.camera,
@ -503,6 +501,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
if "thumbnail" in search_types:
thumb_results = context.embeddings.search_thumbnail(query, limit)
logger.info(f"thumb results: {thumb_results}")
if "description" in search_types:
desc_results = context.embeddings.search_description(query, limit)

View File

@ -9,11 +9,9 @@ from multiprocessing.synchronize import Event as MpEvent
from typing import Any, Optional
import psutil
import sqlite_vec
import uvicorn
from peewee_migrate import Router
from playhouse.sqlite_ext import SqliteExtDatabase
from playhouse.sqliteq import SqliteQueueDatabase
import frigate.util as util
from frigate.api.auth import hash_password
@ -41,6 +39,7 @@ from frigate.const import (
RECORD_DIR,
)
from frigate.embeddings import EmbeddingsContext, manage_embeddings
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
from frigate.events.audio import AudioProcessor
from frigate.events.cleanup import EventCleanup
from frigate.events.external import ExternalEventProcessor
@ -240,7 +239,7 @@ class FrigateApp:
def bind_database(self) -> None:
"""Bind db to the main process."""
# NOTE: all db accessing processes need to be created before the db can be bound to the main process
self.db = SqliteQueueDatabase(
self.db = SqliteVecQueueDatabase(
self.config.database.path,
pragmas={
"auto_vacuum": "FULL", # Does not defragment database
@ -250,6 +249,7 @@ class FrigateApp:
timeout=max(
60, 10 * len([c for c in self.config.cameras.values() if c.enabled])
),
load_vec_extension=self.config.semantic_search.enabled,
)
models = [
Event,
@ -264,14 +264,6 @@ class FrigateApp:
]
self.db.bind(models)
if self.config.semantic_search.enabled:
# use existing db connection to load sqlite_vec extension
conn = self.db.connection()
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
logger.info(f"main connection: {self.db}")
def check_db_data_migrations(self) -> None:
# check if vacuum needs to be run
if not os.path.exists(f"{CONFIG_DIR}/.exports"):
@ -284,12 +276,9 @@ class FrigateApp:
migrate_exports(self.config.ffmpeg, list(self.config.cameras.keys()))
def init_embeddings_client(self) -> None:
if not self.config.semantic_search.enabled:
self.embeddings = None
return
# Create a client for other processes to use
self.embeddings = EmbeddingsContext(self.db)
if self.config.semantic_search.enabled:
# Create a client for other processes to use
self.embeddings = EmbeddingsContext(self.db)
def init_external_event_processor(self) -> None:
self.external_event_processor = ExternalEventProcessor(self.config)

View File

@ -9,12 +9,11 @@ import threading
from types import FrameType
from typing import Optional
import sqlite_vec
from playhouse.sqliteq import SqliteQueueDatabase
from setproctitle import setproctitle
from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event
from frigate.util.services import listen
@ -43,7 +42,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
listen()
# Configure Frigate DB
db = SqliteQueueDatabase(
db = SqliteVecQueueDatabase(
config.database.path,
pragmas={
"auto_vacuum": "FULL", # Does not defragment database
@ -51,15 +50,11 @@ def manage_embeddings(config: FrigateConfig) -> None:
"synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous
},
timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])),
load_vec_extension=True,
)
models = [Event]
db.bind(models)
conn = db.connection()
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
embeddings = Embeddings(db)
# Check if we need to re-index events
@ -75,9 +70,9 @@ def manage_embeddings(config: FrigateConfig) -> None:
class EmbeddingsContext:
def __init__(self, db: SqliteQueueDatabase):
def __init__(self, db: SqliteVecQueueDatabase):
self.db = db
self.embeddings = Embeddings(db)
self.embeddings = Embeddings(self.db)
self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization()

View File

@ -5,12 +5,12 @@ import io
import logging
import struct
import time
from typing import List, Tuple
from typing import List, Tuple, Union
from PIL import Image
from playhouse.shortcuts import model_to_dict
from playhouse.sqliteq import SqliteQueueDatabase
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event
from .functions.clip import ClipEmbedding
@ -63,10 +63,10 @@ def deserialize(bytes_data: bytes) -> List[float]:
class Embeddings:
"""SQLite-vec embeddings database."""
def __init__(self, db: SqliteQueueDatabase) -> None:
self.conn = db.connection() # Store the database connection instance
def __init__(self, db: SqliteVecQueueDatabase) -> None:
self.db = db
# create tables if they don't exist
# Create tables if they don't exist
self._create_tables()
self.clip_embedding = ClipEmbedding(model="ViT-B/32")
@ -76,7 +76,7 @@ class Embeddings:
def _create_tables(self):
# Create vec0 virtual table for thumbnail embeddings
self.conn.execute("""
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
id TEXT PRIMARY KEY,
thumbnail_embedding FLOAT[512]
@ -84,7 +84,7 @@ class Embeddings:
""")
# Create vec0 virtual table for description embeddings
self.conn.execute("""
self.db.execute_sql("""
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
id TEXT PRIMARY KEY,
description_embedding FLOAT[384]
@ -97,79 +97,65 @@ class Embeddings:
# Generate embedding using CLIP
embedding = self.clip_embedding([image])[0]
# sqlite_vec virtual tables don't support upsert, check if event_id exists
cursor = self.conn.execute(
"SELECT 1 FROM vec_thumbnails WHERE id = ?", (event_id,)
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
)
row = cursor.fetchone()
if row is None:
# Insert if the event_id does not exist
self.conn.execute(
"INSERT INTO vec_thumbnails(id, thumbnail_embedding) VALUES(?, ?)",
[event_id, serialize(embedding)],
)
else:
# Update if the event_id already exists
self.conn.execute(
"UPDATE vec_thumbnails SET thumbnail_embedding = ? WHERE id = ?",
[serialize(embedding), event_id],
)
def upsert_description(self, event_id: str, description: str):
# Generate embedding using MiniLM
embedding = self.minilm_embedding([description])[0]
# sqlite_vec virtual tables don't support upsert, check if event_id exists
cursor = self.conn.execute(
"SELECT 1 FROM vec_descriptions WHERE id = ?", (event_id,)
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
)
row = cursor.fetchone()
if row is None:
# Insert if the event_id does not exist
self.conn.execute(
"INSERT INTO vec_descriptions(id, description_embedding) VALUES(?, ?)",
[event_id, serialize(embedding)],
)
else:
# Update if the event_id already exists
self.conn.execute(
"UPDATE vec_descriptions SET description_embedding = ? WHERE id = ?",
[serialize(embedding), event_id],
)
def delete_thumbnail(self, event_ids: List[str]) -> None:
ids = ", ".join("?" for _ in event_ids)
self.conn.execute(
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", tuple(event_ids)
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
)
def delete_description(self, event_ids: List[str]) -> None:
ids = ", ".join("?" for _ in event_ids)
self.conn.execute(
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", tuple(event_ids)
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
)
def search_thumbnail(self, event_id: str, limit=10) -> List[Tuple[str, float]]:
# check if it's already embedded
cursor = self.conn.execute(
"SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?", (event_id,)
)
row = cursor.fetchone()
if row:
query_embedding = deserialize(row[0])
else:
# If not embedded, fetch the thumbnail from the Event table and embed it
event = Event.get_by_id(event_id)
thumbnail = base64.b64decode(event.thumbnail)
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
query_embedding = self.clip_embedding([image])[0]
self.upsert_thumbnail(event_id, thumbnail)
def search_thumbnail(
self, query: Union[Event, str], limit=10
) -> List[Tuple[str, float]]:
if isinstance(query, Event):
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
)
cursor = self.conn.execute(
row = cursor.fetchone() if cursor else None
if row:
query_embedding = deserialize(
row[0]
) # Deserialize the thumbnail embedding
else:
# If no embedding found, generate it
thumbnail = base64.b64decode(query.thumbnail)
self.upsert_thumbnail(query.id, thumbnail)
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
query = self.clip_embedding([image])[0]
query_embedding = self.clip_embedding([query])[0]
results = self.db.execute_sql(
"""
SELECT
vec_thumbnails.id,
@ -178,14 +164,15 @@ class Embeddings:
WHERE thumbnail_embedding MATCH ?
AND k = ?
ORDER BY distance
""",
[serialize(query_embedding), limit],
)
return cursor.fetchall()
""",
(serialize(query_embedding), limit),
).fetchall()
return results
def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]:
query_embedding = self.minilm_embedding([query_text])[0]
cursor = self.conn.execute(
results = self.db.execute_sql(
"""
SELECT
vec_descriptions.id,
@ -194,13 +181,13 @@ class Embeddings:
WHERE description_embedding MATCH ?
AND k = ?
ORDER BY distance
""",
[serialize(query_embedding), limit],
)
return cursor.fetchall()
""",
(serialize(query_embedding), limit),
).fetchall()
return results
def reindex(self) -> None:
"""Reindex all event embeddings."""
logger.info("Indexing event embeddings...")
st = time.time()

View File

@ -0,0 +1,19 @@
import sqlite_vec
from playhouse.sqliteq import SqliteQueueDatabase
class SqliteVecQueueDatabase(SqliteQueueDatabase):
def __init__(self, *args, load_vec_extension=False, **kwargs):
super().__init__(*args, **kwargs)
self.load_vec_extension = load_vec_extension
def _connect(self, *args, **kwargs):
conn = super()._connect(*args, **kwargs)
if self.load_vec_extension:
self._load_vec_extension(conn)
return conn
def _load_vec_extension(self, conn):
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)