refactor detectors

This commit is contained in:
Dennis George 2022-12-07 15:55:04 -06:00
parent b1ec56de29
commit 8c55966de4
10 changed files with 83 additions and 33 deletions

View File

@ -33,6 +33,7 @@ from frigate.ffmpeg_presets import (
parse_preset_output_rtmp, parse_preset_output_rtmp,
) )
from frigate.version import VERSION from frigate.version import VERSION
from frigate.detectors import DetectorTypeEnum
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,12 +53,6 @@ class FrigateBaseModel(BaseModel):
extra = Extra.forbid extra = Extra.forbid
class DetectorTypeEnum(str, Enum):
edgetpu = "edgetpu"
openvino = "openvino"
cpu = "cpu"
class DetectorConfig(FrigateBaseModel): class DetectorConfig(FrigateBaseModel):
type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type") type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type")
device: str = Field(default="usb", title="Device Type") device: str = Field(default="usb", title="Device Type")

View File

@ -0,0 +1,8 @@
import os
from .detector_type import DetectorTypeEnum
from .detection_api import DetectionApi
from .cpu_tfl import CpuTfl
from .edgetpu_tfl import EdgeTpuTfl
from .openvino import OvDetector
from .tensorrt import TensorRT

View File

@ -1,14 +1,16 @@
import logging import logging
import numpy as np import numpy as np
from frigate.detectors.detection_api import DetectionApi from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
import tflite_runtime.interpreter as tflite import tflite_runtime.interpreter as tflite
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CpuTfl(DetectionApi): class CpuTfl(DetectionApi):
def __init__(self, det_device=None, model_config=None, num_threads=3): def __init__(self, det_device=None, model_config=None, num_threads=3, **kwargs):
self.interpreter = tflite.Interpreter( self.interpreter = tflite.Interpreter(
model_path=model_config.path or "/cpu_model.tflite", num_threads=num_threads model_path=model_config.path or "/cpu_model.tflite", num_threads=num_threads
) )
@ -44,3 +46,6 @@ class CpuTfl(DetectionApi):
] ]
return detections return detections
DetectionApi.register_api(DetectorTypeEnum.cpu, CpuTfl)

View File

@ -1,5 +1,7 @@
import logging import logging
from .detector_type import DetectorTypeEnum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict from typing import Dict
@ -8,6 +10,8 @@ logger = logging.getLogger(__name__)
class DetectionApi(ABC): class DetectionApi(ABC):
_api_types = {}
@abstractmethod @abstractmethod
def __init__(self, det_device=None, model_config=None): def __init__(self, det_device=None, model_config=None):
pass pass
@ -15,3 +19,14 @@ class DetectionApi(ABC):
@abstractmethod @abstractmethod
def detect_raw(self, tensor_input): def detect_raw(self, tensor_input):
pass pass
@staticmethod
def register_api(det_type: DetectorTypeEnum, det_api):
DetectionApi._api_types[det_type] = det_api
@staticmethod
def create(det_type: DetectorTypeEnum, **kwargs):
api = DetectionApi._api_types.get(det_type)
if not api:
raise ValueError(det_type)
return api(**kwargs)

View File

@ -0,0 +1,8 @@
from enum import Enum
class DetectorTypeEnum(str, Enum):
edgetpu = "edgetpu"
openvino = "openvino"
cpu = "cpu"
tensorrt = "tensorrt"

View File

@ -1,15 +1,17 @@
import logging import logging
import numpy as np import numpy as np
from frigate.detectors.detection_api import DetectionApi from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
import tflite_runtime.interpreter as tflite import tflite_runtime.interpreter as tflite
from tflite_runtime.interpreter import load_delegate from tflite_runtime.interpreter import load_delegate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EdgeTpuTfl(DetectionApi): class EdgeTpuTfl(DetectionApi):
def __init__(self, det_device=None, model_config=None): def __init__(self, det_device=None, model_config=None, **kwargs):
device_config = {"device": "usb"} device_config = {"device": "usb"}
if not det_device is None: if not det_device is None:
device_config = {"device": det_device} device_config = {"device": det_device}
@ -61,3 +63,6 @@ class EdgeTpuTfl(DetectionApi):
] ]
return detections return detections
DetectionApi.register_api(DetectorTypeEnum.edgetpu, EdgeTpuTfl)

View File

@ -2,14 +2,15 @@ import logging
import numpy as np import numpy as np
import openvino.runtime as ov import openvino.runtime as ov
from frigate.detectors.detection_api import DetectionApi from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OvDetector(DetectionApi): class OvDetector(DetectionApi):
def __init__(self, det_device=None, model_config=None, num_threads=1): def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs):
self.ov_core = ov.Core() self.ov_core = ov.Core()
self.ov_model = self.ov_core.read_model(model_config.path) self.ov_model = self.ov_core.read_model(model_config.path)
@ -52,3 +53,6 @@ class OvDetector(DetectionApi):
i += 1 i += 1
return detections return detections
DetectionApi.register_api(DetectorTypeEnum.openvino, OvDetector)

View File

@ -0,0 +1,19 @@
import logging
import numpy as np
from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
logger = logging.getLogger(__name__)
class TensorRT(DetectionApi):
def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs):
raise NotImplementedError("TensorRT engine is not yet functional!")
def detect_raw(self, tensor_input):
raise NotImplementedError("TensorRT engine is not yet functional!")
DetectionApi.register_api(DetectorTypeEnum.tensorrt, TensorRT)

View File

@ -10,10 +10,8 @@ from abc import ABC, abstractmethod
import numpy as np import numpy as np
from setproctitle import setproctitle from setproctitle import setproctitle
from frigate.config import DetectorTypeEnum, InputTensorEnum from frigate.config import InputTensorEnum
from frigate.detectors.edgetpu_tfl import EdgeTpuTfl from frigate.detectors import DetectionApi, DetectorTypeEnum
from frigate.detectors.openvino import OvDetector
from frigate.detectors.cpu_tfl import CpuTfl
from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels
@ -54,19 +52,12 @@ class LocalObjectDetector(ObjectDetector):
else: else:
self.input_transform = None self.input_transform = None
if det_type == DetectorTypeEnum.edgetpu: if det_type == DetectorTypeEnum.cpu:
self.detect_api = EdgeTpuTfl(
det_device=det_device, model_config=model_config
)
elif det_type == DetectorTypeEnum.openvino:
self.detect_api = OvDetector(
det_device=det_device, model_config=model_config
)
else:
logger.warning( logger.warning(
"CPU detectors are not recommended and should only be used for testing or for trial purposes." "CPU detectors are not recommended and should only be used for testing or for trial purposes."
) )
self.detect_api = CpuTfl(model_config=model_config, num_threads=num_threads)
self.detect_api = DetectionApi.create(det_type, det_device=det_device, model_config=model_config, num_threads=num_threads)
def detect(self, tensor_input, threshold=0.4): def detect(self, tensor_input, threshold=0.4):
detections = [] detections = []

View File

@ -7,8 +7,8 @@ import frigate.object_detection
class TestLocalObjectDetector(unittest.TestCase): class TestLocalObjectDetector(unittest.TestCase):
@patch("frigate.object_detection.EdgeTpuTfl") @patch("frigate.detectors.EdgeTpuTfl")
@patch("frigate.object_detection.CpuTfl") @patch("frigate.detectors.CpuTfl")
def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init( def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init(
self, mock_cputfl, mock_edgetputfl self, mock_cputfl, mock_edgetputfl
): ):
@ -22,8 +22,8 @@ class TestLocalObjectDetector(unittest.TestCase):
mock_edgetputfl.assert_not_called() mock_edgetputfl.assert_not_called()
mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6) mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6)
@patch("frigate.object_detection.EdgeTpuTfl") @patch("frigate.detectors.EdgeTpuTfl")
@patch("frigate.object_detection.CpuTfl") @patch("frigate.detectors.CpuTfl")
def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init( def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init(
self, mock_cputfl, mock_edgetputfl self, mock_cputfl, mock_edgetputfl
): ):
@ -40,7 +40,7 @@ class TestLocalObjectDetector(unittest.TestCase):
mock_cputfl.assert_not_called() mock_cputfl.assert_not_called()
mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg) mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg)
@patch("frigate.object_detection.CpuTfl") @patch("frigate.detectors.CpuTfl")
def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result( def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result(
self, mock_cputfl self, mock_cputfl
): ):
@ -58,7 +58,7 @@ class TestLocalObjectDetector(unittest.TestCase):
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA) mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
assert test_result is mock_det_api.detect_raw.return_value assert test_result is mock_det_api.detect_raw.return_value
@patch("frigate.object_detection.CpuTfl") @patch("frigate.detectors.CpuTfl")
def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor( def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor(
self, mock_cputfl self, mock_cputfl
): ):
@ -85,7 +85,7 @@ class TestLocalObjectDetector(unittest.TestCase):
assert test_result is mock_det_api.detect_raw.return_value assert test_result is mock_det_api.detect_raw.return_value
@patch("frigate.object_detection.CpuTfl") @patch("frigate.detectors.CpuTfl")
@patch("frigate.object_detection.load_labels") @patch("frigate.object_detection.load_labels")
def test_detect_given_tensor_input_should_return_lfiltered_detections( def test_detect_given_tensor_input_should_return_lfiltered_detections(
self, mock_load_labels, mock_cputfl self, mock_load_labels, mock_cputfl