Load labelmap correctly

This commit is contained in:
Nick Mowen 2023-06-18 13:04:41 -06:00
parent b6bb1cd185
commit bf4cb22118
2 changed files with 20 additions and 7 deletions

View File

@ -6,6 +6,7 @@ from typing_extensions import Literal
from frigate.detectors.detection_api import DetectionApi
from frigate.detectors.detector_config import BaseDetectorConfig
from frigate.object_detection import load_labels
try:
from tflite_runtime.interpreter import Interpreter
@ -18,17 +19,14 @@ logger = logging.getLogger(__name__)
DETECTOR_KEY = "audio"
class AudioDetectorConfig(BaseDetectorConfig):
type: Literal[DETECTOR_KEY]
class AudioTfl(DetectionApi):
type_key = DETECTOR_KEY
def __init__(self, detector_config: AudioDetectorConfig):
def __init__(self, labels):
self.labels = load_labels("/audio-labelmap.txt")
self.interpreter = Interpreter(
model_path="/cpu_audio_model.tflite",
num_threads=3,
num_threads=2,
)
self.interpreter.allocate_tensors()
@ -36,7 +34,7 @@ class AudioTfl(DetectionApi):
self.tensor_input_details = self.interpreter.get_input_details()
self.tensor_output_details = self.interpreter.get_output_details()
def detect_raw(self, tensor_input):
def _detect_raw(self, tensor_input):
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input)
self.interpreter.invoke()
detections = np.zeros((20, 6), np.float32)
@ -63,3 +61,16 @@ class AudioTfl(DetectionApi):
]
return detections
def detect(self, tensor_input, threshold=0.8):
detections = []
raw_detections = self._detect_raw(tensor_input)
for d in raw_detections:
if d[1] < threshold:
break
detections.append(
(self.labels[int(d[0])], float(d[1]), (d[2], d[3], d[4], d[5]))
)
return detections

View File

@ -17,6 +17,7 @@ from frigate.const import (
AUDIO_SAMPLE_RATE,
CACHE_DIR,
)
from frigate.detectors.plugins.audio_tfl import AudioTfl
from frigate.util import listen
logger = logging.getLogger(__name__)
@ -50,6 +51,7 @@ class AudioEventMaintainer(threading.Thread):
self.name = f"{camera.name}_audio_event_processor"
self.config = camera
self.stop_event = stop_event
self.detector = AudioTfl()
self.shape = (int(round(AUDIO_DURATION * AUDIO_SAMPLE_RATE)),)
def detect_audio(self) -> None: