Handle zmq error and empty data

This commit is contained in:
Nicolas Mowen 2024-10-10 11:43:47 -06:00
parent d5eab6c794
commit 6287253563
6 changed files with 60 additions and 37 deletions

View File

@ -1,5 +1,6 @@
"""Facilitates communication between processes.""" """Facilitates communication between processes."""
import logging
from enum import Enum from enum import Enum
from typing import Callable from typing import Callable
@ -7,6 +8,8 @@ import zmq
SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings" SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
logger = logging.getLogger(__name__)
class EmbeddingsRequestEnum(Enum): class EmbeddingsRequestEnum(Enum):
embed_description = "embed_description" embed_description = "embed_description"
@ -22,7 +25,9 @@ class EmbeddingsResponder:
def check_for_request(self, process: Callable) -> None: def check_for_request(self, process: Callable) -> None:
while True: # load all messages that are queued while True: # load all messages that are queued
logger.debug("Checking for embeddings requests")
has_message, _, _ = zmq.select([self.socket], [], [], 0.1) has_message, _, _ = zmq.select([self.socket], [], [], 0.1)
logger.debug(f"has a request? {has_message}")
if not has_message: if not has_message:
break break
@ -54,8 +59,11 @@ class EmbeddingsRequestor:
def send_data(self, topic: str, data: any) -> str: def send_data(self, topic: str, data: any) -> str:
"""Sends data and then waits for reply.""" """Sends data and then waits for reply."""
self.socket.send_json((topic, data)) try:
return self.socket.recv_json() self.socket.send_json((topic, data))
return self.socket.recv_json()
except zmq.ZMQError:
return ""
def stop(self) -> None: def stop(self) -> None:
self.socket.close() self.socket.close()

View File

@ -39,7 +39,7 @@ class EventMetadataSubscriber(Subscriber):
super().__init__(topic) super().__init__(topic)
def check_for_update( def check_for_update(
self, timeout: float = None self, timeout: float = 1
) -> Optional[tuple[EventMetadataTypeEnum, str, RegenerateDescriptionEnum]]: ) -> Optional[tuple[EventMetadataTypeEnum, str, RegenerateDescriptionEnum]]:
return super().check_for_update(timeout) return super().check_for_update(timeout)

View File

@ -114,19 +114,25 @@ class EmbeddingsContext:
query_embedding = row[0] query_embedding = row[0]
else: else:
# If no embedding found, generate it and return it # If no embedding found, generate it and return it
query_embedding = serialize( data = self.requestor.send_data(
self.requestor.send_data( EmbeddingsRequestEnum.embed_thumbnail.value,
EmbeddingsRequestEnum.embed_thumbnail.value, {"id": str(query.id), "thumbnail": str(query.thumbnail)},
{"id": str(query.id), "thumbnail": str(query.thumbnail)},
)
) )
if not data:
return []
query_embedding = serialize(data)
else: else:
query_embedding = serialize( data = self.requestor.send_data(
self.requestor.send_data( EmbeddingsRequestEnum.generate_search.value, query
EmbeddingsRequestEnum.generate_search.value, query
)
) )
if not data:
return []
query_embedding = serialize(data)
sql_query = """ sql_query = """
SELECT SELECT
id, id,
@ -155,12 +161,15 @@ class EmbeddingsContext:
def search_description( def search_description(
self, query_text: str, event_ids: list[str] = None self, query_text: str, event_ids: list[str] = None
) -> list[tuple[str, float]]: ) -> list[tuple[str, float]]:
query_embedding = serialize( data = self.requestor.send_data(
self.requestor.send_data( EmbeddingsRequestEnum.generate_search.value, query_text
EmbeddingsRequestEnum.generate_search.value, query_text
)
) )
if not data:
return []
query_embedding = serialize(data)
# Prepare the base SQL query # Prepare the base SQL query
sql_query = """ sql_query = """
SELECT SELECT

View File

@ -59,6 +59,7 @@ class GenericONNXEmbedding:
self.feature_extractor = None self.feature_extractor = None
self.session = None self.session = None
print("starting model download")
self.downloader = ModelDownloader( self.downloader = ModelDownloader(
model_name=self.model_name, model_name=self.model_name,
download_path=self.download_path, download_path=self.download_path,
@ -70,6 +71,7 @@ class GenericONNXEmbedding:
def _download_model(self, path: str): def _download_model(self, path: str):
try: try:
print("beginning model download process")
file_name = os.path.basename(path) file_name = os.path.basename(path)
if file_name in self.download_urls: if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path) ModelDownloader.download_from_url(self.download_urls[file_name], path)
@ -107,10 +109,11 @@ class GenericONNXEmbedding:
self.tokenizer = self._load_tokenizer() self.tokenizer = self._load_tokenizer()
else: else:
self.feature_extractor = self._load_feature_extractor() self.feature_extractor = self._load_feature_extractor()
print("creating onnx session")
self.session = self._load_model( self.session = self._load_model(
os.path.join(self.download_path, self.model_file) os.path.join(self.download_path, self.model_file)
) )
logger.debug("successfully loaded model.") print("successfully loaded model.")
def _load_tokenizer(self): def _load_tokenizer(self):
tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer") tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
@ -127,15 +130,16 @@ class GenericONNXEmbedding:
) )
def _load_model(self, path: str) -> Optional[ort.InferenceSession]: def _load_model(self, path: str) -> Optional[ort.InferenceSession]:
print(f"checking if path exists {path}")
if os.path.exists(path): if os.path.exists(path):
logger.debug( print(
f"loading ORT session with providers {self.providers} and options {self.provider_options}" f"loading ORT session with providers {self.providers} and options {self.provider_options}"
) )
return ort.InferenceSession( return ort.InferenceSession(
path, providers=self.providers, provider_options=self.provider_options path, providers=self.providers, provider_options=self.provider_options
) )
else: else:
logger.warning(f"{self.model_name} model file {path} not found.") print(f"{self.model_name} model file {path} not found.")
return None return None
def _process_image(self, image): def _process_image(self, image):
@ -149,6 +153,7 @@ class GenericONNXEmbedding:
def __call__( def __call__(
self, inputs: Union[List[str], List[Image.Image], List[str]] self, inputs: Union[List[str], List[Image.Image], List[str]]
) -> List[np.ndarray]: ) -> List[np.ndarray]:
print("beginning call for onnx embedding")
self._load_model_and_tokenizer() self._load_model_and_tokenizer()
if self.session is None or ( if self.session is None or (

View File

@ -41,8 +41,7 @@ class EmbeddingMaintainer(threading.Thread):
config: FrigateConfig, config: FrigateConfig,
stop_event: MpEvent, stop_event: MpEvent,
) -> None: ) -> None:
threading.Thread.__init__(self) super().__init__(name="embeddings_maintainer")
self.name = "embeddings_maintainer"
self.config = config self.config = config
self.embeddings = Embeddings(config.semantic_search, db) self.embeddings = Embeddings(config.semantic_search, db)
self.event_subscriber = EventUpdateSubscriber() self.event_subscriber = EventUpdateSubscriber()
@ -61,6 +60,7 @@ class EmbeddingMaintainer(threading.Thread):
def run(self) -> None: def run(self) -> None:
"""Maintain a SQLite-vec database for semantic search.""" """Maintain a SQLite-vec database for semantic search."""
while not self.stop_event.is_set(): while not self.stop_event.is_set():
print("Doing another embeddings loop.")
self._process_requests() self._process_requests()
self._process_updates() self._process_updates()
self._process_finalized() self._process_finalized()
@ -77,9 +77,7 @@ class EmbeddingMaintainer(threading.Thread):
"""Process embeddings requests""" """Process embeddings requests"""
def handle_request(topic: str, data: str) -> str: def handle_request(topic: str, data: str) -> str:
logger.debug( print(f"Handling embeddings request of type {topic} with data {data}")
f"Handling embeddings request of type {topic} with data {data}"
)
try: try:
if topic == EmbeddingsRequestEnum.embed_description.value: if topic == EmbeddingsRequestEnum.embed_description.value:
@ -106,7 +104,7 @@ class EmbeddingMaintainer(threading.Thread):
def _process_updates(self) -> None: def _process_updates(self) -> None:
"""Process event updates""" """Process event updates"""
update = self.event_subscriber.check_for_update(timeout=0.1) update = self.event_subscriber.check_for_update()
if update is None: if update is None:
return return
@ -116,7 +114,7 @@ class EmbeddingMaintainer(threading.Thread):
if not camera or source_type != EventTypeEnum.tracked_object: if not camera or source_type != EventTypeEnum.tracked_object:
return return
logger.debug(f"Processing object update of type {source_type} on {camera}") print(f"Processing object update of type {source_type} on {camera}")
camera_config = self.config.cameras[camera] camera_config = self.config.cameras[camera]
if data["id"] not in self.tracked_events: if data["id"] not in self.tracked_events:
self.tracked_events[data["id"]] = [] self.tracked_events[data["id"]] = []
@ -124,29 +122,33 @@ class EmbeddingMaintainer(threading.Thread):
# Create our own thumbnail based on the bounding box and the frame time # Create our own thumbnail based on the bounding box and the frame time
try: try:
frame_id = f"{camera}{data['frame_time']}" frame_id = f"{camera}{data['frame_time']}"
print("trying to get frame from manager")
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv) yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
print(f"got frame from manager and it is valid {yuv_frame is not None}")
if yuv_frame is not None: if yuv_frame is not None:
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"]) data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
self.tracked_events[data["id"]].append(data) self.tracked_events[data["id"]].append(data)
self.frame_manager.close(frame_id) self.frame_manager.close(frame_id)
else: else:
logger.debug( print(
f"Unable to create embedding for thumbnail from {camera} because frame is missing." f"Unable to create embedding for thumbnail from {camera} because frame is missing."
) )
except FileNotFoundError: except FileNotFoundError:
pass pass
print("Finished processing object update")
def _process_finalized(self) -> None: def _process_finalized(self) -> None:
"""Process the end of an event.""" """Process the end of an event."""
while True: while True:
ended = self.event_end_subscriber.check_for_update(timeout=0.1) ended = self.event_end_subscriber.check_for_update()
if ended == None: if ended == None:
break break
event_id, camera, updated_db = ended event_id, camera, updated_db = ended
logger.debug( print(
f"Processing finalized event for {camera} which updated the db: {updated_db}" f"Processing finalized event for {camera} which updated the db: {updated_db}"
) )
camera_config = self.config.cameras[camera] camera_config = self.config.cameras[camera]
@ -180,7 +182,7 @@ class EmbeddingMaintainer(threading.Thread):
or set(event.zones) & set(camera_config.genai.required_zones) or set(event.zones) & set(camera_config.genai.required_zones)
) )
): ):
logger.debug( print(
f"Description generation for {event}, has_snapshot: {event.has_snapshot}" f"Description generation for {event}, has_snapshot: {event.has_snapshot}"
) )
if event.has_snapshot and camera_config.genai.use_snapshot: if event.has_snapshot and camera_config.genai.use_snapshot:
@ -235,14 +237,12 @@ class EmbeddingMaintainer(threading.Thread):
def _process_event_metadata(self): def _process_event_metadata(self):
# Check for regenerate description requests # Check for regenerate description requests
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update( (topic, event_id, source) = self.event_metadata_subscriber.check_for_update()
timeout=0.1
)
if topic is None: if topic is None:
return return
logger.debug(f"Handling event metadata for id {event_id} and source {source}") print(f"Handling event metadata for id {event_id} and source {source}")
if event_id: if event_id:
self.handle_regenerate_description(event_id, source) self.handle_regenerate_description(event_id, source)
@ -276,7 +276,7 @@ class EmbeddingMaintainer(threading.Thread):
) )
if not description: if not description:
logger.debug("Failed to generate description for %s", event.id) print("Failed to generate description for %s", event.id)
return return
# fire and forget description update # fire and forget description update
@ -288,7 +288,7 @@ class EmbeddingMaintainer(threading.Thread):
# Encode the description # Encode the description
self.embeddings.upsert_description(event.id, description) self.embeddings.upsert_description(event.id, description)
logger.debug( print(
"Generated description for %s (%d images): %s", "Generated description for %s (%d images): %s",
event.id, event.id,
len(thumbnails), len(thumbnails),
@ -309,7 +309,7 @@ class EmbeddingMaintainer(threading.Thread):
thumbnail = base64.b64decode(event.thumbnail) thumbnail = base64.b64decode(event.thumbnail)
logger.debug( print(
f"Trying {source} regeneration for {event}, has_snapshot: {event.has_snapshot}" f"Trying {source} regeneration for {event}, has_snapshot: {event.has_snapshot}"
) )

View File

@ -120,4 +120,5 @@ class ModelDownloader:
logger.info(f"Downloading complete: {url}") logger.info(f"Downloading complete: {url}")
def wait_for_download(self): def wait_for_download(self):
print("waiting for model download")
self.download_complete.wait() self.download_complete.wait()