Use thread lock for JinaV2 call as it sets multiple internal fields while being called

This commit is contained in:
Nicolas Mowen 2025-12-27 06:28:39 -07:00
parent 3c5eb1aee5
commit 25b36a1a7a

View File

@ -3,6 +3,7 @@
import io import io
import logging import logging
import os import os
import threading
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -53,6 +54,11 @@ class JinaV2Embedding(BaseEmbedding):
self.tokenizer = None self.tokenizer = None
self.image_processor = None self.image_processor = None
self.runner = None self.runner = None
# Lock to prevent concurrent calls (text and vision share this instance)
self._call_lock = threading.Lock()
# download the model and tokenizer
files_names = list(self.download_urls.keys()) + [self.tokenizer_file] files_names = list(self.download_urls.keys()) + [self.tokenizer_file]
if not all( if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names os.path.exists(os.path.join(self.download_path, n)) for n in files_names
@ -200,37 +206,40 @@ class JinaV2Embedding(BaseEmbedding):
def __call__( def __call__(
self, inputs: list[str] | list[Image.Image] | list[str], embedding_type=None self, inputs: list[str] | list[Image.Image] | list[str], embedding_type=None
) -> list[np.ndarray]: ) -> list[np.ndarray]:
self.embedding_type = embedding_type # Lock the entire call to prevent race conditions when text and vision
if not self.embedding_type: # embeddings are called concurrently from different threads
raise ValueError( with self._call_lock:
"embedding_type must be specified either in __init__ or __call__" self.embedding_type = embedding_type
) if not self.embedding_type:
raise ValueError(
"embedding_type must be specified either in __init__ or __call__"
)
self._load_model_and_utils() self._load_model_and_utils()
processed = self._preprocess_inputs(inputs) processed = self._preprocess_inputs(inputs)
batch_size = len(processed) batch_size = len(processed)
# Prepare ONNX inputs with matching batch sizes # Prepare ONNX inputs with matching batch sizes
onnx_inputs = {} onnx_inputs = {}
if self.embedding_type == "text": if self.embedding_type == "text":
onnx_inputs["input_ids"] = np.stack([x[0] for x in processed]) onnx_inputs["input_ids"] = np.stack([x[0] for x in processed])
onnx_inputs["pixel_values"] = np.zeros( onnx_inputs["pixel_values"] = np.zeros(
(batch_size, 3, 512, 512), dtype=np.float32 (batch_size, 3, 512, 512), dtype=np.float32
) )
elif self.embedding_type == "vision": elif self.embedding_type == "vision":
onnx_inputs["input_ids"] = np.zeros((batch_size, 16), dtype=np.int64) onnx_inputs["input_ids"] = np.zeros((batch_size, 16), dtype=np.int64)
onnx_inputs["pixel_values"] = np.stack([x[0] for x in processed]) onnx_inputs["pixel_values"] = np.stack([x[0] for x in processed])
else: else:
raise ValueError("Invalid embedding type") raise ValueError("Invalid embedding type")
# Run inference # Run inference
outputs = self.runner.run(onnx_inputs) outputs = self.runner.run(onnx_inputs)
if self.embedding_type == "text": if self.embedding_type == "text":
embeddings = outputs[2] # text embeddings embeddings = outputs[2] # text embeddings
elif self.embedding_type == "vision": elif self.embedding_type == "vision":
embeddings = outputs[3] # image embeddings embeddings = outputs[3] # image embeddings
else: else:
raise ValueError("Invalid embedding type") raise ValueError("Invalid embedding type")
embeddings = self._postprocess_outputs(embeddings) embeddings = self._postprocess_outputs(embeddings)
return [embedding for embedding in embeddings] return [embedding for embedding in embeddings]