Improve naming and handling of embeddings

This commit is contained in:
Nicolas Mowen 2024-10-21 15:36:32 -06:00
parent 2381f7a754
commit 2713928f7b
2 changed files with 73 additions and 49 deletions

View File

@ -123,73 +123,97 @@ class Embeddings:
device="GPU" if config.model_size == "large" else "CPU",
)
def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray:
def embed_thumbnail(
self, event_id: str, thumbnail: bytes, upsert: bool = True
) -> ndarray:
"""Embed thumbnail and optionally insert into DB.
@param: event_id in Events DB
@param: thumbnail bytes in jpg format
@param: upsert If embedding should be upserted into vec DB
"""
# Convert thumbnail bytes to PIL Image
embedding = self.vision_embedding([thumbnail])[0]
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
)
if upsert:
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
)
return embedding
def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]:
def batch_embed_thumbnail(
self, event_thumbs: dict[str, bytes], upsert: bool = True
) -> list[ndarray]:
"""Embed thumbnails and optionally insert into DB.
@param: event_thumbs Map of Event IDs in DB to thumbnail bytes in jpg format
@param: upsert If embedding should be upserted into vec DB
"""
ids = list(event_thumbs.keys())
embeddings = self.vision_embedding(list(event_thumbs.values()))
items = []
if upsert:
items = []
for i in range(len(ids)):
items.append(ids[i])
items.append(serialize(embeddings[i]))
for i in range(len(ids)):
items.append(ids[i])
items.append(serialize(embeddings[i]))
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(ids))),
items,
)
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(ids))),
items,
)
return embeddings
def upsert_description(self, event_id: str, description: str) -> ndarray:
def embed_description(
self, event_id: str, description: str, upsert: bool = True
) -> ndarray:
embedding = self.text_embedding([description])[0]
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
)
if upsert:
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES(?, ?)
""",
(event_id, serialize(embedding)),
)
return embedding
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
def batch_embed_description(
self, event_descriptions: dict[str, str], upsert: bool = True
) -> ndarray:
# upsert embeddings one by one to avoid token limit
embeddings = []
for desc in event_descriptions.values():
embeddings.append(self.text_embedding([desc])[0])
ids = list(event_descriptions.keys())
if upsert:
ids = list(event_descriptions.keys())
items = []
items = []
for i in range(len(ids)):
items.append(ids[i])
items.append(serialize(embeddings[i]))
for i in range(len(ids)):
items.append(ids[i])
items.append(serialize(embeddings[i]))
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(ids))),
items,
)
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(ids))),
items,
)
return embeddings
@ -256,10 +280,10 @@ class Embeddings:
totals["processed_objects"] += 1
# run batch embedding
self.batch_upsert_thumbnail(batch_thumbs)
self.batch_embed_thumbnail(batch_thumbs)
if batch_descs:
self.batch_upsert_description(batch_descs)
self.batch_embed_description(batch_descs)
# report progress every batch so we don't spam the logs
progress = (totals["processed_objects"] / total_events) * 100

View File

@ -86,7 +86,7 @@ class EmbeddingMaintainer(threading.Thread):
try:
if topic == EmbeddingsRequestEnum.embed_description.value:
return serialize(
self.embeddings.upsert_description(
self.embeddings.embed_description(
data["id"], data["description"]
),
pack=False,
@ -94,7 +94,7 @@ class EmbeddingMaintainer(threading.Thread):
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
thumbnail = base64.b64decode(data["thumbnail"])
return serialize(
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
self.embeddings.embed_thumbnail(data["id"], thumbnail),
pack=False,
)
elif topic == EmbeddingsRequestEnum.generate_search.value:
@ -270,7 +270,7 @@ class EmbeddingMaintainer(threading.Thread):
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
"""Embed the thumbnail for an event."""
self.embeddings.upsert_thumbnail(event_id, thumbnail)
self.embeddings.embed_thumbnail(event_id, thumbnail)
def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
"""Embed the description for an event."""
@ -290,8 +290,8 @@ class EmbeddingMaintainer(threading.Thread):
{"id": event.id, "description": description},
)
# Encode the description
self.embeddings.upsert_description(event.id, description)
# Embed the description
self.embeddings.embed_description(event.id, description)
logger.debug(
"Generated description for %s (%d images): %s",