Add capability to dynamically determine number of classes in yolox model

This commit is contained in:
Anil Ozyalcin 2023-01-28 22:46:35 -08:00
parent 7884042709
commit e237003372

View File

@ -17,7 +17,6 @@ class OvDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY]
device: str = Field(default=None, title="Device Type")
class OvDetector(DetectionApi):
type_key = DETECTOR_KEY
@ -26,10 +25,10 @@ class OvDetector(DetectionApi):
self.ov_model = self.ov_core.read_model(detector_config.model.path)
self.ov_model_type = detector_config.model.model_type
self.num_classes = 80 # TODO
#self.num_classes = 80 # TODO
self.h = detector_config.model.height # 416
self.w = detector_config.model.width # 416
logger.info(self.ov_model_type)
if(self.ov_model_type == ModelTypeEnum.yolox):
self.set_strides_grids()
@ -42,10 +41,14 @@ class OvDetector(DetectionApi):
try:
tensor_shape = self.interpreter.output(self.output_indexes).shape
logger.info(f"Model Output-{self.output_indexes} Shape: {tensor_shape}")
logger.info(f"Model Output-{self.output_indexes} Shape: {tensor_shape}")
self.output_indexes += 1
except:
logger.info(f"Model has {self.output_indexes} Output Tensors")
break
if(self.ov_model_type == ModelTypeEnum.yolox):
self.num_classes = tensor_shape[2]-5
logger.info(f"YOLOX model has {self.num_classes} classes")
def set_strides_grids(self):
grids = []