Add mediapipe detector plugin and update requirements-wheels.txt with mediapipe library

This commit is contained in:
Sergey Krashevich 2023-05-08 03:47:49 +03:00
parent e3b9998879
commit a0264103e4
No known key found for this signature in database
GPG Key ID: 625171324E7D3856
3 changed files with 80 additions and 0 deletions

View File

@ -88,6 +88,7 @@ FROM wget AS models
# Get model and labels
RUN wget -qO edgetpu_model.tflite https://github.com/google-coral/test_data/raw/release-frogfish/ssdlite_mobiledet_coco_qat_postprocess_edgetpu.tflite
RUN wget -qO cpu_model.tflite https://github.com/google-coral/test_data/raw/release-frogfish/ssdlite_mobiledet_coco_qat_postprocess.tflite
RUN wget -qO efficientdet_lite2_fp32.tflite https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite2_fp32.tflite
COPY labelmap.txt .
# Copy OpenVino model
COPY --from=ov-converter /models/public/ssdlite_mobilenet_v2/FP16 openvino-model
@ -126,6 +127,7 @@ RUN apt-get -qq update \
gfortran openexr libatlas-base-dev libssl-dev\
libtbb2 libtbb-dev libdc1394-22-dev libopenexr-dev \
libgstreamer-plugins-base1.0-dev libgstreamer1.0-dev \
libgl1 libgl1-dev \
# scipy dependencies
gcc gfortran libopenblas-dev liblapack-dev && \
rm -rf /var/lib/apt/lists/*

View File

@ -0,0 +1,77 @@
import logging
import numpy as np
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig
from typing import Literal
from pydantic import Extra, Field
import numpy as np
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import cv2
logger = logging.getLogger(__name__)
DETECTOR_KEY = "mediapipe"
class MediapipeDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY]
#num_threads: int = Field(default=3, title="Number of detection threads")
class Mediapipe(DetectionApi):
type_key = DETECTOR_KEY
def __init__(self, detector_config: MediapipeDetectorConfig):
base_options = python.BaseOptions(model_asset_path=detector_config.model.path)
options = vision.ObjectDetectorOptions(base_options=base_options,
score_threshold=0.5)
self.detector = vision.ObjectDetector.create_from_options(options)
self.labels = detector_config.model.merged_labelmap
self.h = detector_config.model.height
self.w = detector_config.model.width
logger.debug(f"Detection started. Model: {detector_config.model.path}, h: {self.h}, w: {self.w}")
def get_label_index(self, label_value):
if label_value.lower() == "truck":
label_value = "car"
for index, value in self.labels.items():
if value == label_value.lower():
return index
return -1
def detect_raw(self, tensor_input):
# STEP 3: Load the input image.
image_data = np.squeeze(tensor_input).astype(np.uint8)
image_data_resized = cv2.resize(image_data, (self.w, self.h))
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_data_resized)
# STEP 4: Detect objects in the input image.
results = self.detector.detect(image)
detections = np.zeros((20, 6), np.float32)
for i, detection in enumerate(results.detections):
if i == 20:
break
bbox = detection.bounding_box # left, upper, right, and lower pixel coordinate.
category = detection.categories[0]
label = self.get_label_index(category.category_name)
if label < 0:
logger.debug(f"Break due to unknown label")
break
detections[i] = [
label,
round(category.score, 2),
bbox.origin_y / self.h,
bbox.origin_x / self.w,
(bbox.origin_y+bbox.height) / self.h,
(bbox.origin_x+bbox.width) / self.w,
]
logger.debug(f"Detection: raw: {detection} result: {detections[i]}")
return detections

View File

@ -2,6 +2,7 @@ click == 8.1.*
Flask == 2.3.*
imutils == 0.5.*
matplotlib == 3.7.*
mediapipe == 0.9.*
mypy == 0.942
numpy == 1.23.*
onvif_zeep == 0.2.12