mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-05 10:45:21 +03:00
Add mediapipe detector plugin and update requirements-wheels.txt with mediapipe library
This commit is contained in:
parent
e3b9998879
commit
a0264103e4
@ -88,6 +88,7 @@ FROM wget AS models
|
||||
# Get model and labels
|
||||
RUN wget -qO edgetpu_model.tflite https://github.com/google-coral/test_data/raw/release-frogfish/ssdlite_mobiledet_coco_qat_postprocess_edgetpu.tflite
|
||||
RUN wget -qO cpu_model.tflite https://github.com/google-coral/test_data/raw/release-frogfish/ssdlite_mobiledet_coco_qat_postprocess.tflite
|
||||
RUN wget -qO efficientdet_lite2_fp32.tflite https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite2_fp32.tflite
|
||||
COPY labelmap.txt .
|
||||
# Copy OpenVino model
|
||||
COPY --from=ov-converter /models/public/ssdlite_mobilenet_v2/FP16 openvino-model
|
||||
@ -126,6 +127,7 @@ RUN apt-get -qq update \
|
||||
gfortran openexr libatlas-base-dev libssl-dev\
|
||||
libtbb2 libtbb-dev libdc1394-22-dev libopenexr-dev \
|
||||
libgstreamer-plugins-base1.0-dev libgstreamer1.0-dev \
|
||||
libgl1 libgl1-dev \
|
||||
# scipy dependencies
|
||||
gcc gfortran libopenblas-dev liblapack-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
77
frigate/detectors/plugins/mediapipe.py
Normal file
77
frigate/detectors/plugins/mediapipe.py
Normal file
@ -0,0 +1,77 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from frigate.detectors.detector_config import BaseDetectorConfig
|
||||
from typing import Literal
|
||||
from pydantic import Extra, Field
|
||||
|
||||
import numpy as np
|
||||
import mediapipe as mp
|
||||
from mediapipe.tasks import python
|
||||
from mediapipe.tasks.python import vision
|
||||
import cv2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DETECTOR_KEY = "mediapipe"
|
||||
|
||||
class MediapipeDetectorConfig(BaseDetectorConfig):
|
||||
type: Literal[DETECTOR_KEY]
|
||||
#num_threads: int = Field(default=3, title="Number of detection threads")
|
||||
|
||||
class Mediapipe(DetectionApi):
|
||||
type_key = DETECTOR_KEY
|
||||
|
||||
def __init__(self, detector_config: MediapipeDetectorConfig):
|
||||
base_options = python.BaseOptions(model_asset_path=detector_config.model.path)
|
||||
options = vision.ObjectDetectorOptions(base_options=base_options,
|
||||
score_threshold=0.5)
|
||||
self.detector = vision.ObjectDetector.create_from_options(options)
|
||||
self.labels = detector_config.model.merged_labelmap
|
||||
|
||||
self.h = detector_config.model.height
|
||||
self.w = detector_config.model.width
|
||||
logger.debug(f"Detection started. Model: {detector_config.model.path}, h: {self.h}, w: {self.w}")
|
||||
|
||||
def get_label_index(self, label_value):
|
||||
if label_value.lower() == "truck":
|
||||
label_value = "car"
|
||||
for index, value in self.labels.items():
|
||||
if value == label_value.lower():
|
||||
return index
|
||||
return -1
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
# STEP 3: Load the input image.
|
||||
image_data = np.squeeze(tensor_input).astype(np.uint8)
|
||||
image_data_resized = cv2.resize(image_data, (self.w, self.h))
|
||||
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_data_resized)
|
||||
|
||||
# STEP 4: Detect objects in the input image.
|
||||
results = self.detector.detect(image)
|
||||
|
||||
detections = np.zeros((20, 6), np.float32)
|
||||
|
||||
for i, detection in enumerate(results.detections):
|
||||
if i == 20:
|
||||
break
|
||||
|
||||
bbox = detection.bounding_box # left, upper, right, and lower pixel coordinate.
|
||||
category = detection.categories[0]
|
||||
label = self.get_label_index(category.category_name)
|
||||
if label < 0:
|
||||
logger.debug(f"Break due to unknown label")
|
||||
break
|
||||
detections[i] = [
|
||||
label,
|
||||
round(category.score, 2),
|
||||
bbox.origin_y / self.h,
|
||||
bbox.origin_x / self.w,
|
||||
(bbox.origin_y+bbox.height) / self.h,
|
||||
(bbox.origin_x+bbox.width) / self.w,
|
||||
]
|
||||
logger.debug(f"Detection: raw: {detection} result: {detections[i]}")
|
||||
|
||||
|
||||
return detections
|
||||
@ -2,6 +2,7 @@ click == 8.1.*
|
||||
Flask == 2.3.*
|
||||
imutils == 0.5.*
|
||||
matplotlib == 3.7.*
|
||||
mediapipe == 0.9.*
|
||||
mypy == 0.942
|
||||
numpy == 1.23.*
|
||||
onvif_zeep == 0.2.12
|
||||
|
||||
Loading…
Reference in New Issue
Block a user