code formatiing

This commit is contained in:
Sergey Krashevich 2023-05-08 05:14:03 +03:00
parent a0264103e4
commit f39c2baa94
No known key found for this signature in database
GPG Key ID: 625171324E7D3856

View File

@ -16,23 +16,27 @@ logger = logging.getLogger(__name__)
DETECTOR_KEY = "mediapipe" DETECTOR_KEY = "mediapipe"
class MediapipeDetectorConfig(BaseDetectorConfig): class MediapipeDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY] type: Literal[DETECTOR_KEY]
#num_threads: int = Field(default=3, title="Number of detection threads")
class Mediapipe(DetectionApi): class Mediapipe(DetectionApi):
type_key = DETECTOR_KEY type_key = DETECTOR_KEY
def __init__(self, detector_config: MediapipeDetectorConfig): def __init__(self, detector_config: MediapipeDetectorConfig):
base_options = python.BaseOptions(model_asset_path=detector_config.model.path) base_options = python.BaseOptions(model_asset_path=detector_config.model.path)
options = vision.ObjectDetectorOptions(base_options=base_options, options = vision.ObjectDetectorOptions(
score_threshold=0.5) base_options=base_options, score_threshold=0.5
)
self.detector = vision.ObjectDetector.create_from_options(options) self.detector = vision.ObjectDetector.create_from_options(options)
self.labels = detector_config.model.merged_labelmap self.labels = detector_config.model.merged_labelmap
self.h = detector_config.model.height self.h = detector_config.model.height
self.w = detector_config.model.width self.w = detector_config.model.width
logger.debug(f"Detection started. Model: {detector_config.model.path}, h: {self.h}, w: {self.w}") logger.debug(
f"Detection started. Model: {detector_config.model.path}, h: {self.h}, w: {self.w}"
)
def get_label_index(self, label_value): def get_label_index(self, label_value):
if label_value.lower() == "truck": if label_value.lower() == "truck":
@ -57,7 +61,9 @@ class Mediapipe(DetectionApi):
if i == 20: if i == 20:
break break
bbox = detection.bounding_box # left, upper, right, and lower pixel coordinate. bbox = (
detection.bounding_box
) # left, upper, right, and lower pixel coordinate.
category = detection.categories[0] category = detection.categories[0]
label = self.get_label_index(category.category_name) label = self.get_label_index(category.category_name)
if label < 0: if label < 0:
@ -73,5 +79,4 @@ class Mediapipe(DetectionApi):
] ]
logger.debug(f"Detection: raw: {detection} result: {detections[i]}") logger.debug(f"Detection: raw: {detection} result: {detections[i]}")
return detections return detections