cleanup embeddings inferences

This commit is contained in:
Nicolas Mowen 2025-01-04 15:20:48 -07:00
parent 3faadb633d
commit ccf8848143
2 changed files with 36 additions and 3 deletions

View File

@ -1,6 +1,7 @@
"""SQLite-vec embeddings database.""" """SQLite-vec embeddings database."""
import base64 import base64
import datetime
import logging import logging
import os import os
import time import time
@ -21,6 +22,7 @@ from frigate.types import ModelStatusTypesEnum
from frigate.util.builtin import serialize from frigate.util.builtin import serialize
from .functions.onnx import GenericONNXEmbedding, ModelTypeEnum from .functions.onnx import GenericONNXEmbedding, ModelTypeEnum
from .types import EmbeddingsMetrics
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,9 +61,15 @@ def get_metadata(event: Event) -> dict:
class Embeddings: class Embeddings:
"""SQLite-vec embeddings database.""" """SQLite-vec embeddings database."""
def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase) -> None: def __init__(
self,
config: FrigateConfig,
db: SqliteVecQueueDatabase,
metrics: EmbeddingsMetrics,
) -> None:
self.config = config self.config = config
self.db = db self.db = db
self.metrics = metrics
self.requestor = InterProcessRequestor() self.requestor = InterProcessRequestor()
# Create tables if they don't exist # Create tables if they don't exist
@ -173,6 +181,7 @@ class Embeddings:
@param: thumbnail bytes in jpg format @param: thumbnail bytes in jpg format
@param: upsert If embedding should be upserted into vec DB @param: upsert If embedding should be upserted into vec DB
""" """
start = datetime.datetime.now().timestamp()
# Convert thumbnail bytes to PIL Image # Convert thumbnail bytes to PIL Image
embedding = self.vision_embedding([thumbnail])[0] embedding = self.vision_embedding([thumbnail])[0]
@ -185,6 +194,11 @@ class Embeddings:
(event_id, serialize(embedding)), (event_id, serialize(embedding)),
) )
duration = datetime.datetime.now().timestamp() - start
self.metrics.image_embeddings_fps.value = (
self.metrics.image_embeddings_fps.value * 9 + duration
) / 10
return embedding return embedding
def batch_embed_thumbnail( def batch_embed_thumbnail(
@ -195,6 +209,7 @@ class Embeddings:
@param: event_thumbs Map of Event IDs in DB to thumbnail bytes in jpg format @param: event_thumbs Map of Event IDs in DB to thumbnail bytes in jpg format
@param: upsert If embedding should be upserted into vec DB @param: upsert If embedding should be upserted into vec DB
""" """
start = datetime.datetime.now().timestamp()
ids = list(event_thumbs.keys()) ids = list(event_thumbs.keys())
embeddings = self.vision_embedding(list(event_thumbs.values())) embeddings = self.vision_embedding(list(event_thumbs.values()))
@ -213,11 +228,17 @@ class Embeddings:
items, items,
) )
duration = datetime.datetime.now().timestamp() - start
self.metrics.text_embeddings_sps.value = (
self.metrics.text_embeddings_sps.value * 9 + (duration / len(ids))
) / 10
return embeddings return embeddings
def embed_description( def embed_description(
self, event_id: str, description: str, upsert: bool = True self, event_id: str, description: str, upsert: bool = True
) -> ndarray: ) -> ndarray:
start = datetime.datetime.now().timestamp()
embedding = self.text_embedding([description])[0] embedding = self.text_embedding([description])[0]
if upsert: if upsert:
@ -229,11 +250,17 @@ class Embeddings:
(event_id, serialize(embedding)), (event_id, serialize(embedding)),
) )
duration = datetime.datetime.now().timestamp() - start
self.metrics.text_embeddings_sps.value = (
self.metrics.text_embeddings_sps.value * 9 + duration
) / 10
return embedding return embedding
def batch_embed_description( def batch_embed_description(
self, event_descriptions: dict[str, str], upsert: bool = True self, event_descriptions: dict[str, str], upsert: bool = True
) -> ndarray: ) -> ndarray:
start = datetime.datetime.now().timestamp()
# upsert embeddings one by one to avoid token limit # upsert embeddings one by one to avoid token limit
embeddings = [] embeddings = []
@ -256,6 +283,11 @@ class Embeddings:
items, items,
) )
duration = datetime.datetime.now().timestamp() - start
self.metrics.text_embeddings_sps.value = (
self.metrics.text_embeddings_sps.value * 9 + (duration / len(ids))
) / 10
return embeddings return embeddings
def reindex(self) -> None: def reindex(self) -> None:

View File

@ -62,7 +62,7 @@ class EmbeddingMaintainer(threading.Thread):
super().__init__(name="embeddings_maintainer") super().__init__(name="embeddings_maintainer")
self.config = config self.config = config
self.metrics = metrics self.metrics = metrics
self.embeddings = Embeddings(config, db) self.embeddings = Embeddings(config, db, metrics)
# Check if we need to re-index events # Check if we need to re-index events
if config.semantic_search.reindex: if config.semantic_search.reindex:
@ -139,7 +139,8 @@ class EmbeddingMaintainer(threading.Thread):
) )
elif topic == EmbeddingsRequestEnum.generate_search.value: elif topic == EmbeddingsRequestEnum.generate_search.value:
return serialize( return serialize(
self.embeddings.text_embedding([data])[0], pack=False self.embeddings.embed_description("", data, upsert=False),
pack=False,
) )
elif topic == EmbeddingsRequestEnum.register_face.value: elif topic == EmbeddingsRequestEnum.register_face.value:
if not self.face_recognition_enabled: if not self.face_recognition_enabled: