diff --git a/frigate/config.py b/frigate/config.py index 44e85f9a4..8d8f9bd7e 100644 --- a/frigate/config.py +++ b/frigate/config.py @@ -33,6 +33,7 @@ from frigate.ffmpeg_presets import ( parse_preset_output_rtmp, ) from frigate.version import VERSION +from frigate.detectors import DetectorTypeEnum logger = logging.getLogger(__name__) @@ -52,12 +53,6 @@ class FrigateBaseModel(BaseModel): extra = Extra.forbid -class DetectorTypeEnum(str, Enum): - edgetpu = "edgetpu" - openvino = "openvino" - cpu = "cpu" - - class DetectorConfig(FrigateBaseModel): type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type") device: str = Field(default="usb", title="Device Type") diff --git a/frigate/detectors/__init__.py b/frigate/detectors/__init__.py index e69de29bb..b8d1364e5 100644 --- a/frigate/detectors/__init__.py +++ b/frigate/detectors/__init__.py @@ -0,0 +1,8 @@ +import os + +from .detector_type import DetectorTypeEnum +from .detection_api import DetectionApi +from .cpu_tfl import CpuTfl +from .edgetpu_tfl import EdgeTpuTfl +from .openvino import OvDetector +from .tensorrt import TensorRT diff --git a/frigate/detectors/cpu_tfl.py b/frigate/detectors/cpu_tfl.py index ea1e4ddc2..0828bb38f 100644 --- a/frigate/detectors/cpu_tfl.py +++ b/frigate/detectors/cpu_tfl.py @@ -1,14 +1,16 @@ import logging import numpy as np -from frigate.detectors.detection_api import DetectionApi +from .detection_api import DetectionApi +from .detector_type import DetectorTypeEnum import tflite_runtime.interpreter as tflite + logger = logging.getLogger(__name__) class CpuTfl(DetectionApi): - def __init__(self, det_device=None, model_config=None, num_threads=3): + def __init__(self, det_device=None, model_config=None, num_threads=3, **kwargs): self.interpreter = tflite.Interpreter( model_path=model_config.path or "/cpu_model.tflite", num_threads=num_threads ) @@ -44,3 +46,6 @@ class CpuTfl(DetectionApi): ] return detections + + +DetectionApi.register_api(DetectorTypeEnum.cpu, CpuTfl) diff --git a/frigate/detectors/detection_api.py b/frigate/detectors/detection_api.py index 244195d46..11babed18 100644 --- a/frigate/detectors/detection_api.py +++ b/frigate/detectors/detection_api.py @@ -1,5 +1,7 @@ import logging +from .detector_type import DetectorTypeEnum + from abc import ABC, abstractmethod from typing import Dict @@ -8,6 +10,8 @@ logger = logging.getLogger(__name__) class DetectionApi(ABC): + _api_types = {} + @abstractmethod def __init__(self, det_device=None, model_config=None): pass @@ -15,3 +19,14 @@ class DetectionApi(ABC): @abstractmethod def detect_raw(self, tensor_input): pass + + @staticmethod + def register_api(det_type: DetectorTypeEnum, det_api): + DetectionApi._api_types[det_type] = det_api + + @staticmethod + def create(det_type: DetectorTypeEnum, **kwargs): + api = DetectionApi._api_types.get(det_type) + if not api: + raise ValueError(det_type) + return api(**kwargs) diff --git a/frigate/detectors/detector_type.py b/frigate/detectors/detector_type.py new file mode 100644 index 000000000..266ae13d2 --- /dev/null +++ b/frigate/detectors/detector_type.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class DetectorTypeEnum(str, Enum): + edgetpu = "edgetpu" + openvino = "openvino" + cpu = "cpu" + tensorrt = "tensorrt" diff --git a/frigate/detectors/edgetpu_tfl.py b/frigate/detectors/edgetpu_tfl.py index aa3abf70c..1592bccf2 100644 --- a/frigate/detectors/edgetpu_tfl.py +++ b/frigate/detectors/edgetpu_tfl.py @@ -1,15 +1,17 @@ import logging import numpy as np -from frigate.detectors.detection_api import DetectionApi +from .detection_api import DetectionApi +from .detector_type import DetectorTypeEnum import tflite_runtime.interpreter as tflite from tflite_runtime.interpreter import load_delegate + logger = logging.getLogger(__name__) class EdgeTpuTfl(DetectionApi): - def __init__(self, det_device=None, model_config=None): + def __init__(self, det_device=None, model_config=None, **kwargs): device_config = {"device": "usb"} if not det_device is None: device_config = {"device": det_device} @@ -61,3 +63,6 @@ class EdgeTpuTfl(DetectionApi): ] return detections + + +DetectionApi.register_api(DetectorTypeEnum.edgetpu, EdgeTpuTfl) diff --git a/frigate/detectors/openvino.py b/frigate/detectors/openvino.py index 02bfa1b42..659f4decc 100644 --- a/frigate/detectors/openvino.py +++ b/frigate/detectors/openvino.py @@ -2,14 +2,15 @@ import logging import numpy as np import openvino.runtime as ov -from frigate.detectors.detection_api import DetectionApi +from .detection_api import DetectionApi +from .detector_type import DetectorTypeEnum logger = logging.getLogger(__name__) class OvDetector(DetectionApi): - def __init__(self, det_device=None, model_config=None, num_threads=1): + def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs): self.ov_core = ov.Core() self.ov_model = self.ov_core.read_model(model_config.path) @@ -52,3 +53,6 @@ class OvDetector(DetectionApi): i += 1 return detections + + +DetectionApi.register_api(DetectorTypeEnum.openvino, OvDetector) diff --git a/frigate/detectors/tensorrt.py b/frigate/detectors/tensorrt.py new file mode 100644 index 000000000..5c6fbe54b --- /dev/null +++ b/frigate/detectors/tensorrt.py @@ -0,0 +1,19 @@ +import logging +import numpy as np + +from .detection_api import DetectionApi +from .detector_type import DetectorTypeEnum + + +logger = logging.getLogger(__name__) + + +class TensorRT(DetectionApi): + def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs): + raise NotImplementedError("TensorRT engine is not yet functional!") + + def detect_raw(self, tensor_input): + raise NotImplementedError("TensorRT engine is not yet functional!") + + +DetectionApi.register_api(DetectorTypeEnum.tensorrt, TensorRT) diff --git a/frigate/object_detection.py b/frigate/object_detection.py index 83efb4f6b..e79a94095 100644 --- a/frigate/object_detection.py +++ b/frigate/object_detection.py @@ -10,10 +10,8 @@ from abc import ABC, abstractmethod import numpy as np from setproctitle import setproctitle -from frigate.config import DetectorTypeEnum, InputTensorEnum -from frigate.detectors.edgetpu_tfl import EdgeTpuTfl -from frigate.detectors.openvino import OvDetector -from frigate.detectors.cpu_tfl import CpuTfl +from frigate.config import InputTensorEnum +from frigate.detectors import DetectionApi, DetectorTypeEnum from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels @@ -54,19 +52,12 @@ class LocalObjectDetector(ObjectDetector): else: self.input_transform = None - if det_type == DetectorTypeEnum.edgetpu: - self.detect_api = EdgeTpuTfl( - det_device=det_device, model_config=model_config - ) - elif det_type == DetectorTypeEnum.openvino: - self.detect_api = OvDetector( - det_device=det_device, model_config=model_config - ) - else: + if det_type == DetectorTypeEnum.cpu: logger.warning( "CPU detectors are not recommended and should only be used for testing or for trial purposes." ) - self.detect_api = CpuTfl(model_config=model_config, num_threads=num_threads) + + self.detect_api = DetectionApi.create(det_type, det_device=det_device, model_config=model_config, num_threads=num_threads) def detect(self, tensor_input, threshold=0.4): detections = [] diff --git a/frigate/test/test_object_detector.py b/frigate/test/test_object_detector.py index f90f4d16c..ea2479284 100644 --- a/frigate/test/test_object_detector.py +++ b/frigate/test/test_object_detector.py @@ -7,8 +7,8 @@ import frigate.object_detection class TestLocalObjectDetector(unittest.TestCase): - @patch("frigate.object_detection.EdgeTpuTfl") - @patch("frigate.object_detection.CpuTfl") + @patch("frigate.detectors.EdgeTpuTfl") + @patch("frigate.detectors.CpuTfl") def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init( self, mock_cputfl, mock_edgetputfl ): @@ -22,8 +22,8 @@ class TestLocalObjectDetector(unittest.TestCase): mock_edgetputfl.assert_not_called() mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6) - @patch("frigate.object_detection.EdgeTpuTfl") - @patch("frigate.object_detection.CpuTfl") + @patch("frigate.detectors.EdgeTpuTfl") + @patch("frigate.detectors.CpuTfl") def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init( self, mock_cputfl, mock_edgetputfl ): @@ -40,7 +40,7 @@ class TestLocalObjectDetector(unittest.TestCase): mock_cputfl.assert_not_called() mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg) - @patch("frigate.object_detection.CpuTfl") + @patch("frigate.detectors.CpuTfl") def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result( self, mock_cputfl ): @@ -58,7 +58,7 @@ class TestLocalObjectDetector(unittest.TestCase): mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA) assert test_result is mock_det_api.detect_raw.return_value - @patch("frigate.object_detection.CpuTfl") + @patch("frigate.detectors.CpuTfl") def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor( self, mock_cputfl ): @@ -85,7 +85,7 @@ class TestLocalObjectDetector(unittest.TestCase): assert test_result is mock_det_api.detect_raw.return_value - @patch("frigate.object_detection.CpuTfl") + @patch("frigate.detectors.CpuTfl") @patch("frigate.object_detection.load_labels") def test_detect_given_tensor_input_should_return_lfiltered_detections( self, mock_load_labels, mock_cputfl