diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 8707f6f37..5351912fd 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -123,73 +123,97 @@ class Embeddings: device="GPU" if config.model_size == "large" else "CPU", ) - def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray: + def embed_thumbnail( + self, event_id: str, thumbnail: bytes, upsert: bool = True + ) -> ndarray: + """Embed thumbnail and optionally insert into DB. + + @param: event_id in Events DB + @param: thumbnail bytes in jpg format + @param: upsert If embedding should be upserted into vec DB + """ # Convert thumbnail bytes to PIL Image embedding = self.vision_embedding([thumbnail])[0] - self.db.execute_sql( - """ - INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) - VALUES(?, ?) - """, - (event_id, serialize(embedding)), - ) + if upsert: + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) + VALUES(?, ?) + """, + (event_id, serialize(embedding)), + ) return embedding - def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]: + def batch_embed_thumbnail( + self, event_thumbs: dict[str, bytes], upsert: bool = True + ) -> list[ndarray]: + """Embed thumbnails and optionally insert into DB. + + @param: event_thumbs Map of Event IDs in DB to thumbnail bytes in jpg format + @param: upsert If embedding should be upserted into vec DB + """ ids = list(event_thumbs.keys()) embeddings = self.vision_embedding(list(event_thumbs.values())) - items = [] + if upsert: + items = [] - for i in range(len(ids)): - items.append(ids[i]) - items.append(serialize(embeddings[i])) + for i in range(len(ids)): + items.append(ids[i]) + items.append(serialize(embeddings[i])) + + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) + VALUES {} + """.format(", ".join(["(?, ?)"] * len(ids))), + items, + ) - self.db.execute_sql( - """ - INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) - VALUES {} - """.format(", ".join(["(?, ?)"] * len(ids))), - items, - ) return embeddings - def upsert_description(self, event_id: str, description: str) -> ndarray: + def embed_description( + self, event_id: str, description: str, upsert: bool = True + ) -> ndarray: embedding = self.text_embedding([description])[0] - self.db.execute_sql( - """ - INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) - VALUES(?, ?) - """, - (event_id, serialize(embedding)), - ) + + if upsert: + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) + VALUES(?, ?) + """, + (event_id, serialize(embedding)), + ) return embedding - def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray: + def batch_embed_description( + self, event_descriptions: dict[str, str], upsert: bool = True + ) -> ndarray: # upsert embeddings one by one to avoid token limit embeddings = [] for desc in event_descriptions.values(): embeddings.append(self.text_embedding([desc])[0]) - ids = list(event_descriptions.keys()) + if upsert: + ids = list(event_descriptions.keys()) + items = [] - items = [] + for i in range(len(ids)): + items.append(ids[i]) + items.append(serialize(embeddings[i])) - for i in range(len(ids)): - items.append(ids[i]) - items.append(serialize(embeddings[i])) - - self.db.execute_sql( - """ - INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) - VALUES {} - """.format(", ".join(["(?, ?)"] * len(ids))), - items, - ) + self.db.execute_sql( + """ + INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) + VALUES {} + """.format(", ".join(["(?, ?)"] * len(ids))), + items, + ) return embeddings @@ -256,10 +280,10 @@ class Embeddings: totals["processed_objects"] += 1 # run batch embedding - self.batch_upsert_thumbnail(batch_thumbs) + self.batch_embed_thumbnail(batch_thumbs) if batch_descs: - self.batch_upsert_description(batch_descs) + self.batch_embed_description(batch_descs) # report progress every batch so we don't spam the logs progress = (totals["processed_objects"] / total_events) * 100 diff --git a/frigate/embeddings/maintainer.py b/frigate/embeddings/maintainer.py index 7ce63e7f8..1578a0fe3 100644 --- a/frigate/embeddings/maintainer.py +++ b/frigate/embeddings/maintainer.py @@ -86,7 +86,7 @@ class EmbeddingMaintainer(threading.Thread): try: if topic == EmbeddingsRequestEnum.embed_description.value: return serialize( - self.embeddings.upsert_description( + self.embeddings.embed_description( data["id"], data["description"] ), pack=False, @@ -94,7 +94,7 @@ class EmbeddingMaintainer(threading.Thread): elif topic == EmbeddingsRequestEnum.embed_thumbnail.value: thumbnail = base64.b64decode(data["thumbnail"]) return serialize( - self.embeddings.upsert_thumbnail(data["id"], thumbnail), + self.embeddings.embed_thumbnail(data["id"], thumbnail), pack=False, ) elif topic == EmbeddingsRequestEnum.generate_search.value: @@ -270,7 +270,7 @@ class EmbeddingMaintainer(threading.Thread): def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None: """Embed the thumbnail for an event.""" - self.embeddings.upsert_thumbnail(event_id, thumbnail) + self.embeddings.embed_thumbnail(event_id, thumbnail) def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None: """Embed the description for an event.""" @@ -290,8 +290,8 @@ class EmbeddingMaintainer(threading.Thread): {"id": event.id, "description": description}, ) - # Encode the description - self.embeddings.upsert_description(event.id, description) + # Embed the description + self.embeddings.embed_description(event.id, description) logger.debug( "Generated description for %s (%d images): %s",