diff --git a/frigate/config/classification.py b/frigate/config/classification.py index 9d5b16561..fb8e3de29 100644 --- a/frigate/config/classification.py +++ b/frigate/config/classification.py @@ -19,7 +19,6 @@ __all__ = [ class SemanticSearchModelEnum(str, Enum): jinav1 = "jinav1" jinav2 = "jinav2" - ax_jinav2 = "ax_jinav2" class EnrichmentsDeviceEnum(str, Enum): diff --git a/frigate/detectors/detection_runners.py b/frigate/detectors/detection_runners.py index fcbb41e66..7565c9a3d 100644 --- a/frigate/detectors/detection_runners.py +++ b/frigate/detectors/detection_runners.py @@ -10,6 +10,10 @@ from typing import Any import numpy as np import onnxruntime as ort +from frigate.util.axengine_converter import ( + auto_convert_model as auto_load_axengine_model, +) +from frigate.util.axengine_converter import is_axengine_compatible from frigate.util.model import get_ort_providers from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible @@ -548,12 +552,135 @@ class RKNNModelRunner(BaseModelRunner): pass +class AXEngineModelRunner(BaseModelRunner): + """Run AXEngine models for embeddings.""" + + _mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32).reshape( + 1, 3, 1, 1 + ) + _std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32).reshape( + 1, 3, 1, 1 + ) + + def __init__(self, model_path: str, model_type: str | None = None): + self.model_path = model_path + self.model_type = model_type + self._inference_lock = threading.Lock() + self.image_session = None + self.text_session = None + self.text_pad_token_id = 0 + self._load_model() + + def _load_model(self): + try: + import axengine as axe + from transformers import AutoTokenizer + except ImportError: + logger.error("AXEngine is not available") + raise ImportError("AXEngine is not available") + + model_dir = os.path.dirname(self.model_path) + image_model_path = os.path.join(model_dir, "image_encoder.axmodel") + text_model_path = os.path.join(model_dir, "text_encoder.axmodel") + tokenizer_path = os.path.join(model_dir, "tokenizer") + + self.image_session = axe.InferenceSession(image_model_path) + self.text_session = axe.InferenceSession(text_model_path) + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + trust_remote_code=True, + clean_up_tokenization_spaces=True, + ) + if tokenizer.pad_token_id is not None: + self.text_pad_token_id = int(tokenizer.pad_token_id) + except Exception: + logger.warning( + "Failed to load tokenizer from %s for AXEngine padding, using 0", + tokenizer_path, + ) + + def get_input_names(self) -> list[str]: + return ["input_ids", "pixel_values"] + + def get_input_width(self) -> int: + return 512 + + @staticmethod + def _has_real_text_inputs(inputs: dict[str, Any]) -> bool: + input_ids = inputs.get("input_ids") + + if input_ids is None: + return False + + if input_ids.ndim < 2: + return False + + return input_ids.shape[-1] != 16 or np.any(input_ids) + + @staticmethod + def _has_real_image_inputs(inputs: dict[str, Any]) -> bool: + pixel_values = inputs.get("pixel_values") + + return pixel_values is not None and np.any(pixel_values) + + def _prepare_text_inputs(self, input_ids: np.ndarray) -> np.ndarray: + padded_input_ids = np.full((1, 50), self.text_pad_token_id, dtype=np.int32) + truncated_input_ids = input_ids.reshape(1, -1)[:, :50].astype(np.int32) + padded_input_ids[:, : truncated_input_ids.shape[1]] = truncated_input_ids + return padded_input_ids + + @classmethod + def _prepare_pixel_values(cls, pixel_values: np.ndarray) -> np.ndarray: + if len(pixel_values.shape) == 3: + pixel_values = pixel_values[None, ...] + + pixel_values = pixel_values.astype(np.float32) + return (pixel_values - cls._mean) / cls._std + + def run(self, inputs: dict[str, Any]) -> list[np.ndarray | None]: + outputs: list[np.ndarray | None] = [None, None, None, None] + + with self._inference_lock: + if self._has_real_text_inputs(inputs): + text_embeddings = [] + for input_ids in inputs["input_ids"]: + text_embeddings.append( + self.text_session.run( + None, + {"inputs_id": self._prepare_text_inputs(input_ids)}, + )[0][0] + ) + outputs[2] = np.array(text_embeddings) + + if self._has_real_image_inputs(inputs): + image_embeddings = [] + for pixel_values in inputs["pixel_values"]: + image_embeddings.append( + self.image_session.run( + None, + {"pixel_values": self._prepare_pixel_values(pixel_values)}, + )[0][0] + ) + + outputs[3] = np.array(image_embeddings) + + return outputs + + def get_optimized_runner( model_path: str, device: str | None, model_type: str, **kwargs ) -> BaseModelRunner: """Get an optimized runner for the hardware.""" device = device or "AUTO" + if is_axengine_compatible(model_path, device, model_type): + axmodel_path = auto_load_axengine_model(model_path, model_type) + + if axmodel_path: + return AXEngineModelRunner(axmodel_path, model_type) + if device != "CPU" and is_rknn_compatible(model_path): rknn_path = auto_convert_model(model_path) diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 835986a58..8d7bcd235 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -30,7 +30,6 @@ from frigate.util.file import get_event_thumbnail_bytes from .onnx.jina_v1_embedding import JinaV1ImageEmbedding, JinaV1TextEmbedding from .onnx.jina_v2_embedding import JinaV2Embedding -from .onnx.jina_v2_embedding_ax import AXJinaV2Embedding logger = logging.getLogger(__name__) @@ -119,18 +118,6 @@ class Embeddings: self.vision_embedding = lambda input_data: self.embedding( input_data, embedding_type="vision" ) - elif self.config.semantic_search.model == SemanticSearchModelEnum.ax_jinav2: - # AXJinaV2Embedding instance for both text and vision - self.embedding = AXJinaV2Embedding( - model_size=self.config.semantic_search.model_size, - requestor=self.requestor, - ) - self.text_embedding = lambda input_data: self.embedding( - input_data, embedding_type="text" - ) - self.vision_embedding = lambda input_data: self.embedding( - input_data, embedding_type="vision" - ) else: # Default to jinav1 self.text_embedding = JinaV1TextEmbedding( model_size=config.semantic_search.model_size, diff --git a/frigate/embeddings/onnx/jina_v2_embedding.py b/frigate/embeddings/onnx/jina_v2_embedding.py index 1abd968c9..aa3947943 100644 --- a/frigate/embeddings/onnx/jina_v2_embedding.py +++ b/frigate/embeddings/onnx/jina_v2_embedding.py @@ -37,13 +37,18 @@ class JinaV2Embedding(BaseEmbedding): "model_fp16.onnx" if model_size == "large" else "model_quantized.onnx" ) HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") + use_axengine = (device or "").upper() == "AXENGINE" super().__init__( model_name="jinaai/jina-clip-v2", model_file=model_file, - download_urls={ - model_file: f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/onnx/{model_file}", - "preprocessor_config.json": f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/preprocessor_config.json", - }, + download_urls=( + {} + if use_axengine + else { + model_file: f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/onnx/{model_file}", + "preprocessor_config.json": f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/preprocessor_config.json", + } + ), ) self.tokenizer_file = "tokenizer" self.embedding_type = embedding_type @@ -59,7 +64,11 @@ class JinaV2Embedding(BaseEmbedding): self._call_lock = threading.Lock() # download the model and tokenizer - files_names = list(self.download_urls.keys()) + [self.tokenizer_file] + files_names = ( + [self.tokenizer_file] + if use_axengine + else list(self.download_urls.keys()) + [self.tokenizer_file] + ) if not all( os.path.exists(os.path.join(self.download_path, n)) for n in files_names ): diff --git a/frigate/embeddings/onnx/jina_v2_embedding_ax.py b/frigate/embeddings/onnx/jina_v2_embedding_ax.py deleted file mode 100644 index 1d39ce014..000000000 --- a/frigate/embeddings/onnx/jina_v2_embedding_ax.py +++ /dev/null @@ -1,281 +0,0 @@ -"""AX JinaV2 Embeddings.""" - -import io -import logging -import os -import threading -from typing import Any - -import numpy as np -from PIL import Image -from transformers import AutoTokenizer -from transformers.utils.logging import disable_progress_bar, set_verbosity_error - -from frigate.const import MODEL_CACHE_DIR -from frigate.embeddings.onnx.base_embedding import BaseEmbedding -from frigate.comms.inter_process import InterProcessRequestor -from frigate.util.downloader import ModelDownloader -from frigate.types import ModelStatusTypesEnum -from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE - -import axengine as axe - -# disables the progress bar and download logging for downloading tokenizers and image processors -disable_progress_bar() -set_verbosity_error() -logger = logging.getLogger(__name__) - - -class AXClipRunner: - def __init__(self, image_encoder_path: str, text_encoder_path: str): - self.image_encoder_path = image_encoder_path - self.text_encoder_path = text_encoder_path - self.image_encoder_runner = axe.InferenceSession(image_encoder_path) - self.text_encoder_runner = axe.InferenceSession(text_encoder_path) - - for input in self.image_encoder_runner.get_inputs(): - logger.info(f"{input.name} {input.shape} {input.dtype}") - - for output in self.image_encoder_runner.get_outputs(): - logger.info(f"{output.name} {output.shape} {output.dtype}") - - for input in self.text_encoder_runner.get_inputs(): - logger.info(f"{input.name} {input.shape} {input.dtype}") - - for output in self.text_encoder_runner.get_outputs(): - logger.info(f"{output.name} {output.shape} {output.dtype}") - - def run(self, onnx_inputs): - text_embeddings = [] - image_embeddings = [] - if "input_ids" in onnx_inputs: - for input_ids in onnx_inputs["input_ids"]: - input_ids = input_ids.reshape(1, -1) - text_embeddings.append( - self.text_encoder_runner.run(None, {"inputs_id": input_ids})[0][0] - ) - if "pixel_values" in onnx_inputs: - for pixel_values in onnx_inputs["pixel_values"]: - if len(pixel_values.shape) == 3: - pixel_values = pixel_values[None, ...] - image_embeddings.append( - self.image_encoder_runner.run(None, {"pixel_values": pixel_values})[ - 0 - ][0] - ) - return np.array(text_embeddings), np.array(image_embeddings) - - -class AXJinaV2Embedding(BaseEmbedding): - def __init__( - self, - model_size: str, - requestor: InterProcessRequestor, - device: str = "AUTO", - embedding_type: str = None, - ): - HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") - super().__init__( - model_name="AXERA-TECH/jina-clip-v2", - model_file=None, - download_urls={ - "image_encoder.axmodel": f"{HF_ENDPOINT}/AXERA-TECH/jina-clip-v2/resolve/main/image_encoder.axmodel", - "text_encoder.axmodel": f"{HF_ENDPOINT}/AXERA-TECH/jina-clip-v2/resolve/main/text_encoder.axmodel", - }, - ) - - self.tokenizer_source = "jinaai/jina-clip-v2" - self.tokenizer_file = "tokenizer" - self.embedding_type = embedding_type - self.requestor = requestor - self.model_size = model_size - self.device = device - self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) - self.tokenizer = None - self.image_processor = None - self.runner = None - self.mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32) - self.std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32) - - # Lock to prevent concurrent calls (text and vision share this instance) - self._call_lock = threading.Lock() - - # download the model and tokenizer - files_names = list(self.download_urls.keys()) + [self.tokenizer_file] - if not all( - os.path.exists(os.path.join(self.download_path, n)) for n in files_names - ): - logger.debug(f"starting model download for {self.model_name}") - self.downloader = ModelDownloader( - model_name=self.model_name, - download_path=self.download_path, - file_names=files_names, - download_func=self._download_model, - ) - self.downloader.ensure_model_files() - # Avoid lazy loading in worker threads: block until downloads complete - # and load the model on the main thread during initialization. - self._load_model_and_utils() - else: - self.downloader = None - ModelDownloader.mark_files_state( - self.requestor, - self.model_name, - files_names, - ModelStatusTypesEnum.downloaded, - ) - self._load_model_and_utils() - logger.debug(f"models are already downloaded for {self.model_name}") - - def _download_model(self, path: str): - try: - file_name = os.path.basename(path) - - if file_name in self.download_urls: - ModelDownloader.download_from_url(self.download_urls[file_name], path) - elif file_name == self.tokenizer_file: - tokenizer = AutoTokenizer.from_pretrained( - self.tokenizer_source, - trust_remote_code=True, - cache_dir=os.path.join( - MODEL_CACHE_DIR, self.model_name, "tokenizer" - ), - clean_up_tokenization_spaces=True, - ) - tokenizer.save_pretrained(path) - self.requestor.send_data( - UPDATE_MODEL_STATE, - { - "model": f"{self.model_name}-{file_name}", - "state": ModelStatusTypesEnum.downloaded, - }, - ) - except Exception: - self.requestor.send_data( - UPDATE_MODEL_STATE, - { - "model": f"{self.model_name}-{file_name}", - "state": ModelStatusTypesEnum.error, - }, - ) - - def _load_model_and_utils(self): - if self.runner is None: - if self.downloader: - self.downloader.wait_for_download() - - self.tokenizer = AutoTokenizer.from_pretrained( - self.tokenizer_source, - cache_dir=os.path.join(MODEL_CACHE_DIR, self.model_name, "tokenizer"), - trust_remote_code=True, - clean_up_tokenization_spaces=True, - ) - - self.runner = AXClipRunner( - os.path.join(self.download_path, "image_encoder.axmodel"), - os.path.join(self.download_path, "text_encoder.axmodel"), - ) - - def _preprocess_image(self, image_data: bytes | Image.Image): - """ - Manually preprocess a single image from bytes or PIL.Image to (3, 512, 512). - """ - if isinstance(image_data, bytes): - image = Image.open(io.BytesIO(image_data)) - else: - image = image_data - - if image.mode != "RGB": - image = image.convert("RGB") - - image = image.resize((512, 512), Image.Resampling.LANCZOS) - - # Convert to numpy array, normalize to [0, 1], and transpose to (channels, height, width) - image_array = np.array(image, dtype=np.float32) / 255.0 - # Normalize using mean and std - image_array = (image_array - self.mean) / self.std - - image_array = np.transpose(image_array, (2, 0, 1)) # (H, W, C) -> (C, H, W) - - return image_array - - def _preprocess_inputs(self, raw_inputs): - """ - Preprocess inputs into a list of real input tensors (no dummies). - - For text: Returns list of input_ids. - - For vision: Returns list of pixel_values. - """ - if not isinstance(raw_inputs, list): - raw_inputs = [raw_inputs] - - processed = [] - if self.embedding_type == "text": - for text in raw_inputs: - input_ids = self.tokenizer( - [text], return_tensors="np", padding="max_length", max_length=50 - )["input_ids"] - input_ids = input_ids.astype(np.int32) - processed.append(input_ids) - elif self.embedding_type == "vision": - for img in raw_inputs: - pixel_values = self._preprocess_image(img) - processed.append( - pixel_values[np.newaxis, ...] - ) # Add batch dim: (1, 3, 512, 512) - else: - raise ValueError( - f"Invalid embedding_type: {self.embedding_type}. Must be 'text' or 'vision'." - ) - return processed - - def _postprocess_outputs(self, outputs): - """ - Process ONNX model outputs, truncating each embedding in the array to truncate_dim. - - outputs: NumPy array of embeddings. - - Returns: List of truncated embeddings. - """ - # size of vector in database - truncate_dim = 768 - - # jina v2 defaults to 1024 and uses Matryoshka representation, so - # truncating only causes an extremely minor decrease in retrieval accuracy - if outputs.shape[-1] > truncate_dim: - outputs = outputs[..., :truncate_dim] - - return outputs - - def __call__( - self, inputs: list[str] | list[Image.Image] | list[str], embedding_type=None - ): - # Lock the entire call to prevent race conditions when text and vision - # embeddings are called concurrently from different threads - with self._call_lock: - self.embedding_type = embedding_type - if not self.embedding_type: - raise ValueError( - "embedding_type must be specified either in __init__ or __call__" - ) - - self._load_model_and_utils() - processed = self._preprocess_inputs(inputs) - - # Prepare ONNX inputs with matching batch sizes - onnx_inputs = {} - if self.embedding_type == "text": - onnx_inputs["input_ids"] = np.stack([x[0] for x in processed]) - elif self.embedding_type == "vision": - onnx_inputs["pixel_values"] = np.stack([x[0] for x in processed]) - else: - raise ValueError("Invalid embedding type") - - # Run inference - text_embeddings, image_embeddings = self.runner.run(onnx_inputs) - if self.embedding_type == "text": - embeddings = text_embeddings # text embeddings - elif self.embedding_type == "vision": - embeddings = image_embeddings # image embeddings - else: - raise ValueError("Invalid embedding type") - - embeddings = self._postprocess_outputs(embeddings) - return [embedding for embedding in embeddings] diff --git a/frigate/util/axengine_converter.py b/frigate/util/axengine_converter.py new file mode 100644 index 000000000..ab465df4f --- /dev/null +++ b/frigate/util/axengine_converter.py @@ -0,0 +1,190 @@ +"""AXEngine model loading utility for Frigate.""" + +import logging +import os +import time +from pathlib import Path + +from frigate.comms.inter_process import InterProcessRequestor +from frigate.const import UPDATE_MODEL_STATE +from frigate.types import ModelStatusTypesEnum +from frigate.util.downloader import ModelDownloader +from frigate.util.file import FileLock + +logger = logging.getLogger(__name__) + +AXENGINE_JINA_V2_MODEL = "jina_v2" +AXENGINE_JINA_V2_REPO = "AXERA-TECH/jina-clip-v2" + + +def get_axengine_model_type(model_path: str) -> str | None: + if "jina-clip-v2" in str(model_path): + return AXENGINE_JINA_V2_MODEL + + return None + + +def is_axengine_compatible( + model_path: str, device: str | None, model_type: str | None = None +) -> bool: + if (device or "").upper() != "AXENGINE": + return False + + if not model_type: + model_type = get_axengine_model_type(model_path) + + return model_type == AXENGINE_JINA_V2_MODEL + + +def wait_for_download_completion( + image_model_path: Path, + text_model_path: Path, + lock_path: Path, + timeout: int = 300, +) -> bool: + start_time = time.time() + + while time.time() - start_time < timeout: + if image_model_path.exists() and text_model_path.exists(): + return True + + if not lock_path.exists(): + return image_model_path.exists() and text_model_path.exists() + + time.sleep(1) + + logger.warning("Timeout waiting for AXEngine model files: %s", image_model_path) + return False + + +def auto_convert_model(model_path: str, model_type: str | None = None) -> str | None: + """Prepare AXEngine model files and return the image encoder path.""" + if not is_axengine_compatible(model_path, "AXENGINE", model_type): + return None + + model_dir = Path(model_path).parent + ui_model_key = f"jinaai/jina-clip-v2-{Path(model_path).name}" + ui_preprocessor_key = "jinaai/jina-clip-v2-preprocessor_config.json" + image_model_path = model_dir / "image_encoder.axmodel" + text_model_path = model_dir / "text_encoder.axmodel" + model_repo = os.environ.get("AXENGINE_JINA_V2_REPO", AXENGINE_JINA_V2_REPO) + hf_endpoint = os.environ.get("HF_ENDPOINT", "https://huggingface.co") + requestor = InterProcessRequestor() + + download_targets = { + "image_encoder.axmodel": f"{hf_endpoint}/{model_repo}/resolve/main/image_encoder.axmodel", + "text_encoder.axmodel": f"{hf_endpoint}/{model_repo}/resolve/main/text_encoder.axmodel", + } + + if image_model_path.exists() and text_model_path.exists(): + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": ui_preprocessor_key, + "state": ModelStatusTypesEnum.downloaded, + }, + ) + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": ui_model_key, + "state": ModelStatusTypesEnum.downloaded, + }, + ) + requestor.stop() + return str(image_model_path) + + lock_path = model_dir / ".axengine.download.lock" + lock = FileLock(lock_path, timeout=300, cleanup_stale_on_init=True) + + if lock.acquire(): + try: + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": ui_preprocessor_key, + "state": ModelStatusTypesEnum.downloaded, + }, + ) + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": ui_model_key, + "state": ModelStatusTypesEnum.downloading, + }, + ) + + for file_name, url in download_targets.items(): + target_path = model_dir / file_name + if target_path.exists(): + continue + + target_path.parent.mkdir(parents=True, exist_ok=True) + ModelDownloader.download_from_url(url, str(target_path)) + + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": ui_model_key, + "state": ModelStatusTypesEnum.downloaded, + }, + ) + + return str(image_model_path) + except Exception: + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": ui_model_key, + "state": ModelStatusTypesEnum.error, + }, + ) + logger.exception( + "Failed to prepare AXEngine model files for %s", model_repo + ) + return None + finally: + requestor.stop() + lock.release() + + logger.info("Another process is preparing AXEngine models, waiting for completion") + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": ui_preprocessor_key, + "state": ModelStatusTypesEnum.downloaded, + }, + ) + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": ui_model_key, + "state": ModelStatusTypesEnum.downloading, + }, + ) + requestor.stop() + + if wait_for_download_completion(image_model_path, text_model_path, lock_path): + if image_model_path.exists() and text_model_path.exists(): + requestor = InterProcessRequestor() + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": ui_model_key, + "state": ModelStatusTypesEnum.downloaded, + }, + ) + requestor.stop() + return str(image_model_path) + + logger.error("Timeout waiting for AXEngine model download lock for %s", model_dir) + requestor = InterProcessRequestor() + requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": ui_model_key, + "state": ModelStatusTypesEnum.error, + }, + ) + requestor.stop() + return None diff --git a/web/src/pages/Explore.tsx b/web/src/pages/Explore.tsx index 4ff0a2020..8f50e982e 100644 --- a/web/src/pages/Explore.tsx +++ b/web/src/pages/Explore.tsx @@ -292,13 +292,10 @@ export default function Explore() { const modelVersion = config?.semantic_search.model || "jinav1"; const modelSize = config?.semantic_search.model_size || "small"; - const isAxJinaV2 = modelVersion === "ax_jinav2"; // Text model state const { payload: textModelState } = useModelState( - isAxJinaV2 - ? "AXERA-TECH/jina-clip-v2-text_encoder.axmodel" - : modelVersion === "jinav1" + modelVersion === "jinav1" ? "jinaai/jina-clip-v1-text_model_fp16.onnx" : modelSize === "large" ? "jinaai/jina-clip-v2-model_fp16.onnx" @@ -307,18 +304,14 @@ export default function Explore() { // Tokenizer state const { payload: textTokenizerState } = useModelState( - isAxJinaV2 - ? "AXERA-TECH/jina-clip-v2-tokenizer" - : modelVersion === "jinav1" + modelVersion === "jinav1" ? "jinaai/jina-clip-v1-tokenizer" : "jinaai/jina-clip-v2-tokenizer", ); // Vision model state (same as text model for jinav2) const visionModelFile = - isAxJinaV2 - ? "AXERA-TECH/jina-clip-v2-image_encoder.axmodel" - : modelVersion === "jinav1" + modelVersion === "jinav1" ? modelSize === "large" ? "jinaai/jina-clip-v1-vision_model_fp16.onnx" : "jinaai/jina-clip-v1-vision_model_quantized.onnx" @@ -328,49 +321,13 @@ export default function Explore() { const { payload: visionModelState } = useModelState(visionModelFile); // Preprocessor/feature extractor state - const { payload: visionFeatureExtractorStateRaw } = useModelState( + const { payload: visionFeatureExtractorState } = useModelState( modelVersion === "jinav1" ? "jinaai/jina-clip-v1-preprocessor_config.json" : "jinaai/jina-clip-v2-preprocessor_config.json", ); - - const visionFeatureExtractorState = useMemo(() => { - if (isAxJinaV2) { - return visionModelState ?? "downloading"; - } - return visionFeatureExtractorStateRaw; - }, [isAxJinaV2, visionModelState, visionFeatureExtractorStateRaw]); - - const effectiveTextModelState = useMemo(() => { - if (isAxJinaV2) { - return textModelState ?? "downloading"; - } - return textModelState; - }, [isAxJinaV2, textModelState]); - - const effectiveTextTokenizerState = useMemo(() => { - if (isAxJinaV2) { - return textTokenizerState ?? "downloading"; - } - return textTokenizerState; - }, [isAxJinaV2, textTokenizerState]); - - const effectiveVisionModelState = useMemo(() => { - if (isAxJinaV2) { - return visionModelState ?? "downloading"; - } - return visionModelState; - }, [isAxJinaV2, visionModelState]); - const allModelsLoaded = useMemo(() => { - if (isAxJinaV2) { - return ( - effectiveTextModelState === "downloaded" && - effectiveTextTokenizerState === "downloaded" && - effectiveVisionModelState === "downloaded" - ); - } return ( textModelState === "downloaded" && textTokenizerState === "downloaded" && @@ -378,10 +335,6 @@ export default function Explore() { visionFeatureExtractorState === "downloaded" ); }, [ - isAxJinaV2, - effectiveTextModelState, - effectiveTextTokenizerState, - effectiveVisionModelState, textModelState, textTokenizerState, visionModelState, @@ -405,10 +358,10 @@ export default function Explore() { !defaultViewLoaded || (config?.semantic_search.enabled && (!reindexState || - !(isAxJinaV2 ? effectiveTextModelState : textModelState) || - !(isAxJinaV2 ? effectiveTextTokenizerState : textTokenizerState) || - !(isAxJinaV2 ? effectiveVisionModelState : visionModelState) || - (!isAxJinaV2 && !visionFeatureExtractorState))) + !textModelState || + !textTokenizerState || + !visionModelState || + !visionFeatureExtractorState)) ) { return ( diff --git a/web/src/types/frigateConfig.ts b/web/src/types/frigateConfig.ts index 369160319..94c9ba6e9 100644 --- a/web/src/types/frigateConfig.ts +++ b/web/src/types/frigateConfig.ts @@ -28,7 +28,7 @@ export interface FaceRecognitionConfig { recognition_threshold: number; } -export type SearchModel = "jinav1" | "jinav2" | "ax_jinav2"; +export type SearchModel = "jinav1" | "jinav2"; export type SearchModelSize = "small" | "large"; export interface CameraConfig {