This commit is contained in:
alec-groff 2024-07-23 16:11:48 -03:00 committed by GitHub
commit 03837bf1e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 8 deletions

View File

@ -51,6 +51,9 @@ class ModelConfig(BaseModel):
model_type: ModelTypeEnum = Field(
default=ModelTypeEnum.ssd, title="Object Detection Model Type"
)
tfl_detector_output_tensor_order: list[int] = Field(
default=[0,1,2,3], title="Order Output Tensors of TFL models [0=boxes,1=scores,2=class_ids,3=count]"
)
_merged_labelmap: Optional[Dict[int, str]] = PrivateAttr()
_colormap: Dict[int, Tuple[int, int, int]] = PrivateAttr()
_model_hash: str = PrivateAttr()

View File

@ -37,15 +37,17 @@ class CpuTfl(DetectionApi):
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()
self.tfl_detector_output_tensor_order = detector_config.model.tfl_detector_output_tensor_order
def detect_raw(self, tensor_input):
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input)
self.interpreter.invoke()
boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0]
class_ids = self.interpreter.tensor(self.tensor_output_details[1]["index"])()[0]
scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[0]
boxes = self.interpreter.tensor(self.tensor_output_details[self.tfl_detector_output_tensor_order[0]]["index"])()[0]
class_ids = self.interpreter.tensor(self.tensor_output_details[self.tfl_detector_output_tensor_order[1]]["index"])()[0]
scores = self.interpreter.tensor(self.tensor_output_details[self.tfl_detector_output_tensor_order[2]]["index"])()[0]
count = int(
self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0]
self.interpreter.tensor(self.tensor_output_details[self.tfl_detector_output_tensor_order[3]]["index"])()[0]
)
detections = np.zeros((20, 6), np.float32)

View File

@ -44,6 +44,7 @@ class EdgeTpuTfl(DetectionApi):
model_path=detector_config.model.path,
experimental_delegates=[edge_tpu_delegate],
)
self.tfl_detector_output_tensor_order = detector_config.model.tfl_detector_output_tensor_order
except ValueError:
logger.error(
"No EdgeTPU was detected. If you do not have a Coral device yet, you must configure CPU detectors."
@ -60,11 +61,11 @@ class EdgeTpuTfl(DetectionApi):
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input)
self.interpreter.invoke()
boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0]
class_ids = self.interpreter.tensor(self.tensor_output_details[1]["index"])()[0]
scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[0]
boxes = self.interpreter.tensor(self.tensor_output_details[self.tfl_detector_output_tensor_order[0]]["index"])()[0]
class_ids = self.interpreter.tensor(self.tensor_output_details[self.tfl_detector_output_tensor_order[1]]["index"])()[0]
scores = self.interpreter.tensor(self.tensor_output_details[self.tfl_detector_output_tensor_order[2]]["index"])()[0]
count = int(
self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0]
self.interpreter.tensor(self.tensor_output_details[self.tfl_detector_output_tensor_order[3]]["index"])()[0]
)
detections = np.zeros((20, 6), np.float32)