Further refinements

This commit is contained in:
Anil Ozyalcin 2023-01-28 22:52:39 -08:00
parent e237003372
commit bf934a7651

View File

@ -25,18 +25,16 @@ class OvDetector(DetectionApi):
self.ov_model = self.ov_core.read_model(detector_config.model.path)
self.ov_model_type = detector_config.model.model_type
#self.num_classes = 80 # TODO
self.h = detector_config.model.height # 416
self.w = detector_config.model.width # 416
if(self.ov_model_type == ModelTypeEnum.yolox):
self.set_strides_grids()
self.h = detector_config.model.height
self.w = detector_config.model.width
self.interpreter = self.ov_core.compile_model(
model=self.ov_model, device_name=detector_config.device
)
logger.info(f"Model Input Shape: {self.interpreter.input(0).shape}")
self.output_indexes = 0
while True:
try:
tensor_shape = self.interpreter.output(self.output_indexes).shape
@ -46,9 +44,11 @@ class OvDetector(DetectionApi):
except:
logger.info(f"Model has {self.output_indexes} Output Tensors")
break
if(self.ov_model_type == ModelTypeEnum.yolox):
self.num_classes = tensor_shape[2]-5
logger.info(f"YOLOX model has {self.num_classes} classes")
self.set_strides_grids()
def set_strides_grids(self):
grids = []