Update DeepStack plugin to recognize 'truck' as 'car' for label indexing

This commit is contained in:
Sergey Krashevich 2023-04-23 17:21:45 +03:00
parent 445315e9e3
commit f02a9d6238
No known key found for this signature in database
GPG Key ID: 625171324E7D3856

View File

@ -55,10 +55,12 @@ class DeepStack(DetectionApi):
return labels return labels
def get_label_index(self, label_value): def get_label_index(self, label_value):
if label_value.lower() == 'truck':
label_value = 'car'
for index, value in self.labels.items(): for index, value in self.labels.items():
if value == label_value: if value == label_value.lower():
return index return index
return None return -1
def detect_raw(self, tensor_input): def detect_raw(self, tensor_input):
image_data = np.squeeze(tensor_input).astype(np.uint8) image_data = np.squeeze(tensor_input).astype(np.uint8)
@ -74,8 +76,11 @@ class DeepStack(DetectionApi):
for i, detection in enumerate(response_json["predictions"]): for i, detection in enumerate(response_json["predictions"]):
if detection["confidence"] < 0.4: if detection["confidence"] < 0.4:
break break
label = self.get_label_index(detection["label"])
if label < 0:
break
detections[i] = [ detections[i] = [
int(self.get_label_index(detection["label"])), label,
float(detection["confidence"]), float(detection["confidence"]),
detection["y_min"] / self.h, detection["y_min"] / self.h,
detection["x_min"] / self.w, detection["x_min"] / self.w,