Fix selecting the correct detection device type from the config

This commit is contained in:
Nate Meyer 2022-08-08 20:43:52 -04:00
parent c359c7c9bd
commit b0b2a28296
7 changed files with 28 additions and 16 deletions

View File

@ -3,6 +3,7 @@ from statistics import mean
import multiprocessing as mp import multiprocessing as mp
import numpy as np import numpy as np
import datetime import datetime
from frigate.config import DetectorTypeEnum
from frigate.object_detection import ( from frigate.object_detection import (
LocalObjectDetector, LocalObjectDetector,
ObjectDetectProcess, ObjectDetectProcess,
@ -81,8 +82,12 @@ events = {}
for x in range(0, 10): for x in range(0, 10):
events[str(x)] = mp.Event() events[str(x)] = mp.Event()
detection_queue = mp.Queue() detection_queue = mp.Queue()
edgetpu_process_1 = ObjectDetectProcess(detection_queue, events, "usb:0") edgetpu_process_1 = ObjectDetectProcess(
edgetpu_process_2 = ObjectDetectProcess(detection_queue, events, "usb:1") detection_queue, events, DetectorTypeEnum.edgetpu, "usb:0"
)
edgetpu_process_2 = ObjectDetectProcess(
detection_queue, events, DetectorTypeEnum.edgetpu, "usb:1"
)
for x in range(0, 10): for x in range(0, 10):
camera_process = mp.Process( camera_process = mp.Process(

View File

@ -205,6 +205,7 @@ class FrigateApp:
self.detection_out_events, self.detection_out_events,
model_path, model_path,
model_shape, model_shape,
detector.type,
"cpu", "cpu",
detector.num_threads, detector.num_threads,
) )
@ -215,6 +216,7 @@ class FrigateApp:
self.detection_out_events, self.detection_out_events,
model_path, model_path,
model_shape, model_shape,
detector.type,
detector.device, detector.device,
detector.num_threads, detector.num_threads,
) )

View File

View File

@ -8,9 +8,9 @@ logger = logging.getLogger(__name__)
class CpuTfl(DetectionApi): class CpuTfl(DetectionApi):
def __init__(self, tf_device=None, model_path=None, num_threads=3): def __init__(self, det_device=None, model_path=None, num_threads=3):
self.interpreter = tflite.Interpreter( self.interpreter = tflite.Interpreter(
model_path=model_path or "/cpu_model.tflite", num_threads=3 model_path=model_path or "/cpu_model.tflite", num_threads=num_threads
) )
self.interpreter.allocate_tensors() self.interpreter.allocate_tensors()

View File

@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
class DetectionApi(ABC): class DetectionApi(ABC):
@abstractmethod @abstractmethod
def __init__(self, tf_device=None, model_path=None): def __init__(self, det_device=None, model_path=None):
pass pass
@abstractmethod @abstractmethod

View File

@ -9,10 +9,10 @@ logger = logging.getLogger(__name__)
class EdgeTpuTfl(DetectionApi): class EdgeTpuTfl(DetectionApi):
def __init__(self, tf_device=None, model_path=None): def __init__(self, det_device=None, model_path=None):
device_config = {"device": "usb"} device_config = {"device": "usb"}
if not tf_device is None: if not det_device is None:
device_config = {"device": tf_device} device_config = {"device": det_device}
edge_tpu_delegate = None edge_tpu_delegate = None

View File

@ -40,14 +40,12 @@ class LocalObjectDetector(ObjectDetector):
self.labels = load_labels(labels) self.labels = load_labels(labels)
if det_type == DetectorTypeEnum.edgetpu: if det_type == DetectorTypeEnum.edgetpu:
self.detectApi = EdgeTpuTfl(tf_device=det_device, model_path=model_path) self.detectApi = EdgeTpuTfl(det_device=det_device, model_path=model_path)
else: else:
logger.warning( logger.warning(
"CPU detectors are not recommended and should only be used for testing or for trial purposes." "CPU detectors are not recommended and should only be used for testing or for trial purposes."
) )
self.detectApi = CpuTfl( self.detectApi = CpuTfl(model_path=model_path, num_threads=num_threads)
tf_device=det_device, model_path=model_path, num_threads=num_threads
)
def detect(self, tensor_input, threshold=0.4): def detect(self, tensor_input, threshold=0.4):
detections = [] detections = []
@ -75,6 +73,7 @@ def run_detector(
start, start,
model_path, model_path,
model_shape, model_shape,
det_type,
det_device, det_device,
num_threads, num_threads,
): ):
@ -94,7 +93,10 @@ def run_detector(
frame_manager = SharedMemoryFrameManager() frame_manager = SharedMemoryFrameManager()
object_detector = LocalObjectDetector( object_detector = LocalObjectDetector(
det_device=det_device, model_path=model_path, num_threads=num_threads det_type=det_type,
det_device=det_device,
model_path=model_path,
num_threads=num_threads,
) )
outputs = {} outputs = {}
@ -134,7 +136,8 @@ class ObjectDetectProcess:
out_events, out_events,
model_path, model_path,
model_shape, model_shape,
tf_device=None, det_type=None,
det_device=None,
num_threads=3, num_threads=3,
): ):
self.name = name self.name = name
@ -145,7 +148,8 @@ class ObjectDetectProcess:
self.detect_process = None self.detect_process = None
self.model_path = model_path self.model_path = model_path
self.model_shape = model_shape self.model_shape = model_shape
self.tf_device = tf_device self.det_type = det_type
self.det_device = det_device
self.num_threads = num_threads self.num_threads = num_threads
self.start_or_restart() self.start_or_restart()
@ -173,7 +177,8 @@ class ObjectDetectProcess:
self.detection_start, self.detection_start,
self.model_path, self.model_path,
self.model_shape, self.model_shape,
self.tf_device, self.det_type,
self.det_device,
self.num_threads, self.num_threads,
), ),
) )