diff --git a/benchmark.py b/benchmark.py index 79d6f3f4b..3d0cacd87 100755 --- a/benchmark.py +++ b/benchmark.py @@ -3,6 +3,7 @@ from statistics import mean import multiprocessing as mp import numpy as np import datetime +from frigate.config import DetectorTypeEnum from frigate.object_detection import ( LocalObjectDetector, ObjectDetectProcess, @@ -81,8 +82,12 @@ events = {} for x in range(0, 10): events[str(x)] = mp.Event() detection_queue = mp.Queue() -edgetpu_process_1 = ObjectDetectProcess(detection_queue, events, "usb:0") -edgetpu_process_2 = ObjectDetectProcess(detection_queue, events, "usb:1") +edgetpu_process_1 = ObjectDetectProcess( + detection_queue, events, DetectorTypeEnum.edgetpu, "usb:0" +) +edgetpu_process_2 = ObjectDetectProcess( + detection_queue, events, DetectorTypeEnum.edgetpu, "usb:1" +) for x in range(0, 10): camera_process = mp.Process( diff --git a/frigate/app.py b/frigate/app.py index 3e219b35c..cc4023efb 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -205,6 +205,7 @@ class FrigateApp: self.detection_out_events, model_path, model_shape, + detector.type, "cpu", detector.num_threads, ) @@ -215,6 +216,7 @@ class FrigateApp: self.detection_out_events, model_path, model_shape, + detector.type, detector.device, detector.num_threads, ) diff --git a/frigate/detector.py b/frigate/detector.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/frigate/detectors/cpu_tfl.py b/frigate/detectors/cpu_tfl.py index c250777bb..b0c160783 100644 --- a/frigate/detectors/cpu_tfl.py +++ b/frigate/detectors/cpu_tfl.py @@ -8,9 +8,9 @@ logger = logging.getLogger(__name__) class CpuTfl(DetectionApi): - def __init__(self, tf_device=None, model_path=None, num_threads=3): + def __init__(self, det_device=None, model_path=None, num_threads=3): self.interpreter = tflite.Interpreter( - model_path=model_path or "/cpu_model.tflite", num_threads=3 + model_path=model_path or "/cpu_model.tflite", num_threads=num_threads ) self.interpreter.allocate_tensors() diff --git a/frigate/detectors/detection_api.py b/frigate/detectors/detection_api.py index e96ca7eca..b8830334b 100644 --- a/frigate/detectors/detection_api.py +++ b/frigate/detectors/detection_api.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) class DetectionApi(ABC): @abstractmethod - def __init__(self, tf_device=None, model_path=None): + def __init__(self, det_device=None, model_path=None): pass @abstractmethod diff --git a/frigate/detectors/edgetpu_tfl.py b/frigate/detectors/edgetpu_tfl.py index e6d85ca3d..f917f95e0 100644 --- a/frigate/detectors/edgetpu_tfl.py +++ b/frigate/detectors/edgetpu_tfl.py @@ -9,10 +9,10 @@ logger = logging.getLogger(__name__) class EdgeTpuTfl(DetectionApi): - def __init__(self, tf_device=None, model_path=None): + def __init__(self, det_device=None, model_path=None): device_config = {"device": "usb"} - if not tf_device is None: - device_config = {"device": tf_device} + if not det_device is None: + device_config = {"device": det_device} edge_tpu_delegate = None diff --git a/frigate/object_detection.py b/frigate/object_detection.py index 6d51bd8ae..048462fa5 100644 --- a/frigate/object_detection.py +++ b/frigate/object_detection.py @@ -40,14 +40,12 @@ class LocalObjectDetector(ObjectDetector): self.labels = load_labels(labels) if det_type == DetectorTypeEnum.edgetpu: - self.detectApi = EdgeTpuTfl(tf_device=det_device, model_path=model_path) + self.detectApi = EdgeTpuTfl(det_device=det_device, model_path=model_path) else: logger.warning( "CPU detectors are not recommended and should only be used for testing or for trial purposes." ) - self.detectApi = CpuTfl( - tf_device=det_device, model_path=model_path, num_threads=num_threads - ) + self.detectApi = CpuTfl(model_path=model_path, num_threads=num_threads) def detect(self, tensor_input, threshold=0.4): detections = [] @@ -75,6 +73,7 @@ def run_detector( start, model_path, model_shape, + det_type, det_device, num_threads, ): @@ -94,7 +93,10 @@ def run_detector( frame_manager = SharedMemoryFrameManager() object_detector = LocalObjectDetector( - det_device=det_device, model_path=model_path, num_threads=num_threads + det_type=det_type, + det_device=det_device, + model_path=model_path, + num_threads=num_threads, ) outputs = {} @@ -134,7 +136,8 @@ class ObjectDetectProcess: out_events, model_path, model_shape, - tf_device=None, + det_type=None, + det_device=None, num_threads=3, ): self.name = name @@ -145,7 +148,8 @@ class ObjectDetectProcess: self.detect_process = None self.model_path = model_path self.model_shape = model_shape - self.tf_device = tf_device + self.det_type = det_type + self.det_device = det_device self.num_threads = num_threads self.start_or_restart() @@ -173,7 +177,8 @@ class ObjectDetectProcess: self.detection_start, self.model_path, self.model_shape, - self.tf_device, + self.det_type, + self.det_device, self.num_threads, ), )