more readable loops

This commit is contained in:
Josh Hawkins 2024-10-13 16:14:10 -05:00
parent 7a458f6c2b
commit 5374c18e71

View File

@ -145,15 +145,19 @@ class Embeddings:
] ]
ids = list(event_thumbs.keys()) ids = list(event_thumbs.keys())
embeddings = self.vision_embedding(images) embeddings = self.vision_embedding(images)
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
flat_items = [item for sublist in items for item in sublist] items = []
for i in range(len(ids)):
items.append(ids[i])
items.append(serialize(embeddings[i]))
self.db.execute_sql( self.db.execute_sql(
""" """
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding) INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
VALUES {} VALUES {}
""".format(", ".join(["(?, ?)"] * len(items))), """.format(", ".join(["(?, ?)"] * len(items))),
flat_items, items,
) )
return embeddings return embeddings
@ -172,15 +176,19 @@ class Embeddings:
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray: def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
embeddings = self.text_embedding(list(event_descriptions.values())) embeddings = self.text_embedding(list(event_descriptions.values()))
ids = list(event_descriptions.keys()) ids = list(event_descriptions.keys())
items = [(ids[i], serialize(embeddings[i])) for i in range(len(ids))]
flat_items = [item for sublist in items for item in sublist] items = []
for i in range(len(ids)):
items.append(ids[i])
items.append(serialize(embeddings[i]))
self.db.execute_sql( self.db.execute_sql(
""" """
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding) INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
VALUES {} VALUES {}
""".format(", ".join(["(?, ?)"] * len(items))), """.format(", ".join(["(?, ?)"] * len(items))),
flat_items, items,
) )
return embeddings return embeddings