mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 15:45:27 +03:00
Improve naming and handling of embeddings
This commit is contained in:
parent
2381f7a754
commit
2713928f7b
@ -123,73 +123,97 @@ class Embeddings:
|
||||
device="GPU" if config.model_size == "large" else "CPU",
|
||||
)
|
||||
|
||||
def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray:
|
||||
def embed_thumbnail(
|
||||
self, event_id: str, thumbnail: bytes, upsert: bool = True
|
||||
) -> ndarray:
|
||||
"""Embed thumbnail and optionally insert into DB.
|
||||
|
||||
@param: event_id in Events DB
|
||||
@param: thumbnail bytes in jpg format
|
||||
@param: upsert If embedding should be upserted into vec DB
|
||||
"""
|
||||
# Convert thumbnail bytes to PIL Image
|
||||
embedding = self.vision_embedding([thumbnail])[0]
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||
VALUES(?, ?)
|
||||
""",
|
||||
(event_id, serialize(embedding)),
|
||||
)
|
||||
if upsert:
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||
VALUES(?, ?)
|
||||
""",
|
||||
(event_id, serialize(embedding)),
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]:
|
||||
def batch_embed_thumbnail(
|
||||
self, event_thumbs: dict[str, bytes], upsert: bool = True
|
||||
) -> list[ndarray]:
|
||||
"""Embed thumbnails and optionally insert into DB.
|
||||
|
||||
@param: event_thumbs Map of Event IDs in DB to thumbnail bytes in jpg format
|
||||
@param: upsert If embedding should be upserted into vec DB
|
||||
"""
|
||||
ids = list(event_thumbs.keys())
|
||||
embeddings = self.vision_embedding(list(event_thumbs.values()))
|
||||
|
||||
items = []
|
||||
if upsert:
|
||||
items = []
|
||||
|
||||
for i in range(len(ids)):
|
||||
items.append(ids[i])
|
||||
items.append(serialize(embeddings[i]))
|
||||
for i in range(len(ids)):
|
||||
items.append(ids[i])
|
||||
items.append(serialize(embeddings[i]))
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||
VALUES {}
|
||||
""".format(", ".join(["(?, ?)"] * len(ids))),
|
||||
items,
|
||||
)
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||
VALUES {}
|
||||
""".format(", ".join(["(?, ?)"] * len(ids))),
|
||||
items,
|
||||
)
|
||||
return embeddings
|
||||
|
||||
def upsert_description(self, event_id: str, description: str) -> ndarray:
|
||||
def embed_description(
|
||||
self, event_id: str, description: str, upsert: bool = True
|
||||
) -> ndarray:
|
||||
embedding = self.text_embedding([description])[0]
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||
VALUES(?, ?)
|
||||
""",
|
||||
(event_id, serialize(embedding)),
|
||||
)
|
||||
|
||||
if upsert:
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||
VALUES(?, ?)
|
||||
""",
|
||||
(event_id, serialize(embedding)),
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
|
||||
def batch_embed_description(
|
||||
self, event_descriptions: dict[str, str], upsert: bool = True
|
||||
) -> ndarray:
|
||||
# upsert embeddings one by one to avoid token limit
|
||||
embeddings = []
|
||||
|
||||
for desc in event_descriptions.values():
|
||||
embeddings.append(self.text_embedding([desc])[0])
|
||||
|
||||
ids = list(event_descriptions.keys())
|
||||
if upsert:
|
||||
ids = list(event_descriptions.keys())
|
||||
items = []
|
||||
|
||||
items = []
|
||||
for i in range(len(ids)):
|
||||
items.append(ids[i])
|
||||
items.append(serialize(embeddings[i]))
|
||||
|
||||
for i in range(len(ids)):
|
||||
items.append(ids[i])
|
||||
items.append(serialize(embeddings[i]))
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||
VALUES {}
|
||||
""".format(", ".join(["(?, ?)"] * len(ids))),
|
||||
items,
|
||||
)
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||
VALUES {}
|
||||
""".format(", ".join(["(?, ?)"] * len(ids))),
|
||||
items,
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
@ -256,10 +280,10 @@ class Embeddings:
|
||||
totals["processed_objects"] += 1
|
||||
|
||||
# run batch embedding
|
||||
self.batch_upsert_thumbnail(batch_thumbs)
|
||||
self.batch_embed_thumbnail(batch_thumbs)
|
||||
|
||||
if batch_descs:
|
||||
self.batch_upsert_description(batch_descs)
|
||||
self.batch_embed_description(batch_descs)
|
||||
|
||||
# report progress every batch so we don't spam the logs
|
||||
progress = (totals["processed_objects"] / total_events) * 100
|
||||
|
||||
@ -86,7 +86,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
try:
|
||||
if topic == EmbeddingsRequestEnum.embed_description.value:
|
||||
return serialize(
|
||||
self.embeddings.upsert_description(
|
||||
self.embeddings.embed_description(
|
||||
data["id"], data["description"]
|
||||
),
|
||||
pack=False,
|
||||
@ -94,7 +94,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
|
||||
thumbnail = base64.b64decode(data["thumbnail"])
|
||||
return serialize(
|
||||
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
|
||||
self.embeddings.embed_thumbnail(data["id"], thumbnail),
|
||||
pack=False,
|
||||
)
|
||||
elif topic == EmbeddingsRequestEnum.generate_search.value:
|
||||
@ -270,7 +270,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
|
||||
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
|
||||
"""Embed the thumbnail for an event."""
|
||||
self.embeddings.upsert_thumbnail(event_id, thumbnail)
|
||||
self.embeddings.embed_thumbnail(event_id, thumbnail)
|
||||
|
||||
def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
|
||||
"""Embed the description for an event."""
|
||||
@ -290,8 +290,8 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
{"id": event.id, "description": description},
|
||||
)
|
||||
|
||||
# Encode the description
|
||||
self.embeddings.upsert_description(event.id, description)
|
||||
# Embed the description
|
||||
self.embeddings.embed_description(event.id, description)
|
||||
|
||||
logger.debug(
|
||||
"Generated description for %s (%d images): %s",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user