mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 07:35:27 +03:00
Use ZMQ to proxy embeddings requests
This commit is contained in:
parent
2142a39b3c
commit
f8f1852b0c
@ -944,9 +944,9 @@ def set_description(
|
|||||||
# If semantic search is enabled, update the index
|
# If semantic search is enabled, update the index
|
||||||
if request.app.frigate_config.semantic_search.enabled:
|
if request.app.frigate_config.semantic_search.enabled:
|
||||||
context: EmbeddingsContext = request.app.embeddings
|
context: EmbeddingsContext = request.app.embeddings
|
||||||
context.embeddings.upsert_description(
|
context.update_description(
|
||||||
event_id=event_id,
|
event_id,
|
||||||
description=new_description,
|
new_description,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_message = (
|
response_message = (
|
||||||
|
|||||||
@ -276,7 +276,7 @@ class FrigateApp:
|
|||||||
def init_embeddings_client(self) -> None:
|
def init_embeddings_client(self) -> None:
|
||||||
if self.config.semantic_search.enabled:
|
if self.config.semantic_search.enabled:
|
||||||
# Create a client for other processes to use
|
# Create a client for other processes to use
|
||||||
self.embeddings = EmbeddingsContext(self.config, self.db)
|
self.embeddings = EmbeddingsContext(self.db)
|
||||||
|
|
||||||
def init_external_event_processor(self) -> None:
|
def init_external_event_processor(self) -> None:
|
||||||
self.external_event_processor = ExternalEventProcessor(self.config)
|
self.external_event_processor = ExternalEventProcessor(self.config)
|
||||||
@ -699,7 +699,7 @@ class FrigateApp:
|
|||||||
|
|
||||||
# Save embeddings stats to disk
|
# Save embeddings stats to disk
|
||||||
if self.embeddings:
|
if self.embeddings:
|
||||||
self.embeddings.save_stats()
|
self.embeddings.stop()
|
||||||
|
|
||||||
# Stop Communicators
|
# Stop Communicators
|
||||||
self.inter_process_communicator.stop()
|
self.inter_process_communicator.stop()
|
||||||
|
|||||||
62
frigate/comms/embeddings_updater.py
Normal file
62
frigate/comms/embeddings_updater.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
"""Facilitates communication between processes."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsRequestEnum(Enum):
|
||||||
|
embed_description = "embed_description"
|
||||||
|
embed_thumbnail = "embed_thumbnail"
|
||||||
|
generate_search = "generate_search"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsResponder:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.REP)
|
||||||
|
self.socket.bind(SOCKET_REP_REQ)
|
||||||
|
|
||||||
|
def check_for_request(self, process: Callable) -> None:
|
||||||
|
while True: # load all messages that are queued
|
||||||
|
has_message, _, _ = zmq.select([self.socket], [], [], 1)
|
||||||
|
|
||||||
|
if not has_message:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
(topic, value) = self.socket.recv_json(flags=zmq.NOBLOCK)
|
||||||
|
|
||||||
|
response = process(topic, value)
|
||||||
|
|
||||||
|
if response is not None:
|
||||||
|
self.socket.send_json(response)
|
||||||
|
else:
|
||||||
|
self.socket.send_json([])
|
||||||
|
except zmq.ZMQError:
|
||||||
|
break
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self.socket.close()
|
||||||
|
self.context.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsRequestor:
|
||||||
|
"""Simplifies sending data to EmbeddingsResponder and getting a reply."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.REQ)
|
||||||
|
self.socket.connect(SOCKET_REP_REQ)
|
||||||
|
|
||||||
|
def send_data(self, topic: str, data: any) -> any:
|
||||||
|
"""Sends data and then waits for reply."""
|
||||||
|
self.socket.send_json((topic, data))
|
||||||
|
return self.socket.recv_json()
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self.socket.close()
|
||||||
|
self.context.destroy()
|
||||||
@ -12,6 +12,7 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
from setproctitle import setproctitle
|
from setproctitle import setproctitle
|
||||||
|
|
||||||
|
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
|
||||||
from frigate.config import FrigateConfig
|
from frigate.config import FrigateConfig
|
||||||
from frigate.const import CONFIG_DIR
|
from frigate.const import CONFIG_DIR
|
||||||
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||||
@ -72,10 +73,11 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsContext:
|
class EmbeddingsContext:
|
||||||
def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
|
def __init__(self, db: SqliteVecQueueDatabase):
|
||||||
self.db = db
|
self.db = db
|
||||||
self.thumb_stats = ZScoreNormalization()
|
self.thumb_stats = ZScoreNormalization()
|
||||||
self.desc_stats = ZScoreNormalization()
|
self.desc_stats = ZScoreNormalization()
|
||||||
|
self.requestor = EmbeddingsRequestor()
|
||||||
|
|
||||||
# load stats from disk
|
# load stats from disk
|
||||||
try:
|
try:
|
||||||
@ -86,7 +88,7 @@ class EmbeddingsContext:
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def save_stats(self):
|
def stop(self):
|
||||||
"""Write the stats to disk as JSON on exit."""
|
"""Write the stats to disk as JSON on exit."""
|
||||||
contents = {
|
contents = {
|
||||||
"thumb_stats": self.thumb_stats.to_dict(),
|
"thumb_stats": self.thumb_stats.to_dict(),
|
||||||
@ -94,6 +96,7 @@ class EmbeddingsContext:
|
|||||||
}
|
}
|
||||||
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
|
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
|
||||||
json.dump(contents, f)
|
json.dump(contents, f)
|
||||||
|
self.requestor.stop()
|
||||||
|
|
||||||
def search_thumbnail(
|
def search_thumbnail(
|
||||||
self, query: Union[Event, str], event_ids: list[str] = None
|
self, query: Union[Event, str], event_ids: list[str] = None
|
||||||
@ -114,10 +117,14 @@ class EmbeddingsContext:
|
|||||||
) # Deserialize the thumbnail embedding
|
) # Deserialize the thumbnail embedding
|
||||||
else:
|
else:
|
||||||
# If no embedding found, generate it and return it
|
# If no embedding found, generate it and return it
|
||||||
thumbnail = base64.b64decode(query.thumbnail)
|
query_embedding = self.requestor.send_data(
|
||||||
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
|
EmbeddingsRequestEnum.embed_thumbnail,
|
||||||
|
{"id": query.id, "thumbnail": query.thumbnail},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
query_embedding = self.text_embedding([query])[0]
|
query_embedding = self.requestor.send_data(
|
||||||
|
EmbeddingsRequestEnum.generate_search, query
|
||||||
|
)
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
SELECT
|
SELECT
|
||||||
@ -151,7 +158,9 @@ class EmbeddingsContext:
|
|||||||
def search_description(
|
def search_description(
|
||||||
self, query_text: str, event_ids: list[str] = None
|
self, query_text: str, event_ids: list[str] = None
|
||||||
) -> list[tuple[str, float]]:
|
) -> list[tuple[str, float]]:
|
||||||
query_embedding = self.text_embedding([query_text])[0]
|
query_embedding = self.requestor.send_data(
|
||||||
|
EmbeddingsRequestEnum.generate_search, query_text
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare the base SQL query
|
# Prepare the base SQL query
|
||||||
sql_query = """
|
sql_query = """
|
||||||
@ -182,3 +191,9 @@ class EmbeddingsContext:
|
|||||||
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
results = self.db.execute_sql(sql_query, parameters).fetchall()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def update_description(self, event_id: str, description: str) -> None:
|
||||||
|
self.requestor.send_data(
|
||||||
|
EmbeddingsRequestEnum.embed_description,
|
||||||
|
{"id": event_id, "description": description},
|
||||||
|
)
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import numpy as np
|
|||||||
from peewee import DoesNotExist
|
from peewee import DoesNotExist
|
||||||
from playhouse.sqliteq import SqliteQueueDatabase
|
from playhouse.sqliteq import SqliteQueueDatabase
|
||||||
|
|
||||||
|
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsResponder
|
||||||
from frigate.comms.event_metadata_updater import (
|
from frigate.comms.event_metadata_updater import (
|
||||||
EventMetadataSubscriber,
|
EventMetadataSubscriber,
|
||||||
EventMetadataTypeEnum,
|
EventMetadataTypeEnum,
|
||||||
@ -48,6 +49,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
self.event_metadata_subscriber = EventMetadataSubscriber(
|
self.event_metadata_subscriber = EventMetadataSubscriber(
|
||||||
EventMetadataTypeEnum.regenerate_description
|
EventMetadataTypeEnum.regenerate_description
|
||||||
)
|
)
|
||||||
|
self.embeddings_responder = EmbeddingsResponder()
|
||||||
self.frame_manager = SharedMemoryFrameManager()
|
self.frame_manager = SharedMemoryFrameManager()
|
||||||
# create communication for updating event descriptions
|
# create communication for updating event descriptions
|
||||||
self.requestor = InterProcessRequestor()
|
self.requestor = InterProcessRequestor()
|
||||||
@ -58,6 +60,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""Maintain a SQLite-vec database for semantic search."""
|
"""Maintain a SQLite-vec database for semantic search."""
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
|
self._process_requests()
|
||||||
self._process_updates()
|
self._process_updates()
|
||||||
self._process_finalized()
|
self._process_finalized()
|
||||||
self._process_event_metadata()
|
self._process_event_metadata()
|
||||||
@ -65,9 +68,26 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
self.event_subscriber.stop()
|
self.event_subscriber.stop()
|
||||||
self.event_end_subscriber.stop()
|
self.event_end_subscriber.stop()
|
||||||
self.event_metadata_subscriber.stop()
|
self.event_metadata_subscriber.stop()
|
||||||
|
self.embeddings_responder.stop()
|
||||||
self.requestor.stop()
|
self.requestor.stop()
|
||||||
logger.info("Exiting embeddings maintenance...")
|
logger.info("Exiting embeddings maintenance...")
|
||||||
|
|
||||||
|
def _process_requests(self) -> None:
|
||||||
|
"""Process embeddings requests"""
|
||||||
|
|
||||||
|
def handle_request(topic: str, data: str) -> any:
|
||||||
|
if topic == EmbeddingsRequestEnum.embed_description:
|
||||||
|
return self.embeddings.upsert_description(
|
||||||
|
data["id"], data["description"]
|
||||||
|
)
|
||||||
|
elif topic == EmbeddingsRequestEnum.embed_thumbnail:
|
||||||
|
thumbnail = base64.b64decode(data["thumbnail"])
|
||||||
|
return self.embeddings.upsert_thumbnail(data["id"], thumbnail)
|
||||||
|
elif topic == EmbeddingsRequestEnum.generate_search:
|
||||||
|
return self.embeddings.text_embedding([data])[0]
|
||||||
|
|
||||||
|
self.embeddings_responder.check_for_request(handle_request)
|
||||||
|
|
||||||
def _process_updates(self) -> None:
|
def _process_updates(self) -> None:
|
||||||
"""Process event updates"""
|
"""Process event updates"""
|
||||||
update = self.event_subscriber.check_for_update()
|
update = self.event_subscriber.check_for_update()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user