diff --git a/.github/workflows/ax.yml b/.github/workflows/ax.yml new file mode 100644 index 000000000..82721eb3c --- /dev/null +++ b/.github/workflows/ax.yml @@ -0,0 +1,143 @@ +name: AXERA + +on: + workflow_dispatch: + +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true + +env: + PYTHON_VERSION: 3.9 + +jobs: + x86_axcl_builds: + runs-on: ubuntu-22.04 + name: x86_AXCL Build + steps: + - name: Check out code + uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Set x86_AXCL_TAG + run: echo "x86_AXCL_TAG=x86-axcl-${GITHUB_SHA:0:7}" >> $GITHUB_ENV + + - name: Set Version + run: make version + + - name: Build + uses: docker/bake-action@v6 + with: + source: . + push: false + targets: x86-axcl + files: docker/axcl/x86-axcl.hcl + no-cache: true + set: | + x86-axcl.tags=frigate:${{ env.x86_AXCL_TAG }} + + - name: Clean up disk space + run: | + docker system prune -f + + - name: Save Docker image as tar file + run: | + docker save frigate:${{ env.x86_AXCL_TAG }} -o frigate-${{ env.x86_AXCL_TAG }}.tar + ls -lh frigate-${{ env.x86_AXCL_TAG }}.tar + + - name: Upload Docker image artifact + uses: actions/upload-artifact@v4 + with: + name: x86-axcl-docker-image + path: frigate-${{ env.x86_AXCL_TAG }}.tar + retention-days: 7 + + rk_axcl_builds: + runs-on: ubuntu-22.04-arm + name: rk_AXCL Build + steps: + - name: Check out code + uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Set RK_AXCL_TAG + run: echo "RK_AXCL_TAG=rk-axcl-${GITHUB_SHA:0:7}" >> $GITHUB_ENV + + - name: Set Version + run: make version + + - name: Build + uses: docker/bake-action@v6 + with: + source: . + push: false + targets: rk-axcl + files: | + docker/rockchip/rk.hcl + docker/axcl/rk-axcl.hcl + no-cache: true + set: | + rk-axcl.tags=frigate:${{ env.RK_AXCL_TAG }} + + - name: Clean up disk space + run: | + docker system prune -f + + - name: Save Docker image as tar file + run: | + docker save frigate:${{ env.RK_AXCL_TAG }} -o frigate-${{ env.RK_AXCL_TAG }}.tar + ls -lh frigate-${{ env.RK_AXCL_TAG }}.tar + + - name: Upload Docker image artifact + uses: actions/upload-artifact@v4 + with: + name: rk-axcl-docker-image + path: frigate-${{ env.RK_AXCL_TAG }}.tar + retention-days: 7 + + + rpi_axcl_builds: + runs-on: ubuntu-22.04-arm + name: RPi_AXCL Build + steps: + - name: Check out code + uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Set RPi_AXCL_TAG + run: echo "RPi_AXCL_TAG=rpi-axcl-${GITHUB_SHA:0:7}" >> $GITHUB_ENV + + - name: Set Version + run: make version + + - name: Build + uses: docker/bake-action@v6 + with: + source: . + push: false + targets: rpi-axcl + files: | + docker/rpi/rpi.hcl + docker/axcl/rpi-axcl.hcl + no-cache: true + set: | + rpi-axcl.tags=frigate:${{ env.RPi_AXCL_TAG }} + + - name: Clean up disk space + run: | + docker system prune -f + + - name: Save Docker image as tar file + run: | + docker save frigate:${{ env.RPi_AXCL_TAG }} -o frigate-${{ env.RPi_AXCL_TAG }}.tar + ls -lh frigate-${{ env.RPi_AXCL_TAG }}.tar + + - name: Upload Docker image artifact + uses: actions/upload-artifact@v4 + with: + name: rpi-axcl-docker-image + path: frigate-${{ env.RPi_AXCL_TAG }}.tar + retention-days: 7 diff --git a/docker/axcl/Dockerfile b/docker/axcl/Dockerfile index 83271bce8..e67046055 100644 --- a/docker/axcl/Dockerfile +++ b/docker/axcl/Dockerfile @@ -6,7 +6,6 @@ ARG DEBIAN_FRONTEND=noninteractive # Globally set pip break-system-packages option to avoid having to specify it every time ARG PIP_BREAK_SYSTEM_PACKAGES=1 - FROM frigate AS frigate-axcl ARG TARGETARCH ARG PIP_BREAK_SYSTEM_PACKAGES @@ -16,35 +15,6 @@ RUN wget https://github.com/AXERA-TECH/pyaxengine/releases/download/0.1.3.rc1/ax RUN pip3 install -i https://mirrors.aliyun.com/pypi/simple/ /axengine-0.1.3-py3-none-any.whl \ && rm /axengine-0.1.3-py3-none-any.whl -# Install axcl -RUN if [ "$TARGETARCH" = "amd64" ]; then \ - echo "Installing x86_64 version of axcl"; \ - wget https://github.com/ivanshi1108/assets/releases/download/v0.16.2/axcl_host_x86_64_V3.6.5_20250908154509_NO4973.deb -O /axcl.deb; \ - else \ - echo "Installing aarch64 version of axcl"; \ - wget https://github.com/ivanshi1108/assets/releases/download/v0.16.2/axcl_host_aarch64_V3.6.5_20250908154509_NO4973.deb -O /axcl.deb; \ - fi - -RUN mkdir /unpack_axcl && \ - dpkg-deb -x /axcl.deb /unpack_axcl && \ - cp -R /unpack_axcl/usr/bin/axcl /usr/bin/ && \ - cp -R /unpack_axcl/usr/lib/axcl /usr/lib/ && \ - rm -rf /unpack_axcl /axcl.deb - - -# Install axcl ffmpeg -RUN mkdir -p /usr/lib/ffmpeg/axcl - -RUN if [ "$TARGETARCH" = "amd64" ]; then \ - wget https://github.com/ivanshi1108/assets/releases/download/v0.16.2/ffmpeg-x64 -O /usr/lib/ffmpeg/axcl/ffmpeg && \ - wget https://github.com/ivanshi1108/assets/releases/download/v0.16.2/ffprobe-x64 -O /usr/lib/ffmpeg/axcl/ffprobe; \ - else \ - wget https://github.com/ivanshi1108/assets/releases/download/v0.16.2/ffmpeg-aarch64 -O /usr/lib/ffmpeg/axcl/ffmpeg && \ - wget https://github.com/ivanshi1108/assets/releases/download/v0.16.2/ffprobe-aarch64 -O /usr/lib/ffmpeg/axcl/ffprobe; \ - fi - -RUN chmod +x /usr/lib/ffmpeg/axcl/ffmpeg /usr/lib/ffmpeg/axcl/ffprobe - # Set ldconfig path RUN echo "/usr/lib/axcl" > /etc/ld.so.conf.d/ax.conf diff --git a/docker/axcl/rk-axcl.hcl b/docker/axcl/rk-axcl.hcl new file mode 100644 index 000000000..eea2bd93d --- /dev/null +++ b/docker/axcl/rk-axcl.hcl @@ -0,0 +1,7 @@ +target rk-axcl { + dockerfile = "docker/axcl/Dockerfile" + contexts = { + frigate = "target:rk", + } + platforms = ["linux/arm64"] +} \ No newline at end of file diff --git a/docker/axcl/rpi-axcl.hcl b/docker/axcl/rpi-axcl.hcl new file mode 100644 index 000000000..72cdc71c0 --- /dev/null +++ b/docker/axcl/rpi-axcl.hcl @@ -0,0 +1,7 @@ +target rpi-axcl { + dockerfile = "docker/axcl/Dockerfile" + contexts = { + frigate = "target:rpi", + } + platforms = ["linux/arm64"] +} \ No newline at end of file diff --git a/docker/axcl/user_installation.sh b/docker/axcl/user_installation.sh index e053a5faf..4a36a99e1 100755 --- a/docker/axcl/user_installation.sh +++ b/docker/axcl/user_installation.sh @@ -1,14 +1,25 @@ #!/bin/bash +set -e + +# Function to clean up on error +cleanup() { + echo "Cleaning up temporary files..." + rm -f "$deb_file" +} + +trap cleanup ERR +trap 'echo "Script interrupted by user (Ctrl+C)"; cleanup; exit 130' INT + # Update package list and install dependencies +echo "Updating package list and installing dependencies..." sudo apt-get update sudo apt-get install -y build-essential cmake git wget pciutils kmod udev # Check if gcc-12 is needed +echo "Checking GCC version..." current_gcc_version=$(gcc --version | head -n1 | awk '{print $NF}') -gcc_major_version=$(echo $current_gcc_version | cut -d'.' -f1) - -if [[ $gcc_major_version -lt 12 ]]; then +if ! dpkg --compare-versions "$current_gcc_version" ge "12" 2>/dev/null; then echo "Current GCC version ($current_gcc_version) is lower than 12, installing gcc-12..." sudo apt-get install -y gcc-12 sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12 @@ -18,26 +29,37 @@ else fi # Determine architecture +echo "Determining system architecture..." arch=$(uname -m) download_url="" if [[ $arch == "x86_64" ]]; then - download_url="https://github.com/ivanshi1108/assets/releases/download/v0.16.2/axcl_host_x86_64_V3.6.5_20250908154509_NO4973.deb" - deb_file="axcl_host_x86_64_V3.6.5_20250908154509_NO4973.deb" + download_url="https://github.com/ivanshi1108/assets/releases/download/v0.17/axcl_host_x86_64_V3.10.2_20251111020143_NO5046.deb" + deb_file="axcl.deb" elif [[ $arch == "aarch64" ]]; then - download_url="https://github.com/ivanshi1108/assets/releases/download/v0.16.2/axcl_host_aarch64_V3.6.5_20250908154509_NO4973.deb" - deb_file="axcl_host_aarch64_V3.6.5_20250908154509_NO4973.deb" + download_url="https://github.com/ivanshi1108/assets/releases/download/v0.17/axcl_host_aarch64_V3.10.2_20251111020143_NO5046.deb" + deb_file="axcl.deb" else echo "Unsupported architecture: $arch" exit 1 fi +# Check for required Linux headers before downloading +echo "Checking for required Linux headers..." +kernel_version=$(uname -r) +if dpkg -l | grep -q "linux-headers-${kernel_version}" || [ -d "/lib/modules/${kernel_version}/build" ]; then + echo "Linux headers or kernel modules directory found for kernel ${kernel_version}/build." +else + echo "Linux headers for kernel ${kernel_version} not found. Please install them first: sudo apt-get install linux-headers-${kernel_version}" + exit 1 +fi + # Download AXCL driver echo "Downloading AXCL driver for $arch..." -wget "$download_url" -O "$deb_file" +wget --timeout=30 --tries=3 "$download_url" -O "$deb_file" if [ $? -ne 0 ]; then - echo "Failed to download AXCL driver" + echo "Failed to download AXCL driver after retries" exit 1 fi @@ -51,7 +73,7 @@ if [ $? -ne 0 ]; then sudo dpkg -i "$deb_file" if [ $? -ne 0 ]; then - echo "AXCL driver installation failed" + echo "AXCL driver installation failed after dependency fix" exit 1 fi fi @@ -80,4 +102,9 @@ if command -v axcl-smi &> /dev/null; then else echo "axcl-smi command not found. AXCL driver installation may have failed." exit 1 -fi \ No newline at end of file +fi + +# Clean up +echo "Cleaning up temporary files..." +rm -f "$deb_file" +echo "Installation script completed." \ No newline at end of file diff --git a/docker/axcl/x86-axcl.hcl b/docker/axcl/x86-axcl.hcl new file mode 100644 index 000000000..78546be1a --- /dev/null +++ b/docker/axcl/x86-axcl.hcl @@ -0,0 +1,13 @@ +target frigate { + dockerfile = "docker/main/Dockerfile" + platforms = ["linux/amd64"] + target = "frigate" +} + +target x86-axcl { + dockerfile = "docker/axcl/Dockerfile" + contexts = { + frigate = "target:frigate", + } + platforms = ["linux/amd64"] +} \ No newline at end of file diff --git a/frigate/config/classification.py b/frigate/config/classification.py index fb8e3de29..9d5b16561 100644 --- a/frigate/config/classification.py +++ b/frigate/config/classification.py @@ -19,6 +19,7 @@ __all__ = [ class SemanticSearchModelEnum(str, Enum): jinav1 = "jinav1" jinav2 = "jinav2" + ax_jinav2 = "ax_jinav2" class EnrichmentsDeviceEnum(str, Enum): diff --git a/frigate/embeddings/embeddings.py b/frigate/embeddings/embeddings.py index 8d7bcd235..835986a58 100644 --- a/frigate/embeddings/embeddings.py +++ b/frigate/embeddings/embeddings.py @@ -30,6 +30,7 @@ from frigate.util.file import get_event_thumbnail_bytes from .onnx.jina_v1_embedding import JinaV1ImageEmbedding, JinaV1TextEmbedding from .onnx.jina_v2_embedding import JinaV2Embedding +from .onnx.jina_v2_embedding_ax import AXJinaV2Embedding logger = logging.getLogger(__name__) @@ -118,6 +119,18 @@ class Embeddings: self.vision_embedding = lambda input_data: self.embedding( input_data, embedding_type="vision" ) + elif self.config.semantic_search.model == SemanticSearchModelEnum.ax_jinav2: + # AXJinaV2Embedding instance for both text and vision + self.embedding = AXJinaV2Embedding( + model_size=self.config.semantic_search.model_size, + requestor=self.requestor, + ) + self.text_embedding = lambda input_data: self.embedding( + input_data, embedding_type="text" + ) + self.vision_embedding = lambda input_data: self.embedding( + input_data, embedding_type="vision" + ) else: # Default to jinav1 self.text_embedding = JinaV1TextEmbedding( model_size=config.semantic_search.model_size, diff --git a/frigate/embeddings/onnx/jina_v2_embedding_ax.py b/frigate/embeddings/onnx/jina_v2_embedding_ax.py new file mode 100644 index 000000000..1d39ce014 --- /dev/null +++ b/frigate/embeddings/onnx/jina_v2_embedding_ax.py @@ -0,0 +1,281 @@ +"""AX JinaV2 Embeddings.""" + +import io +import logging +import os +import threading +from typing import Any + +import numpy as np +from PIL import Image +from transformers import AutoTokenizer +from transformers.utils.logging import disable_progress_bar, set_verbosity_error + +from frigate.const import MODEL_CACHE_DIR +from frigate.embeddings.onnx.base_embedding import BaseEmbedding +from frigate.comms.inter_process import InterProcessRequestor +from frigate.util.downloader import ModelDownloader +from frigate.types import ModelStatusTypesEnum +from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE + +import axengine as axe + +# disables the progress bar and download logging for downloading tokenizers and image processors +disable_progress_bar() +set_verbosity_error() +logger = logging.getLogger(__name__) + + +class AXClipRunner: + def __init__(self, image_encoder_path: str, text_encoder_path: str): + self.image_encoder_path = image_encoder_path + self.text_encoder_path = text_encoder_path + self.image_encoder_runner = axe.InferenceSession(image_encoder_path) + self.text_encoder_runner = axe.InferenceSession(text_encoder_path) + + for input in self.image_encoder_runner.get_inputs(): + logger.info(f"{input.name} {input.shape} {input.dtype}") + + for output in self.image_encoder_runner.get_outputs(): + logger.info(f"{output.name} {output.shape} {output.dtype}") + + for input in self.text_encoder_runner.get_inputs(): + logger.info(f"{input.name} {input.shape} {input.dtype}") + + for output in self.text_encoder_runner.get_outputs(): + logger.info(f"{output.name} {output.shape} {output.dtype}") + + def run(self, onnx_inputs): + text_embeddings = [] + image_embeddings = [] + if "input_ids" in onnx_inputs: + for input_ids in onnx_inputs["input_ids"]: + input_ids = input_ids.reshape(1, -1) + text_embeddings.append( + self.text_encoder_runner.run(None, {"inputs_id": input_ids})[0][0] + ) + if "pixel_values" in onnx_inputs: + for pixel_values in onnx_inputs["pixel_values"]: + if len(pixel_values.shape) == 3: + pixel_values = pixel_values[None, ...] + image_embeddings.append( + self.image_encoder_runner.run(None, {"pixel_values": pixel_values})[ + 0 + ][0] + ) + return np.array(text_embeddings), np.array(image_embeddings) + + +class AXJinaV2Embedding(BaseEmbedding): + def __init__( + self, + model_size: str, + requestor: InterProcessRequestor, + device: str = "AUTO", + embedding_type: str = None, + ): + HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") + super().__init__( + model_name="AXERA-TECH/jina-clip-v2", + model_file=None, + download_urls={ + "image_encoder.axmodel": f"{HF_ENDPOINT}/AXERA-TECH/jina-clip-v2/resolve/main/image_encoder.axmodel", + "text_encoder.axmodel": f"{HF_ENDPOINT}/AXERA-TECH/jina-clip-v2/resolve/main/text_encoder.axmodel", + }, + ) + + self.tokenizer_source = "jinaai/jina-clip-v2" + self.tokenizer_file = "tokenizer" + self.embedding_type = embedding_type + self.requestor = requestor + self.model_size = model_size + self.device = device + self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) + self.tokenizer = None + self.image_processor = None + self.runner = None + self.mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32) + self.std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32) + + # Lock to prevent concurrent calls (text and vision share this instance) + self._call_lock = threading.Lock() + + # download the model and tokenizer + files_names = list(self.download_urls.keys()) + [self.tokenizer_file] + if not all( + os.path.exists(os.path.join(self.download_path, n)) for n in files_names + ): + logger.debug(f"starting model download for {self.model_name}") + self.downloader = ModelDownloader( + model_name=self.model_name, + download_path=self.download_path, + file_names=files_names, + download_func=self._download_model, + ) + self.downloader.ensure_model_files() + # Avoid lazy loading in worker threads: block until downloads complete + # and load the model on the main thread during initialization. + self._load_model_and_utils() + else: + self.downloader = None + ModelDownloader.mark_files_state( + self.requestor, + self.model_name, + files_names, + ModelStatusTypesEnum.downloaded, + ) + self._load_model_and_utils() + logger.debug(f"models are already downloaded for {self.model_name}") + + def _download_model(self, path: str): + try: + file_name = os.path.basename(path) + + if file_name in self.download_urls: + ModelDownloader.download_from_url(self.download_urls[file_name], path) + elif file_name == self.tokenizer_file: + tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_source, + trust_remote_code=True, + cache_dir=os.path.join( + MODEL_CACHE_DIR, self.model_name, "tokenizer" + ), + clean_up_tokenization_spaces=True, + ) + tokenizer.save_pretrained(path) + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{self.model_name}-{file_name}", + "state": ModelStatusTypesEnum.downloaded, + }, + ) + except Exception: + self.requestor.send_data( + UPDATE_MODEL_STATE, + { + "model": f"{self.model_name}-{file_name}", + "state": ModelStatusTypesEnum.error, + }, + ) + + def _load_model_and_utils(self): + if self.runner is None: + if self.downloader: + self.downloader.wait_for_download() + + self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_source, + cache_dir=os.path.join(MODEL_CACHE_DIR, self.model_name, "tokenizer"), + trust_remote_code=True, + clean_up_tokenization_spaces=True, + ) + + self.runner = AXClipRunner( + os.path.join(self.download_path, "image_encoder.axmodel"), + os.path.join(self.download_path, "text_encoder.axmodel"), + ) + + def _preprocess_image(self, image_data: bytes | Image.Image): + """ + Manually preprocess a single image from bytes or PIL.Image to (3, 512, 512). + """ + if isinstance(image_data, bytes): + image = Image.open(io.BytesIO(image_data)) + else: + image = image_data + + if image.mode != "RGB": + image = image.convert("RGB") + + image = image.resize((512, 512), Image.Resampling.LANCZOS) + + # Convert to numpy array, normalize to [0, 1], and transpose to (channels, height, width) + image_array = np.array(image, dtype=np.float32) / 255.0 + # Normalize using mean and std + image_array = (image_array - self.mean) / self.std + + image_array = np.transpose(image_array, (2, 0, 1)) # (H, W, C) -> (C, H, W) + + return image_array + + def _preprocess_inputs(self, raw_inputs): + """ + Preprocess inputs into a list of real input tensors (no dummies). + - For text: Returns list of input_ids. + - For vision: Returns list of pixel_values. + """ + if not isinstance(raw_inputs, list): + raw_inputs = [raw_inputs] + + processed = [] + if self.embedding_type == "text": + for text in raw_inputs: + input_ids = self.tokenizer( + [text], return_tensors="np", padding="max_length", max_length=50 + )["input_ids"] + input_ids = input_ids.astype(np.int32) + processed.append(input_ids) + elif self.embedding_type == "vision": + for img in raw_inputs: + pixel_values = self._preprocess_image(img) + processed.append( + pixel_values[np.newaxis, ...] + ) # Add batch dim: (1, 3, 512, 512) + else: + raise ValueError( + f"Invalid embedding_type: {self.embedding_type}. Must be 'text' or 'vision'." + ) + return processed + + def _postprocess_outputs(self, outputs): + """ + Process ONNX model outputs, truncating each embedding in the array to truncate_dim. + - outputs: NumPy array of embeddings. + - Returns: List of truncated embeddings. + """ + # size of vector in database + truncate_dim = 768 + + # jina v2 defaults to 1024 and uses Matryoshka representation, so + # truncating only causes an extremely minor decrease in retrieval accuracy + if outputs.shape[-1] > truncate_dim: + outputs = outputs[..., :truncate_dim] + + return outputs + + def __call__( + self, inputs: list[str] | list[Image.Image] | list[str], embedding_type=None + ): + # Lock the entire call to prevent race conditions when text and vision + # embeddings are called concurrently from different threads + with self._call_lock: + self.embedding_type = embedding_type + if not self.embedding_type: + raise ValueError( + "embedding_type must be specified either in __init__ or __call__" + ) + + self._load_model_and_utils() + processed = self._preprocess_inputs(inputs) + + # Prepare ONNX inputs with matching batch sizes + onnx_inputs = {} + if self.embedding_type == "text": + onnx_inputs["input_ids"] = np.stack([x[0] for x in processed]) + elif self.embedding_type == "vision": + onnx_inputs["pixel_values"] = np.stack([x[0] for x in processed]) + else: + raise ValueError("Invalid embedding type") + + # Run inference + text_embeddings, image_embeddings = self.runner.run(onnx_inputs) + if self.embedding_type == "text": + embeddings = text_embeddings # text embeddings + elif self.embedding_type == "vision": + embeddings = image_embeddings # image embeddings + else: + raise ValueError("Invalid embedding type") + + embeddings = self._postprocess_outputs(embeddings) + return [embedding for embedding in embeddings] diff --git a/web/src/pages/Explore.tsx b/web/src/pages/Explore.tsx index 8f50e982e..4ff0a2020 100644 --- a/web/src/pages/Explore.tsx +++ b/web/src/pages/Explore.tsx @@ -292,10 +292,13 @@ export default function Explore() { const modelVersion = config?.semantic_search.model || "jinav1"; const modelSize = config?.semantic_search.model_size || "small"; + const isAxJinaV2 = modelVersion === "ax_jinav2"; // Text model state const { payload: textModelState } = useModelState( - modelVersion === "jinav1" + isAxJinaV2 + ? "AXERA-TECH/jina-clip-v2-text_encoder.axmodel" + : modelVersion === "jinav1" ? "jinaai/jina-clip-v1-text_model_fp16.onnx" : modelSize === "large" ? "jinaai/jina-clip-v2-model_fp16.onnx" @@ -304,14 +307,18 @@ export default function Explore() { // Tokenizer state const { payload: textTokenizerState } = useModelState( - modelVersion === "jinav1" + isAxJinaV2 + ? "AXERA-TECH/jina-clip-v2-tokenizer" + : modelVersion === "jinav1" ? "jinaai/jina-clip-v1-tokenizer" : "jinaai/jina-clip-v2-tokenizer", ); // Vision model state (same as text model for jinav2) const visionModelFile = - modelVersion === "jinav1" + isAxJinaV2 + ? "AXERA-TECH/jina-clip-v2-image_encoder.axmodel" + : modelVersion === "jinav1" ? modelSize === "large" ? "jinaai/jina-clip-v1-vision_model_fp16.onnx" : "jinaai/jina-clip-v1-vision_model_quantized.onnx" @@ -321,13 +328,49 @@ export default function Explore() { const { payload: visionModelState } = useModelState(visionModelFile); // Preprocessor/feature extractor state - const { payload: visionFeatureExtractorState } = useModelState( + const { payload: visionFeatureExtractorStateRaw } = useModelState( modelVersion === "jinav1" ? "jinaai/jina-clip-v1-preprocessor_config.json" : "jinaai/jina-clip-v2-preprocessor_config.json", ); + + const visionFeatureExtractorState = useMemo(() => { + if (isAxJinaV2) { + return visionModelState ?? "downloading"; + } + return visionFeatureExtractorStateRaw; + }, [isAxJinaV2, visionModelState, visionFeatureExtractorStateRaw]); + + const effectiveTextModelState = useMemo(() => { + if (isAxJinaV2) { + return textModelState ?? "downloading"; + } + return textModelState; + }, [isAxJinaV2, textModelState]); + + const effectiveTextTokenizerState = useMemo(() => { + if (isAxJinaV2) { + return textTokenizerState ?? "downloading"; + } + return textTokenizerState; + }, [isAxJinaV2, textTokenizerState]); + + const effectiveVisionModelState = useMemo(() => { + if (isAxJinaV2) { + return visionModelState ?? "downloading"; + } + return visionModelState; + }, [isAxJinaV2, visionModelState]); + const allModelsLoaded = useMemo(() => { + if (isAxJinaV2) { + return ( + effectiveTextModelState === "downloaded" && + effectiveTextTokenizerState === "downloaded" && + effectiveVisionModelState === "downloaded" + ); + } return ( textModelState === "downloaded" && textTokenizerState === "downloaded" && @@ -335,6 +378,10 @@ export default function Explore() { visionFeatureExtractorState === "downloaded" ); }, [ + isAxJinaV2, + effectiveTextModelState, + effectiveTextTokenizerState, + effectiveVisionModelState, textModelState, textTokenizerState, visionModelState, @@ -358,10 +405,10 @@ export default function Explore() { !defaultViewLoaded || (config?.semantic_search.enabled && (!reindexState || - !textModelState || - !textTokenizerState || - !visionModelState || - !visionFeatureExtractorState)) + !(isAxJinaV2 ? effectiveTextModelState : textModelState) || + !(isAxJinaV2 ? effectiveTextTokenizerState : textTokenizerState) || + !(isAxJinaV2 ? effectiveVisionModelState : visionModelState) || + (!isAxJinaV2 && !visionFeatureExtractorState))) ) { return ( diff --git a/web/src/types/frigateConfig.ts b/web/src/types/frigateConfig.ts index 94c9ba6e9..369160319 100644 --- a/web/src/types/frigateConfig.ts +++ b/web/src/types/frigateConfig.ts @@ -28,7 +28,7 @@ export interface FaceRecognitionConfig { recognition_threshold: number; } -export type SearchModel = "jinav1" | "jinav2"; +export type SearchModel = "jinav1" | "jinav2" | "ax_jinav2"; export type SearchModelSize = "small" | "large"; export interface CameraConfig {