Use thread lock for JinaV2 in onnxruntime

This commit is contained in:
Nicolas Mowen 2025-12-25 06:08:26 -07:00
parent 735d400beb
commit 6a28b616cb

View File

@ -139,8 +139,31 @@ class ONNXModelRunner(BaseModelRunner):
ModelTypeEnum.dfine.value,
]
def __init__(self, ort: ort.InferenceSession):
@staticmethod
def is_concurrent_model(model_type: str | None) -> bool:
"""Check if model requires thread locking for concurrent inference.
Some models (like JinaV2) share one runner between text and vision embeddings
called from different threads, requiring thread synchronization.
"""
if not model_type:
return False
# Import here to avoid circular imports
from frigate.embeddings.types import EnrichmentModelTypeEnum
return model_type == EnrichmentModelTypeEnum.jina_v2.value
def __init__(self, ort: ort.InferenceSession, model_type: str | None = None):
self.ort = ort
self.model_type = model_type
# Thread lock to prevent concurrent inference (needed for JinaV2 which shares
# one runner between text and vision embeddings called from different threads)
if self.is_concurrent_model(model_type):
self._inference_lock = threading.Lock()
else:
self._inference_lock = None
def get_input_names(self) -> list[str]:
return [input.name for input in self.ort.get_inputs()]
@ -150,6 +173,10 @@ class ONNXModelRunner(BaseModelRunner):
return self.ort.get_inputs()[0].shape[3]
def run(self, input: dict[str, Any]) -> Any | None:
if self._inference_lock:
with self._inference_lock:
return self.ort.run(None, input)
return self.ort.run(None, input)
@ -576,5 +603,6 @@ def get_optimized_runner(
),
providers=providers,
provider_options=options,
)
),
model_type=model_type,
)