mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-24 07:44:30 +03:00
Improve openvino width detection
This commit is contained in:
parent
c0c9099616
commit
7b24e01509
@ -133,8 +133,12 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
|||||||
x:x2,
|
x:x2,
|
||||||
]
|
]
|
||||||
|
|
||||||
if frame.shape != (224, 224):
|
if input.shape != (224, 224):
|
||||||
frame = cv2.resize(frame, (224, 224))
|
try:
|
||||||
|
input = cv2.resize(input, (224, 224))
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to resize image for state classification")
|
||||||
|
return
|
||||||
|
|
||||||
input = np.expand_dims(frame, axis=0)
|
input = np.expand_dims(frame, axis=0)
|
||||||
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
||||||
@ -254,8 +258,12 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
|||||||
x:x2,
|
x:x2,
|
||||||
]
|
]
|
||||||
|
|
||||||
if crop.shape != (224, 224):
|
if input.shape != (224, 224):
|
||||||
crop = cv2.resize(crop, (224, 224))
|
try:
|
||||||
|
input = cv2.resize(input, (224, 224))
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to resize image for object classification")
|
||||||
|
return
|
||||||
|
|
||||||
input = np.expand_dims(crop, axis=0)
|
input = np.expand_dims(crop, axis=0)
|
||||||
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
||||||
|
|||||||
@ -195,9 +195,24 @@ class OpenVINOModelRunner(BaseModelRunner):
|
|||||||
|
|
||||||
def get_input_width(self) -> int:
|
def get_input_width(self) -> int:
|
||||||
"""Get the input width of the model."""
|
"""Get the input width of the model."""
|
||||||
input_shape = self.compiled_model.inputs[0].get_shape()
|
input_info = self.compiled_model.inputs
|
||||||
# Assuming NCHW format, width is the last dimension
|
first_input = input_info[0]
|
||||||
return int(input_shape[-1])
|
|
||||||
|
try:
|
||||||
|
partial_shape = first_input.get_partial_shape()
|
||||||
|
# width dimension
|
||||||
|
if len(partial_shape) >= 4 and partial_shape[3].is_static:
|
||||||
|
return partial_shape[3].get_length()
|
||||||
|
|
||||||
|
# If width is dynamic or we can't determine it
|
||||||
|
return -1
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
# gemini says some ov versions might still allow this
|
||||||
|
input_shape = first_input.shape
|
||||||
|
return input_shape[3] if len(input_shape) >= 4 else -1
|
||||||
|
except Exception:
|
||||||
|
return -1
|
||||||
|
|
||||||
def run(self, inputs: dict[str, Any]) -> list[np.ndarray]:
|
def run(self, inputs: dict[str, Any]) -> list[np.ndarray]:
|
||||||
"""Run inference with the model.
|
"""Run inference with the model.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user