move create_detector and DetectorTypeEnum

This commit is contained in:
Dennis George 2022-12-08 11:25:46 -06:00
parent 8c55966de4
commit 3fa174eca4
9 changed files with 37 additions and 69 deletions

View File

@ -1,8 +1,35 @@
import os import logging
from enum import Enum
from .detector_type import DetectorTypeEnum
from .detection_api import DetectionApi from .detection_api import DetectionApi
from .cpu_tfl import CpuTfl from .cpu_tfl import CpuTfl
from .edgetpu_tfl import EdgeTpuTfl from .edgetpu_tfl import EdgeTpuTfl
from .openvino import OvDetector from .openvino import OvDetector
from .tensorrt import TensorRT
logger = logging.getLogger(__name__)
class DetectorTypeEnum(str, Enum):
edgetpu = "edgetpu"
openvino = "openvino"
cpu = "cpu"
tensorrt = "tensorrt"
def create_detector(det_type: DetectorTypeEnum, **kwargs):
_api_types = {
DetectorTypeEnum.cpu: CpuTfl,
DetectorTypeEnum.edgetpu: EdgeTpuTfl,
DetectorTypeEnum.openvino: OvDetector
}
if det_type == DetectorTypeEnum.cpu:
logger.warning(
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
)
api = _api_types.get(det_type)
if not api:
raise ValueError(det_type)
return api(**kwargs)

View File

@ -2,7 +2,6 @@ import logging
import numpy as np import numpy as np
from .detection_api import DetectionApi from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
import tflite_runtime.interpreter as tflite import tflite_runtime.interpreter as tflite
@ -46,6 +45,3 @@ class CpuTfl(DetectionApi):
] ]
return detections return detections
DetectionApi.register_api(DetectorTypeEnum.cpu, CpuTfl)

View File

@ -1,7 +1,5 @@
import logging import logging
from .detector_type import DetectorTypeEnum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict from typing import Dict
@ -10,8 +8,6 @@ logger = logging.getLogger(__name__)
class DetectionApi(ABC): class DetectionApi(ABC):
_api_types = {}
@abstractmethod @abstractmethod
def __init__(self, det_device=None, model_config=None): def __init__(self, det_device=None, model_config=None):
pass pass
@ -19,14 +15,3 @@ class DetectionApi(ABC):
@abstractmethod @abstractmethod
def detect_raw(self, tensor_input): def detect_raw(self, tensor_input):
pass pass
@staticmethod
def register_api(det_type: DetectorTypeEnum, det_api):
DetectionApi._api_types[det_type] = det_api
@staticmethod
def create(det_type: DetectorTypeEnum, **kwargs):
api = DetectionApi._api_types.get(det_type)
if not api:
raise ValueError(det_type)
return api(**kwargs)

View File

@ -1,8 +0,0 @@
from enum import Enum
class DetectorTypeEnum(str, Enum):
edgetpu = "edgetpu"
openvino = "openvino"
cpu = "cpu"
tensorrt = "tensorrt"

View File

@ -2,7 +2,6 @@ import logging
import numpy as np import numpy as np
from .detection_api import DetectionApi from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
import tflite_runtime.interpreter as tflite import tflite_runtime.interpreter as tflite
from tflite_runtime.interpreter import load_delegate from tflite_runtime.interpreter import load_delegate
@ -63,6 +62,3 @@ class EdgeTpuTfl(DetectionApi):
] ]
return detections return detections
DetectionApi.register_api(DetectorTypeEnum.edgetpu, EdgeTpuTfl)

View File

@ -3,7 +3,6 @@ import numpy as np
import openvino.runtime as ov import openvino.runtime as ov
from .detection_api import DetectionApi from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -53,6 +52,3 @@ class OvDetector(DetectionApi):
i += 1 i += 1
return detections return detections
DetectionApi.register_api(DetectorTypeEnum.openvino, OvDetector)

View File

@ -1,19 +0,0 @@
import logging
import numpy as np
from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
logger = logging.getLogger(__name__)
class TensorRT(DetectionApi):
def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs):
raise NotImplementedError("TensorRT engine is not yet functional!")
def detect_raw(self, tensor_input):
raise NotImplementedError("TensorRT engine is not yet functional!")
DetectionApi.register_api(DetectorTypeEnum.tensorrt, TensorRT)

View File

@ -11,7 +11,7 @@ import numpy as np
from setproctitle import setproctitle from setproctitle import setproctitle
from frigate.config import InputTensorEnum from frigate.config import InputTensorEnum
from frigate.detectors import DetectionApi, DetectorTypeEnum from frigate.detectors import create_detector, DetectorTypeEnum
from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels
@ -52,12 +52,7 @@ class LocalObjectDetector(ObjectDetector):
else: else:
self.input_transform = None self.input_transform = None
if det_type == DetectorTypeEnum.cpu: self.detect_api = create_detector(det_type, det_device=det_device, model_config=model_config, num_threads=num_threads)
logger.warning(
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
)
self.detect_api = DetectionApi.create(det_type, det_device=det_device, model_config=model_config, num_threads=num_threads)
def detect(self, tensor_input, threshold=0.4): def detect(self, tensor_input, threshold=0.4):
detections = [] detections = []

View File

@ -7,8 +7,8 @@ import frigate.object_detection
class TestLocalObjectDetector(unittest.TestCase): class TestLocalObjectDetector(unittest.TestCase):
@patch("frigate.detectors.EdgeTpuTfl") @patch("frigate.detectors.edgetpu_tfl.EdgeTpuTfl")
@patch("frigate.detectors.CpuTfl") @patch("frigate.detectors.cpu_tfl.CpuTfl")
def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init( def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init(
self, mock_cputfl, mock_edgetputfl self, mock_cputfl, mock_edgetputfl
): ):
@ -22,8 +22,8 @@ class TestLocalObjectDetector(unittest.TestCase):
mock_edgetputfl.assert_not_called() mock_edgetputfl.assert_not_called()
mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6) mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6)
@patch("frigate.detectors.EdgeTpuTfl") @patch("frigate.detectors.edgetpu_tfl.EdgeTpuTfl")
@patch("frigate.detectors.CpuTfl") @patch("frigate.detectors.cpu_tfl.CpuTfl")
def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init( def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init(
self, mock_cputfl, mock_edgetputfl self, mock_cputfl, mock_edgetputfl
): ):
@ -40,7 +40,7 @@ class TestLocalObjectDetector(unittest.TestCase):
mock_cputfl.assert_not_called() mock_cputfl.assert_not_called()
mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg) mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg)
@patch("frigate.detectors.CpuTfl") @patch("frigate.detectors.cpu_tfl.CpuTfl")
def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result( def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result(
self, mock_cputfl self, mock_cputfl
): ):