mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-02-03 17:55:21 +03:00
move create_detector and DetectorTypeEnum
This commit is contained in:
parent
8c55966de4
commit
3fa174eca4
@ -1,8 +1,35 @@
|
||||
import os
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from .detector_type import DetectorTypeEnum
|
||||
from .detection_api import DetectionApi
|
||||
from .cpu_tfl import CpuTfl
|
||||
from .edgetpu_tfl import EdgeTpuTfl
|
||||
from .openvino import OvDetector
|
||||
from .tensorrt import TensorRT
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DetectorTypeEnum(str, Enum):
|
||||
edgetpu = "edgetpu"
|
||||
openvino = "openvino"
|
||||
cpu = "cpu"
|
||||
tensorrt = "tensorrt"
|
||||
|
||||
|
||||
def create_detector(det_type: DetectorTypeEnum, **kwargs):
|
||||
_api_types = {
|
||||
DetectorTypeEnum.cpu: CpuTfl,
|
||||
DetectorTypeEnum.edgetpu: EdgeTpuTfl,
|
||||
DetectorTypeEnum.openvino: OvDetector
|
||||
}
|
||||
|
||||
if det_type == DetectorTypeEnum.cpu:
|
||||
logger.warning(
|
||||
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
|
||||
)
|
||||
|
||||
api = _api_types.get(det_type)
|
||||
if not api:
|
||||
raise ValueError(det_type)
|
||||
return api(**kwargs)
|
||||
|
||||
@ -2,7 +2,6 @@ import logging
|
||||
import numpy as np
|
||||
|
||||
from .detection_api import DetectionApi
|
||||
from .detector_type import DetectorTypeEnum
|
||||
import tflite_runtime.interpreter as tflite
|
||||
|
||||
|
||||
@ -46,6 +45,3 @@ class CpuTfl(DetectionApi):
|
||||
]
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
DetectionApi.register_api(DetectorTypeEnum.cpu, CpuTfl)
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import logging
|
||||
|
||||
from .detector_type import DetectorTypeEnum
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict
|
||||
|
||||
@ -10,8 +8,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DetectionApi(ABC):
|
||||
_api_types = {}
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, det_device=None, model_config=None):
|
||||
pass
|
||||
@ -19,14 +15,3 @@ class DetectionApi(ABC):
|
||||
@abstractmethod
|
||||
def detect_raw(self, tensor_input):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def register_api(det_type: DetectorTypeEnum, det_api):
|
||||
DetectionApi._api_types[det_type] = det_api
|
||||
|
||||
@staticmethod
|
||||
def create(det_type: DetectorTypeEnum, **kwargs):
|
||||
api = DetectionApi._api_types.get(det_type)
|
||||
if not api:
|
||||
raise ValueError(det_type)
|
||||
return api(**kwargs)
|
||||
|
||||
@ -1,8 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DetectorTypeEnum(str, Enum):
|
||||
edgetpu = "edgetpu"
|
||||
openvino = "openvino"
|
||||
cpu = "cpu"
|
||||
tensorrt = "tensorrt"
|
||||
@ -2,7 +2,6 @@ import logging
|
||||
import numpy as np
|
||||
|
||||
from .detection_api import DetectionApi
|
||||
from .detector_type import DetectorTypeEnum
|
||||
import tflite_runtime.interpreter as tflite
|
||||
from tflite_runtime.interpreter import load_delegate
|
||||
|
||||
@ -63,6 +62,3 @@ class EdgeTpuTfl(DetectionApi):
|
||||
]
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
DetectionApi.register_api(DetectorTypeEnum.edgetpu, EdgeTpuTfl)
|
||||
|
||||
@ -3,7 +3,6 @@ import numpy as np
|
||||
import openvino.runtime as ov
|
||||
|
||||
from .detection_api import DetectionApi
|
||||
from .detector_type import DetectorTypeEnum
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -53,6 +52,3 @@ class OvDetector(DetectionApi):
|
||||
i += 1
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
DetectionApi.register_api(DetectorTypeEnum.openvino, OvDetector)
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from .detection_api import DetectionApi
|
||||
from .detector_type import DetectorTypeEnum
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TensorRT(DetectionApi):
|
||||
def __init__(self, det_device=None, model_config=None, num_threads=1, **kwargs):
|
||||
raise NotImplementedError("TensorRT engine is not yet functional!")
|
||||
|
||||
def detect_raw(self, tensor_input):
|
||||
raise NotImplementedError("TensorRT engine is not yet functional!")
|
||||
|
||||
|
||||
DetectionApi.register_api(DetectorTypeEnum.tensorrt, TensorRT)
|
||||
@ -11,7 +11,7 @@ import numpy as np
|
||||
from setproctitle import setproctitle
|
||||
|
||||
from frigate.config import InputTensorEnum
|
||||
from frigate.detectors import DetectionApi, DetectorTypeEnum
|
||||
from frigate.detectors import create_detector, DetectorTypeEnum
|
||||
|
||||
from frigate.util import EventsPerSecond, SharedMemoryFrameManager, listen, load_labels
|
||||
|
||||
@ -52,12 +52,7 @@ class LocalObjectDetector(ObjectDetector):
|
||||
else:
|
||||
self.input_transform = None
|
||||
|
||||
if det_type == DetectorTypeEnum.cpu:
|
||||
logger.warning(
|
||||
"CPU detectors are not recommended and should only be used for testing or for trial purposes."
|
||||
)
|
||||
|
||||
self.detect_api = DetectionApi.create(det_type, det_device=det_device, model_config=model_config, num_threads=num_threads)
|
||||
self.detect_api = create_detector(det_type, det_device=det_device, model_config=model_config, num_threads=num_threads)
|
||||
|
||||
def detect(self, tensor_input, threshold=0.4):
|
||||
detections = []
|
||||
|
||||
@ -7,8 +7,8 @@ import frigate.object_detection
|
||||
|
||||
|
||||
class TestLocalObjectDetector(unittest.TestCase):
|
||||
@patch("frigate.detectors.EdgeTpuTfl")
|
||||
@patch("frigate.detectors.CpuTfl")
|
||||
@patch("frigate.detectors.edgetpu_tfl.EdgeTpuTfl")
|
||||
@patch("frigate.detectors.cpu_tfl.CpuTfl")
|
||||
def test_localdetectorprocess_given_type_cpu_should_call_cputfl_init(
|
||||
self, mock_cputfl, mock_edgetputfl
|
||||
):
|
||||
@ -22,8 +22,8 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
mock_edgetputfl.assert_not_called()
|
||||
mock_cputfl.assert_called_once_with(model_config=test_cfg, num_threads=6)
|
||||
|
||||
@patch("frigate.detectors.EdgeTpuTfl")
|
||||
@patch("frigate.detectors.CpuTfl")
|
||||
@patch("frigate.detectors.edgetpu_tfl.EdgeTpuTfl")
|
||||
@patch("frigate.detectors.cpu_tfl.CpuTfl")
|
||||
def test_localdetectorprocess_given_type_edgtpu_should_call_edgtpu_init(
|
||||
self, mock_cputfl, mock_edgetputfl
|
||||
):
|
||||
@ -40,7 +40,7 @@ class TestLocalObjectDetector(unittest.TestCase):
|
||||
mock_cputfl.assert_not_called()
|
||||
mock_edgetputfl.assert_called_once_with(det_device="usb", model_config=test_cfg)
|
||||
|
||||
@patch("frigate.detectors.CpuTfl")
|
||||
@patch("frigate.detectors.cpu_tfl.CpuTfl")
|
||||
def test_detect_raw_given_tensor_input_should_return_api_detect_raw_result(
|
||||
self, mock_cputfl
|
||||
):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user