Initial audio classification model implementation

This commit is contained in:
Jason Hunter 2023-01-07 01:02:35 -05:00
parent ec7aaa18ab
commit 68ade5063d
10 changed files with 402 additions and 71 deletions

View File

@ -12,7 +12,7 @@ FROM debian:11-slim AS slim-base
FROM slim-base AS wget
ARG DEBIAN_FRONTEND
RUN apt-get update \
&& apt-get install -y wget xz-utils \
&& apt-get install -y wget xz-utils unzip \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /rootfs
@ -93,7 +93,10 @@ COPY labelmap.txt .
COPY --from=ov-converter /models/public/ssdlite_mobilenet_v2/FP16 openvino-model
RUN wget -q https://github.com/openvinotoolkit/open_model_zoo/raw/master/data/dataset_classes/coco_91cl_bkgr.txt -O openvino-model/coco_91cl_bkgr.txt && \
sed -i 's/truck/car/g' openvino-model/coco_91cl_bkgr.txt
# Get Audio Model and labels
RUN wget -qO edgetpu_audio_model.tflite https://tfhub.dev/google/coral-model/yamnet/classification/coral/1?coral-format=tflite
RUN wget -qO cpu_audio_model.tflite https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite
RUN unzip -q edgetpu_audio_model.tflite yamnet_label_list.txt && chmod +r yamnet_label_list.txt
FROM wget AS s6-overlay

View File

@ -13,6 +13,7 @@ from peewee_migrate import Router
from playhouse.sqlite_ext import SqliteExtDatabase
from playhouse.sqliteq import SqliteQueueDatabase
from frigate.audio import capture_audio, process_audio
from frigate.comms.dispatcher import Communicator, Dispatcher
from frigate.comms.mqtt import MqttClient
from frigate.comms.ws import WebSocketClient
@ -42,6 +43,7 @@ class FrigateApp:
def __init__(self) -> None:
self.stop_event: MpEvent = mp.Event()
self.detection_queue: Queue = mp.Queue()
self.audio_detection_queue: Queue = mp.Queue()
self.detectors: dict[str, ObjectDetectProcess] = {}
self.detection_out_events: dict[str, MpEvent] = {}
self.detection_shms: list[mp.shared_memory.SharedMemory] = []
@ -104,6 +106,7 @@ class FrigateApp:
"read_start": mp.Value("d", 0.0),
"ffmpeg_pid": mp.Value("i", 0),
"frame_queue": mp.Queue(maxsize=2),
"audio_queue": mp.Queue(maxsize=2),
"capture_process": None,
"process": None,
}
@ -182,7 +185,7 @@ class FrigateApp:
self.dispatcher = Dispatcher(self.config, self.camera_metrics, comms)
def start_detectors(self) -> None:
for name in self.config.cameras.keys():
for name, camera_config in self.config.cameras.items():
self.detection_out_events[name] = mp.Event()
try:
@ -190,6 +193,7 @@ class FrigateApp:
[
det.model.height * det.model.width * 3
for (name, det) in self.config.detectors.items()
if det.model.type == "object"
]
)
shm_in = mp.shared_memory.SharedMemory(
@ -210,10 +214,43 @@ class FrigateApp:
self.detection_shms.append(shm_in)
self.detection_shms.append(shm_out)
if any(
["detect_audio" in input.roles for input in camera_config.ffmpeg.inputs]
):
self.detection_out_events[f"{name}-audio"] = mp.Event()
try:
shm_in_audio = mp.shared_memory.SharedMemory(
name=f"{name}-audio",
create=True,
size=int(
round(
self.config.audio_model.duration
* self.config.audio_model.sample_rate
)
)
* 4, # stored as float32, so 4 bytes per sample
)
except FileExistsError:
shm_in_audio = mp.shared_memory.SharedMemory(name=f"{name}-audio")
try:
shm_out_audio = mp.shared_memory.SharedMemory(
name=f"out-{name}-audio", create=True, size=20 * 6 * 4
)
except FileExistsError:
shm_out_audio = mp.shared_memory.SharedMemory(
name=f"out-{name}-audio"
)
self.detection_shms.append(shm_in_audio)
self.detection_shms.append(shm_out_audio)
for name, detector_config in self.config.detectors.items():
self.detectors[name] = ObjectDetectProcess(
name,
self.detection_queue,
self.audio_detection_queue
if detector_config.model.type == "audio"
else self.detection_queue,
self.detection_out_events,
detector_config,
)
@ -245,6 +282,54 @@ class FrigateApp:
output_processor.start()
logger.info(f"Output process started: {output_processor.pid}")
def start_audio_processors(self) -> None:
# Make sure we have audio detectors
if not any(
[det.model.type == "audio" for det in self.config.detectors.values()]
):
return
for name, config in self.config.cameras.items():
if not any(
["detect_audio" in inputs.roles for inputs in config.ffmpeg.inputs]
):
continue
if not config.enabled:
logger.info(f"Audio processor not started for disabled camera {name}")
continue
audio_capture = mp.Process(
target=capture_audio,
name=f"audio_capture:{name}",
args=(
name,
self.config.audio_model,
self.camera_metrics[name],
),
)
audio_capture.daemon = True
self.camera_metrics[name]["audio_capture"] = audio_capture
audio_capture.start()
logger.info(f"Audio capture started for {name}: {audio_capture.pid}")
audio_process = mp.Process(
target=process_audio,
name=f"audio_process:{name}",
args=(
name,
config,
self.config.audio_model,
self.config.audio_model.merged_labelmap,
self.audio_detection_queue,
self.detection_out_events[f"{name}-audio"],
self.camera_metrics[name],
),
)
audio_process.daemon = True
self.camera_metrics[name]["audio_process"] = audio_process
audio_process.start()
logger.info(f"Audio processor started for {name}: {audio_process.pid}")
def start_camera_processors(self) -> None:
for name, config in self.config.cameras.items():
if not self.config.cameras[name].enabled:
@ -364,6 +449,7 @@ class FrigateApp:
self.start_detectors()
self.start_video_output_processor()
self.start_detected_frames_processor()
self.start_audio_processors()
self.start_camera_processors()
self.start_camera_capture_processes()
self.start_storage_maintainer()

126
frigate/audio.py Normal file
View File

@ -0,0 +1,126 @@
import datetime
import logging
import multiprocessing as mp
import queue
import random
import signal
import string
import threading
import numpy as np
from setproctitle import setproctitle
from frigate.config import CameraConfig, AudioModelConfig
from frigate.object_detection import RemoteObjectDetector
from frigate.util import listen, SharedMemoryFrameManager
logger = logging.getLogger(__name__)
def capture_audio(
name: str,
model_config: AudioModelConfig,
process_info,
):
stop_event = mp.Event()
def receiveSignal(signalNumber, frame):
stop_event.set()
signal.signal(signal.SIGTERM, receiveSignal)
signal.signal(signal.SIGINT, receiveSignal)
threading.current_thread().name = f"capture:{name}"
setproctitle(f"frigate.capture:{name}")
listen()
chunk_size = int(round(model_config.duration * model_config.sample_rate * 2))
key = f"{name}-audio"
audio_queue = process_info["audio_queue"]
frame_manager = SharedMemoryFrameManager()
current_frame = mp.Value("d", 0.0)
pipe = open(f"/tmp/{key}", "rb")
while not stop_event.is_set():
current_frame.value = datetime.datetime.now().timestamp()
frame_name = f"{key}{current_frame.value}"
frame_buffer = frame_manager.create(frame_name, chunk_size)
try:
frame_buffer[:] = pipe.read(chunk_size)
except Exception as e:
continue
# if the queue is full, skip this frame
if audio_queue.full():
frame_manager.delete(frame_name)
continue
# close the frame
frame_manager.close(frame_name)
# add to the queue
audio_queue.put(current_frame.value)
def process_audio(
name: str,
camera_config: CameraConfig,
model_config: AudioModelConfig,
labelmap,
detection_queue: mp.Queue,
result_connection,
process_info,
):
stop_event = mp.Event()
def receiveSignal(signalNumber, frame):
stop_event.set()
signal.signal(signal.SIGTERM, receiveSignal)
signal.signal(signal.SIGINT, receiveSignal)
threading.current_thread().name = f"process:{name}"
setproctitle(f"frigate.process:{name}")
listen()
shape = (int(round(model_config.duration * model_config.sample_rate)),)
key = f"{name}-audio"
audio_queue = process_info["audio_queue"]
frame_manager = SharedMemoryFrameManager()
detector = RemoteObjectDetector(
key, labelmap, detection_queue, result_connection, model_config
)
while not stop_event.is_set():
try:
frame_time = audio_queue.get(True, 10)
except queue.Empty:
continue
audio = frame_manager.get(f"{key}{frame_time}", shape, dtype=np.int16)
if audio is None:
logger.info(f"{key}: audio {frame_time} is not in memory store.")
continue
waveform = (audio / 32768.0).astype(np.float32)
model_detections = detector.detect(waveform)
for label, score, _ in model_detections:
if label not in camera_config.objects.track:
continue
filters = camera_config.objects.filters.get(label)
if filters:
if score < filters.min_score:
continue
logger.info(f"{label}: {score}")
frame_manager.close(f"{key}{frame_time}")

View File

@ -36,8 +36,10 @@ from frigate.ffmpeg_presets import (
from frigate.detectors import (
PixelFormatEnum,
InputTensorEnum,
ModelConfig,
DetectorConfig,
ModelConfig,
AudioModelConfig,
ObjectModelConfig,
)
from frigate.version import VERSION
@ -51,7 +53,7 @@ DEFAULT_TIME_FORMAT = "%m/%d/%Y %H:%M:%S"
FRIGATE_ENV_VARS = {k: v for k, v in os.environ.items() if k.startswith("FRIGATE_")}
DEFAULT_TRACKED_OBJECTS = ["person"]
DEFAULT_TRACKED_OBJECTS = ["person", "Speech"]
DEFAULT_DETECTORS = {"cpu": {"type": "cpu"}}
@ -358,6 +360,7 @@ class BirdseyeCameraConfig(BaseModel):
FFMPEG_GLOBAL_ARGS_DEFAULT = ["-hide_banner", "-loglevel", "warning"]
FFMPEG_INPUT_ARGS_DEFAULT = "preset-rtsp-generic"
DETECT_FFMPEG_OUTPUT_ARGS_DEFAULT = ["-f", "rawvideo", "-pix_fmt", "yuv420p"]
DETECT_AUDIO_FFMPEG_OUTPUT_ARGS_DEFAULT = ["-f", "s16le", "-ar", "16000", "-ac", "1"]
RTMP_FFMPEG_OUTPUT_ARGS_DEFAULT = "preset-rtmp-generic"
RECORD_FFMPEG_OUTPUT_ARGS_DEFAULT = "preset-record-generic"
@ -367,6 +370,10 @@ class FfmpegOutputArgsConfig(FrigateBaseModel):
default=DETECT_FFMPEG_OUTPUT_ARGS_DEFAULT,
title="Detect role FFmpeg output arguments.",
)
detect_audio: Union[str, List[str]] = Field(
default=DETECT_AUDIO_FFMPEG_OUTPUT_ARGS_DEFAULT,
title="Detect role FFmpeg output arguments.",
)
record: Union[str, List[str]] = Field(
default=RECORD_FFMPEG_OUTPUT_ARGS_DEFAULT,
title="Record role FFmpeg output arguments.",
@ -398,6 +405,7 @@ class CameraRoleEnum(str, Enum):
restream = "restream"
rtmp = "rtmp"
detect = "detect"
detect_audio = "detect_audio"
class CameraInput(FrigateBaseModel):
@ -597,6 +605,7 @@ class CameraConfig(FrigateBaseModel):
# add roles to the input if there is only one
if len(config["ffmpeg"]["inputs"]) == 1:
has_rtmp = "rtmp" in config["ffmpeg"]["inputs"][0].get("roles", [])
has_audio = "detect_audio" in config["ffmpeg"]["inputs"][0].get("roles", [])
config["ffmpeg"]["inputs"][0]["roles"] = [
"record",
@ -606,6 +615,8 @@ class CameraConfig(FrigateBaseModel):
if has_rtmp:
config["ffmpeg"]["inputs"][0]["roles"].append("rtmp")
if has_audio:
config["ffmpeg"]["inputs"][0]["roles"].append("detect_audio")
super().__init__(**config)
@ -646,6 +657,15 @@ class CameraConfig(FrigateBaseModel):
)
ffmpeg_output_args = scale_detect_args + ffmpeg_output_args + ["pipe:"]
if "detect_audio" in ffmpeg_input.roles:
detect_args = get_ffmpeg_arg_list(self.ffmpeg.output_args.detect_audio)
pipe = f"/tmp/{self.name}-audio"
try:
os.mkfifo(pipe)
except FileExistsError:
pass
ffmpeg_output_args = detect_args + ["-y", pipe] + ffmpeg_output_args
if "rtmp" in ffmpeg_input.roles and self.rtmp.enabled:
rtmp_args = get_ffmpeg_arg_list(
parse_preset_output_rtmp(self.ffmpeg.output_args.rtmp)
@ -815,8 +835,11 @@ class FrigateConfig(FrigateBaseModel):
default_factory=dict, title="Frigate environment variables."
)
ui: UIConfig = Field(default_factory=UIConfig, title="UI configuration.")
model: ModelConfig = Field(
default_factory=ModelConfig, title="Detection model configuration."
audio_model: AudioModelConfig = Field(
default_factory=AudioModelConfig, title="Audio model configuration."
)
model: ObjectModelConfig = Field(
default_factory=ObjectModelConfig, title="Detection model configuration."
)
detectors: Dict[str, DetectorConfig] = Field(
default=DEFAULT_DETECTORS,
@ -975,25 +998,21 @@ class FrigateConfig(FrigateBaseModel):
if detector_config.model is None:
detector_config.model = config.model
else:
model = detector_config.model
schema = ModelConfig.schema()["properties"]
if (
model.width != schema["width"]["default"]
or model.height != schema["height"]["default"]
or model.labelmap_path is not None
or model.labelmap is not {}
or model.input_tensor != schema["input_tensor"]["default"]
or model.input_pixel_format
!= schema["input_pixel_format"]["default"]
):
detector_model = detector_config.model.dict(exclude_unset=True)
# If any keys are set in the detector_model other than type or path, warn
if any(key not in ["type", "path"] for key in detector_model.keys()):
logger.warning(
"Customizing more than a detector model path is unsupported."
"Customizing more than a detector model type or path is unsupported."
)
merged_model = deep_merge(
detector_config.model.dict(exclude_unset=True),
config.model.dict(exclude_unset=True),
)
detector_config.model = ModelConfig.parse_obj(merged_model)
merged_model = deep_merge(
detector_model,
config.model.dict(exclude_unset=True)
if detector_config.model.type == "object"
else config.audio_model.dict(exclude_unset=True),
)
detector_config.model = parse_obj_as(
ModelConfig, {"type": detector_config.model.type, **merged_model}
)
config.detectors[key] = detector_config
return config

View File

@ -2,17 +2,23 @@ import logging
from .detection_api import DetectionApi
from .detector_config import (
AudioModelConfig,
PixelFormatEnum,
InputTensorEnum,
ModelConfig,
ObjectModelConfig,
)
from .detector_types import (
DetectorTypeEnum,
api_types,
DetectorConfig,
)
from .detector_types import DetectorTypeEnum, api_types, DetectorConfig
logger = logging.getLogger(__name__)
def create_detector(detector_config):
def create_detector(detector_config: DetectorConfig):
if detector_config.type == DetectorTypeEnum.cpu:
logger.warning(
"CPU detectors are not recommended and should only be used for testing or for trial purposes."

View File

@ -1,6 +1,7 @@
import logging
from enum import Enum
from typing import Dict, List, Optional, Tuple, Union, Literal
from typing_extensions import Annotated
import matplotlib.pyplot as plt
from pydantic import BaseModel, Extra, Field, validator
@ -12,6 +13,11 @@ from frigate.util import load_labels
logger = logging.getLogger(__name__)
class ModelTypeEnum(str, Enum):
object = "object"
audio = "audio"
class PixelFormatEnum(str, Enum):
rgb = "rgb"
bgr = "bgr"
@ -23,20 +29,13 @@ class InputTensorEnum(str, Enum):
nhwc = "nhwc"
class ModelConfig(BaseModel):
path: Optional[str] = Field(title="Custom Object detection model path.")
labelmap_path: Optional[str] = Field(title="Label map for custom object detector.")
width: int = Field(default=320, title="Object detection model input width.")
height: int = Field(default=320, title="Object detection model input height.")
class BaseModelConfig(BaseModel):
type: str = Field(default="object", title="Model Type")
path: Optional[str] = Field(title="Custom model path.")
labelmap_path: Optional[str] = Field(title="Label map for custom model.")
labelmap: Dict[int, str] = Field(
default_factory=dict, title="Labelmap customization."
)
input_tensor: InputTensorEnum = Field(
default=InputTensorEnum.nhwc, title="Model Input Tensor Shape"
)
input_pixel_format: PixelFormatEnum = Field(
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
)
_merged_labelmap: Optional[Dict[int, str]] = PrivateAttr()
_colormap: Dict[int, Tuple[int, int, int]] = PrivateAttr()
@ -65,15 +64,48 @@ class ModelConfig(BaseModel):
self._colormap[val] = tuple(int(round(255 * c)) for c in cmap(key)[:3])
class Config:
extra = Extra.forbid
extra = Extra.allow
arbitrary_types_allowed = True
class ObjectModelConfig(BaseModelConfig):
type: Literal["object"] = "object"
width: int = Field(default=320, title="Object detection model input width.")
height: int = Field(default=320, title="Object detection model input height.")
input_tensor: InputTensorEnum = Field(
default=InputTensorEnum.nhwc, title="Model Input Tensor Shape"
)
input_pixel_format: PixelFormatEnum = Field(
default=PixelFormatEnum.rgb, title="Model Input Pixel Color Format"
)
class AudioModelConfig(BaseModelConfig):
type: Literal["audio"] = "audio"
duration: float = Field(default=0.975, title="Model Input Audio Duration")
format: str = Field(default="s16le", title="Model Input Audio Format")
sample_rate: int = Field(default=16000, title="Model Input Sample Rate")
channels: int = Field(default=1, title="Model Input Number of Channels")
def __init__(self, **config):
super().__init__(**config)
self._merged_labelmap = {
**load_labels(config.get("labelmap_path", "/yamnet_label_list.txt")),
**config.get("labelmap", {}),
}
ModelConfig = Annotated[
Union[tuple(BaseModelConfig.__subclasses__())],
Field(discriminator="type"),
]
class BaseDetectorConfig(BaseModel):
# the type field must be defined in all subclasses
type: str = Field(default="cpu", title="Detector Type")
model: ModelConfig = Field(
default=None, title="Detector specific model configuration."
)
model: Optional[ModelConfig]
class Config:
extra = Extra.allow

View File

@ -22,8 +22,12 @@ class CpuTfl(DetectionApi):
type_key = DETECTOR_KEY
def __init__(self, detector_config: CpuDetectorConfig):
self.is_audio = detector_config.model.type == "audio"
default_model = (
"/cpu_model.tflite" if not self.is_audio else "/cpu_audio_model.tflite"
)
self.interpreter = tflite.Interpreter(
model_path=detector_config.model.path or "/cpu_model.tflite",
model_path=detector_config.model.path or default_model,
num_threads=detector_config.num_threads or 3,
)
@ -36,15 +40,29 @@ class CpuTfl(DetectionApi):
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input)
self.interpreter.invoke()
boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0]
class_ids = self.interpreter.tensor(self.tensor_output_details[1]["index"])()[0]
scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[0]
count = int(
self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0]
)
detections = np.zeros((20, 6), np.float32)
if self.is_audio:
res = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0]
non_zero_indices = res > 0
class_ids = np.argpartition(-res, 20)[:20]
class_ids = class_ids[np.argsort(-res[class_ids])]
class_ids = class_ids[non_zero_indices[class_ids]]
scores = res[class_ids]
boxes = np.full((scores.shape[0], 4), -1, np.float32)
count = len(scores)
else:
boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0]
class_ids = self.interpreter.tensor(
self.tensor_output_details[1]["index"]
)()[0]
scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[
0
]
count = int(
self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0]
)
for i in range(count):
if scores[i] < 0.4 or i == 20:
break

View File

@ -23,6 +23,7 @@ class EdgeTpuTfl(DetectionApi):
type_key = DETECTOR_KEY
def __init__(self, detector_config: EdgeTpuDetectorConfig):
self.is_audio = detector_config.model.type == "audio"
device_config = {"device": "usb"}
if detector_config.device is not None:
device_config = {"device": detector_config.device}
@ -33,8 +34,13 @@ class EdgeTpuTfl(DetectionApi):
logger.info(f"Attempting to load TPU as {device_config['device']}")
edge_tpu_delegate = load_delegate("libedgetpu.so.1.0", device_config)
logger.info("TPU found")
default_model = (
"/edgetpu_model.tflite"
if not self.is_audio
else "/edgetpu_audio_model.tflite"
)
self.interpreter = tflite.Interpreter(
model_path=detector_config.model.path or "/edgetpu_model.tflite",
model_path=detector_config.model.path or default_model,
experimental_delegates=[edge_tpu_delegate],
)
except ValueError:
@ -52,15 +58,29 @@ class EdgeTpuTfl(DetectionApi):
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input)
self.interpreter.invoke()
boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0]
class_ids = self.interpreter.tensor(self.tensor_output_details[1]["index"])()[0]
scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[0]
count = int(
self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0]
)
detections = np.zeros((20, 6), np.float32)
if self.is_audio:
res = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0]
non_zero_indices = res > 0
class_ids = np.argpartition(-res, 20)[:20]
class_ids = class_ids[np.argsort(-res[class_ids])]
class_ids = class_ids[non_zero_indices[class_ids]]
scores = res[class_ids]
boxes = np.full((scores.shape[0], 4), -1, np.float32)
count = len(scores)
else:
boxes = self.interpreter.tensor(self.tensor_output_details[0]["index"])()[0]
class_ids = self.interpreter.tensor(
self.tensor_output_details[1]["index"]
)()[0]
scores = self.interpreter.tensor(self.tensor_output_details[2]["index"])()[
0
]
count = int(
self.interpreter.tensor(self.tensor_output_details[3]["index"])()[0]
)
for i in range(count):
if scores[i] < 0.4 or i == 20:
break

View File

@ -44,7 +44,7 @@ class LocalObjectDetector(ObjectDetector):
else:
self.labels = load_labels(labels)
if detector_config:
if detector_config.model.type == "object":
self.input_transform = tensor_transform(detector_config.model.input_tensor)
else:
self.input_transform = None
@ -107,10 +107,24 @@ def run_detector(
connection_id = detection_queue.get(timeout=5)
except queue.Empty:
continue
input_frame = frame_manager.get(
connection_id,
(1, detector_config.model.height, detector_config.model.width, 3),
)
if detector_config.model.type == "audio":
input_frame = frame_manager.get(
connection_id,
(
int(
round(
detector_config.model.duration
* detector_config.model.sample_rate
)
),
),
dtype=np.float32,
)
else:
input_frame = frame_manager.get(
connection_id,
(1, detector_config.model.height, detector_config.model.width, 3),
)
if input_frame is None:
continue
@ -180,11 +194,18 @@ class RemoteObjectDetector:
self.detection_queue = detection_queue
self.event = event
self.shm = mp.shared_memory.SharedMemory(name=self.name, create=False)
self.np_shm = np.ndarray(
(1, model_config.height, model_config.width, 3),
dtype=np.uint8,
buffer=self.shm.buf,
)
if model_config.type == "audio":
self.np_shm = np.ndarray(
(int(round(model_config.duration * model_config.sample_rate)),),
dtype=np.float32,
buffer=self.shm.buf,
)
else:
self.np_shm = np.ndarray(
(1, model_config.height, model_config.width, 3),
dtype=np.uint8,
buffer=self.shm.buf,
)
self.out_shm = mp.shared_memory.SharedMemory(
name=f"out-{self.name}", create=False
)

View File

@ -915,7 +915,7 @@ class FrameManager(ABC):
pass
@abstractmethod
def get(self, name, timeout_ms=0):
def get(self, name):
pass
@abstractmethod
@ -956,13 +956,13 @@ class SharedMemoryFrameManager(FrameManager):
self.shm_store[name] = shm
return shm.buf
def get(self, name, shape):
def get(self, name, shape, dtype=np.uint8):
if name in self.shm_store:
shm = self.shm_store[name]
else:
shm = shared_memory.SharedMemory(name=name)
self.shm_store[name] = shm
return np.ndarray(shape, dtype=np.uint8, buffer=shm.buf)
return np.ndarray(shape, dtype=dtype, buffer=shm.buf)
def close(self, name):
if name in self.shm_store: