mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-15 07:35:27 +03:00
Handle zmq error and empty data
This commit is contained in:
parent
d5eab6c794
commit
6287253563
@ -1,5 +1,6 @@
|
|||||||
"""Facilitates communication between processes."""
|
"""Facilitates communication between processes."""
|
||||||
|
|
||||||
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
@ -7,6 +8,8 @@ import zmq
|
|||||||
|
|
||||||
SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
|
SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsRequestEnum(Enum):
|
class EmbeddingsRequestEnum(Enum):
|
||||||
embed_description = "embed_description"
|
embed_description = "embed_description"
|
||||||
@ -22,7 +25,9 @@ class EmbeddingsResponder:
|
|||||||
|
|
||||||
def check_for_request(self, process: Callable) -> None:
|
def check_for_request(self, process: Callable) -> None:
|
||||||
while True: # load all messages that are queued
|
while True: # load all messages that are queued
|
||||||
|
logger.debug("Checking for embeddings requests")
|
||||||
has_message, _, _ = zmq.select([self.socket], [], [], 0.1)
|
has_message, _, _ = zmq.select([self.socket], [], [], 0.1)
|
||||||
|
logger.debug(f"has a request? {has_message}")
|
||||||
|
|
||||||
if not has_message:
|
if not has_message:
|
||||||
break
|
break
|
||||||
@ -54,8 +59,11 @@ class EmbeddingsRequestor:
|
|||||||
|
|
||||||
def send_data(self, topic: str, data: any) -> str:
|
def send_data(self, topic: str, data: any) -> str:
|
||||||
"""Sends data and then waits for reply."""
|
"""Sends data and then waits for reply."""
|
||||||
|
try:
|
||||||
self.socket.send_json((topic, data))
|
self.socket.send_json((topic, data))
|
||||||
return self.socket.recv_json()
|
return self.socket.recv_json()
|
||||||
|
except zmq.ZMQError:
|
||||||
|
return ""
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
self.socket.close()
|
self.socket.close()
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class EventMetadataSubscriber(Subscriber):
|
|||||||
super().__init__(topic)
|
super().__init__(topic)
|
||||||
|
|
||||||
def check_for_update(
|
def check_for_update(
|
||||||
self, timeout: float = None
|
self, timeout: float = 1
|
||||||
) -> Optional[tuple[EventMetadataTypeEnum, str, RegenerateDescriptionEnum]]:
|
) -> Optional[tuple[EventMetadataTypeEnum, str, RegenerateDescriptionEnum]]:
|
||||||
return super().check_for_update(timeout)
|
return super().check_for_update(timeout)
|
||||||
|
|
||||||
|
|||||||
@ -114,18 +114,24 @@ class EmbeddingsContext:
|
|||||||
query_embedding = row[0]
|
query_embedding = row[0]
|
||||||
else:
|
else:
|
||||||
# If no embedding found, generate it and return it
|
# If no embedding found, generate it and return it
|
||||||
query_embedding = serialize(
|
data = self.requestor.send_data(
|
||||||
self.requestor.send_data(
|
|
||||||
EmbeddingsRequestEnum.embed_thumbnail.value,
|
EmbeddingsRequestEnum.embed_thumbnail.value,
|
||||||
{"id": str(query.id), "thumbnail": str(query.thumbnail)},
|
{"id": str(query.id), "thumbnail": str(query.thumbnail)},
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
if not data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_embedding = serialize(data)
|
||||||
else:
|
else:
|
||||||
query_embedding = serialize(
|
data = self.requestor.send_data(
|
||||||
self.requestor.send_data(
|
|
||||||
EmbeddingsRequestEnum.generate_search.value, query
|
EmbeddingsRequestEnum.generate_search.value, query
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
if not data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_embedding = serialize(data)
|
||||||
|
|
||||||
sql_query = """
|
sql_query = """
|
||||||
SELECT
|
SELECT
|
||||||
@ -155,11 +161,14 @@ class EmbeddingsContext:
|
|||||||
def search_description(
|
def search_description(
|
||||||
self, query_text: str, event_ids: list[str] = None
|
self, query_text: str, event_ids: list[str] = None
|
||||||
) -> list[tuple[str, float]]:
|
) -> list[tuple[str, float]]:
|
||||||
query_embedding = serialize(
|
data = self.requestor.send_data(
|
||||||
self.requestor.send_data(
|
|
||||||
EmbeddingsRequestEnum.generate_search.value, query_text
|
EmbeddingsRequestEnum.generate_search.value, query_text
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
if not data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_embedding = serialize(data)
|
||||||
|
|
||||||
# Prepare the base SQL query
|
# Prepare the base SQL query
|
||||||
sql_query = """
|
sql_query = """
|
||||||
|
|||||||
@ -59,6 +59,7 @@ class GenericONNXEmbedding:
|
|||||||
self.feature_extractor = None
|
self.feature_extractor = None
|
||||||
self.session = None
|
self.session = None
|
||||||
|
|
||||||
|
print("starting model download")
|
||||||
self.downloader = ModelDownloader(
|
self.downloader = ModelDownloader(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
download_path=self.download_path,
|
download_path=self.download_path,
|
||||||
@ -70,6 +71,7 @@ class GenericONNXEmbedding:
|
|||||||
|
|
||||||
def _download_model(self, path: str):
|
def _download_model(self, path: str):
|
||||||
try:
|
try:
|
||||||
|
print("beginning model download process")
|
||||||
file_name = os.path.basename(path)
|
file_name = os.path.basename(path)
|
||||||
if file_name in self.download_urls:
|
if file_name in self.download_urls:
|
||||||
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
||||||
@ -107,10 +109,11 @@ class GenericONNXEmbedding:
|
|||||||
self.tokenizer = self._load_tokenizer()
|
self.tokenizer = self._load_tokenizer()
|
||||||
else:
|
else:
|
||||||
self.feature_extractor = self._load_feature_extractor()
|
self.feature_extractor = self._load_feature_extractor()
|
||||||
|
print("creating onnx session")
|
||||||
self.session = self._load_model(
|
self.session = self._load_model(
|
||||||
os.path.join(self.download_path, self.model_file)
|
os.path.join(self.download_path, self.model_file)
|
||||||
)
|
)
|
||||||
logger.debug("successfully loaded model.")
|
print("successfully loaded model.")
|
||||||
|
|
||||||
def _load_tokenizer(self):
|
def _load_tokenizer(self):
|
||||||
tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
|
tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
|
||||||
@ -127,15 +130,16 @@ class GenericONNXEmbedding:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _load_model(self, path: str) -> Optional[ort.InferenceSession]:
|
def _load_model(self, path: str) -> Optional[ort.InferenceSession]:
|
||||||
|
print(f"checking if path exists {path}")
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
logger.debug(
|
print(
|
||||||
f"loading ORT session with providers {self.providers} and options {self.provider_options}"
|
f"loading ORT session with providers {self.providers} and options {self.provider_options}"
|
||||||
)
|
)
|
||||||
return ort.InferenceSession(
|
return ort.InferenceSession(
|
||||||
path, providers=self.providers, provider_options=self.provider_options
|
path, providers=self.providers, provider_options=self.provider_options
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{self.model_name} model file {path} not found.")
|
print(f"{self.model_name} model file {path} not found.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _process_image(self, image):
|
def _process_image(self, image):
|
||||||
@ -149,6 +153,7 @@ class GenericONNXEmbedding:
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, inputs: Union[List[str], List[Image.Image], List[str]]
|
self, inputs: Union[List[str], List[Image.Image], List[str]]
|
||||||
) -> List[np.ndarray]:
|
) -> List[np.ndarray]:
|
||||||
|
print("beginning call for onnx embedding")
|
||||||
self._load_model_and_tokenizer()
|
self._load_model_and_tokenizer()
|
||||||
|
|
||||||
if self.session is None or (
|
if self.session is None or (
|
||||||
|
|||||||
@ -41,8 +41,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
config: FrigateConfig,
|
config: FrigateConfig,
|
||||||
stop_event: MpEvent,
|
stop_event: MpEvent,
|
||||||
) -> None:
|
) -> None:
|
||||||
threading.Thread.__init__(self)
|
super().__init__(name="embeddings_maintainer")
|
||||||
self.name = "embeddings_maintainer"
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.embeddings = Embeddings(config.semantic_search, db)
|
self.embeddings = Embeddings(config.semantic_search, db)
|
||||||
self.event_subscriber = EventUpdateSubscriber()
|
self.event_subscriber = EventUpdateSubscriber()
|
||||||
@ -61,6 +60,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""Maintain a SQLite-vec database for semantic search."""
|
"""Maintain a SQLite-vec database for semantic search."""
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
|
print("Doing another embeddings loop.")
|
||||||
self._process_requests()
|
self._process_requests()
|
||||||
self._process_updates()
|
self._process_updates()
|
||||||
self._process_finalized()
|
self._process_finalized()
|
||||||
@ -77,9 +77,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
"""Process embeddings requests"""
|
"""Process embeddings requests"""
|
||||||
|
|
||||||
def handle_request(topic: str, data: str) -> str:
|
def handle_request(topic: str, data: str) -> str:
|
||||||
logger.debug(
|
print(f"Handling embeddings request of type {topic} with data {data}")
|
||||||
f"Handling embeddings request of type {topic} with data {data}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if topic == EmbeddingsRequestEnum.embed_description.value:
|
if topic == EmbeddingsRequestEnum.embed_description.value:
|
||||||
@ -106,7 +104,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
|
|
||||||
def _process_updates(self) -> None:
|
def _process_updates(self) -> None:
|
||||||
"""Process event updates"""
|
"""Process event updates"""
|
||||||
update = self.event_subscriber.check_for_update(timeout=0.1)
|
update = self.event_subscriber.check_for_update()
|
||||||
|
|
||||||
if update is None:
|
if update is None:
|
||||||
return
|
return
|
||||||
@ -116,7 +114,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
if not camera or source_type != EventTypeEnum.tracked_object:
|
if not camera or source_type != EventTypeEnum.tracked_object:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug(f"Processing object update of type {source_type} on {camera}")
|
print(f"Processing object update of type {source_type} on {camera}")
|
||||||
camera_config = self.config.cameras[camera]
|
camera_config = self.config.cameras[camera]
|
||||||
if data["id"] not in self.tracked_events:
|
if data["id"] not in self.tracked_events:
|
||||||
self.tracked_events[data["id"]] = []
|
self.tracked_events[data["id"]] = []
|
||||||
@ -124,29 +122,33 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
# Create our own thumbnail based on the bounding box and the frame time
|
# Create our own thumbnail based on the bounding box and the frame time
|
||||||
try:
|
try:
|
||||||
frame_id = f"{camera}{data['frame_time']}"
|
frame_id = f"{camera}{data['frame_time']}"
|
||||||
|
print("trying to get frame from manager")
|
||||||
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
|
yuv_frame = self.frame_manager.get(frame_id, camera_config.frame_shape_yuv)
|
||||||
|
print(f"got frame from manager and it is valid {yuv_frame is not None}")
|
||||||
|
|
||||||
if yuv_frame is not None:
|
if yuv_frame is not None:
|
||||||
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
|
data["thumbnail"] = self._create_thumbnail(yuv_frame, data["box"])
|
||||||
self.tracked_events[data["id"]].append(data)
|
self.tracked_events[data["id"]].append(data)
|
||||||
self.frame_manager.close(frame_id)
|
self.frame_manager.close(frame_id)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
print(
|
||||||
f"Unable to create embedding for thumbnail from {camera} because frame is missing."
|
f"Unable to create embedding for thumbnail from {camera} because frame is missing."
|
||||||
)
|
)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
print("Finished processing object update")
|
||||||
|
|
||||||
def _process_finalized(self) -> None:
|
def _process_finalized(self) -> None:
|
||||||
"""Process the end of an event."""
|
"""Process the end of an event."""
|
||||||
while True:
|
while True:
|
||||||
ended = self.event_end_subscriber.check_for_update(timeout=0.1)
|
ended = self.event_end_subscriber.check_for_update()
|
||||||
|
|
||||||
if ended == None:
|
if ended == None:
|
||||||
break
|
break
|
||||||
|
|
||||||
event_id, camera, updated_db = ended
|
event_id, camera, updated_db = ended
|
||||||
logger.debug(
|
print(
|
||||||
f"Processing finalized event for {camera} which updated the db: {updated_db}"
|
f"Processing finalized event for {camera} which updated the db: {updated_db}"
|
||||||
)
|
)
|
||||||
camera_config = self.config.cameras[camera]
|
camera_config = self.config.cameras[camera]
|
||||||
@ -180,7 +182,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
or set(event.zones) & set(camera_config.genai.required_zones)
|
or set(event.zones) & set(camera_config.genai.required_zones)
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
logger.debug(
|
print(
|
||||||
f"Description generation for {event}, has_snapshot: {event.has_snapshot}"
|
f"Description generation for {event}, has_snapshot: {event.has_snapshot}"
|
||||||
)
|
)
|
||||||
if event.has_snapshot and camera_config.genai.use_snapshot:
|
if event.has_snapshot and camera_config.genai.use_snapshot:
|
||||||
@ -235,14 +237,12 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
|
|
||||||
def _process_event_metadata(self):
|
def _process_event_metadata(self):
|
||||||
# Check for regenerate description requests
|
# Check for regenerate description requests
|
||||||
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update(
|
(topic, event_id, source) = self.event_metadata_subscriber.check_for_update()
|
||||||
timeout=0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
if topic is None:
|
if topic is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug(f"Handling event metadata for id {event_id} and source {source}")
|
print(f"Handling event metadata for id {event_id} and source {source}")
|
||||||
|
|
||||||
if event_id:
|
if event_id:
|
||||||
self.handle_regenerate_description(event_id, source)
|
self.handle_regenerate_description(event_id, source)
|
||||||
@ -276,7 +276,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not description:
|
if not description:
|
||||||
logger.debug("Failed to generate description for %s", event.id)
|
print("Failed to generate description for %s", event.id)
|
||||||
return
|
return
|
||||||
|
|
||||||
# fire and forget description update
|
# fire and forget description update
|
||||||
@ -288,7 +288,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
# Encode the description
|
# Encode the description
|
||||||
self.embeddings.upsert_description(event.id, description)
|
self.embeddings.upsert_description(event.id, description)
|
||||||
|
|
||||||
logger.debug(
|
print(
|
||||||
"Generated description for %s (%d images): %s",
|
"Generated description for %s (%d images): %s",
|
||||||
event.id,
|
event.id,
|
||||||
len(thumbnails),
|
len(thumbnails),
|
||||||
@ -309,7 +309,7 @@ class EmbeddingMaintainer(threading.Thread):
|
|||||||
|
|
||||||
thumbnail = base64.b64decode(event.thumbnail)
|
thumbnail = base64.b64decode(event.thumbnail)
|
||||||
|
|
||||||
logger.debug(
|
print(
|
||||||
f"Trying {source} regeneration for {event}, has_snapshot: {event.has_snapshot}"
|
f"Trying {source} regeneration for {event}, has_snapshot: {event.has_snapshot}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -120,4 +120,5 @@ class ModelDownloader:
|
|||||||
logger.info(f"Downloading complete: {url}")
|
logger.info(f"Downloading complete: {url}")
|
||||||
|
|
||||||
def wait_for_download(self):
|
def wait_for_download(self):
|
||||||
|
print("waiting for model download")
|
||||||
self.download_complete.wait()
|
self.download_complete.wait()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user