From 52a2bbc489c8a2fe5ba2e7bf2a796aeafed64d22 Mon Sep 17 00:00:00 2001 From: Josh Hawkins <32435876+hawkeye217@users.noreply.github.com> Date: Mon, 7 Oct 2024 08:05:59 -0500 Subject: [PATCH] improve model downloading and add status screen --- frigate/app.py | 4 +- frigate/comms/dispatcher.py | 11 ++ frigate/const.py | 1 + frigate/embeddings/__init__.py | 3 +- frigate/embeddings/embeddings.py | 24 ++- frigate/embeddings/functions/clip.py | 165 +++++++++++++------ frigate/embeddings/functions/minilm_l6_v2.py | 128 +++++++------- frigate/types.py | 8 + frigate/util/downloader.py | 119 +++++++++++++ web/src/api/ws.tsx | 36 ++++ web/src/pages/Explore.tsx | 128 +++++++++++--- web/src/types/ws.ts | 6 + 12 files changed, 492 insertions(+), 141 deletions(-) create mode 100644 frigate/util/downloader.py diff --git a/frigate/app.py b/frigate/app.py index e370150bb..65272f2be 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -590,13 +590,13 @@ class FrigateApp: self.init_onvif() self.init_recording_manager() self.init_review_segment_manager() - self.init_embeddings_manager() self.init_go2rtc() self.bind_database() self.check_db_data_migrations() - self.init_embeddings_client() self.init_inter_process_communicator() self.init_dispatcher() + self.init_embeddings_manager() + self.init_embeddings_client() self.start_detectors() self.start_video_output_processor() self.start_ptz_autotracker() diff --git a/frigate/comms/dispatcher.py b/frigate/comms/dispatcher.py index a987f6a38..1605d645a 100644 --- a/frigate/comms/dispatcher.py +++ b/frigate/comms/dispatcher.py @@ -16,10 +16,12 @@ from frigate.const import ( REQUEST_REGION_GRID, UPDATE_CAMERA_ACTIVITY, UPDATE_EVENT_DESCRIPTION, + UPDATE_MODEL_STATE, UPSERT_REVIEW_SEGMENT, ) from frigate.models import Event, Previews, Recordings, ReviewSegment from frigate.ptz.onvif import OnvifCommandEnum, OnvifController +from frigate.types import ModelStatusTypesEnum from frigate.util.object import get_camera_regions_grid from frigate.util.services import restart_frigate @@ -83,6 +85,7 @@ class Dispatcher: comm.subscribe(self._receive) self.camera_activity = {} + self.model_state = {} def _receive(self, topic: str, payload: str) -> Optional[Any]: """Handle receiving of payload from communicators.""" @@ -144,6 +147,14 @@ class Dispatcher: "event_update", json.dumps({"id": event.id, "description": event.data["description"]}), ) + elif topic == UPDATE_MODEL_STATE: + model = payload["model"] + state = payload["state"] + self.model_state[model] = ModelStatusTypesEnum[state] + self.publish("model_state", json.dumps(self.model_state)) + elif topic == "modelState": + model_state = self.model_state.copy() + self.publish("model_state", json.dumps(model_state)) elif topic == "onConnect": camera_status = self.camera_activity.copy() diff --git a/frigate/const.py b/frigate/const.py index b37ca662e..e8e841f4f 100644 --- a/frigate/const.py +++ b/frigate/const.py @@ -84,6 +84,7 @@ UPSERT_REVIEW_SEGMENT = "upsert_review_segment" CLEAR_ONGOING_REVIEW_SEGMENTS = "clear_ongoing_review_segments" UPDATE_CAMERA_ACTIVITY = "update_camera_activity" UPDATE_EVENT_DESCRIPTION = "update_event_description" +UPDATE_MODEL_STATE = "update_model_state" # Stats Values diff --git a/frigate/embeddings/__init__.py b/frigate/embeddings/__init__.py index 12fca7caf..381d95ed1 100644 --- a/frigate/embeddings/__init__.py +++ b/frigate/embeddings/__init__.py @@ -71,8 +71,7 @@ def manage_embeddings(config: FrigateConfig) -> None: class EmbeddingsContext: def __init__(self, db: SqliteVecQueueDatabase): - self.db = db - self.embeddings = Embeddings(self.db) + self.embeddings = Embeddings(db) self.thumb_stats = ZScoreNormalization() self.desc_stats = ZScoreNormalization() diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 2378ea64c..c763bf304 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -10,8 +10,11 @@ from typing import List, Tuple, Union from PIL import Image from playhouse.shortcuts import model_to_dict +from frigate.comms.inter_process import InterProcessRequestor +from frigate.const import UPDATE_MODEL_STATE from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.models import Event +from frigate.types import ModelStatusTypesEnum from .functions.clip import ClipEmbedding from .functions.minilm_l6_v2 import MiniLMEmbedding @@ -65,11 +68,30 @@ class Embeddings: def __init__(self, db: SqliteVecQueueDatabase) -> None: self.db = db + self.requestor = InterProcessRequestor() # Create tables if they don't exist self._create_tables() - self.clip_embedding = ClipEmbedding(model="ViT-B/32") + models = [ + "sentence-transformers/all-MiniLM-L6-v2-model.onnx", + "sentence-transformers/all-MiniLM-L6-v2-tokenizer", + "clip-clip_image_model_vitb32.onnx", + "clip-clip_text_model_vitb32.onnx", + ] + + for model in models: + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": model, + "state": ModelStatusTypesEnum.not_downloaded, + }, + ) + + self.clip_embedding = ClipEmbedding( + preferred_providers=["CPUExecutionProvider"] + ) self.minilm_embedding = MiniLMEmbedding( preferred_providers=["CPUExecutionProvider"], ) diff --git a/frigate/embeddings/functions/clip.py b/frigate/embeddings/functions/clip.py index 55cdb3b47..a997bcb6f 100644 --- a/frigate/embeddings/functions/clip.py +++ b/frigate/embeddings/functions/clip.py @@ -1,30 +1,59 @@ -"""CLIP Embeddings for Frigate.""" - -import errno import logging import os -from pathlib import Path -from typing import List, Union +from typing import List, Optional, Union import numpy as np import onnxruntime as ort -import requests -from onnx_clip import OnnxClip +from onnx_clip import OnnxClip, Preprocessor, Tokenizer from PIL import Image -from frigate.const import MODEL_CACHE_DIR +from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE +from frigate.types import ModelStatusTypesEnum +from frigate.util.downloader import ModelDownloader + +logger = logging.getLogger(__name__) class Clip(OnnxClip): - """Override load models to download to cache directory.""" + """Override load models to use pre-downloaded models from cache directory.""" + + def __init__( + self, + model: str = "ViT-B/32", + batch_size: Optional[int] = None, + providers: List[str] = ["CPUExecutionProvider"], + ): + """ + Instantiates the model and required encoding classes. + + Args: + model: The model to utilize. Currently ViT-B/32 and RN50 are + allowed. + batch_size: If set, splits the lists in `get_image_embeddings` + and `get_text_embeddings` into batches of this size before + passing them to the model. The embeddings are then concatenated + back together before being returned. This is necessary when + passing large amounts of data (perhaps ~100 or more). + """ + allowed_models = ["ViT-B/32", "RN50"] + if model not in allowed_models: + raise ValueError(f"`model` must be in {allowed_models}. Got {model}.") + if model == "ViT-B/32": + self.embedding_size = 512 + elif model == "RN50": + self.embedding_size = 1024 + self.image_model, self.text_model = self._load_models(model, providers) + self._tokenizer = Tokenizer() + self._preprocessor = Preprocessor() + self._batch_size = batch_size @staticmethod def _load_models( model: str, - silent: bool, + providers: List[str], ) -> tuple[ort.InferenceSession, ort.InferenceSession]: """ - These models are a part of the container. Treat as as such. + Load models from cache directory. """ if model == "ViT-B/32": IMAGE_MODEL_FILE = "clip_image_model_vitb32.onnx" @@ -38,58 +67,92 @@ class Clip(OnnxClip): models = [] for model_file in [IMAGE_MODEL_FILE, TEXT_MODEL_FILE]: path = os.path.join(MODEL_CACHE_DIR, "clip", model_file) - models.append(Clip._load_model(path, silent)) + models.append(Clip._load_model(path, providers)) return models[0], models[1] @staticmethod - def _load_model(path: str, silent: bool): - providers = ["CPUExecutionProvider"] - - try: - if os.path.exists(path): - return ort.InferenceSession(path, providers=providers) - else: - raise FileNotFoundError( - errno.ENOENT, - os.strerror(errno.ENOENT), - path, - ) - except Exception: - s3_url = f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}" - if not silent: - logging.info( - f"The model file ({path}) doesn't exist " - f"or it is invalid. Downloading it from the public S3 " - f"bucket: {s3_url}." # noqa: E501 - ) - - # Download from S3 - # Saving to a temporary file first to avoid corrupting the file - temporary_filename = Path(path).with_name(os.path.basename(path) + ".part") - - # Create any missing directories in the path - temporary_filename.parent.mkdir(parents=True, exist_ok=True) - - with requests.get(s3_url, stream=True) as r: - r.raise_for_status() - with open(temporary_filename, "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - f.flush() - # Finally move the temporary file to the correct location - temporary_filename.rename(path) + def _load_model(path: str, providers: List[str]): + if os.path.exists(path): return ort.InferenceSession(path, providers=providers) + else: + logger.warning(f"CLIP model file {path} not found.") + return None class ClipEmbedding: """Embedding function for CLIP model.""" - def __init__(self, model: str = "ViT-B/32"): - """Initialize CLIP Embedding function.""" - self.model = Clip(model) + def __init__( + self, + model: str = "ViT-B/32", + silent: bool = False, + preferred_providers: List[str] = ["CPUExecutionProvider"], + ): + self.model_name = model + self.silent = silent + self.preferred_providers = preferred_providers + self.model_files = self._get_model_files() + self.model = None + + self.downloader = ModelDownloader( + model_name="clip", + download_path=os.path.join(MODEL_CACHE_DIR, "clip"), + file_names=self.model_files, + download_func=self._download_model, + silent=self.silent, + ) + self.downloader.ensure_model_files() + + def _get_model_files(self): + if self.model_name == "ViT-B/32": + return ["clip_image_model_vitb32.onnx", "clip_text_model_vitb32.onnx"] + elif self.model_name == "RN50": + return ["clip_image_model_rn50.onnx", "clip_text_model_rn50.onnx"] + else: + raise ValueError( + f"Unexpected model {self.model_name}. No `.onnx` file found." + ) + + def _download_model(self, path: str): + s3_url = ( + f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}" + ) + try: + ModelDownloader.download_from_url(s3_url, path, self.silent) + self.downloader.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{self.model_name}-{os.path.basename(path)}", + "state": ModelStatusTypesEnum.downloaded, + }, + ) + except Exception: + self.downloader.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{self.model_name}-{os.path.basename(path)}", + "state": ModelStatusTypesEnum.error, + }, + ) + + def _load_model(self): + if self.model is None: + self.downloader.wait_for_download() + self.model = Clip(self.model_name, providers=self.preferred_providers) def __call__(self, input: Union[List[str], List[Image.Image]]) -> List[np.ndarray]: + self._load_model() + if ( + self.model is None + or self.model.image_model is None + or self.model.text_model is None + ): + logger.info( + "CLIP model is not fully loaded. Please wait for the download to complete." + ) + return [] + embeddings = [] for item in input: if isinstance(item, Image.Image): diff --git a/frigate/embeddings/functions/minilm_l6_v2.py b/frigate/embeddings/functions/minilm_l6_v2.py index 6a0e2d5ef..a3a8b45b3 100644 --- a/frigate/embeddings/functions/minilm_l6_v2.py +++ b/frigate/embeddings/functions/minilm_l6_v2.py @@ -1,21 +1,20 @@ -"""Embedding function for ONNX MiniLM-L6 model.""" - -import errno import logging import os -from pathlib import Path from typing import List import numpy as np import onnxruntime as ort -import requests # importing this without pytorch or others causes a warning # https://github.com/huggingface/transformers/issues/27214 # suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1 from transformers import AutoTokenizer -from frigate.const import MODEL_CACHE_DIR +from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE +from frigate.types import ModelStatusTypesEnum +from frigate.util.downloader import ModelDownloader + +logger = logging.getLogger(__name__) class MiniLMEmbedding: @@ -26,86 +25,83 @@ class MiniLMEmbedding: IMAGE_MODEL_FILE = "model.onnx" TOKENIZER_FILE = "tokenizer" - def __init__(self, preferred_providers=None): - """Initialize MiniLM Embedding function.""" - self.tokenizer = self._load_tokenizer() + def __init__(self, preferred_providers=["CPUExecutionProvider"]): + self.preferred_providers = preferred_providers + self.tokenizer = None + self.session = None - model_path = os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE) - if not os.path.exists(model_path): - self._download_model() + self.downloader = ModelDownloader( + model_name=self.MODEL_NAME, + download_path=self.DOWNLOAD_PATH, + file_names=[self.IMAGE_MODEL_FILE, self.TOKENIZER_FILE], + download_func=self._download_model, + ) + self.downloader.ensure_model_files() - if preferred_providers is None: - preferred_providers = ["CPUExecutionProvider"] + def _download_model(self, path: str): + try: + if os.path.basename(path) == self.IMAGE_MODEL_FILE: + s3_url = f"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/{self.IMAGE_MODEL_FILE}" + ModelDownloader.download_from_url(s3_url, path) + elif os.path.basename(path) == self.TOKENIZER_FILE: + logger.info("Downloading MiniLM tokenizer") + tokenizer = AutoTokenizer.from_pretrained( + self.MODEL_NAME, clean_up_tokenization_spaces=False + ) + tokenizer.save_pretrained(path) - self.session = self._load_model(model_path) + self.downloader.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{self.MODEL_NAME}-{os.path.basename(path)}", + "state": ModelStatusTypesEnum.downloaded, + }, + ) + except Exception: + self.downloader.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{self.MODEL_NAME}-{os.path.basename(path)}", + "state": ModelStatusTypesEnum.error, + }, + ) + + def _load_model_and_tokenizer(self): + if self.tokenizer is None or self.session is None: + self.downloader.wait_for_download() + self.tokenizer = self._load_tokenizer() + self.session = self._load_model( + os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE), + self.preferred_providers, + ) def _load_tokenizer(self): - """Load the tokenizer from the local path or download it if not available.""" tokenizer_path = os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE) - if os.path.exists(tokenizer_path): - return AutoTokenizer.from_pretrained(tokenizer_path) - else: - return AutoTokenizer.from_pretrained(self.MODEL_NAME) - - def _download_model(self): - """Download the ONNX model and tokenizer from a remote source if they don't exist.""" - logging.info(f"Downloading {self.MODEL_NAME} ONNX model and tokenizer...") - - # Download the tokenizer - tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME) - os.makedirs(self.DOWNLOAD_PATH, exist_ok=True) - tokenizer.save_pretrained(os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE)) - - # Download the ONNX model - s3_url = f"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/{self.IMAGE_MODEL_FILE}" - model_path = os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE) - self._download_from_url(s3_url, model_path) - - logging.info(f"Model and tokenizer saved to {self.DOWNLOAD_PATH}") - - def _download_from_url(self, url: str, save_path: str): - """Download a file from a URL and save it to a specified path.""" - temporary_filename = Path(save_path).with_name( - os.path.basename(save_path) + ".part" + return AutoTokenizer.from_pretrained( + tokenizer_path, clean_up_tokenization_spaces=False ) - temporary_filename.parent.mkdir(parents=True, exist_ok=True) - with requests.get(url, stream=True, allow_redirects=True) as r: - # if the content type is HTML, it's not the actual model file - if "text/html" in r.headers.get("Content-Type", ""): - raise ValueError( - f"Expected an ONNX file but received HTML from the URL: {url}" - ) - # Ensure the download is successful - r.raise_for_status() - - # Write the model to a temporary file first - with open(temporary_filename, "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - - temporary_filename.rename(save_path) - - def _load_model(self, path: str): - """Load the ONNX model from a given path.""" - providers = ["CPUExecutionProvider"] + def _load_model(self, path: str, providers: List[str]): if os.path.exists(path): return ort.InferenceSession(path, providers=providers) else: - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), path) + logger.warning(f"MiniLM model file {path} not found.") + return None def __call__(self, texts: List[str]) -> List[np.ndarray]: - """Generate embeddings for the given texts.""" + self._load_model_and_tokenizer() + + if self.session is None or self.tokenizer is None: + logger.error("MiniLM model or tokenizer is not loaded.") + return [] + inputs = self.tokenizer( texts, padding=True, truncation=True, return_tensors="np" ) - input_names = [input.name for input in self.session.get_inputs()] onnx_inputs = {name: inputs[name] for name in input_names if name in inputs} - # Run inference outputs = self.session.run(None, onnx_inputs) - embeddings = outputs[0].mean(axis=1) return [embedding for embedding in embeddings] diff --git a/frigate/types.py b/frigate/types.py index 21f55e502..3e6ad46cc 100644 --- a/frigate/types.py +++ b/frigate/types.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import TypedDict from frigate.camera import CameraMetrics @@ -11,3 +12,10 @@ class StatsTrackingTypes(TypedDict): latest_frigate_version: str last_updated: int processes: dict[str, int] + + +class ModelStatusTypesEnum(str, Enum): + not_downloaded = "not_downloaded" + downloading = "downloading" + downloaded = "downloaded" + error = "error" diff --git a/frigate/util/downloader.py b/frigate/util/downloader.py new file mode 100644 index 000000000..65b45425a --- /dev/null +++ b/frigate/util/downloader.py @@ -0,0 +1,119 @@ +import logging +import os +import threading +import time +from pathlib import Path +from typing import Callable, List + +import requests + +from frigate.comms.inter_process import InterProcessRequestor +from frigate.const import UPDATE_MODEL_STATE +from frigate.types import ModelStatusTypesEnum + +logger = logging.getLogger(__name__) + + +class FileLock: + def __init__(self, path): + self.path = path + self.lock_file = f"{path}.lock" + + def acquire(self): + parent_dir = os.path.dirname(self.lock_file) + os.makedirs(parent_dir, exist_ok=True) + + while True: + try: + with open(self.lock_file, "x"): + return + except FileExistsError: + time.sleep(0.1) + + def release(self): + try: + os.remove(self.lock_file) + except FileNotFoundError: + pass + + +class ModelDownloader: + def __init__( + self, + model_name: str, + download_path: str, + file_names: List[str], + download_func: Callable[[str], None], + silent: bool = False, + ): + self.model_name = model_name + self.download_path = download_path + self.file_names = file_names + self.download_func = download_func + self.silent = silent + self.requestor = InterProcessRequestor() + self.download_thread = None + self.download_complete = threading.Event() + + def ensure_model_files(self): + for file in self.file_names: + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{self.model_name}-{file}", + "state": ModelStatusTypesEnum.downloading, + }, + ) + self.download_thread = threading.Thread(target=self._download_models) + self.download_thread.start() + + def _download_models(self): + for file_name in self.file_names: + path = os.path.join(self.download_path, file_name) + lock = FileLock(path) + + if not os.path.exists(path): + lock.acquire() + try: + if not os.path.exists(path): + self.download_func(path) + finally: + lock.release() + + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{self.model_name}-{file_name}", + "state": ModelStatusTypesEnum.downloaded, + }, + ) + + self.download_complete.set() + + @staticmethod + def download_from_url(url: str, save_path: str, silent: bool = False): + temporary_filename = Path(save_path).with_name( + os.path.basename(save_path) + ".part" + ) + temporary_filename.parent.mkdir(parents=True, exist_ok=True) + + if not silent: + logger.info(f"Downloading model file from: {url}") + + try: + with requests.get(url, stream=True, allow_redirects=True) as r: + r.raise_for_status() + with open(temporary_filename, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + temporary_filename.rename(save_path) + except Exception as e: + logger.error(f"Error downloading model: {str(e)}") + raise + + if not silent: + logger.info(f"Downloading complete: {url}") + + def wait_for_download(self): + self.download_complete.wait() diff --git a/web/src/api/ws.tsx b/web/src/api/ws.tsx index 79d2bd3b4..a78722b66 100644 --- a/web/src/api/ws.tsx +++ b/web/src/api/ws.tsx @@ -5,6 +5,7 @@ import { FrigateCameraState, FrigateEvent, FrigateReview, + ModelState, ToggleableSetting, } from "@/types/ws"; import { FrigateStats } from "@/types/stats"; @@ -266,6 +267,41 @@ export function useInitialCameraState( return { payload: data ? data[camera] : undefined }; } +export function useModelState( + model: string, + revalidateOnFocus: boolean = true, +): { payload: ModelState } { + const { + value: { payload }, + send: sendCommand, + } = useWs("model_state", "modelState"); + + const data = useDeepMemo(JSON.parse(payload as string)); + + useEffect(() => { + let listener = undefined; + if (revalidateOnFocus) { + sendCommand("modelState"); + listener = () => { + if (document.visibilityState == "visible") { + sendCommand("modelState"); + } + }; + addEventListener("visibilitychange", listener); + } + + return () => { + if (listener) { + removeEventListener("visibilitychange", listener); + } + }; + // we know that these deps are correct + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [revalidateOnFocus]); + + return { payload: data ? data[model] : undefined }; +} + export function useMotionActivity(camera: string): { payload: string } { const { value: { payload }, diff --git a/web/src/pages/Explore.tsx b/web/src/pages/Explore.tsx index 4af6e1f19..7fc0f9286 100644 --- a/web/src/pages/Explore.tsx +++ b/web/src/pages/Explore.tsx @@ -1,11 +1,15 @@ -import { useEventUpdate } from "@/api/ws"; +import { useEventUpdate, useModelState } from "@/api/ws"; +import ActivityIndicator from "@/components/indicators/activity-indicator"; import { useApiFilterArgs } from "@/hooks/use-api-filter"; import { useTimezone } from "@/hooks/use-date-utils"; import { FrigateConfig } from "@/types/frigateConfig"; import { SearchFilter, SearchQuery, SearchResult } from "@/types/search"; +import { ModelState } from "@/types/ws"; import SearchView from "@/views/search/SearchView"; import { useCallback, useEffect, useMemo, useState } from "react"; +import { LuCheck, LuExternalLink, LuX } from "react-icons/lu"; import { TbExclamationCircle } from "react-icons/tb"; +import { Link } from "react-router-dom"; import useSWR from "swr"; import useSWRInfinite from "swr/infinite"; @@ -111,14 +115,10 @@ export default function Explore() { // paging - // usually slow only on first run while downloading models - const [isSlowLoading, setIsSlowLoading] = useState(false); - const getKey = ( pageIndex: number, previousPageData: SearchResult[] | null, ): SearchQuery => { - if (isSlowLoading && !similaritySearch) return null; if (previousPageData && !previousPageData.length) return null; // reached the end if (!searchQuery) return null; @@ -143,12 +143,6 @@ export default function Explore() { revalidateFirstPage: true, revalidateOnFocus: true, revalidateAll: false, - onLoadingSlow: () => { - if (!similaritySearch) { - setIsSlowLoading(true); - } - }, - loadingTimeout: 15000, }); const searchResults = useMemo( @@ -188,17 +182,113 @@ export default function Explore() { // eslint-disable-next-line react-hooks/exhaustive-deps }, [eventUpdate]); + // model states + + const { payload: minilmModelState } = useModelState( + "sentence-transformers/all-MiniLM-L6-v2-model.onnx", + ); + const { payload: minilmTokenizerState } = useModelState( + "sentence-transformers/all-MiniLM-L6-v2-tokenizer", + ); + const { payload: clipImageModelState } = useModelState( + "clip-clip_image_model_vitb32.onnx", + ); + const { payload: clipTextModelState } = useModelState( + "clip-clip_text_model_vitb32.onnx", + ); + + const allModelsLoaded = useMemo(() => { + return ( + minilmModelState === "downloaded" && + minilmTokenizerState === "downloaded" && + clipImageModelState === "downloaded" && + clipTextModelState === "downloaded" + ); + }, [ + minilmModelState, + minilmTokenizerState, + clipImageModelState, + clipTextModelState, + ]); + + const renderModelStateIcon = (modelState: ModelState) => { + if (modelState === "downloading") { + return ; + } + if (modelState === "downloaded") { + return ; + } + if (modelState === "not_downloaded" || modelState === "error") { + return ; + } + return null; + }; + + if ( + !minilmModelState || + !minilmTokenizerState || + !clipImageModelState || + !clipTextModelState + ) { + return ( + + ); + } + return ( <> - {isSlowLoading && !similaritySearch ? ( + {!allModelsLoaded ? (
-
-

Search Unavailable

- -

- If this is your first time using Search, be patient while Frigate - downloads the necessary embeddings models. Check Frigate logs. -

+
+
+ +
Search Unavailable
+
+
+ Frigate is downloading the necessary embeddings models to support + semantic searching. This may take several minutes depending on the + speed of your network connection. +
+
+
+ {renderModelStateIcon(clipImageModelState)} + CLIP image model +
+
+ {renderModelStateIcon(clipTextModelState)} + CLIP text model +
+
+ {renderModelStateIcon(minilmModelState)} + MiniLM sentence model +
+
+ {renderModelStateIcon(minilmTokenizerState)} + MiniLM tokenizer +
+
+ {(minilmModelState === "error" || + clipImageModelState === "error" || + clipTextModelState === "error") && ( +
+ An error has occurred. Check Frigate logs. +
+ )} +
+ You may want to reindex the embeddings of your tracked objects + once the models are downloaded. +
+
+ + Read the documentation{" "} + + +
) : ( diff --git a/web/src/types/ws.ts b/web/src/types/ws.ts index 0fae44b07..a8211d269 100644 --- a/web/src/types/ws.ts +++ b/web/src/types/ws.ts @@ -56,4 +56,10 @@ export interface FrigateCameraState { objects: ObjectType[]; } +export type ModelState = + | "not_downloaded" + | "downloading" + | "downloaded" + | "error"; + export type ToggleableSetting = "ON" | "OFF";