mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-17 00:25:23 +03:00
Improve naming and handling of embeddings
This commit is contained in:
parent
2381f7a754
commit
2713928f7b
@ -123,73 +123,97 @@ class Embeddings:
|
|||||||
device="GPU" if config.model_size == "large" else "CPU",
|
device="GPU" if config.model_size == "large" else "CPU",
|
||||||
)
|
)
|
||||||
|
|
||||||
def upsert_thumbnail(self, event_id: str, thumbnail: bytes) -> ndarray:
|
def embed_thumbnail(
|
||||||
|
self, event_id: str, thumbnail: bytes, upsert: bool = True
|
||||||
|
) -> ndarray:
|
||||||
|
"""Embed thumbnail and optionally insert into DB.
|
||||||
|
|
||||||
|
@param: event_id in Events DB
|
||||||
|
@param: thumbnail bytes in jpg format
|
||||||
|
@param: upsert If embedding should be upserted into vec DB
|
||||||
|
"""
|
||||||
# Convert thumbnail bytes to PIL Image
|
# Convert thumbnail bytes to PIL Image
|
||||||
embedding = self.vision_embedding([thumbnail])[0]
|
embedding = self.vision_embedding([thumbnail])[0]
|
||||||
|
|
||||||
self.db.execute_sql(
|
if upsert:
|
||||||
"""
|
self.db.execute_sql(
|
||||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
"""
|
||||||
VALUES(?, ?)
|
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||||
""",
|
VALUES(?, ?)
|
||||||
(event_id, serialize(embedding)),
|
""",
|
||||||
)
|
(event_id, serialize(embedding)),
|
||||||
|
)
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def batch_upsert_thumbnail(self, event_thumbs: dict[str, bytes]) -> list[ndarray]:
|
def batch_embed_thumbnail(
|
||||||
|
self, event_thumbs: dict[str, bytes], upsert: bool = True
|
||||||
|
) -> list[ndarray]:
|
||||||
|
"""Embed thumbnails and optionally insert into DB.
|
||||||
|
|
||||||
|
@param: event_thumbs Map of Event IDs in DB to thumbnail bytes in jpg format
|
||||||
|
@param: upsert If embedding should be upserted into vec DB
|
||||||
|
"""
|
||||||
ids = list(event_thumbs.keys())
|
ids = list(event_thumbs.keys())
|
||||||
embeddings = self.vision_embedding(list(event_thumbs.values()))
|
embeddings = self.vision_embedding(list(event_thumbs.values()))
|
||||||
|
|
||||||
items = []
|
if upsert:
|
||||||
|
items = []
|
||||||
|
|
||||||
for i in range(len(ids)):
|
for i in range(len(ids)):
|
||||||
items.append(ids[i])
|
items.append(ids[i])
|
||||||
items.append(serialize(embeddings[i]))
|
items.append(serialize(embeddings[i]))
|
||||||
|
|
||||||
|
self.db.execute_sql(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
||||||
|
VALUES {}
|
||||||
|
""".format(", ".join(["(?, ?)"] * len(ids))),
|
||||||
|
items,
|
||||||
|
)
|
||||||
|
|
||||||
self.db.execute_sql(
|
|
||||||
"""
|
|
||||||
INSERT OR REPLACE INTO vec_thumbnails(id, thumbnail_embedding)
|
|
||||||
VALUES {}
|
|
||||||
""".format(", ".join(["(?, ?)"] * len(ids))),
|
|
||||||
items,
|
|
||||||
)
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def upsert_description(self, event_id: str, description: str) -> ndarray:
|
def embed_description(
|
||||||
|
self, event_id: str, description: str, upsert: bool = True
|
||||||
|
) -> ndarray:
|
||||||
embedding = self.text_embedding([description])[0]
|
embedding = self.text_embedding([description])[0]
|
||||||
self.db.execute_sql(
|
|
||||||
"""
|
if upsert:
|
||||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
self.db.execute_sql(
|
||||||
VALUES(?, ?)
|
"""
|
||||||
""",
|
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||||
(event_id, serialize(embedding)),
|
VALUES(?, ?)
|
||||||
)
|
""",
|
||||||
|
(event_id, serialize(embedding)),
|
||||||
|
)
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def batch_upsert_description(self, event_descriptions: dict[str, str]) -> ndarray:
|
def batch_embed_description(
|
||||||
|
self, event_descriptions: dict[str, str], upsert: bool = True
|
||||||
|
) -> ndarray:
|
||||||
# upsert embeddings one by one to avoid token limit
|
# upsert embeddings one by one to avoid token limit
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
for desc in event_descriptions.values():
|
for desc in event_descriptions.values():
|
||||||
embeddings.append(self.text_embedding([desc])[0])
|
embeddings.append(self.text_embedding([desc])[0])
|
||||||
|
|
||||||
ids = list(event_descriptions.keys())
|
if upsert:
|
||||||
|
ids = list(event_descriptions.keys())
|
||||||
|
items = []
|
||||||
|
|
||||||
items = []
|
for i in range(len(ids)):
|
||||||
|
items.append(ids[i])
|
||||||
|
items.append(serialize(embeddings[i]))
|
||||||
|
|
||||||
for i in range(len(ids)):
|
self.db.execute_sql(
|
||||||
items.append(ids[i])
|
"""
|
||||||
items.append(serialize(embeddings[i]))
|
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
||||||
|
VALUES {}
|
||||||
self.db.execute_sql(
|
""".format(", ".join(["(?, ?)"] * len(ids))),
|
||||||
"""
|
items,
|
||||||
INSERT OR REPLACE INTO vec_descriptions(id, description_embedding)
|
)
|
||||||
VALUES {}
|
|
||||||
""".format(", ".join(["(?, ?)"] * len(ids))),
|
|
||||||
items,
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@ -256,10 +280,10 @@ class Embeddings:
|
|||||||
totals["processed_objects"] += 1
|
totals["processed_objects"] += 1
|
||||||
|
|
||||||
# run batch embedding
|
# run batch embedding
|
||||||
self.batch_upsert_thumbnail(batch_thumbs)
|
self.batch_embed_thumbnail(batch_thumbs)
|
||||||
|
|
||||||
if batch_descs:
|
if batch_descs:
|
||||||
self.batch_upsert_description(batch_descs)
|
self.batch_embed_description(batch_descs)
|
||||||
|
|
||||||
# report progress every batch so we don't spam the logs
|
# report progress every batch so we don't spam the logs
|
||||||
progress = (totals["processed_objects"] / total_events) * 100
|
progress = (totals["processed_objects"] / total_events) * 100
|
||||||
|
|||||||
@ -86,7 +86,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
try:
|
try:
|
||||||
if topic == EmbeddingsRequestEnum.embed_description.value:
|
if topic == EmbeddingsRequestEnum.embed_description.value:
|
||||||
return serialize(
|
return serialize(
|
||||||
self.embeddings.upsert_description(
|
self.embeddings.embed_description(
|
||||||
data["id"], data["description"]
|
data["id"], data["description"]
|
||||||
),
|
),
|
||||||
pack=False,
|
pack=False,
|
||||||
@ -94,7 +94,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
|
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
|
||||||
thumbnail = base64.b64decode(data["thumbnail"])
|
thumbnail = base64.b64decode(data["thumbnail"])
|
||||||
return serialize(
|
return serialize(
|
||||||
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
|
self.embeddings.embed_thumbnail(data["id"], thumbnail),
|
||||||
pack=False,
|
pack=False,
|
||||||
)
|
)
|
||||||
elif topic == EmbeddingsRequestEnum.generate_search.value:
|
elif topic == EmbeddingsRequestEnum.generate_search.value:
|
||||||
@ -270,7 +270,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
|
|
||||||
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
|
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
|
||||||
"""Embed the thumbnail for an event."""
|
"""Embed the thumbnail for an event."""
|
||||||
self.embeddings.upsert_thumbnail(event_id, thumbnail)
|
self.embeddings.embed_thumbnail(event_id, thumbnail)
|
||||||
|
|
||||||
def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
|
def _embed_description(self, event: Event, thumbnails: list[bytes]) -> None:
|
||||||
"""Embed the description for an event."""
|
"""Embed the description for an event."""
|
||||||
@ -290,8 +290,8 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
{"id": event.id, "description": description},
|
{"id": event.id, "description": description},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Encode the description
|
# Embed the description
|
||||||
self.embeddings.upsert_description(event.id, description)
|
self.embeddings.embed_description(event.id, description)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Generated description for %s (%d images): %s",
|
"Generated description for %s (%d images): %s",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user