Add device config for semantic search

This commit is contained in:
Nicolas Mowen 2024-10-10 06:32:32 -06:00
parent 1e6ee1a636
commit f4eef74fdf
7 changed files with 16 additions and 10 deletions

View File

@ -276,7 +276,7 @@ class FrigateApp:
def init_embeddings_client(self) -> None: def init_embeddings_client(self) -> None:
if self.config.semantic_search.enabled: if self.config.semantic_search.enabled:
# Create a client for other processes to use # Create a client for other processes to use
self.embeddings = EmbeddingsContext(self.db) self.embeddings = EmbeddingsContext(self.config, self.db)
def init_external_event_processor(self) -> None: def init_external_event_processor(self) -> None:
self.external_event_processor = ExternalEventProcessor(self.config) self.external_event_processor = ExternalEventProcessor(self.config)

View File

@ -12,3 +12,4 @@ class SemanticSearchConfig(FrigateBaseModel):
reindex: Optional[bool] = Field( reindex: Optional[bool] = Field(
default=False, title="Reindex all detections on startup." default=False, title="Reindex all detections on startup."
) )
device: str = Field(default="AUTO", title="Device Type")

View File

@ -55,7 +55,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
models = [Event] models = [Event]
db.bind(models) db.bind(models)
embeddings = Embeddings(db) embeddings = Embeddings(config.semantic_search, db)
# Check if we need to re-index events # Check if we need to re-index events
if config.semantic_search.reindex: if config.semantic_search.reindex:
@ -70,8 +70,8 @@ def manage_embeddings(config: FrigateConfig) -> None:
class EmbeddingsContext: class EmbeddingsContext:
def __init__(self, db: SqliteVecQueueDatabase): def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
self.embeddings = Embeddings(db) self.embeddings = Embeddings(config.semantic_search, db)
self.thumb_stats = ZScoreNormalization() self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization() self.desc_stats = ZScoreNormalization()

View File

@ -12,6 +12,7 @@ from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from frigate.comms.inter_process import InterProcessRequestor from frigate.comms.inter_process import InterProcessRequestor
from frigate.config.semantic_search import SemanticSearchConfig
from frigate.const import UPDATE_MODEL_STATE from frigate.const import UPDATE_MODEL_STATE
from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event from frigate.models import Event
@ -80,7 +81,10 @@ def deserialize(bytes_data: bytes) -> List[float]:
class Embeddings: class Embeddings:
"""SQLite-vec embeddings database.""" """SQLite-vec embeddings database."""
def __init__(self, db: SqliteVecQueueDatabase) -> None: def __init__(
self, config: SemanticSearchConfig, db: SqliteVecQueueDatabase
) -> None:
self.config = config
self.db = db self.db = db
self.requestor = InterProcessRequestor() self.requestor = InterProcessRequestor()
@ -118,7 +122,7 @@ class Embeddings:
}, },
embedding_function=jina_text_embedding_function, embedding_function=jina_text_embedding_function,
model_type="text", model_type="text",
force_cpu=True, device="CPU",
) )
self.vision_embedding = GenericONNXEmbedding( self.vision_embedding = GenericONNXEmbedding(
@ -130,6 +134,7 @@ class Embeddings:
}, },
embedding_function=jina_vision_embedding_function, embedding_function=jina_vision_embedding_function,
model_type="vision", model_type="vision",
device=self.config.device,
) )
def _create_tables(self): def _create_tables(self):

View File

@ -42,7 +42,7 @@ class GenericONNXEmbedding:
embedding_function: Callable[[List[np.ndarray]], np.ndarray], embedding_function: Callable[[List[np.ndarray]], np.ndarray],
model_type: str, model_type: str,
tokenizer_file: Optional[str] = None, tokenizer_file: Optional[str] = None,
force_cpu: bool = False, device: str = "AUTO",
): ):
self.model_name = model_name self.model_name = model_name
self.model_file = model_file self.model_file = model_file
@ -51,7 +51,7 @@ class GenericONNXEmbedding:
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.model_type = model_type # 'text' or 'vision' self.model_type = model_type # 'text' or 'vision'
self.providers, self.provider_options = get_ort_providers( self.providers, self.provider_options = get_ort_providers(
force_cpu=force_cpu, requires_fp16=True force_cpu=device == "CPU", requires_fp16=True, openvino_device=device
) )
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)

View File

@ -42,7 +42,7 @@ class EmbeddingMaintainer(threading.Thread):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.name = "embeddings_maintainer" self.name = "embeddings_maintainer"
self.config = config self.config = config
self.embeddings = Embeddings(db) self.embeddings = Embeddings(config.semantic_search, db)
self.event_subscriber = EventUpdateSubscriber() self.event_subscriber = EventUpdateSubscriber()
self.event_end_subscriber = EventEndSubscriber() self.event_end_subscriber = EventEndSubscriber()
self.event_metadata_subscriber = EventMetadataSubscriber( self.event_metadata_subscriber = EventMetadataSubscriber(

View File

@ -36,7 +36,7 @@ class EventCleanup(threading.Thread):
self.camera_labels: dict[str, dict[str, any]] = {} self.camera_labels: dict[str, dict[str, any]] = {}
if self.config.semantic_search.enabled: if self.config.semantic_search.enabled:
self.embeddings = Embeddings(self.db) self.embeddings = Embeddings(self.config.semantic_search, self.db)
def get_removed_camera_labels(self) -> list[Event]: def get_removed_camera_labels(self) -> list[Event]:
"""Get a list of distinct labels for removed cameras.""" """Get a list of distinct labels for removed cameras."""