Improve naming and handling of embeddings

This commit is contained in:
Nicolas Mowen 2024-10-21 15:36:32 -06:00
parent 2381f7a754
commit 2713928f7b
2 changed files with 73 additions and 49 deletions

View File

@ -123,73 +123,97 @@ class Embeddings:
device="GPU" if config.model_size == "large" else "CPU", device="GPU" if config.model_size == "large" else "CPU",
) )
def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray: def embed_thumbnail(
self, event_id: str, thumbnail: bytes, upsert: bool = True
) -> ndarray:
"""Embed thumbnail and optionally insert into DB.
@param: event_id in Events DB
@param: thumbnail bytes in jpg format
@param: upsert If embedding should be upserted into vec DB
"""
# Convert thumbnail bytes to PIL Image # Convert thumbnail bytes to PIL Image
embedding = self.vision_embedding([thumbnail])[0] embedding = self.vision_embedding([thumbnail])[0]
self.db.execute_sql( if upsert:
""" self.db.execute_sql(
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) """
VALUES(?, ?) INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
""", VALUES(?, ?)
(event_id, serialize(embedding)), """,
) (event_id, serialize(embedding)),
)
return embedding return embedding
def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]: def batch_embed_thumbnail(
self, event_thumbs: dict[str, bytes], upsert: bool = True
) -> list[ndarray]:
"""Embed thumbnails and optionally insert into DB.
@param: event_thumbs Map of Event IDs in DB to thumbnail bytes in jpg format
@param: upsert If embedding should be upserted into vec DB
"""
ids = list(event_thumbs.keys()) ids = list(event_thumbs.keys())
embeddings = self.vision_embedding(list(event_thumbs.values())) embeddings = self.vision_embedding(list(event_thumbs.values()))
items = [] if upsert:
items = []
for i in range(len(ids)): for i in range(len(ids)):
items.append(ids[i]) items.append(ids[i])
items.append(serialize(embeddings[i])) items.append(serialize(embeddings[i]))
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(ids))),
items,
)
self.db.execute_sql(
"""
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES {}
""".format(", ".join(["(?, ?)"] * len(ids))),
items,
)
return embeddings return embeddings
def upsert_description(self, event_id: str, description: str) -> ndarray: def embed_description(
self, event_id: str, description: str, upsert: bool = True
) -> ndarray:
embedding = self.text_embedding([description])[0] embedding = self.text_embedding([description])[0]
self.db.execute_sql(
""" if upsert:
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) self.db.execute_sql(
VALUES(?, ?) """
""", INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
(event_id, serialize(embedding)), VALUES(?, ?)
) """,
(event_id, serialize(embedding)),
)
return embedding return embedding
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray: def batch_embed_description(
self, event_descriptions: dict[str, str], upsert: bool = True
) -> ndarray:
# upsert embeddings one by one to avoid token limit # upsert embeddings one by one to avoid token limit
embeddings = [] embeddings = []
for desc in event_descriptions.values(): for desc in event_descriptions.values():
embeddings.append(self.text_embedding([desc])[0]) embeddings.append(self.text_embedding([desc])[0])
ids = list(event_descriptions.keys()) if upsert:
ids = list(event_descriptions.keys())
items = []
items = [] for i in range(len(ids)):
items.append(ids[i])
items.append(serialize(embeddings[i]))
for i in range(len(ids)): self.db.execute_sql(
items.append(ids[i]) """
items.append(serialize(embeddings[i])) INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES {}
self.db.execute_sql( """.format(", ".join(["(?, ?)"] * len(ids))),
""" items,
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) )
VALUES {}
""".format(", ".join(["(?, ?)"] * len(ids))),
items,
)
return embeddings return embeddings
@ -256,10 +280,10 @@ class Embeddings:
totals["processed_objects"] += 1 totals["processed_objects"] += 1
# run batch embedding # run batch embedding
self.batch_upsert_thumbnail(batch_thumbs) self.batch_embed_thumbnail(batch_thumbs)
if batch_descs: if batch_descs:
self.batch_upsert_description(batch_descs) self.batch_embed_description(batch_descs)
# report progress every batch so we don't spam the logs # report progress every batch so we don't spam the logs
progress = (totals["processed_objects"] / total_events) * 100 progress = (totals["processed_objects"] / total_events) * 100

View File

@ -86,7 +86,7 @@ class EmbeddingMaintainer(threading.Thread):
try: try:
if topic == EmbeddingsRequestEnum.embed_description.value: if topic == EmbeddingsRequestEnum.embed_description.value:
return serialize( return serialize(
self.embeddings.upsert_description( self.embeddings.embed_description(
data["id"], data["description"] data["id"], data["description"]
), ),
pack=False, pack=False,
@ -94,7 +94,7 @@ class EmbeddingMaintainer(threading.Thread):
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value: elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
thumbnail = base64.b64decode(data["thumbnail"]) thumbnail = base64.b64decode(data["thumbnail"])
return serialize( return serialize(
self.embeddings.upsert_thumbnail(data["id"], thumbnail), self.embeddings.embed_thumbnail(data["id"], thumbnail),
pack=False, pack=False,
) )
elif topic == EmbeddingsRequestEnum.generate_search.value: elif topic == EmbeddingsRequestEnum.generate_search.value:
@ -270,7 +270,7 @@ class EmbeddingMaintainer(threading.Thread):
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None: def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
"""Embed the thumbnail for an event.""" """Embed the thumbnail for an event."""
self.embeddings.upsert_thumbnail(event_id, thumbnail) self.embeddings.embed_thumbnail(event_id, thumbnail)
def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None: def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
"""Embed the description for an event.""" """Embed the description for an event."""
@ -290,8 +290,8 @@ class EmbeddingMaintainer(threading.Thread):
{"id": event.id, "description": description}, {"id": event.id, "description": description},
) )
# Encode the description # Embed the description
self.embeddings.upsert_description(event.id, description) self.embeddings.embed_description(event.id, description)
logger.debug( logger.debug(
"Generated description for %s (%d images): %s", "Generated description for %s (%d images): %s",