add locks to jina v1 embeddings

protect tokenizer and feature extractor in jina_v1_embedding with per-instance thread lock to avoid the "Already borrowed" RuntimeError during concurrent tokenization
This commit is contained in:
Josh Hawkins 2026-01-14 16:26:25 -06:00
parent 7bc2ef731b
commit 2c9f7a5275

View File

@ -2,6 +2,7 @@
import logging import logging
import os import os
import threading
import warnings import warnings
from transformers import AutoFeatureExtractor, AutoTokenizer from transformers import AutoFeatureExtractor, AutoTokenizer
@ -54,6 +55,7 @@ class JinaV1TextEmbedding(BaseEmbedding):
self.tokenizer = None self.tokenizer = None
self.feature_extractor = None self.feature_extractor = None
self.runner = None self.runner = None
self._lock = threading.Lock()
files_names = list(self.download_urls.keys()) + [self.tokenizer_file] files_names = list(self.download_urls.keys()) + [self.tokenizer_file]
if not all( if not all(
@ -134,17 +136,18 @@ class JinaV1TextEmbedding(BaseEmbedding):
) )
def _preprocess_inputs(self, raw_inputs): def _preprocess_inputs(self, raw_inputs):
max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs) with self._lock:
return [ max_length = max(len(self.tokenizer.encode(text)) for text in raw_inputs)
self.tokenizer( return [
text, self.tokenizer(
padding="max_length", text,
truncation=True, padding="max_length",
max_length=max_length, truncation=True,
return_tensors="np", max_length=max_length,
) return_tensors="np",
for text in raw_inputs )
] for text in raw_inputs
]
class JinaV1ImageEmbedding(BaseEmbedding): class JinaV1ImageEmbedding(BaseEmbedding):
@ -174,6 +177,7 @@ class JinaV1ImageEmbedding(BaseEmbedding):
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name) self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
self.feature_extractor = None self.feature_extractor = None
self.runner: BaseModelRunner | None = None self.runner: BaseModelRunner | None = None
self._lock = threading.Lock()
files_names = list(self.download_urls.keys()) files_names = list(self.download_urls.keys())
if not all( if not all(
os.path.exists(os.path.join(self.download_path, n)) for n in files_names os.path.exists(os.path.join(self.download_path, n)) for n in files_names
@ -216,8 +220,9 @@ class JinaV1ImageEmbedding(BaseEmbedding):
) )
def _preprocess_inputs(self, raw_inputs): def _preprocess_inputs(self, raw_inputs):
processed_images = [self._process_image(img) for img in raw_inputs] with self._lock:
return [ processed_images = [self._process_image(img) for img in raw_inputs]
self.feature_extractor(images=image, return_tensors="np") return [
for image in processed_images self.feature_extractor(images=image, return_tensors="np")
] for image in processed_images
]