mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
refactor detectors
This commit is contained in:
parent
b1ec56de29
commit
8c55966de4
@ -33,6 +33,7 @@ from frigate.ffmpeg_presets import (
|
||||
parse_preset_output_rtmp,
|
||||
)
|
||||
from frigate.version import VERSION
|
||||
from frigate.detectors import DetectorTypeEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -52,12 +53,6 @@ class FrigateBaseModel(BaseModel):
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
class DetectorTypeEnum(str, Enum):
|
||||
edgetpu = "edgetpu"
|
||||
openvino = "openvino"
|
||||
cpu = "cpu"
|
||||
|
||||
|
||||
class DetectorConfig(FrigateBaseModel):
|
||||
type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type")
|
||||
device: str = Field(default="usb", title="Device Type")
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
import os
|
||||
|
||||
from .detector_type import DetectorTypeEnum
|
||||
from .detection_api import DetectionApi
|
||||
from .cpu_tfl import CpuTfl
|
||||
from .edgetpu_tfl import EdgeTpuTfl
|
||||
from .openvino import OvDetector
|
||||
from .tensorrt import TensorRT
|
||||
@ -1,14 +1,16 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from .detection_api import DetectionApi
|
||||
from .detector_type import DetectorTypeEnum
|
||||
import tflite_runtime.interpreter as tflite
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CpuTfl(DetectionApi):
|
||||
def __init__(self, det_device=None, model_config=None, num_threads=3):
|
||||
def __init__(self, det_device=None, model_config=None, num_threads=3, **kwargs):
|
||||
self.interpreter = tflite.Interpreter(
|
||||
model_path=model_config.path or "/cpu_model.tflite", num_threads=num_threads
|
||||
)
|
||||
@ -44,3 +46,6 @@ class CpuTfl(DetectionApi):
|
||||
]
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
DetectionApi.register_api(DetectorTypeEnum.cpu, CpuTfl)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import logging
|
||||
|
||||
from .detector_type import DetectorTypeEnum
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict
|
||||
|
||||
@ -8,6 +10,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DetectionApi(ABC):
|
||||
_api_types = {}
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, det_device=None, model_config=None):
|
||||
pass
|
||||
@ -15,3 +19,14 @@ class DetectionApi(ABC):
|
||||
@abstractmethod
|
||||
def detect_raw(self, tensor_input):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def register_api(det_type: DetectorTypeEnum, det_api):
|
||||
DetectionApi._api_types[det_type] = det_api
|
||||
|
||||
@staticmethod
|
||||
def create(det_type: DetectorTypeEnum, **kwargs):
|
||||
api = DetectionApi._api_types.get(det_type)
|
||||
if not api:
|
||||
raise ValueError(det_type)
|
||||
return api(**kwargs)
|
||||
|
||||
8
frigate/detectors/detector_type.py
Normal file
8
frigate/detectors/detector_type.py
Normal file
@ -0,0 +1,8 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DetectorTypeEnum(str, Enum):
|
||||
edgetpu = "edgetpu"
|
||||
openvino = "openvino"
|
||||
cpu = "cpu"
|
||||
tensorrt = "tensorrt"
|
||||
@ -1,15 +1,17 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from .detection_api import DetectionApi
|
||||
from .detector_type import DetectorTypeEnum
|
||||
import tflite_runtime.interpreter as tflite
|
||||
from tflite_runtime.interpreter import load_delegate
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EdgeTpuTfl(DetectionApi):
|
||||
def __init__(self, det_device=None, model_config=None):
|
||||
def __init__(self, det_device=None, model_config=None, **kwargs):
|
||||
device_config = {"device": "usb"}
|
||||
if not det_device is None:
|
||||
device_config = {"device": det_device}
|
||||
@ -61,3 +63,6 @@ class EdgeTpuTfl(DetectionApi):
|
||||
]
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
DetectionApi.register_api(DetectorTypeEnum.edgetpu, EdgeTpuTfl)
|
||||
|
||||
@ -2,14 +2,15 @@ import logging
|
||||
import numpy as np
|
||||
import openvino.runtime as ov
|
||||
|
||||
from frigate.detectors.detection_api import DetectionApi
|
||||
from .detection_api import DetectionApi
|
||||
from .detector_type import DetectorTypeEnum
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OvDetector(DetectionApi):
|
||||
def __init__(self, det_device=None, model_config=None, num_threads=1):
|
||||
def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs):
|
||||
self.ov_core = ov.Core()
|
||||
self.ov_model = self.ov_core.read_model(model_config.path)
|
||||
|
||||
@ -52,3 +53,6 @@ class OvDetector(DetectionApi):
|
||||
i += 1
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
DetectionApi.register_api(DetectorTypeEnum.openvino, OvDetector)
|
||||
|
||||
19
frigate/detectors/tensorrt.py
Normal file
19
frigate/detectors/tensorrt.py
Normal file
@ -0,0 +1,19 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from .detection_api import DetectionApi
|
||||
from .detector_type import DetectorTypeEnum
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TensorRT(DetectionApi):
|
||||
def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs):
|
||||
raise NotImplementedError("TensorRT engine is not yet functional!")
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
raise NotImplementedError("TensorRT engine is not yet functional!")
|
||||
|
||||
|
||||
DetectionApi.register_api(DetectorTypeEnum.tensorrt, TensorRT)
|
||||
@ -10,10 +10,8 @@ from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
from setproctitle import setproctitle
|
||||
|
||||
from frigate.config import DetectorTypeEnum, InputTensorEnum
|
||||
from frigate.detectors.edgetpu_tfl import EdgeTpuTfl
|
||||
from frigate.detectors.openvino import OvDetector
|
||||
from frigate.detectors.cpu_tfl import CpuTfl
|
||||
from frigate.config import InputTensorEnum
|
||||
from frigate.detectors import DetectionApi, DetectorTypeEnum
|
||||
|
||||
from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels
|
||||
|
||||
@ -54,19 +52,12 @@ class LocalObjectDetector(ObjectDetector):
|
||||
else:
|
||||
self.input_transform = None
|
||||
|
||||
if det_type == DetectorTypeEnum.edgetpu:
|
||||
self.detect_api = EdgeTpuTfl(
|
||||
det_device=det_device, model_config=model_config
|
||||
)
|
||||
elif det_type == DetectorTypeEnum.openvino:
|
||||
self.detect_api = OvDetector(
|
||||
det_device=det_device, model_config=model_config
|
||||
)
|
||||
else:
|
||||
if det_type == DetectorTypeEnum.cpu:
|
||||
logger.warning(
|
||||
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
|
||||
)
|
||||
self.detect_api = CpuTfl(model_config=model_config, num_threads=num_threads)
|
||||
|
||||
self.detect_api = DetectionApi.create(det_type, det_device=det_device, model_config=model_config, num_threads=num_threads)
|
||||
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
detections = []
|
||||
|
||||
@ -7,8 +7,8 @@ import frigate.object_detection
|
||||
|
||||
|
||||
class TestLocalObjectDetector(unittest.TestCase):
|
||||
@patch("frigate.object_detection.EdgeTpuTfl")
|
||||
@patch("frigate.object_detection.CpuTfl")
|
||||
@patch("frigate.detectors.EdgeTpuTfl")
|
||||
@patch("frigate.detectors.CpuTfl")
|
||||
def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init(
|
||||
self, mock_cputfl, mock_edgetputfl
|
||||
):
|
||||
@ -22,8 +22,8 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
mock_edgetputfl.assert_not_called()
|
||||
mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6)
|
||||
|
||||
@patch("frigate.object_detection.EdgeTpuTfl")
|
||||
@patch("frigate.object_detection.CpuTfl")
|
||||
@patch("frigate.detectors.EdgeTpuTfl")
|
||||
@patch("frigate.detectors.CpuTfl")
|
||||
def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init(
|
||||
self, mock_cputfl, mock_edgetputfl
|
||||
):
|
||||
@ -40,7 +40,7 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
mock_cputfl.assert_not_called()
|
||||
mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg)
|
||||
|
||||
@patch("frigate.object_detection.CpuTfl")
|
||||
@patch("frigate.detectors.CpuTfl")
|
||||
def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result(
|
||||
self, mock_cputfl
|
||||
):
|
||||
@ -58,7 +58,7 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
|
||||
assert test_result is mock_det_api.detect_raw.return_value
|
||||
|
||||
@patch("frigate.object_detection.CpuTfl")
|
||||
@patch("frigate.detectors.CpuTfl")
|
||||
def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor(
|
||||
self, mock_cputfl
|
||||
):
|
||||
@ -85,7 +85,7 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
|
||||
assert test_result is mock_det_api.detect_raw.return_value
|
||||
|
||||
@patch("frigate.object_detection.CpuTfl")
|
||||
@patch("frigate.detectors.CpuTfl")
|
||||
@patch("frigate.object_detection.load_labels")
|
||||
def test_detect_given_tensor_input_should_return_lfiltered_detections(
|
||||
self, mock_load_labels, mock_cputfl
|
||||
|
||||
Loading…
Reference in New Issue
Block a user