code formatiing

This commit is contained in:
Sergey Krashevich 2023-05-08 05:14:03 +03:00
parent a0264103e4
commit f39c2baa94
No known key found for this signature in database
GPG Key ID: 625171324E7D3856

View File

@ -16,24 +16,28 @@ logger = logging.getLogger(__name__)
DETECTOR_KEY = "mediapipe" DETECTOR_KEY = "mediapipe"
class MediapipeDetectorConfig(BaseDetectorConfig): class MediapipeDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY] type: Literal[DETECTOR_KEY]
#num_threads: int = Field(default=3, title="Number of detection threads")
class Mediapipe(DetectionApi): class Mediapipe(DetectionApi):
type_key = DETECTOR_KEY type_key = DETECTOR_KEY
def __init__(self, detector_config: MediapipeDetectorConfig): def __init__(self, detector_config: MediapipeDetectorConfig):
base_options = python.BaseOptions(model_asset_path=detector_config.model.path) base_options = python.BaseOptions(model_asset_path=detector_config.model.path)
options = vision.ObjectDetectorOptions(base_options=base_options, options = vision.ObjectDetectorOptions(
score_threshold=0.5) base_options=base_options, score_threshold=0.5
)
self.detector = vision.ObjectDetector.create_from_options(options) self.detector = vision.ObjectDetector.create_from_options(options)
self.labels = detector_config.model.merged_labelmap self.labels = detector_config.model.merged_labelmap
self.h = detector_config.model.height self.h = detector_config.model.height
self.w = detector_config.model.width self.w = detector_config.model.width
logger.debug(f"Detection started. Model: {detector_config.model.path}, h: {self.h}, w: {self.w}") logger.debug(
f"Detection started. Model: {detector_config.model.path}, h: {self.h}, w: {self.w}"
)
def get_label_index(self, label_value): def get_label_index(self, label_value):
if label_value.lower() == "truck": if label_value.lower() == "truck":
label_value = "car" label_value = "car"
@ -41,7 +45,7 @@ class Mediapipe(DetectionApi):
if value == label_value.lower(): if value == label_value.lower():
return index return index
return -1 return -1
def detect_raw(self, tensor_input): def detect_raw(self, tensor_input):
# STEP 3: Load the input image. # STEP 3: Load the input image.
image_data = np.squeeze(tensor_input).astype(np.uint8) image_data = np.squeeze(tensor_input).astype(np.uint8)
@ -56,8 +60,10 @@ class Mediapipe(DetectionApi):
for i, detection in enumerate(results.detections): for i, detection in enumerate(results.detections):
if i == 20: if i == 20:
break break
bbox = detection.bounding_box # left, upper, right, and lower pixel coordinate. bbox = (
detection.bounding_box
) # left, upper, right, and lower pixel coordinate.
category = detection.categories[0] category = detection.categories[0]
label = self.get_label_index(category.category_name) label = self.get_label_index(category.category_name)
if label < 0: if label < 0:
@ -68,10 +74,9 @@ class Mediapipe(DetectionApi):
round(category.score, 2), round(category.score, 2),
bbox.origin_y / self.h, bbox.origin_y / self.h,
bbox.origin_x / self.w, bbox.origin_x / self.w,
(bbox.origin_y+bbox.height) / self.h, (bbox.origin_y + bbox.height) / self.h,
(bbox.origin_x+bbox.width) / self.w, (bbox.origin_x + bbox.width) / self.w,
] ]
logger.debug(f"Detection: raw: {detection} result: {detections[i]}") logger.debug(f"Detection: raw: {detection} result: {detections[i]}")
return detections
return detections