mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
extend the SqliteQueueDatabase class and use peewee db.execute_sql
This commit is contained in:
parent
5181ea7b3d
commit
1b7f469daf
@ -384,8 +384,6 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
|||||||
|
|
||||||
context: EmbeddingsContext = request.app.embeddings
|
context: EmbeddingsContext = request.app.embeddings
|
||||||
|
|
||||||
logger.info(f"context: {context.embeddings}, conn: {context.embeddings.conn}")
|
|
||||||
|
|
||||||
selected_columns = [
|
selected_columns = [
|
||||||
Event.id,
|
Event.id,
|
||||||
Event.camera,
|
Event.camera,
|
||||||
@ -503,6 +501,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
|
|||||||
|
|
||||||
if "thumbnail" in search_types:
|
if "thumbnail" in search_types:
|
||||||
thumb_results = context.embeddings.search_thumbnail(query, limit)
|
thumb_results = context.embeddings.search_thumbnail(query, limit)
|
||||||
|
logger.info(f"thumb results: {thumb_results}")
|
||||||
|
|
||||||
if "description" in search_types:
|
if "description" in search_types:
|
||||||
desc_results = context.embeddings.search_description(query, limit)
|
desc_results = context.embeddings.search_description(query, limit)
|
||||||
|
|||||||
@ -9,11 +9,9 @@ from multiprocessing.synchronize import Event as MpEvent
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import sqlite_vec
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from peewee_migrate import Router
|
from peewee_migrate import Router
|
||||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||||
from playhouse.sqliteq import SqliteQueueDatabase
|
|
||||||
|
|
||||||
import frigate.util as util
|
import frigate.util as util
|
||||||
from frigate.api.auth import hash_password
|
from frigate.api.auth import hash_password
|
||||||
@ -41,6 +39,7 @@ from frigate.const import (
|
|||||||
RECORD_DIR,
|
RECORD_DIR,
|
||||||
)
|
)
|
||||||
from frigate.embeddings import EmbeddingsContext, manage_embeddings
|
from frigate.embeddings import EmbeddingsContext, manage_embeddings
|
||||||
|
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
|
||||||
from frigate.events.audio import AudioProcessor
|
from frigate.events.audio import AudioProcessor
|
||||||
from frigate.events.cleanup import EventCleanup
|
from frigate.events.cleanup import EventCleanup
|
||||||
from frigate.events.external import ExternalEventProcessor
|
from frigate.events.external import ExternalEventProcessor
|
||||||
@ -240,7 +239,7 @@ class FrigateApp:
|
|||||||
def bind_database(self) -> None:
|
def bind_database(self) -> None:
|
||||||
"""Bind db to the main process."""
|
"""Bind db to the main process."""
|
||||||
# NOTE: all db accessing processes need to be created before the db can be bound to the main process
|
# NOTE: all db accessing processes need to be created before the db can be bound to the main process
|
||||||
self.db = SqliteQueueDatabase(
|
self.db = SqliteVecQueueDatabase(
|
||||||
self.config.database.path,
|
self.config.database.path,
|
||||||
pragmas={
|
pragmas={
|
||||||
"auto_vacuum": "FULL", # Does not defragment database
|
"auto_vacuum": "FULL", # Does not defragment database
|
||||||
@ -250,6 +249,7 @@ class FrigateApp:
|
|||||||
timeout=max(
|
timeout=max(
|
||||||
60, 10 * len([c for c in self.config.cameras.values() if c.enabled])
|
60, 10 * len([c for c in self.config.cameras.values() if c.enabled])
|
||||||
),
|
),
|
||||||
|
load_vec_extension=self.config.semantic_search.enabled,
|
||||||
)
|
)
|
||||||
models = [
|
models = [
|
||||||
Event,
|
Event,
|
||||||
@ -264,14 +264,6 @@ class FrigateApp:
|
|||||||
]
|
]
|
||||||
self.db.bind(models)
|
self.db.bind(models)
|
||||||
|
|
||||||
if self.config.semantic_search.enabled:
|
|
||||||
# use existing db connection to load sqlite_vec extension
|
|
||||||
conn = self.db.connection()
|
|
||||||
conn.enable_load_extension(True)
|
|
||||||
sqlite_vec.load(conn)
|
|
||||||
conn.enable_load_extension(False)
|
|
||||||
logger.info(f"main connection: {self.db}")
|
|
||||||
|
|
||||||
def check_db_data_migrations(self) -> None:
|
def check_db_data_migrations(self) -> None:
|
||||||
# check if vacuum needs to be run
|
# check if vacuum needs to be run
|
||||||
if not os.path.exists(f"{CONFIG_DIR}/.exports"):
|
if not os.path.exists(f"{CONFIG_DIR}/.exports"):
|
||||||
@ -284,12 +276,9 @@ class FrigateApp:
|
|||||||
migrate_exports(self.config.ffmpeg, list(self.config.cameras.keys()))
|
migrate_exports(self.config.ffmpeg, list(self.config.cameras.keys()))
|
||||||
|
|
||||||
def init_embeddings_client(self) -> None:
|
def init_embeddings_client(self) -> None:
|
||||||
if not self.config.semantic_search.enabled:
|
if self.config.semantic_search.enabled:
|
||||||
self.embeddings = None
|
# Create a client for other processes to use
|
||||||
return
|
self.embeddings = EmbeddingsContext(self.db)
|
||||||
|
|
||||||
# Create a client for other processes to use
|
|
||||||
self.embeddings = EmbeddingsContext(self.db)
|
|
||||||
|
|
||||||
def init_external_event_processor(self) -> None:
|
def init_external_event_processor(self) -> None:
|
||||||
self.external_event_processor = ExternalEventProcessor(self.config)
|
self.external_event_processor = ExternalEventProcessor(self.config)
|
||||||
|
|||||||
@ -9,12 +9,11 @@ import threading
|
|||||||
from types import FrameType
|
from types import FrameType
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import sqlite_vec
|
|
||||||
from playhouse.sqliteq import SqliteQueueDatabase
|
|
||||||
from setproctitle import setproctitle
|
from setproctitle import setproctitle
|
||||||
|
|
||||||
from frigate.config import FrigateConfig
|
from frigate.config import FrigateConfig
|
||||||
from frigate.const import CONFIG_DIR
|
from frigate.const import CONFIG_DIR
|
||||||
|
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
|
||||||
from frigate.models import Event
|
from frigate.models import Event
|
||||||
from frigate.util.services import listen
|
from frigate.util.services import listen
|
||||||
|
|
||||||
@ -43,7 +42,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
|||||||
listen()
|
listen()
|
||||||
|
|
||||||
# Configure Frigate DB
|
# Configure Frigate DB
|
||||||
db = SqliteQueueDatabase(
|
db = SqliteVecQueueDatabase(
|
||||||
config.database.path,
|
config.database.path,
|
||||||
pragmas={
|
pragmas={
|
||||||
"auto_vacuum": "FULL", # Does not defragment database
|
"auto_vacuum": "FULL", # Does not defragment database
|
||||||
@ -51,15 +50,11 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
|||||||
"synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous
|
"synchronous": "NORMAL", # Safe when using WAL https://www.sqlite.org/pragma.html#pragma_synchronous
|
||||||
},
|
},
|
||||||
timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])),
|
timeout=max(60, 10 * len([c for c in config.cameras.values() if c.enabled])),
|
||||||
|
load_vec_extension=True,
|
||||||
)
|
)
|
||||||
models = [Event]
|
models = [Event]
|
||||||
db.bind(models)
|
db.bind(models)
|
||||||
|
|
||||||
conn = db.connection()
|
|
||||||
conn.enable_load_extension(True)
|
|
||||||
sqlite_vec.load(conn)
|
|
||||||
conn.enable_load_extension(False)
|
|
||||||
|
|
||||||
embeddings = Embeddings(db)
|
embeddings = Embeddings(db)
|
||||||
|
|
||||||
# Check if we need to re-index events
|
# Check if we need to re-index events
|
||||||
@ -75,9 +70,9 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsContext:
|
class EmbeddingsContext:
|
||||||
def __init__(self, db: SqliteQueueDatabase):
|
def __init__(self, db: SqliteVecQueueDatabase):
|
||||||
self.db = db
|
self.db = db
|
||||||
self.embeddings = Embeddings(db)
|
self.embeddings = Embeddings(self.db)
|
||||||
self.thumb_stats = ZScoreNormalization()
|
self.thumb_stats = ZScoreNormalization()
|
||||||
self.desc_stats = ZScoreNormalization()
|
self.desc_stats = ZScoreNormalization()
|
||||||
|
|
||||||
|
|||||||
@ -5,12 +5,12 @@ import io
|
|||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from playhouse.shortcuts import model_to_dict
|
from playhouse.shortcuts import model_to_dict
|
||||||
from playhouse.sqliteq import SqliteQueueDatabase
|
|
||||||
|
|
||||||
|
from frigate.embeddings.sqlitevecq import SqliteVecQueueDatabase
|
||||||
from frigate.models import Event
|
from frigate.models import Event
|
||||||
|
|
||||||
from .functions.clip import ClipEmbedding
|
from .functions.clip import ClipEmbedding
|
||||||
@ -63,10 +63,10 @@ def deserialize(bytes_data: bytes) -> List[float]:
|
|||||||
class Embeddings:
|
class Embeddings:
|
||||||
"""SQLite-vec embeddings database."""
|
"""SQLite-vec embeddings database."""
|
||||||
|
|
||||||
def __init__(self, db: SqliteQueueDatabase) -> None:
|
def __init__(self, db: SqliteVecQueueDatabase) -> None:
|
||||||
self.conn = db.connection() # Store the database connection instance
|
self.db = db
|
||||||
|
|
||||||
# create tables if they don't exist
|
# Create tables if they don't exist
|
||||||
self._create_tables()
|
self._create_tables()
|
||||||
|
|
||||||
self.clip_embedding = ClipEmbedding(model="ViT-B/32")
|
self.clip_embedding = ClipEmbedding(model="ViT-B/32")
|
||||||
@ -76,7 +76,7 @@ class Embeddings:
|
|||||||
|
|
||||||
def _create_tables(self):
|
def _create_tables(self):
|
||||||
# Create vec0 virtual table for thumbnail embeddings
|
# Create vec0 virtual table for thumbnail embeddings
|
||||||
self.conn.execute("""
|
self.db.execute_sql("""
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
thumbnail_embedding FLOAT[512]
|
thumbnail_embedding FLOAT[512]
|
||||||
@ -84,7 +84,7 @@ class Embeddings:
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
# Create vec0 virtual table for description embeddings
|
# Create vec0 virtual table for description embeddings
|
||||||
self.conn.execute("""
|
self.db.execute_sql("""
|
||||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
description_embedding FLOAT[384]
|
description_embedding FLOAT[384]
|
||||||
@ -97,79 +97,65 @@ class Embeddings:
|
|||||||
# Generate embedding using CLIP
|
# Generate embedding using CLIP
|
||||||
embedding = self.clip_embedding([image])[0]
|
embedding = self.clip_embedding([image])[0]
|
||||||
|
|
||||||
# sqlite_vec virtual tables don't support upsert, check if event_id exists
|
self.db.execute_sql(
|
||||||
cursor = self.conn.execute(
|
"""
|
||||||
"SELECT 1 FROM vec_thumbnails WHERE id = ?", (event_id,)
|
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||||
|
VALUES(?, ?)
|
||||||
|
""",
|
||||||
|
(event_id, serialize(embedding)),
|
||||||
)
|
)
|
||||||
row = cursor.fetchone()
|
|
||||||
|
|
||||||
if row is None:
|
|
||||||
# Insert if the event_id does not exist
|
|
||||||
self.conn.execute(
|
|
||||||
"INSERT INTO vec_thumbnails(id, thumbnail_embedding) VALUES(?, ?)",
|
|
||||||
[event_id, serialize(embedding)],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Update if the event_id already exists
|
|
||||||
self.conn.execute(
|
|
||||||
"UPDATE vec_thumbnails SET thumbnail_embedding = ? WHERE id = ?",
|
|
||||||
[serialize(embedding), event_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
def upsert_description(self, event_id: str, description: str):
|
def upsert_description(self, event_id: str, description: str):
|
||||||
# Generate embedding using MiniLM
|
# Generate embedding using MiniLM
|
||||||
embedding = self.minilm_embedding([description])[0]
|
embedding = self.minilm_embedding([description])[0]
|
||||||
|
|
||||||
# sqlite_vec virtual tables don't support upsert, check if event_id exists
|
self.db.execute_sql(
|
||||||
cursor = self.conn.execute(
|
"""
|
||||||
"SELECT 1 FROM vec_descriptions WHERE id = ?", (event_id,)
|
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||||
|
VALUES(?, ?)
|
||||||
|
""",
|
||||||
|
(event_id, serialize(embedding)),
|
||||||
)
|
)
|
||||||
row = cursor.fetchone()
|
|
||||||
|
|
||||||
if row is None:
|
|
||||||
# Insert if the event_id does not exist
|
|
||||||
self.conn.execute(
|
|
||||||
"INSERT INTO vec_descriptions(id, description_embedding) VALUES(?, ?)",
|
|
||||||
[event_id, serialize(embedding)],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Update if the event_id already exists
|
|
||||||
self.conn.execute(
|
|
||||||
"UPDATE vec_descriptions SET description_embedding = ? WHERE id = ?",
|
|
||||||
[serialize(embedding), event_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_thumbnail(self, event_ids: List[str]) -> None:
|
def delete_thumbnail(self, event_ids: List[str]) -> None:
|
||||||
ids = ", ".join("?" for _ in event_ids)
|
ids = ",".join(["?" for _ in event_ids])
|
||||||
|
self.db.execute_sql(
|
||||||
self.conn.execute(
|
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
|
||||||
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", tuple(event_ids)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_description(self, event_ids: List[str]) -> None:
|
def delete_description(self, event_ids: List[str]) -> None:
|
||||||
ids = ", ".join("?" for _ in event_ids)
|
ids = ",".join(["?" for _ in event_ids])
|
||||||
|
self.db.execute_sql(
|
||||||
self.conn.execute(
|
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
|
||||||
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", tuple(event_ids)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def search_thumbnail(self, event_id: str, limit=10) -> List[Tuple[str, float]]:
|
def search_thumbnail(
|
||||||
# check if it's already embedded
|
self, query: Union[Event, str], limit=10
|
||||||
cursor = self.conn.execute(
|
) -> List[Tuple[str, float]]:
|
||||||
"SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?", (event_id,)
|
if isinstance(query, Event):
|
||||||
)
|
cursor = self.db.execute_sql(
|
||||||
row = cursor.fetchone()
|
"""
|
||||||
if row:
|
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
|
||||||
query_embedding = deserialize(row[0])
|
""",
|
||||||
else:
|
[query.id],
|
||||||
# If not embedded, fetch the thumbnail from the Event table and embed it
|
)
|
||||||
event = Event.get_by_id(event_id)
|
|
||||||
thumbnail = base64.b64decode(event.thumbnail)
|
|
||||||
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
|
||||||
query_embedding = self.clip_embedding([image])[0]
|
|
||||||
self.upsert_thumbnail(event_id, thumbnail)
|
|
||||||
|
|
||||||
cursor = self.conn.execute(
|
row = cursor.fetchone() if cursor else None
|
||||||
|
|
||||||
|
if row:
|
||||||
|
query_embedding = deserialize(
|
||||||
|
row[0]
|
||||||
|
) # Deserialize the thumbnail embedding
|
||||||
|
else:
|
||||||
|
# If no embedding found, generate it
|
||||||
|
thumbnail = base64.b64decode(query.thumbnail)
|
||||||
|
self.upsert_thumbnail(query.id, thumbnail)
|
||||||
|
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
||||||
|
query = self.clip_embedding([image])[0]
|
||||||
|
|
||||||
|
query_embedding = self.clip_embedding([query])[0]
|
||||||
|
|
||||||
|
results = self.db.execute_sql(
|
||||||
"""
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
vec_thumbnails.id,
|
vec_thumbnails.id,
|
||||||
@ -178,14 +164,15 @@ class Embeddings:
|
|||||||
WHERE thumbnail_embedding MATCH ?
|
WHERE thumbnail_embedding MATCH ?
|
||||||
AND k = ?
|
AND k = ?
|
||||||
ORDER BY distance
|
ORDER BY distance
|
||||||
""",
|
""",
|
||||||
[serialize(query_embedding), limit],
|
(serialize(query_embedding), limit),
|
||||||
)
|
).fetchall()
|
||||||
return cursor.fetchall()
|
|
||||||
|
return results
|
||||||
|
|
||||||
def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]:
|
def search_description(self, query_text: str, limit=10) -> List[Tuple[str, float]]:
|
||||||
query_embedding = self.minilm_embedding([query_text])[0]
|
query_embedding = self.minilm_embedding([query_text])[0]
|
||||||
cursor = self.conn.execute(
|
results = self.db.execute_sql(
|
||||||
"""
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
vec_descriptions.id,
|
vec_descriptions.id,
|
||||||
@ -194,13 +181,13 @@ class Embeddings:
|
|||||||
WHERE description_embedding MATCH ?
|
WHERE description_embedding MATCH ?
|
||||||
AND k = ?
|
AND k = ?
|
||||||
ORDER BY distance
|
ORDER BY distance
|
||||||
""",
|
""",
|
||||||
[serialize(query_embedding), limit],
|
(serialize(query_embedding), limit),
|
||||||
)
|
).fetchall()
|
||||||
return cursor.fetchall()
|
|
||||||
|
return results
|
||||||
|
|
||||||
def reindex(self) -> None:
|
def reindex(self) -> None:
|
||||||
"""Reindex all event embeddings."""
|
|
||||||
logger.info("Indexing event embeddings...")
|
logger.info("Indexing event embeddings...")
|
||||||
|
|
||||||
st = time.time()
|
st = time.time()
|
||||||
|
|||||||
19
frigate/embeddings/sqlitevecq.py
Normal file
19
frigate/embeddings/sqlitevecq.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import sqlite_vec
|
||||||
|
from playhouse.sqliteq import SqliteQueueDatabase
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteVecQueueDatabase(SqliteQueueDatabase):
|
||||||
|
def __init__(self, *args, load_vec_extension=False, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.load_vec_extension = load_vec_extension
|
||||||
|
|
||||||
|
def _connect(self, *args, **kwargs):
|
||||||
|
conn = super()._connect(*args, **kwargs)
|
||||||
|
if self.load_vec_extension:
|
||||||
|
self._load_vec_extension(conn)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _load_vec_extension(self, conn):
|
||||||
|
conn.enable_load_extension(True)
|
||||||
|
sqlite_vec.load(conn)
|
||||||
|
conn.enable_load_extension(False)
|
||||||
Loading…
Reference in New Issue
Block a user