From 52a2bbc489c8a2fe5ba2e7bf2a796aeafed64d22 Mon Sep 17 00:00:00 2001
From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com>
Date: Mon, 7 Oct 2024 08:05:59 -0500
Subject: [PATCH] improve model downloading and add status screen
---
frigate/app.py | 4 +-
frigate/comms/dispatcher.py | 11 ++
frigate/const.py | 1 +
frigate/embeddings/__init__.py | 3 +-
frigate/embeddings/embeddings.py | 24 ++-
frigate/embeddings/functions/clip.py | 165 +++++++++++++------
frigate/embeddings/functions/minilm_l6_v2.py | 128 +++++++-------
frigate/types.py | 8 +
frigate/util/downloader.py | 119 +++++++++++++
web/src/api/ws.tsx | 36 ++++
web/src/pages/Explore.tsx | 128 +++++++++++---
web/src/types/ws.ts | 6 +
12 files changed, 492 insertions(+), 141 deletions(-)
create mode 100644 frigate/util/downloader.py
diff --git a/frigate/app.py b/frigate/app.py
index e370150bb..65272f2be 100644
--- a/frigate/app.py
+++ b/frigate/app.py
@@ -590,13 +590,13 @@ class FrigateApp:
self.init_onvif()
self.init_recording_manager()
self.init_review_segment_manager()
- self.init_embeddings_manager()
self.init_go2rtc()
self.bind_database()
self.check_db_data_migrations()
- self.init_embeddings_client()
self.init_inter_process_communicator()
self.init_dispatcher()
+ self.init_embeddings_manager()
+ self.init_embeddings_client()
self.start_detectors()
self.start_video_output_processor()
self.start_ptz_autotracker()
diff --git a/frigate/comms/dispatcher.py b/frigate/comms/dispatcher.py
index a987f6a38..1605d645a 100644
--- a/frigate/comms/dispatcher.py
+++ b/frigate/comms/dispatcher.py
@@ -16,10 +16,12 @@ from frigate.const import (
REQUEST_REGION_GRID,
UPDATE_CAMERA_ACTIVITY,
UPDATE_EVENT_DESCRIPTION,
+ UPDATE_MODEL_STATE,
UPSERT_REVIEW_SEGMENT,
)
from frigate.models import Event, Previews, Recordings, ReviewSegment
from frigate.ptz.onvif import OnvifCommandEnum, OnvifController
+from frigate.types import ModelStatusTypesEnum
from frigate.util.object import get_camera_regions_grid
from frigate.util.services import restart_frigate
@@ -83,6 +85,7 @@ class Dispatcher:
comm.subscribe(self._receive)
self.camera_activity = {}
+ self.model_state = {}
def _receive(self, topic: str, payload: str) -> Optional[Any]:
"""Handle receiving of payload from communicators."""
@@ -144,6 +147,14 @@ class Dispatcher:
"event_update",
json.dumps({"id": event.id, "description": event.data["description"]}),
)
+ elif topic == UPDATE_MODEL_STATE:
+ model = payload["model"]
+ state = payload["state"]
+ self.model_state[model] = ModelStatusTypesEnum[state]
+ self.publish("model_state", json.dumps(self.model_state))
+ elif topic == "modelState":
+ model_state = self.model_state.copy()
+ self.publish("model_state", json.dumps(model_state))
elif topic == "onConnect":
camera_status = self.camera_activity.copy()
diff --git a/frigate/const.py b/frigate/const.py
index b37ca662e..e8e841f4f 100644
--- a/frigate/const.py
+++ b/frigate/const.py
@@ -84,6 +84,7 @@ UPSERT_REVIEW_SEGMENT = "upsert_review_segment"
CLEAR_ONGOING_REVIEW_SEGMENTS = "clear_ongoing_review_segments"
UPDATE_CAMERA_ACTIVITY = "update_camera_activity"
UPDATE_EVENT_DESCRIPTION = "update_event_description"
+UPDATE_MODEL_STATE = "update_model_state"
# Stats Values
diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py
index 12fca7caf..381d95ed1 100644
--- a/frigate/embeddings/__init__.py
+++ b/frigate/embeddings/__init__.py
@@ -71,8 +71,7 @@ def manage_embeddings(config: FrigateConfig) -> None:
class EmbeddingsContext:
def __init__(self, db: SqliteVecQueueDatabase):
- self.db = db
- self.embeddings = Embeddings(self.db)
+ self.embeddings = Embeddings(db)
self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization()
diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py
index 2378ea64c..c763bf304 100644
--- a/frigate/embeddings/embeddings.py
+++ b/frigate/embeddings/embeddings.py
@@ -10,8 +10,11 @@ from typing import List, Tuple, Union
from PIL import Image
from playhouse.shortcuts import model_to_dict
+from frigate.comms.inter_process import InterProcessRequestor
+from frigate.const import UPDATE_MODEL_STATE
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event
+from frigate.types import ModelStatusTypesEnum
from .functions.clip import ClipEmbedding
from .functions.minilm_l6_v2 import MiniLMEmbedding
@@ -65,11 +68,30 @@ class Embeddings:
def __init__(self, db: SqliteVecQueueDatabase) -> None:
self.db = db
+ self.requestor = InterProcessRequestor()
# Create tables if they don't exist
self._create_tables()
- self.clip_embedding = ClipEmbedding(model="ViT-B/32")
+ models = [
+ "sentence-transformers/all-MiniLM-L6-v2-model.onnx",
+ "sentence-transformers/all-MiniLM-L6-v2-tokenizer",
+ "clip-clip_image_model_vitb32.onnx",
+ "clip-clip_text_model_vitb32.onnx",
+ ]
+
+ for model in models:
+ self.requestor.send_data(
+ UPDATE_MODEL_STATE,
+ {
+ "model": model,
+ "state": ModelStatusTypesEnum.not_downloaded,
+ },
+ )
+
+ self.clip_embedding = ClipEmbedding(
+ preferred_providers=["CPUExecutionProvider"]
+ )
self.minilm_embedding = MiniLMEmbedding(
preferred_providers=["CPUExecutionProvider"],
)
diff --git a/frigate/embeddings/functions/clip.py b/frigate/embeddings/functions/clip.py
index 55cdb3b47..a997bcb6f 100644
--- a/frigate/embeddings/functions/clip.py
+++ b/frigate/embeddings/functions/clip.py
@@ -1,30 +1,59 @@
-"""CLIP Embeddings for Frigate."""
-
-import errno
import logging
import os
-from pathlib import Path
-from typing import List, Union
+from typing import List, Optional, Union
import numpy as np
import onnxruntime as ort
-import requests
-from onnx_clip import OnnxClip
+from onnx_clip import OnnxClip, Preprocessor, Tokenizer
from PIL import Image
-from frigate.const import MODEL_CACHE_DIR
+from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
+from frigate.types import ModelStatusTypesEnum
+from frigate.util.downloader import ModelDownloader
+
+logger = logging.getLogger(__name__)
class Clip(OnnxClip):
- """Override load models to download to cache directory."""
+ """Override load models to use pre-downloaded models from cache directory."""
+
+ def __init__(
+ self,
+ model: str = "ViT-B/32",
+ batch_size: Optional[int] = None,
+ providers: List[str] = ["CPUExecutionProvider"],
+ ):
+ """
+ Instantiates the model and required encoding classes.
+
+ Args:
+ model: The model to utilize. Currently ViT-B/32 and RN50 are
+ allowed.
+ batch_size: If set, splits the lists in `get_image_embeddings`
+ and `get_text_embeddings` into batches of this size before
+ passing them to the model. The embeddings are then concatenated
+ back together before being returned. This is necessary when
+ passing large amounts of data (perhaps ~100 or more).
+ """
+ allowed_models = ["ViT-B/32", "RN50"]
+ if model not in allowed_models:
+ raise ValueError(f"`model` must be in {allowed_models}. Got {model}.")
+ if model == "ViT-B/32":
+ self.embedding_size = 512
+ elif model == "RN50":
+ self.embedding_size = 1024
+ self.image_model, self.text_model = self._load_models(model, providers)
+ self._tokenizer = Tokenizer()
+ self._preprocessor = Preprocessor()
+ self._batch_size = batch_size
@staticmethod
def _load_models(
model: str,
- silent: bool,
+ providers: List[str],
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
"""
- These models are a part of the container. Treat as as such.
+ Load models from cache directory.
"""
if model == "ViT-B/32":
IMAGE_MODEL_FILE = "clip_image_model_vitb32.onnx"
@@ -38,58 +67,92 @@ class Clip(OnnxClip):
models = []
for model_file in [IMAGE_MODEL_FILE, TEXT_MODEL_FILE]:
path = os.path.join(MODEL_CACHE_DIR, "clip", model_file)
- models.append(Clip._load_model(path, silent))
+ models.append(Clip._load_model(path, providers))
return models[0], models[1]
@staticmethod
- def _load_model(path: str, silent: bool):
- providers = ["CPUExecutionProvider"]
-
- try:
- if os.path.exists(path):
- return ort.InferenceSession(path, providers=providers)
- else:
- raise FileNotFoundError(
- errno.ENOENT,
- os.strerror(errno.ENOENT),
- path,
- )
- except Exception:
- s3_url = f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}"
- if not silent:
- logging.info(
- f"The model file ({path}) doesn't exist "
- f"or it is invalid. Downloading it from the public S3 "
- f"bucket: {s3_url}." # noqa: E501
- )
-
- # Download from S3
- # Saving to a temporary file first to avoid corrupting the file
- temporary_filename = Path(path).with_name(os.path.basename(path) + ".part")
-
- # Create any missing directories in the path
- temporary_filename.parent.mkdir(parents=True, exist_ok=True)
-
- with requests.get(s3_url, stream=True) as r:
- r.raise_for_status()
- with open(temporary_filename, "wb") as f:
- for chunk in r.iter_content(chunk_size=8192):
- f.write(chunk)
- f.flush()
- # Finally move the temporary file to the correct location
- temporary_filename.rename(path)
+ def _load_model(path: str, providers: List[str]):
+ if os.path.exists(path):
return ort.InferenceSession(path, providers=providers)
+ else:
+ logger.warning(f"CLIP model file {path} not found.")
+ return None
class ClipEmbedding:
"""Embedding function for CLIP model."""
- def __init__(self, model: str = "ViT-B/32"):
- """Initialize CLIP Embedding function."""
- self.model = Clip(model)
+ def __init__(
+ self,
+ model: str = "ViT-B/32",
+ silent: bool = False,
+ preferred_providers: List[str] = ["CPUExecutionProvider"],
+ ):
+ self.model_name = model
+ self.silent = silent
+ self.preferred_providers = preferred_providers
+ self.model_files = self._get_model_files()
+ self.model = None
+
+ self.downloader = ModelDownloader(
+ model_name="clip",
+ download_path=os.path.join(MODEL_CACHE_DIR, "clip"),
+ file_names=self.model_files,
+ download_func=self._download_model,
+ silent=self.silent,
+ )
+ self.downloader.ensure_model_files()
+
+ def _get_model_files(self):
+ if self.model_name == "ViT-B/32":
+ return ["clip_image_model_vitb32.onnx", "clip_text_model_vitb32.onnx"]
+ elif self.model_name == "RN50":
+ return ["clip_image_model_rn50.onnx", "clip_text_model_rn50.onnx"]
+ else:
+ raise ValueError(
+ f"Unexpected model {self.model_name}. No `.onnx` file found."
+ )
+
+ def _download_model(self, path: str):
+ s3_url = (
+ f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}"
+ )
+ try:
+ ModelDownloader.download_from_url(s3_url, path, self.silent)
+ self.downloader.requestor.send_data(
+ UPDATE_MODEL_STATE,
+ {
+ "model": f"{self.model_name}-{os.path.basename(path)}",
+ "state": ModelStatusTypesEnum.downloaded,
+ },
+ )
+ except Exception:
+ self.downloader.requestor.send_data(
+ UPDATE_MODEL_STATE,
+ {
+ "model": f"{self.model_name}-{os.path.basename(path)}",
+ "state": ModelStatusTypesEnum.error,
+ },
+ )
+
+ def _load_model(self):
+ if self.model is None:
+ self.downloader.wait_for_download()
+ self.model = Clip(self.model_name, providers=self.preferred_providers)
def __call__(self, input: Union[List[str], List[Image.Image]]) -> List[np.ndarray]:
+ self._load_model()
+ if (
+ self.model is None
+ or self.model.image_model is None
+ or self.model.text_model is None
+ ):
+ logger.info(
+ "CLIP model is not fully loaded. Please wait for the download to complete."
+ )
+ return []
+
embeddings = []
for item in input:
if isinstance(item, Image.Image):
diff --git a/frigate/embeddings/functions/minilm_l6_v2.py b/frigate/embeddings/functions/minilm_l6_v2.py
index 6a0e2d5ef..a3a8b45b3 100644
--- a/frigate/embeddings/functions/minilm_l6_v2.py
+++ b/frigate/embeddings/functions/minilm_l6_v2.py
@@ -1,21 +1,20 @@
-"""Embedding function for ONNX MiniLM-L6 model."""
-
-import errno
import logging
import os
-from pathlib import Path
from typing import List
import numpy as np
import onnxruntime as ort
-import requests
# importing this without pytorch or others causes a warning
# https://github.com/huggingface/transformers/issues/27214
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
from transformers import AutoTokenizer
-from frigate.const import MODEL_CACHE_DIR
+from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
+from frigate.types import ModelStatusTypesEnum
+from frigate.util.downloader import ModelDownloader
+
+logger = logging.getLogger(__name__)
class MiniLMEmbedding:
@@ -26,86 +25,83 @@ class MiniLMEmbedding:
IMAGE_MODEL_FILE = "model.onnx"
TOKENIZER_FILE = "tokenizer"
- def __init__(self, preferred_providers=None):
- """Initialize MiniLM Embedding function."""
- self.tokenizer = self._load_tokenizer()
+ def __init__(self, preferred_providers=["CPUExecutionProvider"]):
+ self.preferred_providers = preferred_providers
+ self.tokenizer = None
+ self.session = None
- model_path = os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE)
- if not os.path.exists(model_path):
- self._download_model()
+ self.downloader = ModelDownloader(
+ model_name=self.MODEL_NAME,
+ download_path=self.DOWNLOAD_PATH,
+ file_names=[self.IMAGE_MODEL_FILE, self.TOKENIZER_FILE],
+ download_func=self._download_model,
+ )
+ self.downloader.ensure_model_files()
- if preferred_providers is None:
- preferred_providers = ["CPUExecutionProvider"]
+ def _download_model(self, path: str):
+ try:
+ if os.path.basename(path) == self.IMAGE_MODEL_FILE:
+ s3_url = f"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/{self.IMAGE_MODEL_FILE}"
+ ModelDownloader.download_from_url(s3_url, path)
+ elif os.path.basename(path) == self.TOKENIZER_FILE:
+ logger.info("Downloading MiniLM tokenizer")
+ tokenizer = AutoTokenizer.from_pretrained(
+ self.MODEL_NAME, clean_up_tokenization_spaces=False
+ )
+ tokenizer.save_pretrained(path)
- self.session = self._load_model(model_path)
+ self.downloader.requestor.send_data(
+ UPDATE_MODEL_STATE,
+ {
+ "model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
+ "state": ModelStatusTypesEnum.downloaded,
+ },
+ )
+ except Exception:
+ self.downloader.requestor.send_data(
+ UPDATE_MODEL_STATE,
+ {
+ "model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
+ "state": ModelStatusTypesEnum.error,
+ },
+ )
+
+ def _load_model_and_tokenizer(self):
+ if self.tokenizer is None or self.session is None:
+ self.downloader.wait_for_download()
+ self.tokenizer = self._load_tokenizer()
+ self.session = self._load_model(
+ os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE),
+ self.preferred_providers,
+ )
def _load_tokenizer(self):
- """Load the tokenizer from the local path or download it if not available."""
tokenizer_path = os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE)
- if os.path.exists(tokenizer_path):
- return AutoTokenizer.from_pretrained(tokenizer_path)
- else:
- return AutoTokenizer.from_pretrained(self.MODEL_NAME)
-
- def _download_model(self):
- """Download the ONNX model and tokenizer from a remote source if they don't exist."""
- logging.info(f"Downloading {self.MODEL_NAME} ONNX model and tokenizer...")
-
- # Download the tokenizer
- tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
- os.makedirs(self.DOWNLOAD_PATH, exist_ok=True)
- tokenizer.save_pretrained(os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE))
-
- # Download the ONNX model
- s3_url = f"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/{self.IMAGE_MODEL_FILE}"
- model_path = os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE)
- self._download_from_url(s3_url, model_path)
-
- logging.info(f"Model and tokenizer saved to {self.DOWNLOAD_PATH}")
-
- def _download_from_url(self, url: str, save_path: str):
- """Download a file from a URL and save it to a specified path."""
- temporary_filename = Path(save_path).with_name(
- os.path.basename(save_path) + ".part"
+ return AutoTokenizer.from_pretrained(
+ tokenizer_path, clean_up_tokenization_spaces=False
)
- temporary_filename.parent.mkdir(parents=True, exist_ok=True)
- with requests.get(url, stream=True, allow_redirects=True) as r:
- # if the content type is HTML, it's not the actual model file
- if "text/html" in r.headers.get("Content-Type", ""):
- raise ValueError(
- f"Expected an ONNX file but received HTML from the URL: {url}"
- )
- # Ensure the download is successful
- r.raise_for_status()
-
- # Write the model to a temporary file first
- with open(temporary_filename, "wb") as f:
- for chunk in r.iter_content(chunk_size=8192):
- f.write(chunk)
-
- temporary_filename.rename(save_path)
-
- def _load_model(self, path: str):
- """Load the ONNX model from a given path."""
- providers = ["CPUExecutionProvider"]
+ def _load_model(self, path: str, providers: List[str]):
if os.path.exists(path):
return ort.InferenceSession(path, providers=providers)
else:
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), path)
+ logger.warning(f"MiniLM model file {path} not found.")
+ return None
def __call__(self, texts: List[str]) -> List[np.ndarray]:
- """Generate embeddings for the given texts."""
+ self._load_model_and_tokenizer()
+
+ if self.session is None or self.tokenizer is None:
+ logger.error("MiniLM model or tokenizer is not loaded.")
+ return []
+
inputs = self.tokenizer(
texts, padding=True, truncation=True, return_tensors="np"
)
-
input_names = [input.name for input in self.session.get_inputs()]
onnx_inputs = {name: inputs[name] for name in input_names if name in inputs}
- # Run inference
outputs = self.session.run(None, onnx_inputs)
-
embeddings = outputs[0].mean(axis=1)
return [embedding for embedding in embeddings]
diff --git a/frigate/types.py b/frigate/types.py
index 21f55e502..3e6ad46cc 100644
--- a/frigate/types.py
+++ b/frigate/types.py
@@ -1,3 +1,4 @@
+from enum import Enum
from typing import TypedDict
from frigate.camera import CameraMetrics
@@ -11,3 +12,10 @@ class StatsTrackingTypes(TypedDict):
latest_frigate_version: str
last_updated: int
processes: dict[str, int]
+
+
+class ModelStatusTypesEnum(str, Enum):
+ not_downloaded = "not_downloaded"
+ downloading = "downloading"
+ downloaded = "downloaded"
+ error = "error"
diff --git a/frigate/util/downloader.py b/frigate/util/downloader.py
new file mode 100644
index 000000000..65b45425a
--- /dev/null
+++ b/frigate/util/downloader.py
@@ -0,0 +1,119 @@
+import logging
+import os
+import threading
+import time
+from pathlib import Path
+from typing import Callable, List
+
+import requests
+
+from frigate.comms.inter_process import InterProcessRequestor
+from frigate.const import UPDATE_MODEL_STATE
+from frigate.types import ModelStatusTypesEnum
+
+logger = logging.getLogger(__name__)
+
+
+class FileLock:
+ def __init__(self, path):
+ self.path = path
+ self.lock_file = f"{path}.lock"
+
+ def acquire(self):
+ parent_dir = os.path.dirname(self.lock_file)
+ os.makedirs(parent_dir, exist_ok=True)
+
+ while True:
+ try:
+ with open(self.lock_file, "x"):
+ return
+ except FileExistsError:
+ time.sleep(0.1)
+
+ def release(self):
+ try:
+ os.remove(self.lock_file)
+ except FileNotFoundError:
+ pass
+
+
+class ModelDownloader:
+ def __init__(
+ self,
+ model_name: str,
+ download_path: str,
+ file_names: List[str],
+ download_func: Callable[[str], None],
+ silent: bool = False,
+ ):
+ self.model_name = model_name
+ self.download_path = download_path
+ self.file_names = file_names
+ self.download_func = download_func
+ self.silent = silent
+ self.requestor = InterProcessRequestor()
+ self.download_thread = None
+ self.download_complete = threading.Event()
+
+ def ensure_model_files(self):
+ for file in self.file_names:
+ self.requestor.send_data(
+ UPDATE_MODEL_STATE,
+ {
+ "model": f"{self.model_name}-{file}",
+ "state": ModelStatusTypesEnum.downloading,
+ },
+ )
+ self.download_thread = threading.Thread(target=self._download_models)
+ self.download_thread.start()
+
+ def _download_models(self):
+ for file_name in self.file_names:
+ path = os.path.join(self.download_path, file_name)
+ lock = FileLock(path)
+
+ if not os.path.exists(path):
+ lock.acquire()
+ try:
+ if not os.path.exists(path):
+ self.download_func(path)
+ finally:
+ lock.release()
+
+ self.requestor.send_data(
+ UPDATE_MODEL_STATE,
+ {
+ "model": f"{self.model_name}-{file_name}",
+ "state": ModelStatusTypesEnum.downloaded,
+ },
+ )
+
+ self.download_complete.set()
+
+ @staticmethod
+ def download_from_url(url: str, save_path: str, silent: bool = False):
+ temporary_filename = Path(save_path).with_name(
+ os.path.basename(save_path) + ".part"
+ )
+ temporary_filename.parent.mkdir(parents=True, exist_ok=True)
+
+ if not silent:
+ logger.info(f"Downloading model file from: {url}")
+
+ try:
+ with requests.get(url, stream=True, allow_redirects=True) as r:
+ r.raise_for_status()
+ with open(temporary_filename, "wb") as f:
+ for chunk in r.iter_content(chunk_size=8192):
+ f.write(chunk)
+
+ temporary_filename.rename(save_path)
+ except Exception as e:
+ logger.error(f"Error downloading model: {str(e)}")
+ raise
+
+ if not silent:
+ logger.info(f"Downloading complete: {url}")
+
+ def wait_for_download(self):
+ self.download_complete.wait()
diff --git a/web/src/api/ws.tsx b/web/src/api/ws.tsx
index 79d2bd3b4..a78722b66 100644
--- a/web/src/api/ws.tsx
+++ b/web/src/api/ws.tsx
@@ -5,6 +5,7 @@ import {
FrigateCameraState,
FrigateEvent,
FrigateReview,
+ ModelState,
ToggleableSetting,
} from "@/types/ws";
import { FrigateStats } from "@/types/stats";
@@ -266,6 +267,41 @@ export function useInitialCameraState(
return { payload: data ? data[camera] : undefined };
}
+export function useModelState(
+ model: string,
+ revalidateOnFocus: boolean = true,
+): { payload: ModelState } {
+ const {
+ value: { payload },
+ send: sendCommand,
+ } = useWs("model_state", "modelState");
+
+ const data = useDeepMemo(JSON.parse(payload as string));
+
+ useEffect(() => {
+ let listener = undefined;
+ if (revalidateOnFocus) {
+ sendCommand("modelState");
+ listener = () => {
+ if (document.visibilityState == "visible") {
+ sendCommand("modelState");
+ }
+ };
+ addEventListener("visibilitychange", listener);
+ }
+
+ return () => {
+ if (listener) {
+ removeEventListener("visibilitychange", listener);
+ }
+ };
+ // we know that these deps are correct
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, [revalidateOnFocus]);
+
+ return { payload: data ? data[model] : undefined };
+}
+
export function useMotionActivity(camera: string): { payload: string } {
const {
value: { payload },
diff --git a/web/src/pages/Explore.tsx b/web/src/pages/Explore.tsx
index 4af6e1f19..7fc0f9286 100644
--- a/web/src/pages/Explore.tsx
+++ b/web/src/pages/Explore.tsx
@@ -1,11 +1,15 @@
-import { useEventUpdate } from "@/api/ws";
+import { useEventUpdate, useModelState } from "@/api/ws";
+import ActivityIndicator from "@/components/indicators/activity-indicator";
import { useApiFilterArgs } from "@/hooks/use-api-filter";
import { useTimezone } from "@/hooks/use-date-utils";
import { FrigateConfig } from "@/types/frigateConfig";
import { SearchFilter, SearchQuery, SearchResult } from "@/types/search";
+import { ModelState } from "@/types/ws";
import SearchView from "@/views/search/SearchView";
import { useCallback, useEffect, useMemo, useState } from "react";
+import { LuCheck, LuExternalLink, LuX } from "react-icons/lu";
import { TbExclamationCircle } from "react-icons/tb";
+import { Link } from "react-router-dom";
import useSWR from "swr";
import useSWRInfinite from "swr/infinite";
@@ -111,14 +115,10 @@ export default function Explore() {
// paging
- // usually slow only on first run while downloading models
- const [isSlowLoading, setIsSlowLoading] = useState(false);
-
const getKey = (
pageIndex: number,
previousPageData: SearchResult[] | null,
): SearchQuery => {
- if (isSlowLoading && !similaritySearch) return null;
if (previousPageData && !previousPageData.length) return null; // reached the end
if (!searchQuery) return null;
@@ -143,12 +143,6 @@ export default function Explore() {
revalidateFirstPage: true,
revalidateOnFocus: true,
revalidateAll: false,
- onLoadingSlow: () => {
- if (!similaritySearch) {
- setIsSlowLoading(true);
- }
- },
- loadingTimeout: 15000,
});
const searchResults = useMemo(
@@ -188,17 +182,113 @@ export default function Explore() {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [eventUpdate]);
+ // model states
+
+ const { payload: minilmModelState } = useModelState(
+ "sentence-transformers/all-MiniLM-L6-v2-model.onnx",
+ );
+ const { payload: minilmTokenizerState } = useModelState(
+ "sentence-transformers/all-MiniLM-L6-v2-tokenizer",
+ );
+ const { payload: clipImageModelState } = useModelState(
+ "clip-clip_image_model_vitb32.onnx",
+ );
+ const { payload: clipTextModelState } = useModelState(
+ "clip-clip_text_model_vitb32.onnx",
+ );
+
+ const allModelsLoaded = useMemo(() => {
+ return (
+ minilmModelState === "downloaded" &&
+ minilmTokenizerState === "downloaded" &&
+ clipImageModelState === "downloaded" &&
+ clipTextModelState === "downloaded"
+ );
+ }, [
+ minilmModelState,
+ minilmTokenizerState,
+ clipImageModelState,
+ clipTextModelState,
+ ]);
+
+ const renderModelStateIcon = (modelState: ModelState) => {
+ if (modelState === "downloading") {
+ return
Search Unavailable
-- If this is your first time using Search, be patient while Frigate - downloads the necessary embeddings models. Check Frigate logs. -
+