From bf4cb22118f627d35022dcf7396b105e20d83d33 Mon Sep 17 00:00:00 2001 From: Nick Mowen Date: Sun, 18 Jun 2023 13:04:41 -0600 Subject: [PATCH] Load labelmap correctly --- frigate/detectors/plugins/audio_tfl.py | 25 ++++++++++++++++++------- frigate/events/audio.py | 2 ++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/frigate/detectors/plugins/audio_tfl.py b/frigate/detectors/plugins/audio_tfl.py index 3f24b7790..a3750ebfa 100644 --- a/frigate/detectors/plugins/audio_tfl.py +++ b/frigate/detectors/plugins/audio_tfl.py @@ -6,6 +6,7 @@ from typing_extensions import Literal from frigate.detectors.detection_api import DetectionApi from frigate.detectors.detector_config import BaseDetectorConfig +from frigate.object_detection import load_labels try: from tflite_runtime.interpreter import Interpreter @@ -18,17 +19,14 @@ logger = logging.getLogger(__name__) DETECTOR_KEY = "audio" -class AudioDetectorConfig(BaseDetectorConfig): - type: Literal[DETECTOR_KEY] - - class AudioTfl(DetectionApi): type_key = DETECTOR_KEY - def __init__(self, detector_config: AudioDetectorConfig): + def __init__(self, labels): + self.labels = load_labels("/audio-labelmap.txt") self.interpreter = Interpreter( model_path="/cpu_audio_model.tflite", - num_threads=3, + num_threads=2, ) self.interpreter.allocate_tensors() @@ -36,7 +34,7 @@ class AudioTfl(DetectionApi): self.tensor_input_details = self.interpreter.get_input_details() self.tensor_output_details = self.interpreter.get_output_details() - def detect_raw(self, tensor_input): + def _detect_raw(self, tensor_input): self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input) self.interpreter.invoke() detections = np.zeros((20, 6), np.float32) @@ -63,3 +61,16 @@ class AudioTfl(DetectionApi): ] return detections + + def detect(self, tensor_input, threshold=0.8): + detections = [] + + raw_detections = self._detect_raw(tensor_input) + + for d in raw_detections: + if d[1] < threshold: + break + detections.append( + (self.labels[int(d[0])], float(d[1]), (d[2], d[3], d[4], d[5])) + ) + return detections diff --git a/frigate/events/audio.py b/frigate/events/audio.py index e92288a99..96731f0da 100644 --- a/frigate/events/audio.py +++ b/frigate/events/audio.py @@ -17,6 +17,7 @@ from frigate.const import ( AUDIO_SAMPLE_RATE, CACHE_DIR, ) +from frigate.detectors.plugins.audio_tfl import AudioTfl from frigate.util import listen logger = logging.getLogger(__name__) @@ -50,6 +51,7 @@ class AudioEventMaintainer(threading.Thread): self.name = f"{camera.name}_audio_event_processor" self.config = camera self.stop_event = stop_event + self.detector = AudioTfl() self.shape = (int(round(AUDIO_DURATION * AUDIO_SAMPLE_RATE)),) def detect_audio(self) -> None: