Refactor: Reimplement the clip function in axengine

This commit is contained in:
shizhicheng 2026-03-07 01:58:11 +08:00
parent f955e6d661
commit 4eae551341
6 changed files with 369 additions and 359 deletions

View File

@ -10,6 +10,10 @@ from typing import Any
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
from frigate.util.axengine_converter import (
auto_convert_model as auto_load_axengine_model,
)
from frigate.util.axengine_converter import is_axengine_compatible
from frigate.util.model import get_ort_providers from frigate.util.model import get_ort_providers
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
@ -548,12 +552,135 @@ class RKNNModelRunner(BaseModelRunner):
pass pass
class AXEngineModelRunner(BaseModelRunner):
"""Run AXEngine models for embeddings."""
_mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32).reshape(
1, 3, 1, 1
)
_std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32).reshape(
1, 3, 1, 1
)
def __init__(self, model_path: str, model_type: str | None = None):
self.model_path = model_path
self.model_type = model_type
self._inference_lock = threading.Lock()
self.image_session = None
self.text_session = None
self.text_pad_token_id = 0
self._load_model()
def _load_model(self):
try:
import axengine as axe
from transformers import AutoTokenizer
except ImportError:
logger.error("AXEngine is not available")
raise ImportError("AXEngine is not available")
model_dir = os.path.dirname(self.model_path)
image_model_path = os.path.join(model_dir, "image_encoder.axmodel")
text_model_path = os.path.join(model_dir, "text_encoder.axmodel")
tokenizer_path = os.path.join(model_dir, "tokenizer")
self.image_session = axe.InferenceSession(image_model_path)
self.text_session = axe.InferenceSession(text_model_path)
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
trust_remote_code=True,
clean_up_tokenization_spaces=True,
)
if tokenizer.pad_token_id is not None:
self.text_pad_token_id = int(tokenizer.pad_token_id)
except Exception:
logger.warning(
"Failed to load tokenizer from %s for AXEngine padding, using 0",
tokenizer_path,
)
def get_input_names(self) -> list[str]:
return ["input_ids", "pixel_values"]
def get_input_width(self) -> int:
return 512
@staticmethod
def _has_real_text_inputs(inputs: dict[str, Any]) -> bool:
input_ids = inputs.get("input_ids")
if input_ids is None:
return False
if input_ids.ndim < 2:
return False
return input_ids.shape[-1] != 16 or np.any(input_ids)
@staticmethod
def _has_real_image_inputs(inputs: dict[str, Any]) -> bool:
pixel_values = inputs.get("pixel_values")
return pixel_values is not None and np.any(pixel_values)
def _prepare_text_inputs(self, input_ids: np.ndarray) -> np.ndarray:
padded_input_ids = np.full((1, 50), self.text_pad_token_id, dtype=np.int32)
truncated_input_ids = input_ids.reshape(1, -1)[:, :50].astype(np.int32)
padded_input_ids[:, : truncated_input_ids.shape[1]] = truncated_input_ids
return padded_input_ids
@classmethod
def _prepare_pixel_values(cls, pixel_values: np.ndarray) -> np.ndarray:
if len(pixel_values.shape) == 3:
pixel_values = pixel_values[None, ...]
pixel_values = pixel_values.astype(np.float32)
return (pixel_values - cls._mean) / cls._std
def run(self, inputs: dict[str, Any]) -> list[np.ndarray | None]:
outputs: list[np.ndarray | None] = [None, None, None, None]
with self._inference_lock:
if self._has_real_text_inputs(inputs):
text_embeddings = []
for input_ids in inputs["input_ids"]:
text_embeddings.append(
self.text_session.run(
None,
{"inputs_id": self._prepare_text_inputs(input_ids)},
)[0][0]
)
outputs[2] = np.array(text_embeddings)
if self._has_real_image_inputs(inputs):
image_embeddings = []
for pixel_values in inputs["pixel_values"]:
image_embeddings.append(
self.image_session.run(
None,
{"pixel_values": self._prepare_pixel_values(pixel_values)},
)[0][0]
)
outputs[3] = np.array(image_embeddings)
return outputs
def get_optimized_runner( def get_optimized_runner(
model_path: str, device: str | None, model_type: str, **kwargs model_path: str, device: str | None, model_type: str, **kwargs
) -> BaseModelRunner: ) -> BaseModelRunner:
"""Get an optimized runner for the hardware.""" """Get an optimized runner for the hardware."""
device = device or "AUTO" device = device or "AUTO"
if is_axengine_compatible(model_path, device, model_type):
axmodel_path = auto_load_axengine_model(model_path, model_type)
if axmodel_path:
return AXEngineModelRunner(axmodel_path, model_type)
if device != "CPU" and is_rknn_compatible(model_path): if device != "CPU" and is_rknn_compatible(model_path):
rknn_path = auto_convert_model(model_path) rknn_path = auto_convert_model(model_path)

View File

@ -30,7 +30,6 @@ from frigate.util.file import get_event_thumbnail_bytes
from .onnx.jina_v1_embedding import JinaV1ImageEmbedding, JinaV1TextEmbedding from .onnx.jina_v1_embedding import JinaV1ImageEmbedding, JinaV1TextEmbedding
from .onnx.jina_v2_embedding import JinaV2Embedding from .onnx.jina_v2_embedding import JinaV2Embedding
from .onnx.jina_v2_embedding_ax import AXJinaV2Embedding
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -94,10 +93,6 @@ class Embeddings:
# Create tables if they don't exist # Create tables if they don't exist
self.db.create_embeddings_tables() self.db.create_embeddings_tables()
self.has_axengine = any(
d.type == "axengine" for d in self.config.detectors.values()
)
models = self.get_model_definitions() models = self.get_model_definitions()
for model in models: for model in models:
@ -110,22 +105,13 @@ class Embeddings:
) )
if self.config.semantic_search.model == SemanticSearchModelEnum.jinav2: if self.config.semantic_search.model == SemanticSearchModelEnum.jinav2:
if self.has_axengine: # Single JinaV2Embedding instance for both text and vision
# AXJinaV2Embedding instance for both text and vision on Axera NPU self.embedding = JinaV2Embedding(
self.embedding = AXJinaV2Embedding( model_size=self.config.semantic_search.model_size,
model_size=self.config.semantic_search.model_size, requestor=self.requestor,
requestor=self.requestor, device=config.semantic_search.device
) or ("GPU" if config.semantic_search.model_size == "large" else "CPU"),
else: )
# Single JinaV2Embedding instance for both text and vision
self.embedding = JinaV2Embedding(
model_size=self.config.semantic_search.model_size,
requestor=self.requestor,
device=config.semantic_search.device
or (
"GPU" if config.semantic_search.model_size == "large" else "CPU"
),
)
self.text_embedding = lambda input_data: self.embedding( self.text_embedding = lambda input_data: self.embedding(
input_data, embedding_type="text" input_data, embedding_type="text"
) )
@ -152,20 +138,13 @@ class Embeddings:
def get_model_definitions(self): def get_model_definitions(self):
# Version-specific models # Version-specific models
if self.config.semantic_search.model == SemanticSearchModelEnum.jinav2: if self.config.semantic_search.model == SemanticSearchModelEnum.jinav2:
if self.has_axengine: models = [
models = [ "jinaai/jina-clip-v2-tokenizer",
"AXERA-TECH/jina-clip-v2-text_encoder.axmodel", "jinaai/jina-clip-v2-model_fp16.onnx"
"AXERA-TECH/jina-clip-v2-image_encoder.axmodel", if self.config.semantic_search.model_size == "large"
"AXERA-TECH/jina-clip-v2-tokenizer", else "jinaai/jina-clip-v2-model_quantized.onnx",
] "jinaai/jina-clip-v2-preprocessor_config.json",
else: ]
models = [
"jinaai/jina-clip-v2-tokenizer",
"jinaai/jina-clip-v2-model_fp16.onnx"
if self.config.semantic_search.model_size == "large"
else "jinaai/jina-clip-v2-model_quantized.onnx",
"jinaai/jina-clip-v2-preprocessor_config.json",
]
else: # Default to jinav1 else: # Default to jinav1
models = [ models = [
"jinaai/jina-clip-v1-text_model_fp16.onnx", "jinaai/jina-clip-v1-text_model_fp16.onnx",

View File

@ -37,13 +37,18 @@ class JinaV2Embedding(BaseEmbedding):
"model_fp16.onnx" if model_size == "large" else "model_quantized.onnx" "model_fp16.onnx" if model_size == "large" else "model_quantized.onnx"
) )
HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
use_axengine = (device or "").upper() == "AXENGINE"
super().__init__( super().__init__(
model_name="jinaai/jina-clip-v2", model_name="jinaai/jina-clip-v2",
model_file=model_file, model_file=model_file,
download_urls={ download_urls=(
model_file: f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/onnx/{model_file}", {}
"preprocessor_config.json": f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/preprocessor_config.json", if use_axengine
}, else {
model_file: f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/onnx/{model_file}",
"preprocessor_config.json": f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/preprocessor_config.json",
}
),
) )
self.tokenizer_file = "tokenizer" self.tokenizer_file = "tokenizer"
self.embedding_type = embedding_type self.embedding_type = embedding_type
@ -59,7 +64,11 @@ class JinaV2Embedding(BaseEmbedding):
self._call_lock = threading.Lock() self._call_lock = threading.Lock()
# download the model and tokenizer # download the model and tokenizer
files_names = list(self.download_urls.keys()) + [self.tokenizer_file] files_names = (
[self.tokenizer_file]
if use_axengine
else list(self.download_urls.keys()) + [self.tokenizer_file]
)
if not all( if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names os.path.exists(os.path.join(self.download_path, n)) for n in files_names
): ):

View File

@ -1,278 +0,0 @@
"""AX JinaV2 Embeddings."""
import io
import logging
import os
import threading
import axengine as axe
import numpy as np
from PIL import Image
from transformers import AutoTokenizer
from transformers.utils.logging import disable_progress_bar, set_verbosity_error
from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
from frigate.embeddings.onnx.base_embedding import BaseEmbedding
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
# disables the progress bar and download logging for downloading tokenizers and image processors
disable_progress_bar()
set_verbosity_error()
logger = logging.getLogger(__name__)
class AXClipRunner:
def __init__(self, image_encoder_path: str, text_encoder_path: str):
self.image_encoder_path = image_encoder_path
self.text_encoder_path = text_encoder_path
self.image_encoder_runner = axe.InferenceSession(image_encoder_path)
self.text_encoder_runner = axe.InferenceSession(text_encoder_path)
for input in self.image_encoder_runner.get_inputs():
logger.info(f"{input.name} {input.shape} {input.dtype}")
for output in self.image_encoder_runner.get_outputs():
logger.info(f"{output.name} {output.shape} {output.dtype}")
for input in self.text_encoder_runner.get_inputs():
logger.info(f"{input.name} {input.shape} {input.dtype}")
for output in self.text_encoder_runner.get_outputs():
logger.info(f"{output.name} {output.shape} {output.dtype}")
def run(self, onnx_inputs):
text_embeddings = []
image_embeddings = []
if "input_ids" in onnx_inputs:
for input_ids in onnx_inputs["input_ids"]:
input_ids = input_ids.reshape(1, -1)
text_embeddings.append(
self.text_encoder_runner.run(None, {"inputs_id": input_ids})[0][0]
)
if "pixel_values" in onnx_inputs:
for pixel_values in onnx_inputs["pixel_values"]:
if len(pixel_values.shape) == 3:
pixel_values = pixel_values[None, ...]
image_embeddings.append(
self.image_encoder_runner.run(None, {"pixel_values": pixel_values})[
0
][0]
)
return np.array(text_embeddings), np.array(image_embeddings)
class AXJinaV2Embedding(BaseEmbedding):
def __init__(
self,
model_size: str,
requestor: InterProcessRequestor,
device: str = "AUTO",
embedding_type: str = None,
):
HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
super().__init__(
model_name="AXERA-TECH/jina-clip-v2",
model_file=None,
download_urls={
"image_encoder.axmodel": f"{HF_ENDPOINT}/AXERA-TECH/jina-clip-v2/resolve/main/image_encoder.axmodel",
"text_encoder.axmodel": f"{HF_ENDPOINT}/AXERA-TECH/jina-clip-v2/resolve/main/text_encoder.axmodel",
},
)
self.tokenizer_source = "jinaai/jina-clip-v2"
self.tokenizer_file = "tokenizer"
self.embedding_type = embedding_type
self.requestor = requestor
self.model_size = model_size
self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.tokenizer = None
self.image_processor = None
self.runner = None
self.mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)
self.std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)
# Lock to prevent concurrent calls (text and vision share this instance)
self._call_lock = threading.Lock()
# download the model and tokenizer
files_names = list(self.download_urls.keys()) + [self.tokenizer_file]
if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
):
logger.debug(f"starting model download for {self.model_name}")
self.downloader = ModelDownloader(
model_name=self.model_name,
download_path=self.download_path,
file_names=files_names,
download_func=self._download_model,
)
self.downloader.ensure_model_files()
# Avoid lazy loading in worker threads: block until downloads complete
# and load the model on the main thread during initialization.
self._load_model_and_utils()
else:
self.downloader = None
ModelDownloader.mark_files_state(
self.requestor,
self.model_name,
files_names,
ModelStatusTypesEnum.downloaded,
)
self._load_model_and_utils()
logger.debug(f"models are already downloaded for {self.model_name}")
def _download_model(self, path: str):
try:
file_name = os.path.basename(path)
if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path)
elif file_name == self.tokenizer_file:
tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_source,
trust_remote_code=True,
cache_dir=os.path.join(
MODEL_CACHE_DIR, self.model_name, "tokenizer"
),
clean_up_tokenization_spaces=True,
)
tokenizer.save_pretrained(path)
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.downloaded,
},
)
except Exception:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.error,
},
)
def _load_model_and_utils(self):
if self.runner is None:
if self.downloader:
self.downloader.wait_for_download()
self.tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_source,
cache_dir=os.path.join(MODEL_CACHE_DIR, self.model_name, "tokenizer"),
trust_remote_code=True,
clean_up_tokenization_spaces=True,
)
self.runner = AXClipRunner(
os.path.join(self.download_path, "image_encoder.axmodel"),
os.path.join(self.download_path, "text_encoder.axmodel"),
)
def _preprocess_image(self, image_data: bytes | Image.Image):
"""
Manually preprocess a single image from bytes or PIL.Image to (3, 512, 512).
"""
if isinstance(image_data, bytes):
image = Image.open(io.BytesIO(image_data))
else:
image = image_data
if image.mode != "RGB":
image = image.convert("RGB")
image = image.resize((512, 512), Image.Resampling.LANCZOS)
# Convert to numpy array, normalize to [0, 1], and transpose to (channels, height, width)
image_array = np.array(image, dtype=np.float32) / 255.0
# Normalize using mean and std
image_array = (image_array - self.mean) / self.std
image_array = np.transpose(image_array, (2, 0, 1)) # (H, W, C) -> (C, H, W)
return image_array
def _preprocess_inputs(self, raw_inputs):
"""
Preprocess inputs into a list of real input tensors (no dummies).
- For text: Returns list of input_ids.
- For vision: Returns list of pixel_values.
"""
if not isinstance(raw_inputs, list):
raw_inputs = [raw_inputs]
processed = []
if self.embedding_type == "text":
for text in raw_inputs:
input_ids = self.tokenizer(
[text], return_tensors="np", padding="max_length", max_length=50
)["input_ids"]
input_ids = input_ids.astype(np.int32)
processed.append(input_ids)
elif self.embedding_type == "vision":
for img in raw_inputs:
pixel_values = self._preprocess_image(img)
processed.append(
pixel_values[np.newaxis, ...]
) # Add batch dim: (1, 3, 512, 512)
else:
raise ValueError(
f"Invalid embedding_type: {self.embedding_type}. Must be 'text' or 'vision'."
)
return processed
def _postprocess_outputs(self, outputs):
"""
Process ONNX model outputs, truncating each embedding in the array to truncate_dim.
- outputs: NumPy array of embeddings.
- Returns: List of truncated embeddings.
"""
# size of vector in database
truncate_dim = 768
# jina v2 defaults to 1024 and uses Matryoshka representation, so
# truncating only causes an extremely minor decrease in retrieval accuracy
if outputs.shape[-1] > truncate_dim:
outputs = outputs[..., :truncate_dim]
return outputs
def __call__(
self, inputs: list[str] | list[Image.Image] | list[str], embedding_type=None
):
# Lock the entire call to prevent race conditions when text and vision
# embeddings are called concurrently from different threads
with self._call_lock:
self.embedding_type = embedding_type
if not self.embedding_type:
raise ValueError(
"embedding_type must be specified either in __init__ or __call__"
)
self._load_model_and_utils()
processed = self._preprocess_inputs(inputs)
# Prepare ONNX inputs with matching batch sizes
onnx_inputs = {}
if self.embedding_type == "text":
onnx_inputs["input_ids"] = np.stack([x[0] for x in processed])
elif self.embedding_type == "vision":
onnx_inputs["pixel_values"] = np.stack([x[0] for x in processed])
else:
raise ValueError("Invalid embedding type")
# Run inference
text_embeddings, image_embeddings = self.runner.run(onnx_inputs)
if self.embedding_type == "text":
embeddings = text_embeddings # text embeddings
elif self.embedding_type == "vision":
embeddings = image_embeddings # image embeddings
else:
raise ValueError("Invalid embedding type")
embeddings = self._postprocess_outputs(embeddings)
return [embedding for embedding in embeddings]

View File

@ -0,0 +1,190 @@
"""AXEngine model loading utility for Frigate."""
import logging
import os
import time
from pathlib import Path
from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
from frigate.util.file import FileLock
logger = logging.getLogger(__name__)
AXENGINE_JINA_V2_MODEL = "jina_v2"
AXENGINE_JINA_V2_REPO = "AXERA-TECH/jina-clip-v2"
def get_axengine_model_type(model_path: str) -> str | None:
if "jina-clip-v2" in str(model_path):
return AXENGINE_JINA_V2_MODEL
return None
def is_axengine_compatible(
model_path: str, device: str | None, model_type: str | None = None
) -> bool:
if (device or "").upper() != "AXENGINE":
return False
if not model_type:
model_type = get_axengine_model_type(model_path)
return model_type == AXENGINE_JINA_V2_MODEL
def wait_for_download_completion(
image_model_path: Path,
text_model_path: Path,
lock_path: Path,
timeout: int = 300,
) -> bool:
start_time = time.time()
while time.time() - start_time < timeout:
if image_model_path.exists() and text_model_path.exists():
return True
if not lock_path.exists():
return image_model_path.exists() and text_model_path.exists()
time.sleep(1)
logger.warning("Timeout waiting for AXEngine model files: %s", image_model_path)
return False
def auto_convert_model(model_path: str, model_type: str | None = None) -> str | None:
"""Prepare AXEngine model files and return the image encoder path."""
if not is_axengine_compatible(model_path, "AXENGINE", model_type):
return None
model_dir = Path(model_path).parent
ui_model_key = f"jinaai/jina-clip-v2-{Path(model_path).name}"
ui_preprocessor_key = "jinaai/jina-clip-v2-preprocessor_config.json"
image_model_path = model_dir / "image_encoder.axmodel"
text_model_path = model_dir / "text_encoder.axmodel"
model_repo = os.environ.get("AXENGINE_JINA_V2_REPO", AXENGINE_JINA_V2_REPO)
hf_endpoint = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
requestor = InterProcessRequestor()
download_targets = {
"image_encoder.axmodel": f"{hf_endpoint}/{model_repo}/resolve/main/image_encoder.axmodel",
"text_encoder.axmodel": f"{hf_endpoint}/{model_repo}/resolve/main/text_encoder.axmodel",
}
if image_model_path.exists() and text_model_path.exists():
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_preprocessor_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
requestor.stop()
return str(image_model_path)
lock_path = model_dir / ".axengine.download.lock"
lock = FileLock(lock_path, timeout=300, cleanup_stale_on_init=True)
if lock.acquire():
try:
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_preprocessor_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.downloading,
},
)
for file_name, url in download_targets.items():
target_path = model_dir / file_name
if target_path.exists():
continue
target_path.parent.mkdir(parents=True, exist_ok=True)
ModelDownloader.download_from_url(url, str(target_path))
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
return str(image_model_path)
except Exception:
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.error,
},
)
logger.exception(
"Failed to prepare AXEngine model files for %s", model_repo
)
return None
finally:
requestor.stop()
lock.release()
logger.info("Another process is preparing AXEngine models, waiting for completion")
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_preprocessor_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.downloading,
},
)
requestor.stop()
if wait_for_download_completion(image_model_path, text_model_path, lock_path):
if image_model_path.exists() and text_model_path.exists():
requestor = InterProcessRequestor()
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
requestor.stop()
return str(image_model_path)
logger.error("Timeout waiting for AXEngine model download lock for %s", model_dir)
requestor = InterProcessRequestor()
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.error,
},
)
requestor.stop()
return None

View File

@ -292,46 +292,32 @@ export default function Explore() {
const modelVersion = config?.semantic_search.model || "jinav1"; const modelVersion = config?.semantic_search.model || "jinav1";
const modelSize = config?.semantic_search.model_size || "small"; const modelSize = config?.semantic_search.model_size || "small";
const isAxJinaV2 = useMemo(
() =>
modelVersion === "jinav2" &&
Object.values(
(config?.detectors ?? {}) as Record<string, { type?: string }>,
).some((detector) => detector?.type === "axengine"),
[modelVersion, config?.detectors],
);
// Text model state // Text model state
const { payload: textModelState } = useModelState( const { payload: textModelState } = useModelState(
isAxJinaV2 modelVersion === "jinav1"
? "AXERA-TECH/jina-clip-v2-text_encoder.axmodel" ? "jinaai/jina-clip-v1-text_model_fp16.onnx"
: modelVersion === "jinav1" : modelSize === "large"
? "jinaai/jina-clip-v1-text_model_fp16.onnx" ? "jinaai/jina-clip-v2-model_fp16.onnx"
: modelSize === "large" : "jinaai/jina-clip-v2-model_quantized.onnx",
? "jinaai/jina-clip-v2-model_fp16.onnx"
: "jinaai/jina-clip-v2-model_quantized.onnx",
); );
// Tokenizer state // Tokenizer state
const { payload: textTokenizerState } = useModelState( const { payload: textTokenizerState } = useModelState(
isAxJinaV2 modelVersion === "jinav1"
? "AXERA-TECH/jina-clip-v2-tokenizer" ? "jinaai/jina-clip-v1-tokenizer"
: modelVersion === "jinav1" : "jinaai/jina-clip-v2-tokenizer",
? "jinaai/jina-clip-v1-tokenizer"
: "jinaai/jina-clip-v2-tokenizer",
); );
// Vision model state (same as text model for jinav2) // Vision model state (same as text model for jinav2)
const visionModelFile = const visionModelFile =
isAxJinaV2 modelVersion === "jinav1"
? "AXERA-TECH/jina-clip-v2-image_encoder.axmodel" ? modelSize === "large"
: modelVersion === "jinav1" ? "jinaai/jina-clip-v1-vision_model_fp16.onnx"
? modelSize === "large" : "jinaai/jina-clip-v1-vision_model_quantized.onnx"
? "jinaai/jina-clip-v1-vision_model_fp16.onnx" : modelSize === "large"
: "jinaai/jina-clip-v1-vision_model_quantized.onnx" ? "jinaai/jina-clip-v2-model_fp16.onnx"
: modelSize === "large" : "jinaai/jina-clip-v2-model_quantized.onnx";
? "jinaai/jina-clip-v2-model_fp16.onnx"
: "jinaai/jina-clip-v2-model_quantized.onnx";
const { payload: visionModelState } = useModelState(visionModelFile); const { payload: visionModelState } = useModelState(visionModelFile);
// Preprocessor/feature extractor state // Preprocessor/feature extractor state
@ -346,10 +332,9 @@ export default function Explore() {
textModelState === "downloaded" && textModelState === "downloaded" &&
textTokenizerState === "downloaded" && textTokenizerState === "downloaded" &&
visionModelState === "downloaded" && visionModelState === "downloaded" &&
(isAxJinaV2 || visionFeatureExtractorState === "downloaded") visionFeatureExtractorState === "downloaded"
); );
}, [ }, [
isAxJinaV2,
textModelState, textModelState,
textTokenizerState, textTokenizerState,
visionModelState, visionModelState,
@ -376,7 +361,7 @@ export default function Explore() {
!textModelState || !textModelState ||
!textTokenizerState || !textTokenizerState ||
!visionModelState || !visionModelState ||
(!isAxJinaV2 && !visionFeatureExtractorState))) !visionFeatureExtractorState))
) { ) {
return ( return (
<ActivityIndicator className="absolute left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2" /> <ActivityIndicator className="absolute left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2" />
@ -466,14 +451,12 @@ export default function Explore() {
"exploreIsUnavailable.downloadingModels.setup.visionModel", "exploreIsUnavailable.downloadingModels.setup.visionModel",
)} )}
</div> </div>
{!isAxJinaV2 && ( <div className="flex flex-row items-center justify-center gap-2">
<div className="flex flex-row items-center justify-center gap-2"> {renderModelStateIcon(visionFeatureExtractorState)}
{renderModelStateIcon(visionFeatureExtractorState)} {t(
{t( "exploreIsUnavailable.downloadingModels.setup.visionModelFeatureExtractor",
"exploreIsUnavailable.downloadingModels.setup.visionModelFeatureExtractor", )}
)} </div>
</div>
)}
<div className="flex flex-row items-center justify-center gap-2"> <div className="flex flex-row items-center justify-center gap-2">
{renderModelStateIcon(textModelState)} {renderModelStateIcon(textModelState)}
{t( {t(
@ -490,7 +473,7 @@ export default function Explore() {
{(textModelState === "error" || {(textModelState === "error" ||
textTokenizerState === "error" || textTokenizerState === "error" ||
visionModelState === "error" || visionModelState === "error" ||
(!isAxJinaV2 && visionFeatureExtractorState === "error")) && ( visionFeatureExtractorState === "error") && (
<div className="my-3 max-w-96 text-center text-danger"> <div className="my-3 max-w-96 text-center text-danger">
{t("exploreIsUnavailable.downloadingModels.error")} {t("exploreIsUnavailable.downloadingModels.error")}
</div> </div>