refactor detectors

This commit is contained in:
Dennis George 2022-12-07 15:55:04 -06:00
parent b1ec56de29
commit 8c55966de4
10 changed files with 83 additions and 33 deletions

View File

@ -33,6 +33,7 @@ from frigate.ffmpeg_presets import (
parse_preset_output_rtmp,
)
from frigate.version import VERSION
from frigate.detectors import DetectorTypeEnum
logger = logging.getLogger(__name__)
@ -52,12 +53,6 @@ class FrigateBaseModel(BaseModel):
extra = Extra.forbid
class DetectorTypeEnum(str, Enum):
edgetpu = "edgetpu"
openvino = "openvino"
cpu = "cpu"
class DetectorConfig(FrigateBaseModel):
type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type")
device: str = Field(default="usb", title="Device Type")

View File

@ -0,0 +1,8 @@
import os
from .detector_type import DetectorTypeEnum
from .detection_api import DetectionApi
from .cpu_tfl import CpuTfl
from .edgetpu_tfl import EdgeTpuTfl
from .openvino import OvDetector
from .tensorrt import TensorRT

View File

@ -1,14 +1,16 @@
import logging
import numpy as np
from frigate.detectors.detection_api import DetectionApi
from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
import tflite_runtime.interpreter as tflite
logger = logging.getLogger(__name__)
class CpuTfl(DetectionApi):
def __init__(self, det_device=None, model_config=None, num_threads=3):
def __init__(self, det_device=None, model_config=None, num_threads=3, **kwargs):
self.interpreter = tflite.Interpreter(
model_path=model_config.path or "/cpu_model.tflite", num_threads=num_threads
)
@ -44,3 +46,6 @@ class CpuTfl(DetectionApi):
]
return detections
DetectionApi.register_api(DetectorTypeEnum.cpu, CpuTfl)

View File

@ -1,5 +1,7 @@
import logging
from .detector_type import DetectorTypeEnum
from abc import ABC, abstractmethod
from typing import Dict
@ -8,6 +10,8 @@ logger = logging.getLogger(__name__)
class DetectionApi(ABC):
_api_types = {}
@abstractmethod
def __init__(self, det_device=None, model_config=None):
pass
@ -15,3 +19,14 @@ class DetectionApi(ABC):
@abstractmethod
def detect_raw(self, tensor_input):
pass
@staticmethod
def register_api(det_type: DetectorTypeEnum, det_api):
DetectionApi._api_types[det_type] = det_api
@staticmethod
def create(det_type: DetectorTypeEnum, **kwargs):
api = DetectionApi._api_types.get(det_type)
if not api:
raise ValueError(det_type)
return api(**kwargs)

View File

@ -0,0 +1,8 @@
from enum import Enum
class DetectorTypeEnum(str, Enum):
edgetpu = "edgetpu"
openvino = "openvino"
cpu = "cpu"
tensorrt = "tensorrt"

View File

@ -1,15 +1,17 @@
import logging
import numpy as np
from frigate.detectors.detection_api import DetectionApi
from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
import tflite_runtime.interpreter as tflite
from tflite_runtime.interpreter import load_delegate
logger = logging.getLogger(__name__)
class EdgeTpuTfl(DetectionApi):
def __init__(self, det_device=None, model_config=None):
def __init__(self, det_device=None, model_config=None, **kwargs):
device_config = {"device": "usb"}
if not det_device is None:
device_config = {"device": det_device}
@ -61,3 +63,6 @@ class EdgeTpuTfl(DetectionApi):
]
return detections
DetectionApi.register_api(DetectorTypeEnum.edgetpu, EdgeTpuTfl)

View File

@ -2,14 +2,15 @@ import logging
import numpy as np
import openvino.runtime as ov
from frigate.detectors.detection_api import DetectionApi
from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
logger = logging.getLogger(__name__)
class OvDetector(DetectionApi):
def __init__(self, det_device=None, model_config=None, num_threads=1):
def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs):
self.ov_core = ov.Core()
self.ov_model = self.ov_core.read_model(model_config.path)
@ -52,3 +53,6 @@ class OvDetector(DetectionApi):
i += 1
return detections
DetectionApi.register_api(DetectorTypeEnum.openvino, OvDetector)

View File

@ -0,0 +1,19 @@
import logging
import numpy as np
from .detection_api import DetectionApi
from .detector_type import DetectorTypeEnum
logger = logging.getLogger(__name__)
class TensorRT(DetectionApi):
def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs):
raise NotImplementedError("TensorRT engine is not yet functional!")
def detect_raw(self, tensor_input):
raise NotImplementedError("TensorRT engine is not yet functional!")
DetectionApi.register_api(DetectorTypeEnum.tensorrt, TensorRT)

View File

@ -10,10 +10,8 @@ from abc import ABC, abstractmethod
import numpy as np
from setproctitle import setproctitle
from frigate.config import DetectorTypeEnum, InputTensorEnum
from frigate.detectors.edgetpu_tfl import EdgeTpuTfl
from frigate.detectors.openvino import OvDetector
from frigate.detectors.cpu_tfl import CpuTfl
from frigate.config import InputTensorEnum
from frigate.detectors import DetectionApi, DetectorTypeEnum
from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels
@ -54,19 +52,12 @@ class LocalObjectDetector(ObjectDetector):
else:
self.input_transform = None
if det_type == DetectorTypeEnum.edgetpu:
self.detect_api = EdgeTpuTfl(
det_device=det_device, model_config=model_config
)
elif det_type == DetectorTypeEnum.openvino:
self.detect_api = OvDetector(
det_device=det_device, model_config=model_config
)
else:
if det_type == DetectorTypeEnum.cpu:
logger.warning(
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
)
self.detect_api = CpuTfl(model_config=model_config, num_threads=num_threads)
self.detect_api = DetectionApi.create(det_type, det_device=det_device, model_config=model_config, num_threads=num_threads)
def detect(self, tensor_input, threshold=0.4):
detections = []

View File

@ -7,8 +7,8 @@ import frigate.object_detection
class TestLocalObjectDetector(unittest.TestCase):
@patch("frigate.object_detection.EdgeTpuTfl")
@patch("frigate.object_detection.CpuTfl")
@patch("frigate.detectors.EdgeTpuTfl")
@patch("frigate.detectors.CpuTfl")
def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init(
self, mock_cputfl, mock_edgetputfl
):
@ -22,8 +22,8 @@ class TestLocalObjectDetector(unittest.TestCase):
mock_edgetputfl.assert_not_called()
mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6)
@patch("frigate.object_detection.EdgeTpuTfl")
@patch("frigate.object_detection.CpuTfl")
@patch("frigate.detectors.EdgeTpuTfl")
@patch("frigate.detectors.CpuTfl")
def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init(
self, mock_cputfl, mock_edgetputfl
):
@ -40,7 +40,7 @@ class TestLocalObjectDetector(unittest.TestCase):
mock_cputfl.assert_not_called()
mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg)
@patch("frigate.object_detection.CpuTfl")
@patch("frigate.detectors.CpuTfl")
def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result(
self, mock_cputfl
):
@ -58,7 +58,7 @@ class TestLocalObjectDetector(unittest.TestCase):
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
assert test_result is mock_det_api.detect_raw.return_value
@patch("frigate.object_detection.CpuTfl")
@patch("frigate.detectors.CpuTfl")
def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor(
self, mock_cputfl
):
@ -85,7 +85,7 @@ class TestLocalObjectDetector(unittest.TestCase):
assert test_result is mock_det_api.detect_raw.return_value
@patch("frigate.object_detection.CpuTfl")
@patch("frigate.detectors.CpuTfl")
@patch("frigate.object_detection.load_labels")
def test_detect_given_tensor_input_should_return_lfiltered_detections(
self, mock_load_labels, mock_cputfl