load sqlite_vec in embeddings manager

This commit is contained in:
Josh Hawkins 2024-10-04 13:20:02 -05:00
parent affa7bdc77
commit 139c8652a9

View File

@ -1,13 +1,15 @@
"""ChromaDB embeddings database.""" """SQLite-vec embeddings database."""
import json import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os
import signal import signal
import threading import threading
from types import FrameType from types import FrameType
from typing import Optional from typing import Optional
import sqlite_vec
from playhouse.sqliteq import SqliteQueueDatabase from playhouse.sqliteq import SqliteQueueDatabase
from setproctitle import setproctitle from setproctitle import setproctitle
@ -53,13 +55,19 @@ def manage_embeddings(config: FrigateConfig) -> None:
models = [Event] models = [Event]
db.bind(models) db.bind(models)
embeddings = Embeddings() conn = db.connection()
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
embeddings = Embeddings(db)
# Check if we need to re-index events # Check if we need to re-index events
if config.semantic_search.reindex: if config.semantic_search.reindex:
embeddings.reindex() embeddings.reindex()
maintainer = EmbeddingMaintainer( maintainer = EmbeddingMaintainer(
db,
config, config,
stop_event, stop_event,
) )
@ -67,14 +75,17 @@ def manage_embeddings(config: FrigateConfig) -> None:
class EmbeddingsContext: class EmbeddingsContext:
def __init__(self): def __init__(self, db: SqliteQueueDatabase):
self.embeddings = Embeddings() self.db = db
self.embeddings = Embeddings(db)
self.thumb_stats = ZScoreNormalization() self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization() self.desc_stats = ZScoreNormalization()
logger.info(f"Initializing db: {self.db}")
# load stats from disk # load stats from disk
try: try:
with open(f"{CONFIG_DIR}/.search_stats.json", "r") as f: with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "r") as f:
data = json.loads(f.read()) data = json.loads(f.read())
self.thumb_stats.from_dict(data["thumb_stats"]) self.thumb_stats.from_dict(data["thumb_stats"])
self.desc_stats.from_dict(data["desc_stats"]) self.desc_stats.from_dict(data["desc_stats"])
@ -87,5 +98,5 @@ class EmbeddingsContext:
"thumb_stats": self.thumb_stats.to_dict(), "thumb_stats": self.thumb_stats.to_dict(),
"desc_stats": self.desc_stats.to_dict(), "desc_stats": self.desc_stats.to_dict(),
} }
with open(f"{CONFIG_DIR}/.search_stats.json", "w") as f: with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
f.write(json.dumps(contents)) json.dump(contents, f)