mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 07:35:27 +03:00
Handle zmq error and empty data
This commit is contained in:
parent
d5eab6c794
commit
6287253563
@ -1,5 +1,6 @@
|
||||
"""Facilitates communication between processes."""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
@ -7,6 +8,8 @@ import zmq
|
||||
|
||||
SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingsRequestEnum(Enum):
|
||||
embed_description = "embed_description"
|
||||
@ -22,7 +25,9 @@ class EmbeddingsResponder:
|
||||
|
||||
def check_for_request(self, process: Callable) -> None:
|
||||
while True: # load all messages that are queued
|
||||
logger.debug("Checking for embeddings requests")
|
||||
has_message, _, _ = zmq.select([self.socket], [], [], 0.1)
|
||||
logger.debug(f"has a request? {has_message}")
|
||||
|
||||
if not has_message:
|
||||
break
|
||||
@ -54,8 +59,11 @@ class EmbeddingsRequestor:
|
||||
|
||||
def send_data(self, topic: str, data: any) -> str:
|
||||
"""Sends data and then waits for reply."""
|
||||
self.socket.send_json((topic, data))
|
||||
return self.socket.recv_json()
|
||||
try:
|
||||
self.socket.send_json((topic, data))
|
||||
return self.socket.recv_json()
|
||||
except zmq.ZMQError:
|
||||
return ""
|
||||
|
||||
def stop(self) -> None:
|
||||
self.socket.close()
|
||||
|
||||
@ -39,7 +39,7 @@ class EventMetadataSubscriber(Subscriber):
|
||||
super().__init__(topic)
|
||||
|
||||
def check_for_update(
|
||||
self, timeout: float = None
|
||||
self, timeout: float = 1
|
||||
) -> Optional[tuple[EventMetadataTypeEnum, str, RegenerateDescriptionEnum]]:
|
||||
return super().check_for_update(timeout)
|
||||
|
||||
|
||||
@ -114,19 +114,25 @@ class EmbeddingsContext:
|
||||
query_embedding = row[0]
|
||||
else:
|
||||
# If no embedding found, generate it and return it
|
||||
query_embedding = serialize(
|
||||
self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.embed_thumbnail.value,
|
||||
{"id": str(query.id), "thumbnail": str(query.thumbnail)},
|
||||
)
|
||||
data = self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.embed_thumbnail.value,
|
||||
{"id": str(query.id), "thumbnail": str(query.thumbnail)},
|
||||
)
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
query_embedding = serialize(data)
|
||||
else:
|
||||
query_embedding = serialize(
|
||||
self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.generate_search.value, query
|
||||
)
|
||||
data = self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.generate_search.value, query
|
||||
)
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
query_embedding = serialize(data)
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
id,
|
||||
@ -155,12 +161,15 @@ class EmbeddingsContext:
|
||||
def search_description(
|
||||
self, query_text: str, event_ids: list[str] = None
|
||||
) -> list[tuple[str, float]]:
|
||||
query_embedding = serialize(
|
||||
self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.generate_search.value, query_text
|
||||
)
|
||||
data = self.requestor.send_data(
|
||||
EmbeddingsRequestEnum.generate_search.value, query_text
|
||||
)
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
query_embedding = serialize(data)
|
||||
|
||||
# Prepare the base SQL query
|
||||
sql_query = """
|
||||
SELECT
|
||||
|
||||
@ -59,6 +59,7 @@ class GenericONNXEmbedding:
|
||||
self.feature_extractor = None
|
||||
self.session = None
|
||||
|
||||
print("starting model download")
|
||||
self.downloader = ModelDownloader(
|
||||
model_name=self.model_name,
|
||||
download_path=self.download_path,
|
||||
@ -70,6 +71,7 @@ class GenericONNXEmbedding:
|
||||
|
||||
def _download_model(self, path: str):
|
||||
try:
|
||||
print("beginning model download process")
|
||||
file_name = os.path.basename(path)
|
||||
if file_name in self.download_urls:
|
||||
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
||||
@ -107,10 +109,11 @@ class GenericONNXEmbedding:
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
else:
|
||||
self.feature_extractor = self._load_feature_extractor()
|
||||
print("creating onnx session")
|
||||
self.session = self._load_model(
|
||||
os.path.join(self.download_path, self.model_file)
|
||||
)
|
||||
logger.debug("successfully loaded model.")
|
||||
print("successfully loaded model.")
|
||||
|
||||
def _load_tokenizer(self):
|
||||
tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
|
||||
@ -127,15 +130,16 @@ class GenericONNXEmbedding:
|
||||
)
|
||||
|
||||
def _load_model(self, path: str) -> Optional[ort.InferenceSession]:
|
||||
print(f"checking if path exists {path}")
|
||||
if os.path.exists(path):
|
||||
logger.debug(
|
||||
print(
|
||||
f"loading ORT session with providers {self.providers} and options {self.provider_options}"
|
||||
)
|
||||
return ort.InferenceSession(
|
||||
path, providers=self.providers, provider_options=self.provider_options
|
||||
)
|
||||
else:
|
||||
logger.warning(f"{self.model_name} model file {path} not found.")
|
||||
print(f"{self.model_name} model file {path} not found.")
|
||||
return None
|
||||
|
||||
def _process_image(self, image):
|
||||
@ -149,6 +153,7 @@ class GenericONNXEmbedding:
|
||||
def __call__(
|
||||
self, inputs: Union[List[str], List[Image.Image], List[str]]
|
||||
) -> List[np.ndarray]:
|
||||
print("beginning call for onnx embedding")
|
||||
self._load_model_and_tokenizer()
|
||||
|
||||
if self.session is None or (
|
||||
|
||||
@ -41,8 +41,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
config: FrigateConfig,
|
||||
stop_event: MpEvent,
|
||||
) -> None:
|
||||
threading.Thread.__init__(self)
|
||||
self.name = "embeddings_maintainer"
|
||||
super().__init__(name="embeddings_maintainer")
|
||||
self.config = config
|
||||
self.embeddings = Embeddings(config.semantic_search, db)
|
||||
self.event_subscriber = EventUpdateSubscriber()
|
||||
@ -61,6 +60,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
def run(self) -> None:
|
||||
"""Maintain a SQLite-vec database for semantic search."""
|
||||
while not self.stop_event.is_set():
|
||||
print("Doing another embeddings loop.")
|
||||
self._process_requests()
|
||||
self._process_updates()
|
||||
self._process_finalized()
|
||||
@ -77,9 +77,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
"""Process embeddings requests"""
|
||||
|
||||
def handle_request(topic: str, data: str) -> str:
|
||||
logger.debug(
|
||||
f"Handling embeddings request of type {topic} with data {data}"
|
||||
)
|
||||
print(f"Handling embeddings request of type {topic} with data {data}")
|
||||
|
||||
try:
|
||||
if topic == EmbeddingsRequestEnum.embed_description.value:
|
||||
@ -106,7 +104,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
|
||||
def _process_updates(self) -> None:
|
||||
"""Process event updates"""
|
||||
update = self.event_subscriber.check_for_update(timeout=0.1)
|
||||
update = self.event_subscriber.check_for_update()
|
||||
|
||||
if update is None:
|
||||
return
|
||||
@ -116,7 +114,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
if not camera or source_type != EventTypeEnum.tracked_object:
|
||||
return
|
||||
|
||||
logger.debug(f"Processing object update of type {source_type} on {camera}")
|
||||
print(f"Processing object update of type {source_type} on {camera}")
|
||||
camera_config = self.config.cameras[camera]
|
||||
if data["id"] not in self.tracked_events:
|
||||
self.tracked_events[data["id"]] = []
|
||||
@ -124,29 +122,33 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
# Create our own thumbnail based on the bounding box and the frame time
|
||||
try:
|
||||
frame_id = f"{camera}{data['frame_time']}"
|
||||
print("trying to get frame from manager")
|
||||
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
|
||||
print(f"got frame from manager and it is valid {yuv_frame is not None}")
|
||||
|
||||
if yuv_frame is not None:
|
||||
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
|
||||
self.tracked_events[data["id"]].append(data)
|
||||
self.frame_manager.close(frame_id)
|
||||
else:
|
||||
logger.debug(
|
||||
print(
|
||||
f"Unable to create embedding for thumbnail from {camera} because frame is missing."
|
||||
)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
print("Finished processing object update")
|
||||
|
||||
def _process_finalized(self) -> None:
|
||||
"""Process the end of an event."""
|
||||
while True:
|
||||
ended = self.event_end_subscriber.check_for_update(timeout=0.1)
|
||||
ended = self.event_end_subscriber.check_for_update()
|
||||
|
||||
if ended == None:
|
||||
break
|
||||
|
||||
event_id, camera, updated_db = ended
|
||||
logger.debug(
|
||||
print(
|
||||
f"Processing finalized event for {camera} which updated the db: {updated_db}"
|
||||
)
|
||||
camera_config = self.config.cameras[camera]
|
||||
@ -180,7 +182,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
or set(event.zones) & set(camera_config.genai.required_zones)
|
||||
)
|
||||
):
|
||||
logger.debug(
|
||||
print(
|
||||
f"Description generation for {event}, has_snapshot: {event.has_snapshot}"
|
||||
)
|
||||
if event.has_snapshot and camera_config.genai.use_snapshot:
|
||||
@ -235,14 +237,12 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
|
||||
def _process_event_metadata(self):
|
||||
# Check for regenerate description requests
|
||||
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
|
||||
timeout=0.1
|
||||
)
|
||||
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update()
|
||||
|
||||
if topic is None:
|
||||
return
|
||||
|
||||
logger.debug(f"Handling event metadata for id {event_id} and source {source}")
|
||||
print(f"Handling event metadata for id {event_id} and source {source}")
|
||||
|
||||
if event_id:
|
||||
self.handle_regenerate_description(event_id, source)
|
||||
@ -276,7 +276,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
)
|
||||
|
||||
if not description:
|
||||
logger.debug("Failed to generate description for %s", event.id)
|
||||
print("Failed to generate description for %s", event.id)
|
||||
return
|
||||
|
||||
# fire and forget description update
|
||||
@ -288,7 +288,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
# Encode the description
|
||||
self.embeddings.upsert_description(event.id, description)
|
||||
|
||||
logger.debug(
|
||||
print(
|
||||
"Generated description for %s (%d images): %s",
|
||||
event.id,
|
||||
len(thumbnails),
|
||||
@ -309,7 +309,7 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
|
||||
thumbnail = base64.b64decode(event.thumbnail)
|
||||
|
||||
logger.debug(
|
||||
print(
|
||||
f"Trying {source} regeneration for {event}, has_snapshot: {event.has_snapshot}"
|
||||
)
|
||||
|
||||
|
||||
@ -120,4 +120,5 @@ class ModelDownloader:
|
||||
logger.info(f"Downloading complete: {url}")
|
||||
|
||||
def wait_for_download(self):
|
||||
print("waiting for model download")
|
||||
self.download_complete.wait()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user