Merge pull request #7 from ivanshi1108/dev-ax

Refactor: Replace the ax_jinav2 model type with the axengine detector…
This commit is contained in:
GuoQing Liu 2026-03-07 02:29:35 +08:00 committed by GitHub
commit 148221b9eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 340 additions and 356 deletions

View File

@ -19,7 +19,6 @@ __all__ = [
class SemanticSearchModelEnum(str, Enum):
jinav1 = "jinav1"
jinav2 = "jinav2"
ax_jinav2 = "ax_jinav2"
class EnrichmentsDeviceEnum(str, Enum):

View File

@ -10,6 +10,10 @@ from typing import Any
import numpy as np
import onnxruntime as ort
from frigate.util.axengine_converter import (
auto_convert_model as auto_load_axengine_model,
)
from frigate.util.axengine_converter import is_axengine_compatible
from frigate.util.model import get_ort_providers
from frigate.util.rknn_converter import auto_convert_model, is_rknn_compatible
@ -548,12 +552,135 @@ class RKNNModelRunner(BaseModelRunner):
pass
class AXEngineModelRunner(BaseModelRunner):
"""Run AXEngine models for embeddings."""
_mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32).reshape(
1, 3, 1, 1
)
_std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32).reshape(
1, 3, 1, 1
)
def __init__(self, model_path: str, model_type: str | None = None):
self.model_path = model_path
self.model_type = model_type
self._inference_lock = threading.Lock()
self.image_session = None
self.text_session = None
self.text_pad_token_id = 0
self._load_model()
def _load_model(self):
try:
import axengine as axe
from transformers import AutoTokenizer
except ImportError:
logger.error("AXEngine is not available")
raise ImportError("AXEngine is not available")
model_dir = os.path.dirname(self.model_path)
image_model_path = os.path.join(model_dir, "image_encoder.axmodel")
text_model_path = os.path.join(model_dir, "text_encoder.axmodel")
tokenizer_path = os.path.join(model_dir, "tokenizer")
self.image_session = axe.InferenceSession(image_model_path)
self.text_session = axe.InferenceSession(text_model_path)
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
trust_remote_code=True,
clean_up_tokenization_spaces=True,
)
if tokenizer.pad_token_id is not None:
self.text_pad_token_id = int(tokenizer.pad_token_id)
except Exception:
logger.warning(
"Failed to load tokenizer from %s for AXEngine padding, using 0",
tokenizer_path,
)
def get_input_names(self) -> list[str]:
return ["input_ids", "pixel_values"]
def get_input_width(self) -> int:
return 512
@staticmethod
def _has_real_text_inputs(inputs: dict[str, Any]) -> bool:
input_ids = inputs.get("input_ids")
if input_ids is None:
return False
if input_ids.ndim < 2:
return False
return input_ids.shape[-1] != 16 or np.any(input_ids)
@staticmethod
def _has_real_image_inputs(inputs: dict[str, Any]) -> bool:
pixel_values = inputs.get("pixel_values")
return pixel_values is not None and np.any(pixel_values)
def _prepare_text_inputs(self, input_ids: np.ndarray) -> np.ndarray:
padded_input_ids = np.full((1, 50), self.text_pad_token_id, dtype=np.int32)
truncated_input_ids = input_ids.reshape(1, -1)[:, :50].astype(np.int32)
padded_input_ids[:, : truncated_input_ids.shape[1]] = truncated_input_ids
return padded_input_ids
@classmethod
def _prepare_pixel_values(cls, pixel_values: np.ndarray) -> np.ndarray:
if len(pixel_values.shape) == 3:
pixel_values = pixel_values[None, ...]
pixel_values = pixel_values.astype(np.float32)
return (pixel_values - cls._mean) / cls._std
def run(self, inputs: dict[str, Any]) -> list[np.ndarray | None]:
outputs: list[np.ndarray | None] = [None, None, None, None]
with self._inference_lock:
if self._has_real_text_inputs(inputs):
text_embeddings = []
for input_ids in inputs["input_ids"]:
text_embeddings.append(
self.text_session.run(
None,
{"inputs_id": self._prepare_text_inputs(input_ids)},
)[0][0]
)
outputs[2] = np.array(text_embeddings)
if self._has_real_image_inputs(inputs):
image_embeddings = []
for pixel_values in inputs["pixel_values"]:
image_embeddings.append(
self.image_session.run(
None,
{"pixel_values": self._prepare_pixel_values(pixel_values)},
)[0][0]
)
outputs[3] = np.array(image_embeddings)
return outputs
def get_optimized_runner(
model_path: str, device: str | None, model_type: str, **kwargs
) -> BaseModelRunner:
"""Get an optimized runner for the hardware."""
device = device or "AUTO"
if is_axengine_compatible(model_path, device, model_type):
axmodel_path = auto_load_axengine_model(model_path, model_type)
if axmodel_path:
return AXEngineModelRunner(axmodel_path, model_type)
if device != "CPU" and is_rknn_compatible(model_path):
rknn_path = auto_convert_model(model_path)

View File

@ -30,7 +30,6 @@ from frigate.util.file import get_event_thumbnail_bytes
from .onnx.jina_v1_embedding import JinaV1ImageEmbedding, JinaV1TextEmbedding
from .onnx.jina_v2_embedding import JinaV2Embedding
from .onnx.jina_v2_embedding_ax import AXJinaV2Embedding
logger = logging.getLogger(__name__)
@ -119,18 +118,6 @@ class Embeddings:
self.vision_embedding = lambda input_data: self.embedding(
input_data, embedding_type="vision"
)
elif self.config.semantic_search.model == SemanticSearchModelEnum.ax_jinav2:
# AXJinaV2Embedding instance for both text and vision
self.embedding = AXJinaV2Embedding(
model_size=self.config.semantic_search.model_size,
requestor=self.requestor,
)
self.text_embedding = lambda input_data: self.embedding(
input_data, embedding_type="text"
)
self.vision_embedding = lambda input_data: self.embedding(
input_data, embedding_type="vision"
)
else: # Default to jinav1
self.text_embedding = JinaV1TextEmbedding(
model_size=config.semantic_search.model_size,

View File

@ -37,13 +37,18 @@ class JinaV2Embedding(BaseEmbedding):
"model_fp16.onnx" if model_size == "large" else "model_quantized.onnx"
)
HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
use_axengine = (device or "").upper() == "AXENGINE"
super().__init__(
model_name="jinaai/jina-clip-v2",
model_file=model_file,
download_urls={
model_file: f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/onnx/{model_file}",
"preprocessor_config.json": f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/preprocessor_config.json",
},
download_urls=(
{}
if use_axengine
else {
model_file: f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/onnx/{model_file}",
"preprocessor_config.json": f"{HF_ENDPOINT}/jinaai/jina-clip-v2/resolve/main/preprocessor_config.json",
}
),
)
self.tokenizer_file = "tokenizer"
self.embedding_type = embedding_type
@ -59,7 +64,11 @@ class JinaV2Embedding(BaseEmbedding):
self._call_lock = threading.Lock()
# download the model and tokenizer
files_names = list(self.download_urls.keys()) + [self.tokenizer_file]
files_names = (
[self.tokenizer_file]
if use_axengine
else list(self.download_urls.keys()) + [self.tokenizer_file]
)
if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
):

View File

@ -1,281 +0,0 @@
"""AX JinaV2 Embeddings."""
import io
import logging
import os
import threading
from typing import Any
import numpy as np
from PIL import Image
from transformers import AutoTokenizer
from transformers.utils.logging import disable_progress_bar, set_verbosity_error
from frigate.const import MODEL_CACHE_DIR
from frigate.embeddings.onnx.base_embedding import BaseEmbedding
from frigate.comms.inter_process import InterProcessRequestor
from frigate.util.downloader import ModelDownloader
from frigate.types import ModelStatusTypesEnum
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
import axengine as axe
# disables the progress bar and download logging for downloading tokenizers and image processors
disable_progress_bar()
set_verbosity_error()
logger = logging.getLogger(__name__)
class AXClipRunner:
def __init__(self, image_encoder_path: str, text_encoder_path: str):
self.image_encoder_path = image_encoder_path
self.text_encoder_path = text_encoder_path
self.image_encoder_runner = axe.InferenceSession(image_encoder_path)
self.text_encoder_runner = axe.InferenceSession(text_encoder_path)
for input in self.image_encoder_runner.get_inputs():
logger.info(f"{input.name} {input.shape} {input.dtype}")
for output in self.image_encoder_runner.get_outputs():
logger.info(f"{output.name} {output.shape} {output.dtype}")
for input in self.text_encoder_runner.get_inputs():
logger.info(f"{input.name} {input.shape} {input.dtype}")
for output in self.text_encoder_runner.get_outputs():
logger.info(f"{output.name} {output.shape} {output.dtype}")
def run(self, onnx_inputs):
text_embeddings = []
image_embeddings = []
if "input_ids" in onnx_inputs:
for input_ids in onnx_inputs["input_ids"]:
input_ids = input_ids.reshape(1, -1)
text_embeddings.append(
self.text_encoder_runner.run(None, {"inputs_id": input_ids})[0][0]
)
if "pixel_values" in onnx_inputs:
for pixel_values in onnx_inputs["pixel_values"]:
if len(pixel_values.shape) == 3:
pixel_values = pixel_values[None, ...]
image_embeddings.append(
self.image_encoder_runner.run(None, {"pixel_values": pixel_values})[
0
][0]
)
return np.array(text_embeddings), np.array(image_embeddings)
class AXJinaV2Embedding(BaseEmbedding):
def __init__(
self,
model_size: str,
requestor: InterProcessRequestor,
device: str = "AUTO",
embedding_type: str = None,
):
HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
super().__init__(
model_name="AXERA-TECH/jina-clip-v2",
model_file=None,
download_urls={
"image_encoder.axmodel": f"{HF_ENDPOINT}/AXERA-TECH/jina-clip-v2/resolve/main/image_encoder.axmodel",
"text_encoder.axmodel": f"{HF_ENDPOINT}/AXERA-TECH/jina-clip-v2/resolve/main/text_encoder.axmodel",
},
)
self.tokenizer_source = "jinaai/jina-clip-v2"
self.tokenizer_file = "tokenizer"
self.embedding_type = embedding_type
self.requestor = requestor
self.model_size = model_size
self.device = device
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.tokenizer = None
self.image_processor = None
self.runner = None
self.mean = np.array([0.48145466, 0.4578275, 0.40821073], dtype=np.float32)
self.std = np.array([0.26862954, 0.26130258, 0.27577711], dtype=np.float32)
# Lock to prevent concurrent calls (text and vision share this instance)
self._call_lock = threading.Lock()
# download the model and tokenizer
files_names = list(self.download_urls.keys()) + [self.tokenizer_file]
if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names
):
logger.debug(f"starting model download for {self.model_name}")
self.downloader = ModelDownloader(
model_name=self.model_name,
download_path=self.download_path,
file_names=files_names,
download_func=self._download_model,
)
self.downloader.ensure_model_files()
# Avoid lazy loading in worker threads: block until downloads complete
# and load the model on the main thread during initialization.
self._load_model_and_utils()
else:
self.downloader = None
ModelDownloader.mark_files_state(
self.requestor,
self.model_name,
files_names,
ModelStatusTypesEnum.downloaded,
)
self._load_model_and_utils()
logger.debug(f"models are already downloaded for {self.model_name}")
def _download_model(self, path: str):
try:
file_name = os.path.basename(path)
if file_name in self.download_urls:
ModelDownloader.download_from_url(self.download_urls[file_name], path)
elif file_name == self.tokenizer_file:
tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_source,
trust_remote_code=True,
cache_dir=os.path.join(
MODEL_CACHE_DIR, self.model_name, "tokenizer"
),
clean_up_tokenization_spaces=True,
)
tokenizer.save_pretrained(path)
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.downloaded,
},
)
except Exception:
self.requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": f"{self.model_name}-{file_name}",
"state": ModelStatusTypesEnum.error,
},
)
def _load_model_and_utils(self):
if self.runner is None:
if self.downloader:
self.downloader.wait_for_download()
self.tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_source,
cache_dir=os.path.join(MODEL_CACHE_DIR, self.model_name, "tokenizer"),
trust_remote_code=True,
clean_up_tokenization_spaces=True,
)
self.runner = AXClipRunner(
os.path.join(self.download_path, "image_encoder.axmodel"),
os.path.join(self.download_path, "text_encoder.axmodel"),
)
def _preprocess_image(self, image_data: bytes | Image.Image):
"""
Manually preprocess a single image from bytes or PIL.Image to (3, 512, 512).
"""
if isinstance(image_data, bytes):
image = Image.open(io.BytesIO(image_data))
else:
image = image_data
if image.mode != "RGB":
image = image.convert("RGB")
image = image.resize((512, 512), Image.Resampling.LANCZOS)
# Convert to numpy array, normalize to [0, 1], and transpose to (channels, height, width)
image_array = np.array(image, dtype=np.float32) / 255.0
# Normalize using mean and std
image_array = (image_array - self.mean) / self.std
image_array = np.transpose(image_array, (2, 0, 1)) # (H, W, C) -> (C, H, W)
return image_array
def _preprocess_inputs(self, raw_inputs):
"""
Preprocess inputs into a list of real input tensors (no dummies).
- For text: Returns list of input_ids.
- For vision: Returns list of pixel_values.
"""
if not isinstance(raw_inputs, list):
raw_inputs = [raw_inputs]
processed = []
if self.embedding_type == "text":
for text in raw_inputs:
input_ids = self.tokenizer(
[text], return_tensors="np", padding="max_length", max_length=50
)["input_ids"]
input_ids = input_ids.astype(np.int32)
processed.append(input_ids)
elif self.embedding_type == "vision":
for img in raw_inputs:
pixel_values = self._preprocess_image(img)
processed.append(
pixel_values[np.newaxis, ...]
) # Add batch dim: (1, 3, 512, 512)
else:
raise ValueError(
f"Invalid embedding_type: {self.embedding_type}. Must be 'text' or 'vision'."
)
return processed
def _postprocess_outputs(self, outputs):
"""
Process ONNX model outputs, truncating each embedding in the array to truncate_dim.
- outputs: NumPy array of embeddings.
- Returns: List of truncated embeddings.
"""
# size of vector in database
truncate_dim = 768
# jina v2 defaults to 1024 and uses Matryoshka representation, so
# truncating only causes an extremely minor decrease in retrieval accuracy
if outputs.shape[-1] > truncate_dim:
outputs = outputs[..., :truncate_dim]
return outputs
def __call__(
self, inputs: list[str] | list[Image.Image] | list[str], embedding_type=None
):
# Lock the entire call to prevent race conditions when text and vision
# embeddings are called concurrently from different threads
with self._call_lock:
self.embedding_type = embedding_type
if not self.embedding_type:
raise ValueError(
"embedding_type must be specified either in __init__ or __call__"
)
self._load_model_and_utils()
processed = self._preprocess_inputs(inputs)
# Prepare ONNX inputs with matching batch sizes
onnx_inputs = {}
if self.embedding_type == "text":
onnx_inputs["input_ids"] = np.stack([x[0] for x in processed])
elif self.embedding_type == "vision":
onnx_inputs["pixel_values"] = np.stack([x[0] for x in processed])
else:
raise ValueError("Invalid embedding type")
# Run inference
text_embeddings, image_embeddings = self.runner.run(onnx_inputs)
if self.embedding_type == "text":
embeddings = text_embeddings # text embeddings
elif self.embedding_type == "vision":
embeddings = image_embeddings # image embeddings
else:
raise ValueError("Invalid embedding type")
embeddings = self._postprocess_outputs(embeddings)
return [embedding for embedding in embeddings]

View File

@ -0,0 +1,190 @@
"""AXEngine model loading utility for Frigate."""
import logging
import os
import time
from pathlib import Path
from frigate.comms.inter_process import InterProcessRequestor
from frigate.const import UPDATE_MODEL_STATE
from frigate.types import ModelStatusTypesEnum
from frigate.util.downloader import ModelDownloader
from frigate.util.file import FileLock
logger = logging.getLogger(__name__)
AXENGINE_JINA_V2_MODEL = "jina_v2"
AXENGINE_JINA_V2_REPO = "AXERA-TECH/jina-clip-v2"
def get_axengine_model_type(model_path: str) -> str | None:
if "jina-clip-v2" in str(model_path):
return AXENGINE_JINA_V2_MODEL
return None
def is_axengine_compatible(
model_path: str, device: str | None, model_type: str | None = None
) -> bool:
if (device or "").upper() != "AXENGINE":
return False
if not model_type:
model_type = get_axengine_model_type(model_path)
return model_type == AXENGINE_JINA_V2_MODEL
def wait_for_download_completion(
image_model_path: Path,
text_model_path: Path,
lock_path: Path,
timeout: int = 300,
) -> bool:
start_time = time.time()
while time.time() - start_time < timeout:
if image_model_path.exists() and text_model_path.exists():
return True
if not lock_path.exists():
return image_model_path.exists() and text_model_path.exists()
time.sleep(1)
logger.warning("Timeout waiting for AXEngine model files: %s", image_model_path)
return False
def auto_convert_model(model_path: str, model_type: str | None = None) -> str | None:
"""Prepare AXEngine model files and return the image encoder path."""
if not is_axengine_compatible(model_path, "AXENGINE", model_type):
return None
model_dir = Path(model_path).parent
ui_model_key = f"jinaai/jina-clip-v2-{Path(model_path).name}"
ui_preprocessor_key = "jinaai/jina-clip-v2-preprocessor_config.json"
image_model_path = model_dir / "image_encoder.axmodel"
text_model_path = model_dir / "text_encoder.axmodel"
model_repo = os.environ.get("AXENGINE_JINA_V2_REPO", AXENGINE_JINA_V2_REPO)
hf_endpoint = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
requestor = InterProcessRequestor()
download_targets = {
"image_encoder.axmodel": f"{hf_endpoint}/{model_repo}/resolve/main/image_encoder.axmodel",
"text_encoder.axmodel": f"{hf_endpoint}/{model_repo}/resolve/main/text_encoder.axmodel",
}
if image_model_path.exists() and text_model_path.exists():
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_preprocessor_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
requestor.stop()
return str(image_model_path)
lock_path = model_dir / ".axengine.download.lock"
lock = FileLock(lock_path, timeout=300, cleanup_stale_on_init=True)
if lock.acquire():
try:
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_preprocessor_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.downloading,
},
)
for file_name, url in download_targets.items():
target_path = model_dir / file_name
if target_path.exists():
continue
target_path.parent.mkdir(parents=True, exist_ok=True)
ModelDownloader.download_from_url(url, str(target_path))
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
return str(image_model_path)
except Exception:
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.error,
},
)
logger.exception(
"Failed to prepare AXEngine model files for %s", model_repo
)
return None
finally:
requestor.stop()
lock.release()
logger.info("Another process is preparing AXEngine models, waiting for completion")
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_preprocessor_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.downloading,
},
)
requestor.stop()
if wait_for_download_completion(image_model_path, text_model_path, lock_path):
if image_model_path.exists() and text_model_path.exists():
requestor = InterProcessRequestor()
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.downloaded,
},
)
requestor.stop()
return str(image_model_path)
logger.error("Timeout waiting for AXEngine model download lock for %s", model_dir)
requestor = InterProcessRequestor()
requestor.send_data(
UPDATE_MODEL_STATE,
{
"model": ui_model_key,
"state": ModelStatusTypesEnum.error,
},
)
requestor.stop()
return None

View File

@ -292,13 +292,10 @@ export default function Explore() {
const modelVersion = config?.semantic_search.model || "jinav1";
const modelSize = config?.semantic_search.model_size || "small";
const isAxJinaV2 = modelVersion === "ax_jinav2";
// Text model state
const { payload: textModelState } = useModelState(
isAxJinaV2
? "AXERA-TECH/jina-clip-v2-text_encoder.axmodel"
: modelVersion === "jinav1"
modelVersion === "jinav1"
? "jinaai/jina-clip-v1-text_model_fp16.onnx"
: modelSize === "large"
? "jinaai/jina-clip-v2-model_fp16.onnx"
@ -307,18 +304,14 @@ export default function Explore() {
// Tokenizer state
const { payload: textTokenizerState } = useModelState(
isAxJinaV2
? "AXERA-TECH/jina-clip-v2-tokenizer"
: modelVersion === "jinav1"
modelVersion === "jinav1"
? "jinaai/jina-clip-v1-tokenizer"
: "jinaai/jina-clip-v2-tokenizer",
);
// Vision model state (same as text model for jinav2)
const visionModelFile =
isAxJinaV2
? "AXERA-TECH/jina-clip-v2-image_encoder.axmodel"
: modelVersion === "jinav1"
modelVersion === "jinav1"
? modelSize === "large"
? "jinaai/jina-clip-v1-vision_model_fp16.onnx"
: "jinaai/jina-clip-v1-vision_model_quantized.onnx"
@ -328,49 +321,13 @@ export default function Explore() {
const { payload: visionModelState } = useModelState(visionModelFile);
// Preprocessor/feature extractor state
const { payload: visionFeatureExtractorStateRaw } = useModelState(
const { payload: visionFeatureExtractorState } = useModelState(
modelVersion === "jinav1"
? "jinaai/jina-clip-v1-preprocessor_config.json"
: "jinaai/jina-clip-v2-preprocessor_config.json",
);
const visionFeatureExtractorState = useMemo(() => {
if (isAxJinaV2) {
return visionModelState ?? "downloading";
}
return visionFeatureExtractorStateRaw;
}, [isAxJinaV2, visionModelState, visionFeatureExtractorStateRaw]);
const effectiveTextModelState = useMemo<ModelState | undefined>(() => {
if (isAxJinaV2) {
return textModelState ?? "downloading";
}
return textModelState;
}, [isAxJinaV2, textModelState]);
const effectiveTextTokenizerState = useMemo<ModelState | undefined>(() => {
if (isAxJinaV2) {
return textTokenizerState ?? "downloading";
}
return textTokenizerState;
}, [isAxJinaV2, textTokenizerState]);
const effectiveVisionModelState = useMemo<ModelState | undefined>(() => {
if (isAxJinaV2) {
return visionModelState ?? "downloading";
}
return visionModelState;
}, [isAxJinaV2, visionModelState]);
const allModelsLoaded = useMemo(() => {
if (isAxJinaV2) {
return (
effectiveTextModelState === "downloaded" &&
effectiveTextTokenizerState === "downloaded" &&
effectiveVisionModelState === "downloaded"
);
}
return (
textModelState === "downloaded" &&
textTokenizerState === "downloaded" &&
@ -378,10 +335,6 @@ export default function Explore() {
visionFeatureExtractorState === "downloaded"
);
}, [
isAxJinaV2,
effectiveTextModelState,
effectiveTextTokenizerState,
effectiveVisionModelState,
textModelState,
textTokenizerState,
visionModelState,
@ -405,10 +358,10 @@ export default function Explore() {
!defaultViewLoaded ||
(config?.semantic_search.enabled &&
(!reindexState ||
!(isAxJinaV2 ? effectiveTextModelState : textModelState) ||
!(isAxJinaV2 ? effectiveTextTokenizerState : textTokenizerState) ||
!(isAxJinaV2 ? effectiveVisionModelState : visionModelState) ||
(!isAxJinaV2 && !visionFeatureExtractorState)))
!textModelState ||
!textTokenizerState ||
!visionModelState ||
!visionFeatureExtractorState))
) {
return (
<ActivityIndicator className="absolute left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2" />

View File

@ -28,7 +28,7 @@ export interface FaceRecognitionConfig {
recognition_threshold: number;
}
export type SearchModel = "jinav1" | "jinav2" | "ax_jinav2";
export type SearchModel = "jinav1" | "jinav2";
export type SearchModelSize = "small" | "large";
export interface CameraConfig {