mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
extend the SqliteQueueDatabase class and use peewee db.execute_sql
This commit is contained in:
parent
5181ea7b3d
commit
1b7f469daf
@ -384,8 +384,6 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
|
||||
context: EmbeddingsContext = request.app.embeddings
|
||||
|
||||
logger.info(f"context: {context.embeddings}, conn: {context.embeddings.conn}")
|
||||
|
||||
selected_columns = [
|
||||
Event.id,
|
||||
Event.camera,
|
||||
@ -503,6 +501,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
||||
|
||||
if "thumbnail" in search_types:
|
||||
thumb_results = context.embeddings.search_thumbnail(query, limit)
|
||||
logger.info(f"thumb results: {thumb_results}")
|
||||
|
||||
if "description" in search_types:
|
||||
desc_results = context.embeddings.search_description(query, limit)
|
||||
|
||||
@ -9,11 +9,9 @@ from multiprocessing.synchronize import Event as MpEvent
|
||||
from typing import Any, Optional
|
||||
|
||||
import psutil
|
||||
import sqlite_vec
|
||||
import uvicorn
|
||||
from peewee_migrate import Router
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
|
||||
import frigate.util as util
|
||||
from frigate.api.auth import hash_password
|
||||
@ -41,6 +39,7 @@ from frigate.const import (
|
||||
RECORD_DIR,
|
||||
)
|
||||
from frigate.embeddings import EmbeddingsContext, manage_embeddings
|
||||
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.events.audio import AudioProcessor
|
||||
from frigate.events.cleanup import EventCleanup
|
||||
from frigate.events.external import ExternalEventProcessor
|
||||
@ -240,7 +239,7 @@ class FrigateApp:
|
||||
def bind_database(self) -> None:
|
||||
"""Bind db to the main process."""
|
||||
# NOTE: all db accessing processes need to be created before the db can be bound to the main process
|
||||
self.db = SqliteQueueDatabase(
|
||||
self.db = SqliteVecQueueDatabase(
|
||||
self.config.database.path,
|
||||
pragmas={
|
||||
"auto_vacuum": "FULL", # Does not defragment database
|
||||
@ -250,6 +249,7 @@ class FrigateApp:
|
||||
timeout=max(
|
||||
60, 10 * len([c for c in self.config.cameras.values() if c.enabled])
|
||||
),
|
||||
load_vec_extension=self.config.semantic_search.enabled,
|
||||
)
|
||||
models = [
|
||||
Event,
|
||||
@ -264,14 +264,6 @@ class FrigateApp:
|
||||
]
|
||||
self.db.bind(models)
|
||||
|
||||
if self.config.semantic_search.enabled:
|
||||
# use existing db connection to load sqlite_vec extension
|
||||
conn = self.db.connection()
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
conn.enable_load_extension(False)
|
||||
logger.info(f"main connection: {self.db}")
|
||||
|
||||
def check_db_data_migrations(self) -> None:
|
||||
# check if vacuum needs to be run
|
||||
if not os.path.exists(f"{CONFIG_DIR}/.exports"):
|
||||
@ -284,12 +276,9 @@ class FrigateApp:
|
||||
migrate_exports(self.config.ffmpeg, list(self.config.cameras.keys()))
|
||||
|
||||
def init_embeddings_client(self) -> None:
|
||||
if not self.config.semantic_search.enabled:
|
||||
self.embeddings = None
|
||||
return
|
||||
|
||||
# Create a client for other processes to use
|
||||
self.embeddings = EmbeddingsContext(self.db)
|
||||
if self.config.semantic_search.enabled:
|
||||
# Create a client for other processes to use
|
||||
self.embeddings = EmbeddingsContext(self.db)
|
||||
|
||||
def init_external_event_processor(self) -> None:
|
||||
self.external_event_processor = ExternalEventProcessor(self.config)
|
||||
|
||||
@ -9,12 +9,11 @@ import threading
|
||||
from types import FrameType
|
||||
from typing import Optional
|
||||
|
||||
import sqlite_vec
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
from setproctitle import setproctitle
|
||||
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.const import CONFIG_DIR
|
||||
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.models import Event
|
||||
from frigate.util.services import listen
|
||||
|
||||
@ -43,7 +42,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
||||
listen()
|
||||
|
||||
# Configure Frigate DB
|
||||
db = SqliteQueueDatabase(
|
||||
db = SqliteVecQueueDatabase(
|
||||
config.database.path,
|
||||
pragmas={
|
||||
"auto_vacuum": "FULL", # Does not defragment database
|
||||
@ -51,15 +50,11 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
||||
"synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous
|
||||
},
|
||||
timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])),
|
||||
load_vec_extension=True,
|
||||
)
|
||||
models = [Event]
|
||||
db.bind(models)
|
||||
|
||||
conn = db.connection()
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
conn.enable_load_extension(False)
|
||||
|
||||
embeddings = Embeddings(db)
|
||||
|
||||
# Check if we need to re-index events
|
||||
@ -75,9 +70,9 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
||||
|
||||
|
||||
class EmbeddingsContext:
|
||||
def __init__(self, db: SqliteQueueDatabase):
|
||||
def __init__(self, db: SqliteVecQueueDatabase):
|
||||
self.db = db
|
||||
self.embeddings = Embeddings(db)
|
||||
self.embeddings = Embeddings(self.db)
|
||||
self.thumb_stats = ZScoreNormalization()
|
||||
self.desc_stats = ZScoreNormalization()
|
||||
|
||||
|
||||
@ -5,12 +5,12 @@ import io
|
||||
import logging
|
||||
import struct
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from PIL import Image
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
|
||||
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.models import Event
|
||||
|
||||
from .functions.clip import ClipEmbedding
|
||||
@ -63,10 +63,10 @@ def deserialize(bytes_data: bytes) -> List[float]:
|
||||
class Embeddings:
|
||||
"""SQLite-vec embeddings database."""
|
||||
|
||||
def __init__(self, db: SqliteQueueDatabase) -> None:
|
||||
self.conn = db.connection() # Store the database connection instance
|
||||
def __init__(self, db: SqliteVecQueueDatabase) -> None:
|
||||
self.db = db
|
||||
|
||||
# create tables if they don't exist
|
||||
# Create tables if they don't exist
|
||||
self._create_tables()
|
||||
|
||||
self.clip_embedding = ClipEmbedding(model="ViT-B/32")
|
||||
@ -76,7 +76,7 @@ class Embeddings:
|
||||
|
||||
def _create_tables(self):
|
||||
# Create vec0 virtual table for thumbnail embeddings
|
||||
self.conn.execute("""
|
||||
self.db.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
thumbnail_embedding FLOAT[512]
|
||||
@ -84,7 +84,7 @@ class Embeddings:
|
||||
""")
|
||||
|
||||
# Create vec0 virtual table for description embeddings
|
||||
self.conn.execute("""
|
||||
self.db.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
description_embedding FLOAT[384]
|
||||
@ -97,79 +97,65 @@ class Embeddings:
|
||||
# Generate embedding using CLIP
|
||||
embedding = self.clip_embedding([image])[0]
|
||||
|
||||
# sqlite_vec virtual tables don't support upsert, check if event_id exists
|
||||
cursor = self.conn.execute(
|
||||
"SELECT 1 FROM vec_thumbnails WHERE id = ?", (event_id,)
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||
VALUES(?, ?)
|
||||
""",
|
||||
(event_id, serialize(embedding)),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
# Insert if the event_id does not exist
|
||||
self.conn.execute(
|
||||
"INSERT INTO vec_thumbnails(id, thumbnail_embedding) VALUES(?, ?)",
|
||||
[event_id, serialize(embedding)],
|
||||
)
|
||||
else:
|
||||
# Update if the event_id already exists
|
||||
self.conn.execute(
|
||||
"UPDATE vec_thumbnails SET thumbnail_embedding = ? WHERE id = ?",
|
||||
[serialize(embedding), event_id],
|
||||
)
|
||||
|
||||
def upsert_description(self, event_id: str, description: str):
|
||||
# Generate embedding using MiniLM
|
||||
embedding = self.minilm_embedding([description])[0]
|
||||
|
||||
# sqlite_vec virtual tables don't support upsert, check if event_id exists
|
||||
cursor = self.conn.execute(
|
||||
"SELECT 1 FROM vec_descriptions WHERE id = ?", (event_id,)
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||
VALUES(?, ?)
|
||||
""",
|
||||
(event_id, serialize(embedding)),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
# Insert if the event_id does not exist
|
||||
self.conn.execute(
|
||||
"INSERT INTO vec_descriptions(id, description_embedding) VALUES(?, ?)",
|
||||
[event_id, serialize(embedding)],
|
||||
)
|
||||
else:
|
||||
# Update if the event_id already exists
|
||||
self.conn.execute(
|
||||
"UPDATE vec_descriptions SET description_embedding = ? WHERE id = ?",
|
||||
[serialize(embedding), event_id],
|
||||
)
|
||||
|
||||
def delete_thumbnail(self, event_ids: List[str]) -> None:
|
||||
ids = ", ".join("?" for _ in event_ids)
|
||||
|
||||
self.conn.execute(
|
||||
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", tuple(event_ids)
|
||||
ids = ",".join(["?" for _ in event_ids])
|
||||
self.db.execute_sql(
|
||||
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
|
||||
)
|
||||
|
||||
def delete_description(self, event_ids: List[str]) -> None:
|
||||
ids = ", ".join("?" for _ in event_ids)
|
||||
|
||||
self.conn.execute(
|
||||
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", tuple(event_ids)
|
||||
ids = ",".join(["?" for _ in event_ids])
|
||||
self.db.execute_sql(
|
||||
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
|
||||
)
|
||||
|
||||
def search_thumbnail(self, event_id: str, limit=10) -> List[Tuple[str, float]]:
|
||||
# check if it's already embedded
|
||||
cursor = self.conn.execute(
|
||||
"SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?", (event_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
query_embedding = deserialize(row[0])
|
||||
else:
|
||||
# If not embedded, fetch the thumbnail from the Event table and embed it
|
||||
event = Event.get_by_id(event_id)
|
||||
thumbnail = base64.b64decode(event.thumbnail)
|
||||
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
||||
query_embedding = self.clip_embedding([image])[0]
|
||||
self.upsert_thumbnail(event_id, thumbnail)
|
||||
def search_thumbnail(
|
||||
self, query: Union[Event, str], limit=10
|
||||
) -> List[Tuple[str, float]]:
|
||||
if isinstance(query, Event):
|
||||
cursor = self.db.execute_sql(
|
||||
"""
|
||||
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
|
||||
""",
|
||||
[query.id],
|
||||
)
|
||||
|
||||
cursor = self.conn.execute(
|
||||
row = cursor.fetchone() if cursor else None
|
||||
|
||||
if row:
|
||||
query_embedding = deserialize(
|
||||
row[0]
|
||||
) # Deserialize the thumbnail embedding
|
||||
else:
|
||||
# If no embedding found, generate it
|
||||
thumbnail = base64.b64decode(query.thumbnail)
|
||||
self.upsert_thumbnail(query.id, thumbnail)
|
||||
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
||||
query = self.clip_embedding([image])[0]
|
||||
|
||||
query_embedding = self.clip_embedding([query])[0]
|
||||
|
||||
results = self.db.execute_sql(
|
||||
"""
|
||||
SELECT
|
||||
vec_thumbnails.id,
|
||||
@ -178,14 +164,15 @@ class Embeddings:
|
||||
WHERE thumbnail_embedding MATCH ?
|
||||
AND k = ?
|
||||
ORDER BY distance
|
||||
""",
|
||||
[serialize(query_embedding), limit],
|
||||
)
|
||||
return cursor.fetchall()
|
||||
""",
|
||||
(serialize(query_embedding), limit),
|
||||
).fetchall()
|
||||
|
||||
return results
|
||||
|
||||
def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]:
|
||||
query_embedding = self.minilm_embedding([query_text])[0]
|
||||
cursor = self.conn.execute(
|
||||
results = self.db.execute_sql(
|
||||
"""
|
||||
SELECT
|
||||
vec_descriptions.id,
|
||||
@ -194,13 +181,13 @@ class Embeddings:
|
||||
WHERE description_embedding MATCH ?
|
||||
AND k = ?
|
||||
ORDER BY distance
|
||||
""",
|
||||
[serialize(query_embedding), limit],
|
||||
)
|
||||
return cursor.fetchall()
|
||||
""",
|
||||
(serialize(query_embedding), limit),
|
||||
).fetchall()
|
||||
|
||||
return results
|
||||
|
||||
def reindex(self) -> None:
|
||||
"""Reindex all event embeddings."""
|
||||
logger.info("Indexing event embeddings...")
|
||||
|
||||
st = time.time()
|
||||
|
||||
19
frigate/embeddings/sqlitevecq.py
Normal file
19
frigate/embeddings/sqlitevecq.py
Normal file
@ -0,0 +1,19 @@
|
||||
import sqlite_vec
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
|
||||
|
||||
class SqliteVecQueueDatabase(SqliteQueueDatabase):
|
||||
def __init__(self, *args, load_vec_extension=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.load_vec_extension = load_vec_extension
|
||||
|
||||
def _connect(self, *args, **kwargs):
|
||||
conn = super()._connect(*args, **kwargs)
|
||||
if self.load_vec_extension:
|
||||
self._load_vec_extension(conn)
|
||||
return conn
|
||||
|
||||
def _load_vec_extension(self, conn):
|
||||
conn.enable_load_extension(True)
|
||||
sqlite_vec.load(conn)
|
||||
conn.enable_load_extension(False)
|
||||
Loading…
Reference in New Issue
Block a user