Improve openvino width detection

This commit is contained in:
Nicolas Mowen 2025-09-16 15:00:50 -06:00
parent c0c9099616
commit 7b24e01509
2 changed files with 30 additions and 7 deletions

View File

@ -133,8 +133,12 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
x:x2,
]
if frame.shape != (224, 224):
frame = cv2.resize(frame, (224, 224))
if input.shape != (224, 224):
try:
input = cv2.resize(input, (224, 224))
except Exception:
logger.warning("Failed to resize image for state classification")
return
input = np.expand_dims(frame, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
@ -254,8 +258,12 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
x:x2,
]
if crop.shape != (224, 224):
crop = cv2.resize(crop, (224, 224))
if input.shape != (224, 224):
try:
input = cv2.resize(input, (224, 224))
except Exception:
logger.warning("Failed to resize image for object classification")
return
input = np.expand_dims(crop, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)

View File

@ -195,9 +195,24 @@ class OpenVINOModelRunner(BaseModelRunner):
def get_input_width(self) -> int:
"""Get the input width of the model."""
input_shape = self.compiled_model.inputs[0].get_shape()
# Assuming NCHW format, width is the last dimension
return int(input_shape[-1])
input_info = self.compiled_model.inputs
first_input = input_info[0]
try:
partial_shape = first_input.get_partial_shape()
# width dimension
if len(partial_shape) >= 4 and partial_shape[3].is_static:
return partial_shape[3].get_length()
# If width is dynamic or we can't determine it
return -1
except Exception:
try:
# gemini says some ov versions might still allow this
input_shape = first_input.shape
return input_shape[3] if len(input_shape) >= 4 else -1
except Exception:
return -1
def run(self, inputs: dict[str, Any]) -> list[np.ndarray]:
"""Run inference with the model.