Further refinements

This commit is contained in:
Anil Ozyalcin 2023-01-28 22:52:39 -08:00
parent e237003372
commit bf934a7651

View File

@ -25,18 +25,16 @@ class OvDetector(DetectionApi):
self.ov_model = self.ov_core.read_model(detector_config.model.path) self.ov_model = self.ov_core.read_model(detector_config.model.path)
self.ov_model_type = detector_config.model.model_type self.ov_model_type = detector_config.model.model_type
#self.num_classes = 80 # TODO self.h = detector_config.model.height
self.h = detector_config.model.height # 416 self.w = detector_config.model.width
self.w = detector_config.model.width # 416
if(self.ov_model_type == ModelTypeEnum.yolox):
self.set_strides_grids()
self.interpreter = self.ov_core.compile_model( self.interpreter = self.ov_core.compile_model(
model=self.ov_model, device_name=detector_config.device model=self.ov_model, device_name=detector_config.device
) )
logger.info(f"Model Input Shape: {self.interpreter.input(0).shape}") logger.info(f"Model Input Shape: {self.interpreter.input(0).shape}")
self.output_indexes = 0 self.output_indexes = 0
while True: while True:
try: try:
tensor_shape = self.interpreter.output(self.output_indexes).shape tensor_shape = self.interpreter.output(self.output_indexes).shape
@ -46,9 +44,11 @@ class OvDetector(DetectionApi):
except: except:
logger.info(f"Model has {self.output_indexes} Output Tensors") logger.info(f"Model has {self.output_indexes} Output Tensors")
break break
if(self.ov_model_type == ModelTypeEnum.yolox): if(self.ov_model_type == ModelTypeEnum.yolox):
self.num_classes = tensor_shape[2]-5 self.num_classes = tensor_shape[2]-5
logger.info(f"YOLOX model has {self.num_classes} classes") logger.info(f"YOLOX model has {self.num_classes} classes")
self.set_strides_grids()
def set_strides_grids(self): def set_strides_grids(self):
grids = [] grids = []