mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 09:45:22 +03:00
refactor detectors
This commit is contained in:
parent
b1ec56de29
commit
8c55966de4
@ -33,6 +33,7 @@ from frigate.ffmpeg_presets import (
|
|||||||
parse_preset_output_rtmp,
|
parse_preset_output_rtmp,
|
||||||
)
|
)
|
||||||
from frigate.version import VERSION
|
from frigate.version import VERSION
|
||||||
|
from frigate.detectors import DetectorTypeEnum
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -52,12 +53,6 @@ class FrigateBaseModel(BaseModel):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
|
||||||
class DetectorTypeEnum(str, Enum):
|
|
||||||
edgetpu = "edgetpu"
|
|
||||||
openvino = "openvino"
|
|
||||||
cpu = "cpu"
|
|
||||||
|
|
||||||
|
|
||||||
class DetectorConfig(FrigateBaseModel):
|
class DetectorConfig(FrigateBaseModel):
|
||||||
type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type")
|
type: DetectorTypeEnum = Field(default=DetectorTypeEnum.cpu, title="Detector Type")
|
||||||
device: str = Field(default="usb", title="Device Type")
|
device: str = Field(default="usb", title="Device Type")
|
||||||
|
|||||||
@ -0,0 +1,8 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from .detector_type import DetectorTypeEnum
|
||||||
|
from .detection_api import DetectionApi
|
||||||
|
from .cpu_tfl import CpuTfl
|
||||||
|
from .edgetpu_tfl import EdgeTpuTfl
|
||||||
|
from .openvino import OvDetector
|
||||||
|
from .tensorrt import TensorRT
|
||||||
@ -1,14 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from frigate.detectors.detection_api import DetectionApi
|
from .detection_api import DetectionApi
|
||||||
|
from .detector_type import DetectorTypeEnum
|
||||||
import tflite_runtime.interpreter as tflite
|
import tflite_runtime.interpreter as tflite
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CpuTfl(DetectionApi):
|
class CpuTfl(DetectionApi):
|
||||||
def __init__(self, det_device=None, model_config=None, num_threads=3):
|
def __init__(self, det_device=None, model_config=None, num_threads=3, **kwargs):
|
||||||
self.interpreter = tflite.Interpreter(
|
self.interpreter = tflite.Interpreter(
|
||||||
model_path=model_config.path or "/cpu_model.tflite", num_threads=num_threads
|
model_path=model_config.path or "/cpu_model.tflite", num_threads=num_threads
|
||||||
)
|
)
|
||||||
@ -44,3 +46,6 @@ class CpuTfl(DetectionApi):
|
|||||||
]
|
]
|
||||||
|
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
|
|
||||||
|
DetectionApi.register_api(DetectorTypeEnum.cpu, CpuTfl)
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from .detector_type import DetectorTypeEnum
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
@ -8,6 +10,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class DetectionApi(ABC):
|
class DetectionApi(ABC):
|
||||||
|
_api_types = {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(self, det_device=None, model_config=None):
|
def __init__(self, det_device=None, model_config=None):
|
||||||
pass
|
pass
|
||||||
@ -15,3 +19,14 @@ class DetectionApi(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def detect_raw(self, tensor_input):
|
def detect_raw(self, tensor_input):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register_api(det_type: DetectorTypeEnum, det_api):
|
||||||
|
DetectionApi._api_types[det_type] = det_api
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(det_type: DetectorTypeEnum, **kwargs):
|
||||||
|
api = DetectionApi._api_types.get(det_type)
|
||||||
|
if not api:
|
||||||
|
raise ValueError(det_type)
|
||||||
|
return api(**kwargs)
|
||||||
|
|||||||
8
frigate/detectors/detector_type.py
Normal file
8
frigate/detectors/detector_type.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class DetectorTypeEnum(str, Enum):
|
||||||
|
edgetpu = "edgetpu"
|
||||||
|
openvino = "openvino"
|
||||||
|
cpu = "cpu"
|
||||||
|
tensorrt = "tensorrt"
|
||||||
@ -1,15 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from frigate.detectors.detection_api import DetectionApi
|
from .detection_api import DetectionApi
|
||||||
|
from .detector_type import DetectorTypeEnum
|
||||||
import tflite_runtime.interpreter as tflite
|
import tflite_runtime.interpreter as tflite
|
||||||
from tflite_runtime.interpreter import load_delegate
|
from tflite_runtime.interpreter import load_delegate
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EdgeTpuTfl(DetectionApi):
|
class EdgeTpuTfl(DetectionApi):
|
||||||
def __init__(self, det_device=None, model_config=None):
|
def __init__(self, det_device=None, model_config=None, **kwargs):
|
||||||
device_config = {"device": "usb"}
|
device_config = {"device": "usb"}
|
||||||
if not det_device is None:
|
if not det_device is None:
|
||||||
device_config = {"device": det_device}
|
device_config = {"device": det_device}
|
||||||
@ -61,3 +63,6 @@ class EdgeTpuTfl(DetectionApi):
|
|||||||
]
|
]
|
||||||
|
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
|
|
||||||
|
DetectionApi.register_api(DetectorTypeEnum.edgetpu, EdgeTpuTfl)
|
||||||
|
|||||||
@ -2,14 +2,15 @@ import logging
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import openvino.runtime as ov
|
import openvino.runtime as ov
|
||||||
|
|
||||||
from frigate.detectors.detection_api import DetectionApi
|
from .detection_api import DetectionApi
|
||||||
|
from .detector_type import DetectorTypeEnum
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OvDetector(DetectionApi):
|
class OvDetector(DetectionApi):
|
||||||
def __init__(self, det_device=None, model_config=None, num_threads=1):
|
def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs):
|
||||||
self.ov_core = ov.Core()
|
self.ov_core = ov.Core()
|
||||||
self.ov_model = self.ov_core.read_model(model_config.path)
|
self.ov_model = self.ov_core.read_model(model_config.path)
|
||||||
|
|
||||||
@ -52,3 +53,6 @@ class OvDetector(DetectionApi):
|
|||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
|
|
||||||
|
DetectionApi.register_api(DetectorTypeEnum.openvino, OvDetector)
|
||||||
|
|||||||
19
frigate/detectors/tensorrt.py
Normal file
19
frigate/detectors/tensorrt.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .detection_api import DetectionApi
|
||||||
|
from .detector_type import DetectorTypeEnum
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TensorRT(DetectionApi):
|
||||||
|
def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs):
|
||||||
|
raise NotImplementedError("TensorRT engine is not yet functional!")
|
||||||
|
|
||||||
|
def detect_raw(self, tensor_input):
|
||||||
|
raise NotImplementedError("TensorRT engine is not yet functional!")
|
||||||
|
|
||||||
|
|
||||||
|
DetectionApi.register_api(DetectorTypeEnum.tensorrt, TensorRT)
|
||||||
@ -10,10 +10,8 @@ from abc import ABC, abstractmethod
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from setproctitle import setproctitle
|
from setproctitle import setproctitle
|
||||||
|
|
||||||
from frigate.config import DetectorTypeEnum, InputTensorEnum
|
from frigate.config import InputTensorEnum
|
||||||
from frigate.detectors.edgetpu_tfl import EdgeTpuTfl
|
from frigate.detectors import DetectionApi, DetectorTypeEnum
|
||||||
from frigate.detectors.openvino import OvDetector
|
|
||||||
from frigate.detectors.cpu_tfl import CpuTfl
|
|
||||||
|
|
||||||
from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels
|
from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels
|
||||||
|
|
||||||
@ -54,19 +52,12 @@ class LocalObjectDetector(ObjectDetector):
|
|||||||
else:
|
else:
|
||||||
self.input_transform = None
|
self.input_transform = None
|
||||||
|
|
||||||
if det_type == DetectorTypeEnum.edgetpu:
|
if det_type == DetectorTypeEnum.cpu:
|
||||||
self.detect_api = EdgeTpuTfl(
|
|
||||||
det_device=det_device, model_config=model_config
|
|
||||||
)
|
|
||||||
elif det_type == DetectorTypeEnum.openvino:
|
|
||||||
self.detect_api = OvDetector(
|
|
||||||
det_device=det_device, model_config=model_config
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
|
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
|
||||||
)
|
)
|
||||||
self.detect_api = CpuTfl(model_config=model_config, num_threads=num_threads)
|
|
||||||
|
self.detect_api = DetectionApi.create(det_type, det_device=det_device, model_config=model_config, num_threads=num_threads)
|
||||||
|
|
||||||
def detect(self, tensor_input, threshold=0.4):
|
def detect(self, tensor_input, threshold=0.4):
|
||||||
detections = []
|
detections = []
|
||||||
|
|||||||
@ -7,8 +7,8 @@ import frigate.object_detection
|
|||||||
|
|
||||||
|
|
||||||
class TestLocalObjectDetector(unittest.TestCase):
|
class TestLocalObjectDetector(unittest.TestCase):
|
||||||
@patch("frigate.object_detection.EdgeTpuTfl")
|
@patch("frigate.detectors.EdgeTpuTfl")
|
||||||
@patch("frigate.object_detection.CpuTfl")
|
@patch("frigate.detectors.CpuTfl")
|
||||||
def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init(
|
def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init(
|
||||||
self, mock_cputfl, mock_edgetputfl
|
self, mock_cputfl, mock_edgetputfl
|
||||||
):
|
):
|
||||||
@ -22,8 +22,8 @@ class TestLocalObjectDetector(unittest.TestCase):
|
|||||||
mock_edgetputfl.assert_not_called()
|
mock_edgetputfl.assert_not_called()
|
||||||
mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6)
|
mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6)
|
||||||
|
|
||||||
@patch("frigate.object_detection.EdgeTpuTfl")
|
@patch("frigate.detectors.EdgeTpuTfl")
|
||||||
@patch("frigate.object_detection.CpuTfl")
|
@patch("frigate.detectors.CpuTfl")
|
||||||
def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init(
|
def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init(
|
||||||
self, mock_cputfl, mock_edgetputfl
|
self, mock_cputfl, mock_edgetputfl
|
||||||
):
|
):
|
||||||
@ -40,7 +40,7 @@ class TestLocalObjectDetector(unittest.TestCase):
|
|||||||
mock_cputfl.assert_not_called()
|
mock_cputfl.assert_not_called()
|
||||||
mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg)
|
mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg)
|
||||||
|
|
||||||
@patch("frigate.object_detection.CpuTfl")
|
@patch("frigate.detectors.CpuTfl")
|
||||||
def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result(
|
def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result(
|
||||||
self, mock_cputfl
|
self, mock_cputfl
|
||||||
):
|
):
|
||||||
@ -58,7 +58,7 @@ class TestLocalObjectDetector(unittest.TestCase):
|
|||||||
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
|
mock_det_api.detect_raw.assert_called_once_with(tensor_input=TEST_DATA)
|
||||||
assert test_result is mock_det_api.detect_raw.return_value
|
assert test_result is mock_det_api.detect_raw.return_value
|
||||||
|
|
||||||
@patch("frigate.object_detection.CpuTfl")
|
@patch("frigate.detectors.CpuTfl")
|
||||||
def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor(
|
def test_detect_raw_given_tensor_input_should_call_api_detect_raw_with_transposed_tensor(
|
||||||
self, mock_cputfl
|
self, mock_cputfl
|
||||||
):
|
):
|
||||||
@ -85,7 +85,7 @@ class TestLocalObjectDetector(unittest.TestCase):
|
|||||||
|
|
||||||
assert test_result is mock_det_api.detect_raw.return_value
|
assert test_result is mock_det_api.detect_raw.return_value
|
||||||
|
|
||||||
@patch("frigate.object_detection.CpuTfl")
|
@patch("frigate.detectors.CpuTfl")
|
||||||
@patch("frigate.object_detection.load_labels")
|
@patch("frigate.object_detection.load_labels")
|
||||||
def test_detect_given_tensor_input_should_return_lfiltered_detections(
|
def test_detect_given_tensor_input_should_return_lfiltered_detections(
|
||||||
self, mock_load_labels, mock_cputfl
|
self, mock_load_labels, mock_cputfl
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user