Check complex model

This commit is contained in:
Nicolas Mowen 2025-09-17 14:36:07 -06:00
parent 26178444f3
commit a153418934

View File

@ -156,7 +156,7 @@ class CudaGraphRunner(BaseModelRunner):
class OpenVINOModelRunner(BaseModelRunner):
"""OpenVINO model runner that handles inference efficiently."""
def __init__(self, model_path: str, device: str, **kwargs):
def __init__(self, model_path: str, device: str, complex_model: bool, **kwargs):
self.model_path = model_path
self.device = device
@ -180,14 +180,16 @@ class OpenVINOModelRunner(BaseModelRunner):
# Create reusable inference request
self.infer_request = self.compiled_model.create_infer_request()
self.input_tensor: ov.Tensor | None = None
try:
input_shape = self.compiled_model.inputs[0].get_shape()
input_element_type = self.compiled_model.inputs[0].get_element_type()
self.input_tensor = ov.Tensor(input_element_type, input_shape)
except RuntimeError:
# model is complex and has dynamic shape
self.input_tensor = None
if not complex_model:
try:
input_shape = self.compiled_model.inputs[0].get_shape()
input_element_type = self.compiled_model.inputs[0].get_element_type()
self.input_tensor = ov.Tensor(input_element_type, input_shape)
except RuntimeError:
# model is complex and has dynamic shape
raise
def get_input_names(self) -> list[str]:
"""Get input names for the model."""