cleanup embeddings inferences

This commit is contained in:
Nicolas Mowen 2025-01-04 15:20:48 -07:00
parent 3faadb633d
commit ccf8848143
2 changed files with 36 additions and 3 deletions

View File

@ -1,6 +1,7 @@
"""SQLite-vec embeddings database."""
import base64
import datetime
import logging
import os
import time
@ -21,6 +22,7 @@ from frigate.types import ModelStatusTypesEnum
from frigate.util.builtin import serialize
from .functions.onnx import GenericONNXEmbedding, ModelTypeEnum
from .types import EmbeddingsMetrics
logger = logging.getLogger(__name__)
@ -59,9 +61,15 @@ def get_metadata(event: Event) -> dict:
class Embeddings:
"""SQLite-vec embeddings database."""
def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase) -> None:
def __init__(
self,
config: FrigateConfig,
db: SqliteVecQueueDatabase,
metrics: EmbeddingsMetrics,
) -> None:
self.config = config
self.db = db
self.metrics = metrics
self.requestor = InterProcessRequestor()
# Create tables if they don't exist
@ -173,6 +181,7 @@ class Embeddings:
@param: thumbnail bytes in jpg format
@param: upsert If embedding should be upserted into vec DB
"""
start = datetime.datetime.now().timestamp()
# Convert thumbnail bytes to PIL Image
embedding = self.vision_embedding([thumbnail])[0]
@ -185,6 +194,11 @@ class Embeddings:
(event_id, serialize(embedding)),
)
duration = datetime.datetime.now().timestamp() - start
self.metrics.image_embeddings_fps.value = (
self.metrics.image_embeddings_fps.value * 9 + duration
) / 10
return embedding
def batch_embed_thumbnail(
@ -195,6 +209,7 @@ class Embeddings:
@param: event_thumbs Map of Event IDs in DB to thumbnail bytes in jpg format
@param: upsert If embedding should be upserted into vec DB
"""
start = datetime.datetime.now().timestamp()
ids = list(event_thumbs.keys())
embeddings = self.vision_embedding(list(event_thumbs.values()))
@ -213,11 +228,17 @@ class Embeddings:
items,
)
duration = datetime.datetime.now().timestamp() - start
self.metrics.text_embeddings_sps.value = (
self.metrics.text_embeddings_sps.value * 9 + (duration / len(ids))
) / 10
return embeddings
def embed_description(
self, event_id: str, description: str, upsert: bool = True
) -> ndarray:
start = datetime.datetime.now().timestamp()
embedding = self.text_embedding([description])[0]
if upsert:
@ -229,11 +250,17 @@ class Embeddings:
(event_id, serialize(embedding)),
)
duration = datetime.datetime.now().timestamp() - start
self.metrics.text_embeddings_sps.value = (
self.metrics.text_embeddings_sps.value * 9 + duration
) / 10
return embedding
def batch_embed_description(
self, event_descriptions: dict[str, str], upsert: bool = True
) -> ndarray:
start = datetime.datetime.now().timestamp()
# upsert embeddings one by one to avoid token limit
embeddings = []
@ -256,6 +283,11 @@ class Embeddings:
items,
)
duration = datetime.datetime.now().timestamp() - start
self.metrics.text_embeddings_sps.value = (
self.metrics.text_embeddings_sps.value * 9 + (duration / len(ids))
) / 10
return embeddings
def reindex(self) -> None:

View File

@ -62,7 +62,7 @@ class EmbeddingMaintainer(threading.Thread):
super().__init__(name="embeddings_maintainer")
self.config = config
self.metrics = metrics
self.embeddings = Embeddings(config, db)
self.embeddings = Embeddings(config, db, metrics)
# Check if we need to re-index events
if config.semantic_search.reindex:
@ -139,7 +139,8 @@ class EmbeddingMaintainer(threading.Thread):
)
elif topic == EmbeddingsRequestEnum.generate_search.value:
return serialize(
self.embeddings.text_embedding([data])[0], pack=False
self.embeddings.embed_description("", data, upsert=False),
pack=False,
)
elif topic == EmbeddingsRequestEnum.register_face.value:
if not self.face_recognition_enabled: