From f4eef74fdfbed8fba16decd1cc663f778c73399d Mon Sep 17 00:00:00 2001 From: Nicolas Mowen Date: Thu, 10 Oct 2024 06:32:32 -0600 Subject: [PATCH] Add device config for semantic search --- frigate/app.py | 2 +- frigate/config/semantic_search.py | 1 + frigate/embeddings/__init__.py | 6 +++--- frigate/embeddings/embeddings.py | 9 +++++++-- frigate/embeddings/functions/onnx.py | 4 ++-- frigate/embeddings/maintainer.py | 2 +- frigate/events/cleanup.py | 2 +- 7 files changed, 16 insertions(+), 10 deletions(-) diff --git a/frigate/app.py b/frigate/app.py index 1f652ecb2..253bebf89 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -276,7 +276,7 @@ class FrigateApp: def init_embeddings_client(self) -> None: if self.config.semantic_search.enabled: # Create a client for other processes to use - self.embeddings = EmbeddingsContext(self.db) + self.embeddings = EmbeddingsContext(self.config, self.db) def init_external_event_processor(self) -> None: self.external_event_processor = ExternalEventProcessor(self.config) diff --git a/frigate/config/semantic_search.py b/frigate/config/semantic_search.py index a2274e041..ecdcd12d1 100644 --- a/frigate/config/semantic_search.py +++ b/frigate/config/semantic_search.py @@ -12,3 +12,4 @@ class SemanticSearchConfig(FrigateBaseModel): reindex: Optional[bool] = Field( default=False, title="Reindex all detections on startup." ) + device: str = Field(default="AUTO", title="Device Type") diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index 381d95ed1..1c384e90f 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -55,7 +55,7 @@ def manage_embeddings(config: FrigateConfig) -> None: models = [Event] db.bind(models) - embeddings = Embeddings(db) + embeddings = Embeddings(config.semantic_search, db) # Check if we need to re-index events if config.semantic_search.reindex: @@ -70,8 +70,8 @@ def manage_embeddings(config: FrigateConfig) -> None: class EmbeddingsContext: - def __init__(self, db: SqliteVecQueueDatabase): - self.embeddings = Embeddings(db) + def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase): + self.embeddings = Embeddings(config.semantic_search, db) self.thumb_stats = ZScoreNormalization() self.desc_stats = ZScoreNormalization() diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 35d76dece..bb2565048 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -12,6 +12,7 @@ from PIL import Image from playhouse.shortcuts import model_to_dict from frigate.comms.inter_process import InterProcessRequestor +from frigate.config.semantic_search import SemanticSearchConfig from frigate.const import UPDATE_MODEL_STATE from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.models import Event @@ -80,7 +81,10 @@ def deserialize(bytes_data: bytes) -> List[float]: class Embeddings: """SQLite-vec embeddings database.""" - def __init__(self, db: SqliteVecQueueDatabase) -> None: + def __init__( + self, config: SemanticSearchConfig, db: SqliteVecQueueDatabase + ) -> None: + self.config = config self.db = db self.requestor = InterProcessRequestor() @@ -118,7 +122,7 @@ class Embeddings: }, embedding_function=jina_text_embedding_function, model_type="text", - force_cpu=True, + device="CPU", ) self.vision_embedding = GenericONNXEmbedding( @@ -130,6 +134,7 @@ class Embeddings: }, embedding_function=jina_vision_embedding_function, model_type="vision", + device=self.config.device, ) def _create_tables(self): diff --git a/frigate/embeddings/functions/onnx.py b/frigate/embeddings/functions/onnx.py index b5f15f391..08901b6a2 100644 --- a/frigate/embeddings/functions/onnx.py +++ b/frigate/embeddings/functions/onnx.py @@ -42,7 +42,7 @@ class GenericONNXEmbedding: embedding_function: Callable[[List[np.ndarray]], np.ndarray], model_type: str, tokenizer_file: Optional[str] = None, - force_cpu: bool = False, + device: str = "AUTO", ): self.model_name = model_name self.model_file = model_file @@ -51,7 +51,7 @@ class GenericONNXEmbedding: self.embedding_function = embedding_function self.model_type = model_type # 'text' or 'vision' self.providers, self.provider_options = get_ort_providers( - force_cpu=force_cpu, requires_fp16=True + force_cpu=device == "CPU", requires_fp16=True, openvino_device=device ) self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 4cb6a3bca..c95cb5050 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -42,7 +42,7 @@ class EmbeddingMaintainer(threading.Thread): threading.Thread.__init__(self) self.name = "embeddings_maintainer" self.config = config - self.embeddings = Embeddings(db) + self.embeddings = Embeddings(config.semantic_search, db) self.event_subscriber = EventUpdateSubscriber() self.event_end_subscriber = EventEndSubscriber() self.event_metadata_subscriber = EventMetadataSubscriber( diff --git a/frigate/events/cleanup.py b/frigate/events/cleanup.py index 74f4a59ac..828b295b4 100644 --- a/frigate/events/cleanup.py +++ b/frigate/events/cleanup.py @@ -36,7 +36,7 @@ class EventCleanup(threading.Thread): self.camera_labels: dict[str, dict[str, any]] = {} if self.config.semantic_search.enabled: - self.embeddings = Embeddings(self.db) + self.embeddings = Embeddings(self.config.semantic_search, self.db) def get_removed_camera_labels(self) -> list[Event]: """Get a list of distinct labels for removed cameras."""