From 68ade5063d654b802ba95ef43dcfe21cd0afe593 Mon Sep 17 00:00:00 2001 From: Jason Hunter Date: Sat, 7 Jan 2023 01:02:35 -0500 Subject: [PATCH] Initial audio classification model implementation --- Dockerfile | 7 +- frigate/app.py | 90 +++++++++++++++- frigate/audio.py | 126 +++++++++++++++++++++++ frigate/config.py | 61 +++++++---- frigate/detectors/__init__.py | 10 +- frigate/detectors/detector_config.py | 62 ++++++++--- frigate/detectors/plugins/cpu_tfl.py | 34 ++++-- frigate/detectors/plugins/edgetpu_tfl.py | 36 +++++-- frigate/object_detection.py | 41 ++++++-- frigate/util.py | 6 +- 10 files changed, 402 insertions(+), 71 deletions(-) create mode 100644 frigate/audio.py diff --git a/Dockerfile b/Dockerfile index 6a804491a..ffde307cb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ FROM debian:11-slim AS slim-base FROM slim-base AS wget ARG DEBIAN_FRONTEND RUN apt-get update \ - && apt-get install -y wget xz-utils \ + && apt-get install -y wget xz-utils unzip \ && rm -rf /var/lib/apt/lists/* WORKDIR /rootfs @@ -93,7 +93,10 @@ COPY labelmap.txt . COPY --from=ov-converter /models/public/ssdlite_mobilenet_v2/FP16 openvino-model RUN wget -q https://github.com/openvinotoolkit/open_model_zoo/raw/master/data/dataset_classes/coco_91cl_bkgr.txt -O openvino-model/coco_91cl_bkgr.txt && \ sed -i 's/truck/car/g' openvino-model/coco_91cl_bkgr.txt - +# Get Audio Model and labels +RUN wget -qO edgetpu_audio_model.tflite https://tfhub.dev/google/coral-model/yamnet/classification/coral/1?coral-format=tflite +RUN wget -qO cpu_audio_model.tflite https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite +RUN unzip -q edgetpu_audio_model.tflite yamnet_label_list.txt && chmod +r yamnet_label_list.txt FROM wget AS s6-overlay diff --git a/frigate/app.py b/frigate/app.py index 5ffa3d77d..386a6aef4 100644 --- a/frigate/app.py +++ b/frigate/app.py @@ -13,6 +13,7 @@ from peewee_migrate import Router from playhouse.sqlite_ext import SqliteExtDatabase from playhouse.sqliteq import SqliteQueueDatabase +from frigate.audio import capture_audio, process_audio from frigate.comms.dispatcher import Communicator, Dispatcher from frigate.comms.mqtt import MqttClient from frigate.comms.ws import WebSocketClient @@ -42,6 +43,7 @@ class FrigateApp: def __init__(self) -> None: self.stop_event: MpEvent = mp.Event() self.detection_queue: Queue = mp.Queue() + self.audio_detection_queue: Queue = mp.Queue() self.detectors: dict[str, ObjectDetectProcess] = {} self.detection_out_events: dict[str, MpEvent] = {} self.detection_shms: list[mp.shared_memory.SharedMemory] = [] @@ -104,6 +106,7 @@ class FrigateApp: "read_start": mp.Value("d", 0.0), "ffmpeg_pid": mp.Value("i", 0), "frame_queue": mp.Queue(maxsize=2), + "audio_queue": mp.Queue(maxsize=2), "capture_process": None, "process": None, } @@ -182,7 +185,7 @@ class FrigateApp: self.dispatcher = Dispatcher(self.config, self.camera_metrics, comms) def start_detectors(self) -> None: - for name in self.config.cameras.keys(): + for name, camera_config in self.config.cameras.items(): self.detection_out_events[name] = mp.Event() try: @@ -190,6 +193,7 @@ class FrigateApp: [ det.model.height * det.model.width * 3 for (name, det) in self.config.detectors.items() + if det.model.type == "object" ] ) shm_in = mp.shared_memory.SharedMemory( @@ -210,10 +214,43 @@ class FrigateApp: self.detection_shms.append(shm_in) self.detection_shms.append(shm_out) + if any( + ["detect_audio" in input.roles for input in camera_config.ffmpeg.inputs] + ): + self.detection_out_events[f"{name}-audio"] = mp.Event() + try: + shm_in_audio = mp.shared_memory.SharedMemory( + name=f"{name}-audio", + create=True, + size=int( + round( + self.config.audio_model.duration + * self.config.audio_model.sample_rate + ) + ) + * 4, # stored as float32, so 4 bytes per sample + ) + except FileExistsError: + shm_in_audio = mp.shared_memory.SharedMemory(name=f"{name}-audio") + + try: + shm_out_audio = mp.shared_memory.SharedMemory( + name=f"out-{name}-audio", create=True, size=20 * 6 * 4 + ) + except FileExistsError: + shm_out_audio = mp.shared_memory.SharedMemory( + name=f"out-{name}-audio" + ) + + self.detection_shms.append(shm_in_audio) + self.detection_shms.append(shm_out_audio) + for name, detector_config in self.config.detectors.items(): self.detectors[name] = ObjectDetectProcess( name, - self.detection_queue, + self.audio_detection_queue + if detector_config.model.type == "audio" + else self.detection_queue, self.detection_out_events, detector_config, ) @@ -245,6 +282,54 @@ class FrigateApp: output_processor.start() logger.info(f"Output process started: {output_processor.pid}") + def start_audio_processors(self) -> None: + # Make sure we have audio detectors + if not any( + [det.model.type == "audio" for det in self.config.detectors.values()] + ): + return + + for name, config in self.config.cameras.items(): + if not any( + ["detect_audio" in inputs.roles for inputs in config.ffmpeg.inputs] + ): + continue + if not config.enabled: + logger.info(f"Audio processor not started for disabled camera {name}") + continue + + audio_capture = mp.Process( + target=capture_audio, + name=f"audio_capture:{name}", + args=( + name, + self.config.audio_model, + self.camera_metrics[name], + ), + ) + audio_capture.daemon = True + self.camera_metrics[name]["audio_capture"] = audio_capture + audio_capture.start() + logger.info(f"Audio capture started for {name}: {audio_capture.pid}") + + audio_process = mp.Process( + target=process_audio, + name=f"audio_process:{name}", + args=( + name, + config, + self.config.audio_model, + self.config.audio_model.merged_labelmap, + self.audio_detection_queue, + self.detection_out_events[f"{name}-audio"], + self.camera_metrics[name], + ), + ) + audio_process.daemon = True + self.camera_metrics[name]["audio_process"] = audio_process + audio_process.start() + logger.info(f"Audio processor started for {name}: {audio_process.pid}") + def start_camera_processors(self) -> None: for name, config in self.config.cameras.items(): if not self.config.cameras[name].enabled: @@ -364,6 +449,7 @@ class FrigateApp: self.start_detectors() self.start_video_output_processor() self.start_detected_frames_processor() + self.start_audio_processors() self.start_camera_processors() self.start_camera_capture_processes() self.start_storage_maintainer() diff --git a/frigate/audio.py b/frigate/audio.py new file mode 100644 index 000000000..ebefcd5ce --- /dev/null +++ b/frigate/audio.py @@ -0,0 +1,126 @@ +import datetime +import logging +import multiprocessing as mp +import queue +import random +import signal +import string +import threading + +import numpy as np +from setproctitle import setproctitle + +from frigate.config import CameraConfig, AudioModelConfig +from frigate.object_detection import RemoteObjectDetector +from frigate.util import listen, SharedMemoryFrameManager + + +logger = logging.getLogger(__name__) + + +def capture_audio( + name: str, + model_config: AudioModelConfig, + process_info, +): + stop_event = mp.Event() + + def receiveSignal(signalNumber, frame): + stop_event.set() + + signal.signal(signal.SIGTERM, receiveSignal) + signal.signal(signal.SIGINT, receiveSignal) + + threading.current_thread().name = f"capture:{name}" + setproctitle(f"frigate.capture:{name}") + listen() + + chunk_size = int(round(model_config.duration * model_config.sample_rate * 2)) + + key = f"{name}-audio" + + audio_queue = process_info["audio_queue"] + frame_manager = SharedMemoryFrameManager() + current_frame = mp.Value("d", 0.0) + + pipe = open(f"/tmp/{key}", "rb") + + while not stop_event.is_set(): + current_frame.value = datetime.datetime.now().timestamp() + frame_name = f"{key}{current_frame.value}" + frame_buffer = frame_manager.create(frame_name, chunk_size) + + try: + frame_buffer[:] = pipe.read(chunk_size) + except Exception as e: + continue + + # if the queue is full, skip this frame + if audio_queue.full(): + frame_manager.delete(frame_name) + continue + + # close the frame + frame_manager.close(frame_name) + + # add to the queue + audio_queue.put(current_frame.value) + + +def process_audio( + name: str, + camera_config: CameraConfig, + model_config: AudioModelConfig, + labelmap, + detection_queue: mp.Queue, + result_connection, + process_info, +): + stop_event = mp.Event() + + def receiveSignal(signalNumber, frame): + stop_event.set() + + signal.signal(signal.SIGTERM, receiveSignal) + signal.signal(signal.SIGINT, receiveSignal) + + threading.current_thread().name = f"process:{name}" + setproctitle(f"frigate.process:{name}") + listen() + + shape = (int(round(model_config.duration * model_config.sample_rate)),) + + key = f"{name}-audio" + + audio_queue = process_info["audio_queue"] + frame_manager = SharedMemoryFrameManager() + + detector = RemoteObjectDetector( + key, labelmap, detection_queue, result_connection, model_config + ) + + while not stop_event.is_set(): + try: + frame_time = audio_queue.get(True, 10) + except queue.Empty: + continue + + audio = frame_manager.get(f"{key}{frame_time}", shape, dtype=np.int16) + + if audio is None: + logger.info(f"{key}: audio {frame_time} is not in memory store.") + continue + + waveform = (audio / 32768.0).astype(np.float32) + model_detections = detector.detect(waveform) + + for label, score, _ in model_detections: + if label not in camera_config.objects.track: + continue + filters = camera_config.objects.filters.get(label) + if filters: + if score < filters.min_score: + continue + logger.info(f"{label}: {score}") + + frame_manager.close(f"{key}{frame_time}") diff --git a/frigate/config.py b/frigate/config.py index e55db040e..c9893ffe3 100644 --- a/frigate/config.py +++ b/frigate/config.py @@ -36,8 +36,10 @@ from frigate.ffmpeg_presets import ( from frigate.detectors import ( PixelFormatEnum, InputTensorEnum, - ModelConfig, DetectorConfig, + ModelConfig, + AudioModelConfig, + ObjectModelConfig, ) from frigate.version import VERSION @@ -51,7 +53,7 @@ DEFAULT_TIME_FORMAT = "%m/%d/%Y %H:%M:%S" FRIGATE_ENV_VARS = {k: v for k, v in os.environ.items() if k.startswith("FRIGATE_")} -DEFAULT_TRACKED_OBJECTS = ["person"] +DEFAULT_TRACKED_OBJECTS = ["person", "Speech"] DEFAULT_DETECTORS = {"cpu": {"type": "cpu"}} @@ -358,6 +360,7 @@ class BirdseyeCameraConfig(BaseModel): FFMPEG_GLOBAL_ARGS_DEFAULT = ["-hide_banner", "-loglevel", "warning"] FFMPEG_INPUT_ARGS_DEFAULT = "preset-rtsp-generic" DETECT_FFMPEG_OUTPUT_ARGS_DEFAULT = ["-f", "rawvideo", "-pix_fmt", "yuv420p"] +DETECT_AUDIO_FFMPEG_OUTPUT_ARGS_DEFAULT = ["-f", "s16le", "-ar", "16000", "-ac", "1"] RTMP_FFMPEG_OUTPUT_ARGS_DEFAULT = "preset-rtmp-generic" RECORD_FFMPEG_OUTPUT_ARGS_DEFAULT = "preset-record-generic" @@ -367,6 +370,10 @@ class FfmpegOutputArgsConfig(FrigateBaseModel): default=DETECT_FFMPEG_OUTPUT_ARGS_DEFAULT, title="Detect role FFmpeg output arguments.", ) + detect_audio: Union[str, List[str]] = Field( + default=DETECT_AUDIO_FFMPEG_OUTPUT_ARGS_DEFAULT, + title="Detect role FFmpeg output arguments.", + ) record: Union[str, List[str]] = Field( default=RECORD_FFMPEG_OUTPUT_ARGS_DEFAULT, title="Record role FFmpeg output arguments.", @@ -398,6 +405,7 @@ class CameraRoleEnum(str, Enum): restream = "restream" rtmp = "rtmp" detect = "detect" + detect_audio = "detect_audio" class CameraInput(FrigateBaseModel): @@ -597,6 +605,7 @@ class CameraConfig(FrigateBaseModel): # add roles to the input if there is only one if len(config["ffmpeg"]["inputs"]) == 1: has_rtmp = "rtmp" in config["ffmpeg"]["inputs"][0].get("roles", []) + has_audio = "detect_audio" in config["ffmpeg"]["inputs"][0].get("roles", []) config["ffmpeg"]["inputs"][0]["roles"] = [ "record", @@ -606,6 +615,8 @@ class CameraConfig(FrigateBaseModel): if has_rtmp: config["ffmpeg"]["inputs"][0]["roles"].append("rtmp") + if has_audio: + config["ffmpeg"]["inputs"][0]["roles"].append("detect_audio") super().__init__(**config) @@ -646,6 +657,15 @@ class CameraConfig(FrigateBaseModel): ) ffmpeg_output_args = scale_detect_args + ffmpeg_output_args + ["pipe:"] + if "detect_audio" in ffmpeg_input.roles: + detect_args = get_ffmpeg_arg_list(self.ffmpeg.output_args.detect_audio) + + pipe = f"/tmp/{self.name}-audio" + try: + os.mkfifo(pipe) + except FileExistsError: + pass + ffmpeg_output_args = detect_args + ["-y", pipe] + ffmpeg_output_args if "rtmp" in ffmpeg_input.roles and self.rtmp.enabled: rtmp_args = get_ffmpeg_arg_list( parse_preset_output_rtmp(self.ffmpeg.output_args.rtmp) @@ -815,8 +835,11 @@ class FrigateConfig(FrigateBaseModel): default_factory=dict, title="Frigate environment variables." ) ui: UIConfig = Field(default_factory=UIConfig, title="UI configuration.") - model: ModelConfig = Field( - default_factory=ModelConfig, title="Detection model configuration." + audio_model: AudioModelConfig = Field( + default_factory=AudioModelConfig, title="Audio model configuration." + ) + model: ObjectModelConfig = Field( + default_factory=ObjectModelConfig, title="Detection model configuration." ) detectors: Dict[str, DetectorConfig] = Field( default=DEFAULT_DETECTORS, @@ -975,25 +998,21 @@ class FrigateConfig(FrigateBaseModel): if detector_config.model is None: detector_config.model = config.model else: - model = detector_config.model - schema = ModelConfig.schema()["properties"] - if ( - model.width != schema["width"]["default"] - or model.height != schema["height"]["default"] - or model.labelmap_path is not None - or model.labelmap is not {} - or model.input_tensor != schema["input_tensor"]["default"] - or model.input_pixel_format - != schema["input_pixel_format"]["default"] - ): + detector_model = detector_config.model.dict(exclude_unset=True) + # If any keys are set in the detector_model other than type or path, warn + if any(key not in ["type", "path"] for key in detector_model.keys()): logger.warning( - "Customizing more than a detector model path is unsupported." + "Customizing more than a detector model type or path is unsupported." ) - merged_model = deep_merge( - detector_config.model.dict(exclude_unset=True), - config.model.dict(exclude_unset=True), - ) - detector_config.model = ModelConfig.parse_obj(merged_model) + merged_model = deep_merge( + detector_model, + config.model.dict(exclude_unset=True) + if detector_config.model.type == "object" + else config.audio_model.dict(exclude_unset=True), + ) + detector_config.model = parse_obj_as( + ModelConfig, {"type": detector_config.model.type, **merged_model} + ) config.detectors[key] = detector_config return config diff --git a/frigate/detectors/__init__.py b/frigate/detectors/__init__.py index 7cbd82f08..a1fdef4ac 100644 --- a/frigate/detectors/__init__.py +++ b/frigate/detectors/__init__.py @@ -2,17 +2,23 @@ import logging from .detection_api import DetectionApi from .detector_config import ( + AudioModelConfig, PixelFormatEnum, InputTensorEnum, ModelConfig, + ObjectModelConfig, +) +from .detector_types import ( + DetectorTypeEnum, + api_types, + DetectorConfig, ) -from .detector_types import DetectorTypeEnum, api_types, DetectorConfig logger = logging.getLogger(__name__) -def create_detector(detector_config): +def create_detector(detector_config: DetectorConfig): if detector_config.type == DetectorTypeEnum.cpu: logger.warning( "CPU detectors are not recommended and should only be used for testing or for trial purposes." diff --git a/frigate/detectors/detector_config.py b/frigate/detectors/detector_config.py index 747a12de4..f3d3bb37c 100644 --- a/frigate/detectors/detector_config.py +++ b/frigate/detectors/detector_config.py @@ -1,6 +1,7 @@ import logging from enum import Enum from typing import Dict, List, Optional, Tuple, Union, Literal +from typing_extensions import Annotated import matplotlib.pyplot as plt from pydantic import BaseModel, Extra, Field, validator @@ -12,6 +13,11 @@ from frigate.util import load_labels logger = logging.getLogger(__name__) +class ModelTypeEnum(str, Enum): + object = "object" + audio = "audio" + + class PixelFormatEnum(str, Enum): rgb = "rgb" bgr = "bgr" @@ -23,20 +29,13 @@ class InputTensorEnum(str, Enum): nhwc = "nhwc" -class ModelConfig(BaseModel): - path: Optional[str] = Field(title="Custom Object detection model path.") - labelmap_path: Optional[str] = Field(title="Label map for custom object detector.") - width: int = Field(default=320, title="Object detection model input width.") - height: int = Field(default=320, title="Object detection model input height.") +class BaseModelConfig(BaseModel): + type: str = Field(default="object", title="Model Type") + path: Optional[str] = Field(title="Custom model path.") + labelmap_path: Optional[str] = Field(title="Label map for custom model.") labelmap: Dict[int, str] = Field( default_factory=dict, title="Labelmap customization." ) - input_tensor: InputTensorEnum = Field( - default=InputTensorEnum.nhwc, title="Model Input Tensor Shape" - ) - input_pixel_format: PixelFormatEnum = Field( - default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format" - ) _merged_labelmap: Optional[Dict[int, str]] = PrivateAttr() _colormap: Dict[int, Tuple[int, int, int]] = PrivateAttr() @@ -65,15 +64,48 @@ class ModelConfig(BaseModel): self._colormap[val] = tuple(int(round(255 * c)) for c in cmap(key)[:3]) class Config: - extra = Extra.forbid + extra = Extra.allow + arbitrary_types_allowed = True + + +class ObjectModelConfig(BaseModelConfig): + type: Literal["object"] = "object" + width: int = Field(default=320, title="Object detection model input width.") + height: int = Field(default=320, title="Object detection model input height.") + input_tensor: InputTensorEnum = Field( + default=InputTensorEnum.nhwc, title="Model Input Tensor Shape" + ) + input_pixel_format: PixelFormatEnum = Field( + default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format" + ) + + +class AudioModelConfig(BaseModelConfig): + type: Literal["audio"] = "audio" + duration: float = Field(default=0.975, title="Model Input Audio Duration") + format: str = Field(default="s16le", title="Model Input Audio Format") + sample_rate: int = Field(default=16000, title="Model Input Sample Rate") + channels: int = Field(default=1, title="Model Input Number of Channels") + + def __init__(self, **config): + super().__init__(**config) + + self._merged_labelmap = { + **load_labels(config.get("labelmap_path", "/yamnet_label_list.txt")), + **config.get("labelmap", {}), + } + + +ModelConfig = Annotated[ + Union[tuple(BaseModelConfig.__subclasses__())], + Field(discriminator="type"), +] class BaseDetectorConfig(BaseModel): # the type field must be defined in all subclasses type: str = Field(default="cpu", title="Detector Type") - model: ModelConfig = Field( - default=None, title="Detector specific model configuration." - ) + model: Optional[ModelConfig] class Config: extra = Extra.allow diff --git a/frigate/detectors/plugins/cpu_tfl.py b/frigate/detectors/plugins/cpu_tfl.py index 9e24cb1f4..b5a2a0a8c 100644 --- a/frigate/detectors/plugins/cpu_tfl.py +++ b/frigate/detectors/plugins/cpu_tfl.py @@ -22,8 +22,12 @@ class CpuTfl(DetectionApi): type_key = DETECTOR_KEY def __init__(self, detector_config: CpuDetectorConfig): + self.is_audio = detector_config.model.type == "audio" + default_model = ( + "/cpu_model.tflite" if not self.is_audio else "/cpu_audio_model.tflite" + ) self.interpreter = tflite.Interpreter( - model_path=detector_config.model.path or "/cpu_model.tflite", + model_path=detector_config.model.path or default_model, num_threads=detector_config.num_threads or 3, ) @@ -36,15 +40,29 @@ class CpuTfl(DetectionApi): self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input) self.interpreter.invoke() - boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0] - class_ids = self.interpreter.tensor(self.tensor_output_details[1]["index"])()[0] - scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[0] - count = int( - self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0] - ) - detections = np.zeros((20, 6), np.float32) + if self.is_audio: + res = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0] + non_zero_indices = res > 0 + class_ids = np.argpartition(-res, 20)[:20] + class_ids = class_ids[np.argsort(-res[class_ids])] + class_ids = class_ids[non_zero_indices[class_ids]] + scores = res[class_ids] + boxes = np.full((scores.shape[0], 4), -1, np.float32) + count = len(scores) + else: + boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0] + class_ids = self.interpreter.tensor( + self.tensor_output_details[1]["index"] + )()[0] + scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[ + 0 + ] + count = int( + self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0] + ) + for i in range(count): if scores[i] < 0.4 or i == 20: break diff --git a/frigate/detectors/plugins/edgetpu_tfl.py b/frigate/detectors/plugins/edgetpu_tfl.py index 024e6574b..7f0d7f901 100644 --- a/frigate/detectors/plugins/edgetpu_tfl.py +++ b/frigate/detectors/plugins/edgetpu_tfl.py @@ -23,6 +23,7 @@ class EdgeTpuTfl(DetectionApi): type_key = DETECTOR_KEY def __init__(self, detector_config: EdgeTpuDetectorConfig): + self.is_audio = detector_config.model.type == "audio" device_config = {"device": "usb"} if detector_config.device is not None: device_config = {"device": detector_config.device} @@ -33,8 +34,13 @@ class EdgeTpuTfl(DetectionApi): logger.info(f"Attempting to load TPU as {device_config['device']}") edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config) logger.info("TPU found") + default_model = ( + "/edgetpu_model.tflite" + if not self.is_audio + else "/edgetpu_audio_model.tflite" + ) self.interpreter = tflite.Interpreter( - model_path=detector_config.model.path or "/edgetpu_model.tflite", + model_path=detector_config.model.path or default_model, experimental_delegates=[edge_tpu_delegate], ) except ValueError: @@ -52,15 +58,29 @@ class EdgeTpuTfl(DetectionApi): self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input) self.interpreter.invoke() - boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0] - class_ids = self.interpreter.tensor(self.tensor_output_details[1]["index"])()[0] - scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[0] - count = int( - self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0] - ) - detections = np.zeros((20, 6), np.float32) + if self.is_audio: + res = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0] + non_zero_indices = res > 0 + class_ids = np.argpartition(-res, 20)[:20] + class_ids = class_ids[np.argsort(-res[class_ids])] + class_ids = class_ids[non_zero_indices[class_ids]] + scores = res[class_ids] + boxes = np.full((scores.shape[0], 4), -1, np.float32) + count = len(scores) + else: + boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0] + class_ids = self.interpreter.tensor( + self.tensor_output_details[1]["index"] + )()[0] + scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[ + 0 + ] + count = int( + self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0] + ) + for i in range(count): if scores[i] < 0.4 or i == 20: break diff --git a/frigate/object_detection.py b/frigate/object_detection.py index 2fc080329..4dccccc59 100644 --- a/frigate/object_detection.py +++ b/frigate/object_detection.py @@ -44,7 +44,7 @@ class LocalObjectDetector(ObjectDetector): else: self.labels = load_labels(labels) - if detector_config: + if detector_config.model.type == "object": self.input_transform = tensor_transform(detector_config.model.input_tensor) else: self.input_transform = None @@ -107,10 +107,24 @@ def run_detector( connection_id = detection_queue.get(timeout=5) except queue.Empty: continue - input_frame = frame_manager.get( - connection_id, - (1, detector_config.model.height, detector_config.model.width, 3), - ) + if detector_config.model.type == "audio": + input_frame = frame_manager.get( + connection_id, + ( + int( + round( + detector_config.model.duration + * detector_config.model.sample_rate + ) + ), + ), + dtype=np.float32, + ) + else: + input_frame = frame_manager.get( + connection_id, + (1, detector_config.model.height, detector_config.model.width, 3), + ) if input_frame is None: continue @@ -180,11 +194,18 @@ class RemoteObjectDetector: self.detection_queue = detection_queue self.event = event self.shm = mp.shared_memory.SharedMemory(name=self.name, create=False) - self.np_shm = np.ndarray( - (1, model_config.height, model_config.width, 3), - dtype=np.uint8, - buffer=self.shm.buf, - ) + if model_config.type == "audio": + self.np_shm = np.ndarray( + (int(round(model_config.duration * model_config.sample_rate)),), + dtype=np.float32, + buffer=self.shm.buf, + ) + else: + self.np_shm = np.ndarray( + (1, model_config.height, model_config.width, 3), + dtype=np.uint8, + buffer=self.shm.buf, + ) self.out_shm = mp.shared_memory.SharedMemory( name=f"out-{self.name}", create=False ) diff --git a/frigate/util.py b/frigate/util.py index d8cdef0d9..d35b7131a 100755 --- a/frigate/util.py +++ b/frigate/util.py @@ -915,7 +915,7 @@ class FrameManager(ABC): pass @abstractmethod - def get(self, name, timeout_ms=0): + def get(self, name): pass @abstractmethod @@ -956,13 +956,13 @@ class SharedMemoryFrameManager(FrameManager): self.shm_store[name] = shm return shm.buf - def get(self, name, shape): + def get(self, name, shape, dtype=np.uint8): if name in self.shm_store: shm = self.shm_store[name] else: shm = shared_memory.SharedMemory(name=name) self.shm_store[name] = shm - return np.ndarray(shape, dtype=np.uint8, buffer=shm.buf) + return np.ndarray(shape, dtype=dtype, buffer=shm.buf) def close(self, name): if name in self.shm_store: